├── .gitignore ├── env ├── assets │ ├── textures │ │ ├── can.png │ │ ├── bread.png │ │ ├── cereal.png │ │ ├── clay.png │ │ ├── glass.png │ │ ├── lemon.png │ │ ├── metal.png │ │ ├── ceramic.png │ │ ├── dark-wood.png │ │ └── light-wood.png │ ├── objects │ │ ├── meshes │ │ │ ├── bread.stl │ │ │ ├── can.stl │ │ │ ├── lemon.stl │ │ │ ├── milk.stl │ │ │ ├── bottle.stl │ │ │ ├── cereal.stl │ │ │ └── handles.stl │ │ ├── can-visual.xml │ │ ├── milk-visual.xml │ │ ├── cereal-visual.xml │ │ ├── bread-visual.xml │ │ ├── bottle.xml │ │ ├── can.xml │ │ ├── lemon.xml │ │ ├── milk.xml │ │ ├── plate-with-hole.xml │ │ ├── bread.xml │ │ ├── cereal.xml │ │ ├── square-nut.xml │ │ ├── pole.xml │ │ ├── spinning_pole.xml │ │ └── round-nut.xml │ ├── arena_v2_5.xml │ ├── arena_v2_5_vis.xml │ ├── arena_v2_9.xml │ └── arena_v2_7.xml └── light_env.py ├── README.md ├── scripts ├── planner.sh ├── trainF.sh ├── evalF.sh └── gendata.sh ├── trainF.py ├── collectdata.py ├── evalF.py ├── F_models.py └── learn_planner.py /.gitignore: -------------------------------------------------------------------------------- 1 | ppo_lighting_tb/ 2 | .ipynb_checkpoints/ 3 | *.pyc 4 | data/ 5 | figs* 6 | exp* 7 | -------------------------------------------------------------------------------- /env/assets/textures/can.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordVL/causal_induction/HEAD/env/assets/textures/can.png -------------------------------------------------------------------------------- /env/assets/textures/bread.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordVL/causal_induction/HEAD/env/assets/textures/bread.png -------------------------------------------------------------------------------- /env/assets/textures/cereal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordVL/causal_induction/HEAD/env/assets/textures/cereal.png -------------------------------------------------------------------------------- /env/assets/textures/clay.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordVL/causal_induction/HEAD/env/assets/textures/clay.png -------------------------------------------------------------------------------- /env/assets/textures/glass.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordVL/causal_induction/HEAD/env/assets/textures/glass.png -------------------------------------------------------------------------------- /env/assets/textures/lemon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordVL/causal_induction/HEAD/env/assets/textures/lemon.png -------------------------------------------------------------------------------- /env/assets/textures/metal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordVL/causal_induction/HEAD/env/assets/textures/metal.png -------------------------------------------------------------------------------- /env/assets/textures/ceramic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordVL/causal_induction/HEAD/env/assets/textures/ceramic.png -------------------------------------------------------------------------------- /env/assets/objects/meshes/bread.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordVL/causal_induction/HEAD/env/assets/objects/meshes/bread.stl -------------------------------------------------------------------------------- /env/assets/objects/meshes/can.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordVL/causal_induction/HEAD/env/assets/objects/meshes/can.stl -------------------------------------------------------------------------------- /env/assets/objects/meshes/lemon.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordVL/causal_induction/HEAD/env/assets/objects/meshes/lemon.stl -------------------------------------------------------------------------------- /env/assets/objects/meshes/milk.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordVL/causal_induction/HEAD/env/assets/objects/meshes/milk.stl -------------------------------------------------------------------------------- /env/assets/textures/dark-wood.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordVL/causal_induction/HEAD/env/assets/textures/dark-wood.png -------------------------------------------------------------------------------- /env/assets/textures/light-wood.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordVL/causal_induction/HEAD/env/assets/textures/light-wood.png -------------------------------------------------------------------------------- /env/assets/objects/meshes/bottle.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordVL/causal_induction/HEAD/env/assets/objects/meshes/bottle.stl -------------------------------------------------------------------------------- /env/assets/objects/meshes/cereal.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordVL/causal_induction/HEAD/env/assets/objects/meshes/cereal.stl -------------------------------------------------------------------------------- /env/assets/objects/meshes/handles.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordVL/causal_induction/HEAD/env/assets/objects/meshes/handles.stl -------------------------------------------------------------------------------- /env/assets/objects/can-visual.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /env/assets/objects/milk-visual.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /env/assets/objects/cereal-visual.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /env/assets/objects/bread-visual.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /env/assets/objects/bottle.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /env/assets/objects/can.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /env/assets/objects/lemon.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /env/assets/objects/milk.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /env/assets/objects/plate-with-hole.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /env/assets/objects/bread.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /env/assets/objects/cereal.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Code and Environment for [Causal Induction From Visual Observations for Goal-Directed Tasks](https://arxiv.org/pdf/1910.01751.pdf). 2 | 3 | ## Environment 4 | 5 | Consists of the light switch environment for studying visual causal induction, where N switches control N lights, under various causal structures. Includes common cause, common effect, and causal chain relationships. Environment code resides under `env/light_env.py`. 6 | 7 | ## Induction Models 8 | 9 | The different induction models used are located under `F_models.py`, incuding our proposed iterative attention network, as well as baselines which do not use attention or use temporal convolutions. 10 | 11 | ## Reproducing Experiments 12 | 13 | Step 1: Generate Data 14 | 15 | `python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure one_to_one --seen 10 --images 1 --data-dir output/` 16 | 17 | Step 2: Train Induction Model 18 | 19 | `python3 trainF.py --horizon 7 --num 7 --fixed-goal 0 --structure one_to_one --type iter --images 1 --seen 10 --data-dir output/` 20 | 21 | Step 3: Eval Induction Model 22 | 23 | `python3 evalF.py --horizon 7 --num 7 --fixed-goal 0 --structure one_to_one --method trajFi --images 1 --seen 10 --data-dir output/` 24 | 25 | Step 4: Train Policy via Imitation 26 | 27 | `python3 learn_planner.py --horizon 7 --num 7 --fixed-goal 0 --structure one_to_one --method trajFi --seen 10 --images 1 --data-dir output/` 28 | 29 | 30 | -------------------------------------------------------------------------------- /scripts/planner.sh: -------------------------------------------------------------------------------- 1 | ST=masterswitch 2 | SN=100 3 | 4 | # python3 learn_planner.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method gt --seen $SN --images 1 5 | python3 learn_planner.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method trajF --seen $SN --images 1 6 | python3 learn_planner.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method trajFi --seen $SN --images 1 7 | python3 learn_planner.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method trajFia --seen $SN --images 1 8 | 9 | # python3 learn_planner.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method gt --seen 50 --images 1 10 | # python3 learn_planner.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method trajF --seen 50 --images 1 11 | # python3 learn_planner.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method trajFi --seen 50 --images 1 12 | # python3 learn_planner.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method trajFia --seen 50 --images 1 13 | 14 | # python3 learn_planner.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method gt --seen 100 --images 1 15 | # python3 learn_planner.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method trajF --seen 100 --images 1 16 | # python3 learn_planner.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method trajFi --seen 100 --images 1 17 | # python3 learn_planner.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method trajFia --seen 100 --images 1 18 | 19 | 20 | -------------------------------------------------------------------------------- /scripts/trainF.sh: -------------------------------------------------------------------------------- 1 | MT=iter_attn 2 | ST=masterswitch 3 | 4 | python3 trainF.py --horizon 9 --num 9 --fixed-goal 0 --structure $ST --type $MT --images 1 --seen 10 5 | python3 trainF.py --horizon 9 --num 9 --fixed-goal 0 --structure $ST --type $MT --images 1 --seen 50 6 | 7 | ST=one_to_one 8 | 9 | python3 trainF.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --type $MT --images 1 --seen 10 10 | python3 trainF.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --type $MT --images 1 --seen 50 11 | 12 | 13 | 14 | # python3 trainF.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --type $MT --images 1 --seen 100 15 | # python3 trainF.py --horizon 6 --num 6 --fixed-goal 0 --structure $ST --type $MT --images 1 --seen 10 16 | # python3 trainF.py --horizon 6 --num 6 --fixed-goal 0 --structure $ST --type $MT --images 1 --seen 50 17 | # python3 trainF.py --horizon 6 --num 6 --fixed-goal 0 --structure $ST --type $MT --images 1 --seen 100 18 | # python3 trainF.py --horizon 6 --num 6 --fixed-goal 0 --structure $ST --type $MT --images 1 --seen 500 19 | # python3 trainF.py --horizon 7 --num 7 --fixed-goal 0 --structure $ST --type $MT --images 1 --seen 10 20 | # python3 trainF.py --horizon 7 --num 7 --fixed-goal 0 --structure $ST --type $MT --images 1 --seen 50 21 | # python3 trainF.py --horizon 7 --num 7 --fixed-goal 0 --structure $ST --type $MT --images 1 --seen 100 22 | # python3 trainF.py --horizon 7 --num 7 --fixed-goal 0 --structure $ST --type $MT --images 1 --seen 500 23 | -------------------------------------------------------------------------------- /env/assets/objects/square-nut.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /env/assets/objects/pole.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /env/assets/objects/spinning_pole.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 0 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /env/assets/objects/round-nut.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /scripts/evalF.sh: -------------------------------------------------------------------------------- 1 | ST=one_to_one 2 | SN=500 3 | # ST=one_to_one 4 | 5 | # python3 evalF.py --horizon 7 --num 7 --fixed-goal 0 --structure $ST --method trajF --images 1 --seen $SN 6 | # python3 evalF.py --horizon 7 --num 7 --fixed-goal 0 --structure $ST --method trajFi --images 1 --seen $SN 7 | # python3 evalF.py --horizon 7 --num 7 --fixed-goal 0 --structure $ST --method trajFia --images 1 --seen $SN 8 | 9 | # ST=one_to_many 10 | 11 | # python3 evalF.py --horizon 7 --num 7 --fixed-goal 0 --structure $ST --method trajF --images 1 --seen $SN 12 | # python3 evalF.py --horizon 7 --num 7 --fixed-goal 0 --structure $ST --method trajFi --images 1 --seen $SN 13 | # python3 evalF.py --horizon 7 --num 7 --fixed-goal 0 --structure $ST --method trajFia --images 1 --seen $SN 14 | 15 | # ST=many_to_one 16 | 17 | # python3 evalF.py --horizon 7 --num 7 --fixed-goal 0 --structure $ST --method trajF --images 1 --seen $SN 18 | # python3 evalF.py --horizon 7 --num 7 --fixed-goal 0 --structure $ST --method trajFi --images 1 --seen $SN 19 | # python3 evalF.py --horizon 7 --num 7 --fixed-goal 0 --structure $ST --method trajFia --images 1 --seen $SN 20 | 21 | ST=masterswitch 22 | 23 | python3 evalF.py --horizon 7 --num 7 --fixed-goal 0 --structure $ST --method trajF --images 1 --seen $SN 24 | python3 evalF.py --horizon 7 --num 7 --fixed-goal 0 --structure $ST --method trajFi --images 1 --seen $SN 25 | python3 evalF.py --horizon 7 --num 7 --fixed-goal 0 --structure $ST --method trajFia --images 1 --seen $SN 26 | 27 | 28 | # python3 evalF.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method $MT --images 1 --seen 500 29 | 30 | # MT=trajFi 31 | 32 | # python3 evalF.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method $MT --images 1 --seen 10 33 | # python3 evalF.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method $MT --images 1 --seen 50 34 | # python3 evalF.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method $MT --images 1 --seen 100 35 | # # python3 evalF.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method $MT --images 1 --seen 500 36 | 37 | # MT=trajF 38 | 39 | # python3 evalF.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method $MT --images 1 --seen 10 40 | # python3 evalF.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method $MT --images 1 --seen 50 41 | # python3 evalF.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method $MT --images 1 --seen 100 42 | # python3 evalF.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method $MT --images 1 --seen 500 43 | 44 | 45 | # MT=trajFia 46 | # ST=one_to_one 47 | 48 | # python3 evalF.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method $MT --images 1 --seen 10 49 | # python3 evalF.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method $MT --images 1 --seen 50 50 | # python3 evalF.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method $MT --images 1 --seen 100 51 | 52 | # python3 evalF.py --horizon 6 --num 6 --fixed-goal 0 --structure $ST --method $MT --images 1 --seen 10 53 | # python3 evalF.py --horizon 6 --num 6 --fixed-goal 0 --structure $ST --method $MT --images 1 --seen 50 54 | # python3 evalF.py --horizon 6 --num 6 --fixed-goal 0 --structure $ST --method $MT --images 1 --seen 100 55 | # python3 evalF.py --horizon 6 --num 6 --fixed-goal 0 --structure $ST --method $MT --images 1 --seen 500 56 | # python3 evalF.py --horizon 7 --num 7 --fixed-goal 0 --structure $ST --method $MT --images 1 --seen 10 57 | # python3 evalF.py --horizon 7 --num 7 --fixed-goal 0 --structure $ST --method $MT --images 1 --seen 50 58 | # python3 evalF.py --horizon 7 --num 7 --fixed-goal 0 --structure $ST --method $MT --images 1 --seen 100 59 | # python3 evalF.py --horizon 7 --num 7 --fixed-goal 0 --structure $ST --method $MT --images 1 --seen 500 60 | -------------------------------------------------------------------------------- /trainF.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch as th 4 | import argparse 5 | 6 | from F_models import SupervisedModelCNN, IterativeModel, IterativeModelAttention 7 | 8 | 9 | def train_supervised(F, buf, gtbuf, num, steps = 1, bs=32, images=False): 10 | buf = th.FloatTensor(buf).float() 11 | gtbuf = th.FloatTensor(gtbuf).float() 12 | optimizer = th.optim.Adam(F.parameters(), lr=0.0001) 13 | for step in range(steps): 14 | optimizer.zero_grad() 15 | perm = th.randperm(buf.size(0)-5000) 16 | testperm = th.randperm(5000) + 35000 17 | 18 | idx = perm[:bs] 19 | samples = buf[idx] 20 | gts= gtbuf[idx] 21 | testidx = testperm[:bs] 22 | testsamples = buf[testidx] 23 | testgts= gtbuf[testidx] 24 | 25 | if images: 26 | split = 32*32*3 27 | else: 28 | split = num 29 | 30 | states = samples[:, :, :split].contiguous().view(bs, -1).cuda() 31 | actions = samples[:, :, split:].contiguous().view(bs, -1).cuda() 32 | groundtruth = gts.cuda() 33 | pred = F(states, actions) 34 | 35 | teststates = testsamples[:, :, :split].contiguous().view(bs, -1).cuda() 36 | testactions = testsamples[:, :, split:].contiguous().view(bs, -1).cuda() 37 | testgroundtruth = testgts.cuda() 38 | testpred = F(teststates, testactions) 39 | 40 | loss = ((pred - groundtruth)**2).sum(1).mean() 41 | testloss = ((testpred - testgroundtruth)**2).sum(1).mean() 42 | 43 | loss.backward() 44 | if step % 1000 == 0: 45 | print((loss / num).cpu().detach().numpy()) 46 | print((testloss / num).cpu().detach().numpy()) 47 | print(pred[0], groundtruth[0]) 48 | print(step) 49 | print("_"*50) 50 | optimizer.step() 51 | 52 | if __name__ == '__main__': 53 | parser = argparse.ArgumentParser(description='Causal Meta-RL') 54 | parser.add_argument('--fixed-goal', type=int, default=0, help='fixed goal or no') 55 | parser.add_argument('--horizon', type=int, default=10, help='Env horizon') 56 | parser.add_argument('--num', type=int, default=1, help='Num lights') 57 | parser.add_argument('--structure', type=str, default="one_to_one", help='Causal Structure') 58 | parser.add_argument('--type', type=str, default="cnn", help='Model Type') 59 | parser.add_argument('--seen', type=int, default=10, help='Num seen') 60 | parser.add_argument('--images', type=int, default=0, help='Images or no') 61 | parser.add_argument('--data-dir', type=str, help='Data Dir') 62 | 63 | args = parser.parse_args() 64 | 65 | 66 | if args.type == "cnn": 67 | if args.structure == "masterswitch": 68 | msv = True 69 | F = SupervisedModelCNN(2*args.horizon -1,args.num, ms = msv, images=args.images).cuda() 70 | else: 71 | msv = False 72 | F = SupervisedModelCNN(args.horizon,args.num, ms = msv, images=args.images).cuda() 73 | elif args.type == "iter": 74 | if args.structure == "masterswitch": 75 | msv = True 76 | F = IterativeModel(2*args.horizon -1,args.num, ms = msv, images=args.images).cuda() 77 | else: 78 | msv = False 79 | F = IterativeModel(args.horizon, args.num, ms = msv, images=args.images).cuda() 80 | elif args.type == "iter_attn": 81 | if args.structure == "masterswitch": 82 | msv = True 83 | F = IterativeModelAttention(2*args.horizon -1,args.num, ms = msv, images=args.images).cuda() 84 | else: 85 | msv = False 86 | F = IterativeModelAttention(args.horizon, args.num, ms = msv, images=args.images).cuda() 87 | else: 88 | raise NotImplementedError 89 | 90 | if args.images: 91 | addonn = "_I1" 92 | else: 93 | addonn = "" 94 | 95 | a = np.load(args.data_dir+ "buf40K_S"+str(args.seen)+\ 96 | "_"+str(args.structure)+"_"+str(args.horizon) + addonn + ".npy") 97 | a2 = np.load(args.data_dir+ "gtbuf40K_S"+str(args.seen)+\ 98 | "_"+str(args.structure)+"_"+str(args.horizon) + addonn + ".npy") 99 | 100 | print(a.shape, a2.shape) 101 | train_supervised(F, a, a2, args.num, steps=2000, bs=512, images=args.images) 102 | th.save(F, args.data_dir+\ 103 | str(args.type)+"_Redo_L2_S"+str(args.seen)+"_h"+str(args.horizon)+"_"+str(args.structure) \ 104 | + "_I"+str(args.images)) -------------------------------------------------------------------------------- /collectdata.py: -------------------------------------------------------------------------------- 1 | from env.light_env import LightEnv 2 | import numpy as np 3 | from stable_baselines.common.vec_env import DummyVecEnv, SubprocVecEnv 4 | import argparse 5 | import torch as th 6 | import cv2 7 | import os 8 | 9 | if __name__ == '__main__': 10 | parser = argparse.ArgumentParser(description='Causal Meta-RL') 11 | parser.add_argument('--fixed-goal', type=int, default=0, help='fixed goal or no') 12 | parser.add_argument('--horizon', type=int, default=5, help='Env horizon') 13 | parser.add_argument('--num', type=int, default=5, help='Num Switches') 14 | parser.add_argument('--structure', type=str, default="one_to_one", help='Graph Structure') 15 | parser.add_argument('--seen', type=int, default=10, help='Number of seen environments') 16 | parser.add_argument('--images', type=int, default=0, help='Use Images') 17 | parser.add_argument('--data-dir', type=str, help='Directory to Store Data') 18 | args = parser.parse_args() 19 | 20 | ## Init Buffer 21 | gc = 1 - args.fixed_goal 22 | buffer = [] 23 | gtbuffer = [] 24 | num_episodes = 40000 25 | 26 | ## Set Horizon Based On Task 27 | if args.structure == "masterswitch": 28 | st = (args.horizon*(2*args.num+1) + (args.horizon-1)*(2*args.num+1)) 29 | else: 30 | st = (args.horizon*(2*args.num+1)) 31 | 32 | ## Init Env 33 | l = LightEnv(args.horizon, 34 | args.num, 35 | st, 36 | args.structure, 37 | gc, 38 | filename=str(gc)+"_traj", 39 | seen = args.seen) 40 | env = DummyVecEnv(1 * [lambda: l]) 41 | 42 | 43 | for q in range(num_episodes): 44 | ## Reset Causal Structure 45 | l.keep_struct = False 46 | obs = env.reset() 47 | l.keep_struct = True 48 | ##### INDUCTION ##### 49 | ##### OPTIMAL POLICY 1 50 | if args.structure == "masterswitch": 51 | it = None 52 | for i in range(args.num): 53 | p = l._get_obs() 54 | if args.images: 55 | pi = l._get_obs(images=True) 56 | p = p.reshape((1, -1)) 57 | a = np.zeros((1, args.num+1)) 58 | a[:,i] = 1 59 | 60 | if args.images: 61 | mem = np.concatenate([np.expand_dims(pi.flatten(), 0), a], 1) 62 | else: 63 | mem = np.concatenate([p[:,:args.num], a], 1) 64 | 65 | if i == 0: 66 | epbuf = mem 67 | else: 68 | epbuf = np.concatenate([epbuf, mem], 0) 69 | l.step(i, count = False) 70 | p2 = l._get_obs() 71 | if (p != p2).any(): 72 | it = i 73 | break 74 | for i in range(args.num): 75 | if i != it: 76 | p = l._get_obs() 77 | if args.images: 78 | pi = l._get_obs(images=True) 79 | p = p.reshape((1, -1)) 80 | a = np.zeros((1, args.num+1)) 81 | a[:,i] = 1 82 | 83 | if args.images: 84 | mem = np.concatenate([np.expand_dims(pi.flatten(), 0), a], 1) 85 | else: 86 | mem = np.concatenate([p[:,:args.num], a], 1) 87 | 88 | epbuf = np.concatenate([epbuf, mem], 0) 89 | l.step(i, count = False) 90 | ln = epbuf.shape[0] 91 | buf = np.zeros((2 * args.horizon - 1, epbuf.shape[1])) 92 | buf[:ln] = epbuf 93 | else: 94 | for i in range(args.num): 95 | p = l._get_obs() 96 | if args.images: 97 | pi = l._get_obs(images=True) 98 | p = p.reshape((1, -1)) 99 | a = np.zeros((1, args.num+1)) 100 | a[:,i] = 1 101 | 102 | if args.images: 103 | mem = np.concatenate([np.expand_dims(pi.flatten(), 0), a], 1) 104 | else: 105 | mem = np.concatenate([p[:,:args.num], a], 1) 106 | 107 | if i == 0: 108 | epbuf = mem 109 | else: 110 | epbuf = np.concatenate([epbuf, mem], 0) 111 | l.step(i, count = False) 112 | buf = epbuf 113 | 114 | buffer.append(buf) 115 | gtbuffer.append(l.gt) 116 | if q % 10000 == 0: 117 | print(q) 118 | buffer = np.stack(buffer, 0) 119 | gtbuffer = np.stack(gtbuffer, 0) 120 | print(buffer.shape) 121 | print(gtbuffer.shape) 122 | 123 | np.save(args.data_dir+"buf40K_S"+\ 124 | str(args.seen)+"_"+str(args.structure)+"_"+str(args.horizon)+\ 125 | "_I"+str(args.images), buffer) 126 | np.save(args.data_dir+"gtbuf40K_S"+\ 127 | str(args.seen)+"_"+str(args.structure)+"_"+str(args.horizon)+\ 128 | "_I"+str(args.images), gtbuffer) 129 | 130 | 131 | 132 | -------------------------------------------------------------------------------- /evalF.py: -------------------------------------------------------------------------------- 1 | from env.light_env import LightEnv 2 | import numpy as np 3 | from stable_baselines.common.vec_env import DummyVecEnv, SubprocVecEnv 4 | import argparse 5 | import torch as th 6 | import cv2 7 | 8 | def induction(structure, num, horizon, l, images=False): 9 | '''Heuristic Policy to collect interaction data''' 10 | ##### INDUCTION ##### 11 | ##### OPTIMAL POLICY 1 12 | if structure == "masterswitch": 13 | it = None 14 | for i in range(num): 15 | p = l._get_obs() 16 | if images: 17 | pi = l._get_obs(images=True) 18 | p = p.reshape((1, -1)) 19 | a = np.zeros((1, num+1)) 20 | a[:,i] = 1 21 | 22 | if images: 23 | mem = np.concatenate([np.expand_dims(pi.flatten(), 0), a], 1) 24 | else: 25 | mem = np.concatenate([p[:,:args.num], a], 1) 26 | 27 | if i == 0: 28 | epbuf = mem 29 | else: 30 | epbuf = np.concatenate([epbuf, mem], 0) 31 | l.step(i, count = False) 32 | p2 = l._get_obs() 33 | if (p != p2).any(): 34 | it = i 35 | break 36 | for i in range(num): 37 | if i != it: 38 | p = l._get_obs() 39 | if images: 40 | pi = l._get_obs(images=True) 41 | p = p.reshape((1, -1)) 42 | a = np.zeros((1, num+1)) 43 | a[:,i] = 1 44 | 45 | if images: 46 | mem = np.concatenate([np.expand_dims(pi.flatten(), 0), a], 1) 47 | else: 48 | mem = np.concatenate([p[:,:args.num], a], 1) 49 | 50 | epbuf = np.concatenate([epbuf, mem], 0) 51 | l.step(i, count = False) 52 | ln = epbuf.shape[0] 53 | buf = np.zeros((2 * args.horizon - 1, epbuf.shape[1])) 54 | buf[:ln] = epbuf 55 | else: 56 | for i in range(num): 57 | p = l._get_obs() 58 | if images: 59 | pi = l._get_obs(images=True) 60 | # im = l.sim.render(width=480,height=480,camera_name="birdview") 61 | # cv2.imwrite('o'+str(i)+'.png', cv2.cvtColor(im, cv2.COLOR_BGR2RGB)) 62 | p = p.reshape((1, -1)) 63 | a = np.zeros((1, num+1)) 64 | 65 | a[:,i] = 1 66 | 67 | if images: 68 | mem = np.concatenate([np.expand_dims(pi.flatten(), 0), a], 1) 69 | else: 70 | mem = np.concatenate([p[:,:args.num], a], 1) 71 | 72 | if i == 0: 73 | epbuf = mem 74 | else: 75 | epbuf = np.concatenate([epbuf, mem], 0) 76 | l.step(i, count = False) 77 | buf = epbuf 78 | return buf 79 | 80 | def predict(buf, F, structure, num): 81 | '''Predict Graph''' 82 | s = th.FloatTensor(buf[:,:-(num+1)]).float().cuda() 83 | a = th.FloatTensor(buf[:,-(1+num):]).float().cuda() 84 | predgt = F(s, a) 85 | return predgt.cpu().detach().numpy().flatten() 86 | 87 | 88 | def f1score(pred, gt): 89 | '''Compute F1 Score''' 90 | p = 1 * (pred > 0.5) 91 | 92 | if np.sum(p) == 0: 93 | prec = 0 94 | else: 95 | prec = np.sum(gt * p) / np.sum(p) 96 | rec =np.sum(gt*p) / np.sum(gt) 97 | if (prec == 0) and (rec==0): 98 | return 0 99 | return 2 * (prec * rec) / (prec+rec) 100 | 101 | 102 | 103 | 104 | if __name__ == '__main__': 105 | parser = argparse.ArgumentParser(description='Causal Meta-RL') 106 | parser.add_argument('--fixed-goal', type=int, default=0, help='fixed goal or no') 107 | parser.add_argument('--horizon', type=int, default=10, help='Env horizon') 108 | parser.add_argument('--num', type=int, default=1, help='num lights') 109 | parser.add_argument('--structure', type=str, default="one_to_one", help='causal structure') 110 | parser.add_argument('--method', type=str, default="traj", help='method') 111 | parser.add_argument('--seen', type=int, default=10, help='num seen') 112 | parser.add_argument('--images', type=int, default=0, help='images or no') 113 | parser.add_argument('--data-dir', type=str, help='data dir') 114 | args = parser.parse_args() 115 | 116 | 117 | gc = 1 - args.fixed_goal 118 | 119 | if args.structure == "masterswitch": 120 | st = (args.horizon*(2*args.num+1) + (args.horizon-1)*(2*args.num+1)) 121 | else: 122 | st = (args.horizon*(2*args.num+1)) 123 | tj = "gt" 124 | l = LightEnv(args.horizon, 125 | args.num, 126 | tj, 127 | args.structure, 128 | gc, 129 | filename="exp/"+str(gc)+"_"+args.method, 130 | seen = args.seen) 131 | env = DummyVecEnv(1 * [lambda: l]) 132 | 133 | if args.images: 134 | addonn = "_I1" 135 | else: 136 | addonn = "" 137 | 138 | if args.method == "trajF": 139 | F = th.load(args.data_dir+"cnn_Redo_L2_S"+str(args.seen)+"_h"+str(args.horizon)+\ 140 | "_"+str(args.structure)+addonn).cuda() 141 | elif args.method == "trajFia": 142 | F = th.load(args.data_dir+"iter_attn_Redo_L2_S"+str(args.seen)+"_h"+str(args.horizon)+\ 143 | "_"+str(args.structure)+addonn).cuda() 144 | else: 145 | F = th.load(args.data_dir+"iter_Redo_L2_S"+str(args.seen)+"_h"+str(args.horizon)+\ 146 | "_"+str(args.structure)+addonn).cuda() 147 | F = F.eval() 148 | trloss = [] 149 | tsloss = [] 150 | for mep in range(100): 151 | l.keep_struct = False 152 | obs = env.reset() 153 | l.keep_struct = True 154 | 155 | buf = induction(args.structure,args.num, args.horizon, l, images=args.images) 156 | pred = predict(buf, F,args.structure, args.num) 157 | f = f1score(pred, l.gt) 158 | 159 | trloss.append(f) 160 | 161 | #### TEST ON UNSEEN CS 162 | l.keep_struct = False 163 | l.train = False 164 | for i in range(1): 165 | obs= env.reset() 166 | buf = induction(args.structure,args.num, args.horizon, l, images=args.images) 167 | 168 | pred = predict(buf, F,args.structure, args.num) 169 | 170 | 171 | f = f1score(pred, l.gt) 172 | tsloss.append(f) 173 | 174 | l.keep_struct = True 175 | l.train = True 176 | 177 | print(np.mean(trloss), np.mean(tsloss)) 178 | 179 | 180 | -------------------------------------------------------------------------------- /scripts/gendata.sh: -------------------------------------------------------------------------------- 1 | # python3 collectdata.py --horizon 6 --num 6 --fixed-goal 0 --structure one_to_one --method traj --seen 100 --images 1 2 | # python3 collectdata.py --horizon 6 --num 6 --fixed-goal 0 --structure many_to_one --method traj --seen 100 --images 1 3 | # python3 collectdata.py --horizon 6 --num 6 --fixed-goal 0 --structure one_to_many --method traj --seen 100 --images 1 4 | # python3 collectdata.py --horizon 6 --num 6 --fixed-goal 0 --structure masterswitch --method traj --seen 100 --images 1 5 | 6 | # python3 collectdata.py --horizon 6 --num 6 --fixed-goal 0 --structure one_to_one --method traj --seen 50 --images 1 7 | # python3 collectdata.py --horizon 6 --num 6 --fixed-goal 0 --structure many_to_one --method traj --seen 50 --images 1 8 | # python3 collectdata.py --horizon 6 --num 6 --fixed-goal 0 --structure one_to_many --method traj --seen 50 --images 1 9 | # python3 collectdata.py --horizon 6 --num 6 --fixed-goal 0 --structure masterswitch --method traj --seen 50 --images 1 10 | 11 | # python3 collectdata.py --horizon 6 --num 6 --fixed-goal 0 --structure one_to_one --method traj --seen 500 --images 1 12 | # python3 collectdata.py --horizon 6 --num 6 --fixed-goal 0 --structure many_to_one --method traj --seen 500 --images 1 13 | # python3 collectdata.py --horizon 6 --num 6 --fixed-goal 0 --structure one_to_many --method traj --seen 500 --images 1 14 | # python3 collectdata.py --horizon 6 --num 6 --fixed-goal 0 --structure masterswitch --method traj --seen 500 --images 1 15 | 16 | # python3 collectdata.py --horizon 5 --num 5 --fixed-goal 0 --structure one_to_one --method traj --seen 100 --images 1 17 | # python3 collectdata.py --horizon 5 --num 5 --fixed-goal 0 --structure many_to_one --method traj --seen 100 --images 1 18 | # python3 collectdata.py --horizon 5 --num 5 --fixed-goal 0 --structure one_to_many --method traj --seen 100 --images 1 19 | # python3 collectdata.py --horizon 5 --num 5 --fixed-goal 0 --structure masterswitch --method traj --seen 100 --images 1 20 | 21 | # python3 collectdata.py --horizon 5 --num 5 --fixed-goal 0 --structure one_to_one --method traj --seen 50 --images 1 22 | # python3 collectdata.py --horizon 5 --num 5 --fixed-goal 0 --structure many_to_one --method traj --seen 50 --images 1 23 | # python3 collectdata.py --horizon 5 --num 5 --fixed-goal 0 --structure one_to_many --method traj --seen 50 --images 1 24 | # python3 collectdata.py --horizon 5 --num 5 --fixed-goal 0 --structure masterswitch --method traj --seen 50 --images 1 25 | 26 | # python3 collectdata.py --horizon 5 --num 5 --fixed-goal 0 --structure one_to_one --method traj --seen 10 --images 1 27 | # python3 collectdata.py --horizon 5 --num 5 --fixed-goal 0 --structure many_to_one --method traj --seen 10 --images 1 28 | # python3 collectdata.py --horizon 5 --num 5 --fixed-goal 0 --structure one_to_many --method traj --seen 10 --images 1 29 | # python3 collectdata.py --horizon 5 --num 5 --fixed-goal 0 --structure masterswitch --method traj --seen 10 --images 1 30 | 31 | # python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure one_to_one --method traj --seen 10 --images 1 32 | # python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure many_to_one --method traj --seen 10 --images 1 33 | # python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure one_to_many --method traj --seen 10 --images 1 34 | # python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure masterswitch --method traj --seen 10 --images 1 35 | 36 | # python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure one_to_one --method traj --seen 100 --images 1 37 | # python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure many_to_one --method traj --seen 100 --images 1 38 | # python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure one_to_many --method traj --seen 100 --images 1 39 | # python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure masterswitch --method traj --seen 100 --images 1 40 | 41 | # python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure one_to_one --method traj --seen 50 --images 1 42 | # python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure many_to_one --method traj --seen 50 --images 1 43 | # python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure one_to_many --method traj --seen 50 --images 1 44 | # python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure masterswitch --method traj --seen 50 --images 1 45 | 46 | # python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure one_to_one --method traj --seen 500 --images 1 47 | # python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure many_to_one --method traj --seen 500 --images 1 48 | # python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure one_to_many --method traj --seen 500 --images 1 49 | # python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure masterswitch --method traj --seen 500 --images 1 50 | 51 | 52 | python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure one_to_one --method traj --seen 10 --images 1 53 | python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure many_to_one --method traj --seen 10 --images 1 54 | python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure one_to_many --method traj --seen 10 --images 1 55 | python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure masterswitch --method traj --seen 10 --images 1 56 | 57 | python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure one_to_one --method traj --seen 100 --images 1 58 | python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure many_to_one --method traj --seen 100 --images 1 59 | python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure one_to_many --method traj --seen 100 --images 1 60 | python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure masterswitch --method traj --seen 100 --images 1 61 | 62 | python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure one_to_one --method traj --seen 50 --images 1 63 | python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure many_to_one --method traj --seen 50 --images 1 64 | python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure one_to_many --method traj --seen 50 --images 1 65 | python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure masterswitch --method traj --seen 50 --images 1 66 | 67 | python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure one_to_one --method traj --seen 500 --images 1 68 | python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure many_to_one --method traj --seen 500 --images 1 69 | python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure one_to_many --method traj --seen 500 --images 1 70 | python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure masterswitch --method traj --seen 500 --images 1 71 | -------------------------------------------------------------------------------- /F_models.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division, absolute_import 2 | 3 | import torch as th 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision.models as models 7 | import numpy as np 8 | import math 9 | import cv2 10 | 11 | 12 | class ImageEncoder(nn.Module): 13 | """ 14 | IMage Encoder 15 | """ 16 | def __init__(self, num): 17 | super(ImageEncoder, self).__init__() 18 | self.encoder_conv = nn.Sequential( 19 | # 224x224xN_CHANNELS -> 112x112x64 20 | nn.Conv2d(3, 8, kernel_size=3, stride=1, padding=1), 21 | nn.MaxPool2d(kernel_size=2, stride=2), # 56x56x64 22 | nn.ReLU(inplace=True), 23 | ) 24 | self.encoder_conv2 = nn.Sequential( 25 | nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1), 26 | nn.MaxPool2d(kernel_size=2, stride=2), # 27x27x64 27 | nn.ReLU(inplace=True), 28 | ) 29 | self.encoder_conv3 = nn.Sequential( 30 | nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1), 31 | nn.MaxPool2d(kernel_size=2, stride=2), # 6x6x64 32 | nn.ReLU(inplace=True), 33 | ) 34 | 35 | self.fc = nn.Linear(4 * 4 * 32, num) 36 | 37 | self.sigmoid = nn.Sigmoid() 38 | 39 | def forward(self, x): 40 | e1 = self.encoder_conv(x) 41 | e2 = self.encoder_conv2(e1) 42 | e3 = self.encoder_conv3(e2) 43 | e3 = e3.view(e3.size(0), -1) 44 | encoding = self.fc(e3) 45 | return encoding 46 | 47 | 48 | class IterativeModelAttention(nn.Module): 49 | def __init__(self, horizon,num=5, ms=False, images=False): 50 | super(IterativeModelAttention, self).__init__() 51 | 52 | self.horizon = horizon 53 | self.ms = ms 54 | 55 | if self.ms: 56 | self.attnsize = num + 1 57 | self.outsize = num 58 | else: 59 | self.attnsize = num 60 | self.outsize = num 61 | 62 | 63 | self.images = images 64 | self.num = num 65 | 66 | self.ie = ImageEncoder(num) 67 | 68 | self.fc1 = nn.Linear(2*num+1, 1024) 69 | self.fc2 = nn.Linear(1024, 512) 70 | self.fc3 = nn.Linear(512, self.attnsize + num) 71 | 72 | self.fc4 = nn.Linear(self.attnsize*self.outsize, self.attnsize + num) 73 | 74 | self.dp = nn.Dropout(0.3) 75 | self.relu = nn.ReLU() 76 | self.softmax = nn.Softmax(dim=1) 77 | self.sigmoid = nn.Sigmoid() 78 | 79 | def forward(self, s, a): 80 | if self.images: 81 | s_im = s.view(-1, 32, 32, 3).permute(0,3,1,2) 82 | senc = self.ie(s_im) 83 | sp = senc.view(-1, self.horizon, self.num) 84 | else: 85 | sp = s.view(-1, self.horizon, self.num) 86 | sp[:,:-1] = sp[:,1:] - sp[:,:-1] 87 | a = a.view(-1, self.horizon, self.num+1) 88 | e = th.cat([sp, a], 2) 89 | 90 | 91 | p = th.zeros((sp.size(0), self.attnsize, self.outsize)).cuda() 92 | for i in range(self.horizon): 93 | inn = e[:,i,:] 94 | e1 = self.relu(self.dp(self.fc1(inn))) 95 | e2 = self.relu(self.dp(self.fc2(e1))) 96 | e3 = self.fc3(e2) 97 | 98 | atn = self.softmax(e3[:, :self.attnsize]).unsqueeze(-1) 99 | e3 = self.sigmoid(e3[:, self.attnsize:].unsqueeze(1).repeat(1, self.attnsize, 1)) 100 | r = atn * e3 101 | p = p + r 102 | 103 | e3 = self.fc4(p.view(-1, self.attnsize*self.num)) 104 | atn = self.softmax(e3[:, :self.attnsize]).unsqueeze(-1) 105 | e3 = self.sigmoid(e3[:, self.attnsize:].unsqueeze(1).repeat(1, self.attnsize, 1)) 106 | r = atn * e3 107 | p = p + r 108 | p = p.view(-1, self.attnsize*self.num) 109 | 110 | 111 | return p 112 | 113 | class IterativeModel(nn.Module): 114 | def __init__(self, horizon,num=5, ms=False, images=False): 115 | super(IterativeModel, self).__init__() 116 | 117 | self.images = images 118 | self.ie = ImageEncoder(num) 119 | 120 | self.horizon = horizon 121 | self.ms = ms 122 | if self.ms: 123 | self.outsize = num**2 + num 124 | else: 125 | self.outsize = num**2 126 | 127 | self.num = num 128 | 129 | self.fc1 = nn.Linear(2*num+1, 1024) 130 | self.fc2 = nn.Linear(1024, 512) 131 | self.fc3 = nn.Linear(512, self.outsize) 132 | 133 | self.fc4 = nn.Linear(self.outsize, self.outsize) 134 | 135 | self.cnn1 = nn.Conv1d(2*num+1, 256, kernel_size=3, padding=1) 136 | self.cnn2 = nn.Conv1d(256, 128, kernel_size=3, padding=1) 137 | self.cnn3 = nn.Conv1d(128, 128, kernel_size=3, padding=1) 138 | 139 | self.dp = nn.Dropout(0.3) 140 | self.relu = nn.ReLU() 141 | self.softmax = nn.Softmax() 142 | self.sigmoid = nn.Sigmoid() 143 | 144 | def forward(self, s, a): 145 | if self.images: 146 | s_im = s.view(-1, 32, 32, 3).permute(0,3,1,2) 147 | senc = self.ie(s_im) 148 | sp = senc.view(-1, self.horizon, self.num) 149 | else: 150 | sp = s.view(-1, self.horizon, self.num) 151 | sp[:,1:] = sp[:,1:] - sp[:,:-1] 152 | a = a.view(-1, self.horizon, self.num+1) 153 | e = th.cat([sp, a], 2) 154 | 155 | e = e.permute(0,2,1) 156 | c2 = e 157 | 158 | p = th.zeros((sp.size(0), self.outsize)).cuda() 159 | for i in range(self.horizon): 160 | e1 = self.relu(self.dp(self.fc1(c2[:,:,i]))) 161 | e2 = self.relu(self.dp(self.fc2(e1))) 162 | e3 = self.sigmoid(self.fc3(e2)) 163 | p = p + e3 164 | p = self.sigmoid(self.fc4(p)) 165 | 166 | return p 167 | 168 | class SupervisedModelCNN(nn.Module): 169 | def __init__(self, horizon,num=5, ms=False, images=False): 170 | super(SupervisedModelCNN, self).__init__() 171 | 172 | self.images = images 173 | self.ie = ImageEncoder(num) 174 | 175 | self.horizon = horizon 176 | self.ms = ms 177 | if self.ms: 178 | self.outsize = num**2 + num 179 | else: 180 | self.outsize = num**2 181 | 182 | self.num = num 183 | 184 | self.cnn1 = nn.Conv1d(2*num+1, 256, kernel_size=3, padding=1) 185 | self.cnn2 = nn.Conv1d(256, 128, kernel_size=3, padding=1) 186 | self.cnn3 = nn.Conv1d(128, 128, kernel_size=3, padding=1) 187 | 188 | self.fc1 = nn.Linear(self.horizon*128, 1024) 189 | self.fc2 = nn.Linear(1024, 512) 190 | self.fc3 = nn.Linear(512, self.outsize) 191 | 192 | self.dp = nn.Dropout(0.3) 193 | self.relu = nn.ReLU() 194 | self.softmax = nn.Softmax() 195 | self.sigmoid = nn.Sigmoid() 196 | 197 | def forward(self, s, a): 198 | if self.images: 199 | s_im = s.view(-1, 32, 32, 3).permute(0,3,1,2) 200 | senc = self.ie(s_im) 201 | sp = senc.view(-1, self.horizon, self.num) 202 | else: 203 | sp = s.view(-1, self.horizon, self.num) 204 | a = a.view(-1, self.horizon, self.num+1) 205 | e = th.cat([sp, a], 2) 206 | 207 | e = e.permute(0,2,1) 208 | c1 = self.relu(self.cnn1(e)) 209 | c2 = self.relu(self.cnn2(c1)) 210 | c2 = self.relu(self.cnn3(c2)) 211 | 212 | c2 = c2.view(-1, self.horizon*128) 213 | e1 = self.relu(self.dp(self.fc1(c2))) 214 | e2 = self.relu(self.dp(self.fc2(e1))) 215 | 216 | rec = self.sigmoid(self.fc3(e2)) 217 | return rec 218 | -------------------------------------------------------------------------------- /env/light_env.py: -------------------------------------------------------------------------------- 1 | from mujoco_py import load_model_from_path, MjSim, MjViewer 2 | from mujoco_py.modder import TextureModder 3 | import mujoco_py 4 | import os 5 | import numpy as np 6 | import gym 7 | from gym import error, spaces 8 | from gym.utils import seeding 9 | import copy 10 | from itertools import permutations 11 | import cv2 12 | DEFAULT_SIZE = 500 13 | 14 | class LightEnv(gym.GoalEnv): 15 | '''Light Switch Environment for Visual Causal Induction''' 16 | def __init__(self, horizon=5, num=5, cond="gt", structure="one_to_one", gc=True, filename=None, seen=10): 17 | ''' 18 | Creates Light Switch Environment 19 | 20 | Args: 21 | horizon: Length of episode 22 | num: Number of switches [5,7,9] 23 | cond: Whether use default size graph "GT" or custom size (Size). 24 | structure: Type of causal structure [one_to_one, one_to_many, many_to_one, masterswitch] 25 | gc: True/False goal conditioned or now 26 | filename: Path to log episode results 27 | seen: Number of seen causal structures 28 | ''' 29 | 30 | ## Load XML model 31 | fullpath = os.path.join(os.path.dirname(__file__), 'assets', "arena_v2_"+str(num)+".xml") 32 | if not os.path.exists(fullpath): 33 | raise IOError('File {} does not exist'.format(fullpath)) 34 | model = mujoco_py.load_model_from_path(fullpath) 35 | self.sim = mujoco_py.MjSim(model) 36 | 37 | self.filename = filename 38 | self.horizon = horizon 39 | self.cond = cond 40 | self.gc = gc 41 | self.structure = structure 42 | self.num = num 43 | self.metadata = {'render.modes': ['human', 'rgb_array']} 44 | 45 | ## Initialize GT/traj - the underlying causal structure 46 | if self.cond is None: 47 | pass 48 | elif self.cond == "gt": 49 | if self.structure == "masterswitch": 50 | a = self.num**2 + self.num 51 | else: 52 | a = self.num**2 53 | self.gt = np.zeros((a)) 54 | else: 55 | self.traj = np.zeros(self.cond) 56 | self.aj = np.zeros((self.num, self.num)) 57 | 58 | ## If goal conditioned, sample goal state 59 | if self.gc: 60 | self.goal = self._sample_goal() 61 | 62 | ## Set random seed so order of causal structures is preserved 63 | np.random.seed(1) 64 | 65 | 66 | self.state = np.zeros((self.num)) # Initial State 67 | self.eps = 0 # Num Episodes 68 | 69 | # Generate causal structure 70 | if (self.structure == "one_to_many") or (self.structure == "many_to_one"): 71 | if self.num == 9: 72 | self.all_perms = self.generate_cs_set1(self.num, True) 73 | else: 74 | self.all_perms = self.generate_cs_set1(self.num) 75 | else: 76 | self.all_perms = self.generate_cs_set2(self.num) 77 | 78 | ## Params to randomize strcut and train v test 79 | self.keep_struct = True 80 | self.train = True 81 | 82 | ## Shuffled causal structures 83 | np.random.shuffle(self.all_perms) 84 | ## Number of all structs 85 | self.pmsz = self.all_perms.shape[0] 86 | self.seen = seen 87 | 88 | obs = self._get_obs() 89 | self.action_space = spaces.Discrete(self.num+1) 90 | self.observation_space = spaces.Box(0, 1, shape=obs.shape, dtype='float32') 91 | 92 | 93 | # Env methods 94 | # ---------------------------- 95 | 96 | def step(self, action, count = True, log=False): 97 | '''Step in env. 98 | Args: 99 | action: which switch to toggle 100 | count: False if not counted toward episode (for collecting heuristic data) 101 | log: Log episode results 102 | ''' 103 | ## If "Do Nothing" Action 104 | if action == self.num: 105 | pass 106 | else: 107 | if self.structure == "masterswitch": 108 | ## Only once masterswitch is activated can others be activated 109 | if (action == self.ms) or (self.state[self.ms] == 1): 110 | change = np.zeros(self.num) 111 | change[action] = 1 112 | self.state = np.abs(self.state - change) 113 | else: 114 | change = np.zeros(self.num) 115 | change[action] = 1 116 | self.state = np.abs(self.state - change) 117 | 118 | obs = self._get_obs() 119 | 120 | done = False 121 | info = {'is_success': self._is_success(obs)} 122 | self.correct.append((info["is_success"])) 123 | reward = self.compute_reward(obs, info) 124 | 125 | if count: 126 | self.steps += 1 127 | self.eprew += reward 128 | if reward == 0: 129 | done = True 130 | if (self.steps >= self.horizon): 131 | done = True 132 | if done and log: 133 | with open(self.filename + \ 134 | "_S" + str(self.seen) + "_"+str(self.structure)+ \ 135 | "_H"+str(self.horizon)+"_N"+str(self.num)+ \ 136 | "_T"+str(self.current_cs)+".txt", "a") as f: 137 | f.write(str(self.eprew) + "\n") 138 | with open(self.filename + "_S" + str(self.seen) + \ 139 | "_"+str(self.structure)+"_H"+str(self.horizon)+ \ 140 | "_N"+str(self.num)+"_T"+str(self.current_cs)+\ 141 | "successrate.txt", "a") as f: 142 | f.write(str(int(info["is_success"])) + "\n") 143 | 144 | return obs, reward, done, info 145 | 146 | def reset(self): 147 | keep_struct = self.keep_struct 148 | train = self.train 149 | if train: 150 | self.current_cs = "train" 151 | else: 152 | self.current_cs = "test" 153 | 154 | ## Either reset causal structure or not 155 | if keep_struct: 156 | pass 157 | else: 158 | ## Select from seen causal structre or unseen causal structure 159 | if train: 160 | ind = np.random.randint(0, self.seen) 161 | else: 162 | ind = np.random.randint(self.seen, self.pmsz) 163 | perm = self.all_perms[ind] 164 | 165 | ## Set graph according to causal structure 166 | if self.structure == "one_to_one": 167 | aj = np.zeros((self.num,self.num)) 168 | for i in range(self.num): 169 | aj[i, perm[i]] = 1 170 | self.aj = aj 171 | self.gt = self.aj.flatten() 172 | elif self.structure == "one_to_many": 173 | aj = np.zeros((self.num,self.num)) 174 | for i in range(self.num): 175 | aj[i, perm[i]] = 1 176 | self.aj = aj.T 177 | self.gt = self.aj.flatten() 178 | elif self.structure == "many_to_one": 179 | aj = np.zeros((self.num,self.num)) 180 | for i in range(self.num): 181 | aj[i, perm[i]] = 1 182 | self.aj = aj 183 | self.gt = self.aj.flatten() 184 | elif self.structure == "masterswitch": 185 | aj = np.zeros((self.num,self.num)) 186 | for i in range(self.num): 187 | aj[i, perm[i]] = 1 188 | self.aj = aj 189 | self.ms = np.random.randint(self.num) 190 | m = np.zeros((self.num)) 191 | m[self.ms] = 1 192 | self.gt = self.aj.flatten() 193 | self.gt = np.concatenate([self.gt, m]) 194 | self.eprew = 0 195 | self.steps = 0 196 | self.correct = [] 197 | self.state = np.zeros((self.num)) 198 | self.eps += 1 199 | self.goal = self._sample_goal() 200 | 201 | obs = self._get_obs() 202 | return obs 203 | 204 | def compute_reward(self, obs, info=None): 205 | ## Distance to goal configuration 206 | rew =-1 * np.sqrt(((obs[:self.num] - self.goal)**2).sum()) 207 | return rew 208 | 209 | def _get_obs(self, images=False): 210 | """Returns the observation. 211 | """ 212 | ## Compute which lights are on based on underlying state and causal graph 213 | light = np.dot(self.state.T, self.aj) 214 | light = light % 2 215 | o = light 216 | 217 | ## Set corresponding lights and render image 218 | if images: 219 | self.sim.model.light_active[:] = light 220 | im = self.sim.render(width=32,height=32,camera_name="birdview") / 255.0 221 | return im 222 | 223 | ## Concatenate goal 224 | if self.gc: 225 | o = np.concatenate([o, self.goal]) 226 | 227 | ## Concatenate graph 228 | if self.cond is None: 229 | pass 230 | elif self.cond == "gt": 231 | o = np.concatenate([o, self.gt]) 232 | else: 233 | o = np.concatenate([o, self.traj]) 234 | 235 | return o 236 | 237 | def _is_success(self, obs): 238 | """Indicates whether or not the achieved goal successfully achieved the desired goal. 239 | """ 240 | return (obs[:self.num] == self.goal).all() 241 | 242 | def _sample_goal(self): 243 | """Samples a new goal and returns it. 244 | """ 245 | state = np.random.randint(0, 2, size=(self.num)) 246 | light = np.dot(state.T, self.aj) 247 | light = light % 2 248 | self.sim.model.light_active[:] = light 249 | self.goalim = self.sim.render(mode='offscreen', width=32,height=32,camera_name="birdview") / 255.0 250 | return light 251 | 252 | def generate_cs_set1(self, sz, cut=False): 253 | '''Generate Causal Structures for Many to One. For every light, 254 | every possible combination of switches which could control it. 255 | 256 | Args: 257 | sz: num switches 258 | cut: Reduce generated structures for efficiency 259 | ''' 260 | if sz == 1: 261 | lp = [] 262 | for i in range(self.num): 263 | lp.append([i]) 264 | return np.array(lp) 265 | else: 266 | gs = [] 267 | tm = self.generate_cs_set1(sz-1, cut) 268 | for t in tm: 269 | if cut and (np.random.uniform() < 0.4): 270 | continue 271 | for i in range(self.num): 272 | gs.append(np.concatenate([np.array([i]), t])) 273 | return np.array(gs) 274 | 275 | def generate_cs_set2(self, sz): 276 | '''Generate Causal Structures for One to One. All Permutations. 277 | 278 | Args: 279 | sz: num switches 280 | ''' 281 | t = np.arange(self.num) 282 | gs = [] 283 | for perm in permutations(t): 284 | gs.append(np.array(list(perm))) 285 | return np.array(gs) -------------------------------------------------------------------------------- /env/assets/arena_v2_5.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | -------------------------------------------------------------------------------- /env/assets/arena_v2_5_vis.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | -------------------------------------------------------------------------------- /env/assets/arena_v2_9.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | -------------------------------------------------------------------------------- /env/assets/arena_v2_7.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | -------------------------------------------------------------------------------- /learn_planner.py: -------------------------------------------------------------------------------- 1 | from env.light_env import LightEnv 2 | import numpy as np 3 | import argparse 4 | import torch as th 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import cv2 8 | 9 | 10 | 11 | class BCPolicy(nn.Module): 12 | """ 13 | Imitation Policy 14 | """ 15 | def __init__(self, num, structure, attention = False): 16 | super(BCPolicy, self).__init__() 17 | self.encoder_conv = nn.Sequential( 18 | # 224x224xN_CHANNELS -> 112x112x64 19 | nn.Conv2d(6, 8, kernel_size=3, stride=1, padding=1), 20 | nn.MaxPool2d(kernel_size=2, stride=2), # 56x56x64 21 | nn.ReLU(inplace=True), 22 | ) 23 | self.encoder_conv2 = nn.Sequential( 24 | nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1), 25 | nn.MaxPool2d(kernel_size=2, stride=2), # 27x27x64 26 | nn.ReLU(inplace=True), 27 | ) 28 | self.encoder_conv3 = nn.Sequential( 29 | nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1), 30 | nn.MaxPool2d(kernel_size=2, stride=2), # 6x6x64 31 | nn.ReLU(inplace=True), 32 | ) 33 | 34 | self.att = attention 35 | self.num = num 36 | if structure == "masterswitch": 37 | self.ins = self.num + 1 38 | else: 39 | self.ins = self.num 40 | 41 | self.attlayer = nn.Linear(128, num) 42 | self.structure = structure 43 | self.fc1 = nn.Linear(4 * 4 * 32, 128) 44 | 45 | if not self.att: 46 | if structure == "masterswitch": 47 | self.gfc1 = nn.Linear(num*num + num, 128) 48 | else: 49 | self.gfc1 = nn.Linear(num*num, 128) 50 | else: 51 | self.gfc1 = nn.Linear(self.num, 128) 52 | 53 | 54 | if self.structure == "masterswitch": 55 | self.fc2 = nn.Linear(256+args.num, 64) 56 | else: 57 | self.fc2 = nn.Linear(256, 64) 58 | self.fc5 = nn.Linear(64, num) 59 | 60 | self.softmax = nn.Softmax(dim=-1) 61 | self.relu = nn.ReLU() 62 | 63 | def forward(self, x, gr): 64 | 65 | x = x.permute(0, 3, 1, 2).contiguous() 66 | 67 | e1 = self.encoder_conv(x) 68 | e2 = self.encoder_conv2(e1) 69 | e3 = self.encoder_conv3(e2) 70 | e3 = e3.view(e3.size(0), -1) 71 | encoding = self.relu(self.fc1(e3)) 72 | if self.att: 73 | w = self.softmax(self.attlayer(encoding)) 74 | if self.structure == "masterswitch": 75 | ms = gr.view((-1, self.ins, self.num))[:, -1, :] 76 | gr = gr.view((-1, self.ins, self.num))[:, :-1, :] 77 | else: 78 | gr = gr.view((-1, self.ins, self.num)) 79 | gr_sel = th.bmm(gr, w.view(w.size(0), -1, 1)) 80 | gr_sel = gr_sel.squeeze(-1) 81 | g1 = self.relu(self.gfc1(gr_sel)) 82 | else: 83 | g1 = self.relu(self.gfc1(gr)) 84 | 85 | if self.structure == "masterswitch": 86 | eout = th.cat([g1, encoding, ms], 1) 87 | else: 88 | eout = th.cat([g1, encoding], 1) 89 | a = self.relu(self.fc2(eout)) 90 | a = self.fc5(a) 91 | return a 92 | 93 | 94 | class BCPolicyMemory(nn.Module): 95 | """ 96 | Imitation policy with memory 97 | """ 98 | def __init__(self, num, structure): 99 | super(BCPolicyMemory, self).__init__() 100 | self.encoder_conv = nn.Sequential( 101 | # 224x224xN_CHANNELS -> 112x112x64 102 | nn.Conv2d(6, 8, kernel_size=3, stride=1, padding=1), 103 | nn.MaxPool2d(kernel_size=2, stride=2), # 56x56x64 104 | nn.ReLU(inplace=True), 105 | ) 106 | self.encoder_conv2 = nn.Sequential( 107 | nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1), 108 | nn.MaxPool2d(kernel_size=2, stride=2), # 27x27x64 109 | nn.ReLU(inplace=True), 110 | ) 111 | self.encoder_conv3 = nn.Sequential( 112 | nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1), 113 | nn.MaxPool2d(kernel_size=2, stride=2), # 6x6x64 114 | nn.ReLU(inplace=True), 115 | ) 116 | 117 | self.fc1 = nn.Linear(4 * 4 * 32, 128) 118 | 119 | self.aenc = nn.Linear(num+1, 128) 120 | 121 | if structure == "masterswitch": 122 | self.gfc1 = nn.Linear(num*num + num, 128) 123 | else: 124 | self.gfc1 = nn.Linear(num*num, 128) 125 | 126 | self.lstm = nn.LSTMCell(256, 256) 127 | 128 | self.fc2 = nn.Linear(256, 64) 129 | self.fc5 = nn.Linear(64, num) 130 | 131 | self.softmax = nn.Softmax() 132 | self.relu = nn.ReLU() 133 | 134 | def forward(self, x, a, hidden): 135 | x = x.permute(0, 3, 1, 2).contiguous() 136 | e1 = self.encoder_conv(x) 137 | e2 = self.encoder_conv2(e1) 138 | e3 = self.encoder_conv3(e2) 139 | e3 = e3.view(e3.size(0), -1) 140 | encoding = self.relu(self.fc1(e3)) 141 | 142 | ae = self.relu(self.aenc(a)) 143 | eout = th.cat([ae, encoding], 1) 144 | if hidden is None: 145 | hidden = self.lstm(eout) 146 | else: 147 | hidden = self.lstm(eout, hidden) 148 | 149 | a = self.relu(self.fc2(hidden[0])) 150 | a = self.fc5(a) 151 | return a, hidden 152 | 153 | 154 | def induction(structure, num, horizon, l, images=False): 155 | '''Roll out heurisitc interaction policy''' 156 | ##### INDUCTION ##### 157 | ##### OPTIMAL POLICY 1 158 | if structure == "masterswitch": 159 | it = None 160 | for i in range(num): 161 | p = l._get_obs() 162 | if images: 163 | pi = l._get_obs(images=True) 164 | p = p.reshape((1, -1)) 165 | a = np.zeros((1, num+1)) 166 | a[:,i] = 1 167 | 168 | if images: 169 | mem = np.concatenate([np.expand_dims(pi.flatten(), 0), a], 1) 170 | else: 171 | mem = np.concatenate([p[:,:args.num], a], 1) 172 | 173 | if i == 0: 174 | epbuf = mem 175 | else: 176 | epbuf = np.concatenate([epbuf, mem], 0) 177 | l.step(i, count = False) 178 | p2 = l._get_obs() 179 | if (p != p2).any(): 180 | it = i 181 | break 182 | for i in range(num): 183 | if i != it: 184 | p = l._get_obs() 185 | if images: 186 | pi = l._get_obs(images=True) 187 | p = p.reshape((1, -1)) 188 | a = np.zeros((1, num+1)) 189 | a[:,i] = 1 190 | 191 | if images: 192 | mem = np.concatenate([np.expand_dims(pi.flatten(), 0), a], 1) 193 | else: 194 | mem = np.concatenate([p[:,:args.num], a], 1) 195 | 196 | epbuf = np.concatenate([epbuf, mem], 0) 197 | l.step(i, count = False) 198 | ln = epbuf.shape[0] 199 | buf = np.zeros((2 * args.horizon - 1, epbuf.shape[1])) 200 | buf[:ln] = epbuf 201 | else: 202 | for i in range(num): 203 | p = l._get_obs() 204 | if images: 205 | pi = l._get_obs(images=True) 206 | p = p.reshape((1, -1)) 207 | a = np.zeros((1, num+1)) 208 | a[:,i] = 1 209 | 210 | if images: 211 | mem = np.concatenate([np.expand_dims(pi.flatten(), 0), a], 1) 212 | else: 213 | mem = np.concatenate([p[:,:args.num], a], 1) 214 | 215 | if i == 0: 216 | epbuf = mem 217 | else: 218 | epbuf = np.concatenate([epbuf, mem], 0) 219 | l.step(i, count = False) 220 | buf = epbuf 221 | return buf 222 | 223 | 224 | def predict(buf, F, structure, num): 225 | '''Predict graph''' 226 | s = th.FloatTensor(buf[:,:-(num+1)]).float().cuda() 227 | a = th.FloatTensor(buf[:,-(1+num):]).float().cuda() 228 | predgt = th.clamp(F(s, a), 0, 1) 229 | return predgt.cpu().detach().numpy().flatten() 230 | 231 | def train_bc(memory, policy, opt): 232 | '''Train Imitation policy''' 233 | if len(memory['state']) < 50: 234 | return 235 | opt.zero_grad() 236 | choices = np.random.choice(len(memory['state']), 32).astype(np.int32).tolist() 237 | states = [memory['state'][c] for c in choices] 238 | graphs = [memory['graph'][c] for c in choices] 239 | actions = [memory['action'][c] for c in choices] 240 | 241 | states = th.FloatTensor(states).cuda() 242 | graphs = th.FloatTensor(graphs).cuda() 243 | actions = th.LongTensor(actions).cuda() 244 | 245 | pred_acts = policy(states, graphs) 246 | # loss = ((pred_acts - actions)**2).sum(1).mean() 247 | celoss = nn.CrossEntropyLoss() 248 | loss = celoss(pred_acts, actions) 249 | l = loss.cpu().detach().numpy() 250 | loss.backward() 251 | opt.step() 252 | return l 253 | 254 | def train_bclstm(trajs, policy, opt): 255 | '''Train imitation policy with memory''' 256 | if len(trajs) < 10: 257 | return 258 | celoss = nn.CrossEntropyLoss() 259 | opt.zero_grad() 260 | totalloss = 0 261 | choices = np.random.choice(len(trajs), 4).astype(np.int32).tolist() 262 | for t in choices: 263 | memory = trajs[t] 264 | hidden = None 265 | ## Feed interaction trajectory through policy with memory 266 | buf = memory['graph'][0] 267 | for w in range(buf.shape[0]): 268 | states = buf[w, :32*32*3].reshape(1, 32, 32, 3) 269 | sgg = np.zeros_like(states) 270 | states = np.concatenate([states, sgg], -1) 271 | actions = buf[w, 32*32*3:].reshape(1, -1) 272 | num_acts = actions.shape 273 | act, hidden = pol(th.FloatTensor(states).cuda(), th.FloatTensor(actions).cuda(), hidden) 274 | states = np.array(memory['state']) 275 | actions = np.array(memory['action']) 276 | preds = [] 277 | for w in range(states.shape[0]): 278 | a = np.zeros(num_acts) 279 | 280 | pred_acts, hidden = pol(th.FloatTensor(states[w:w+1]).cuda(), th.FloatTensor(a).cuda(), hidden) 281 | preds.append(pred_acts) 282 | preds = th.cat(preds, 0) 283 | loss = celoss(preds, th.LongTensor(actions).cuda()) 284 | totalloss += loss 285 | 286 | l = totalloss.cpu().detach().numpy() 287 | totalloss.backward() 288 | opt.step() 289 | return l 290 | 291 | def eval_bc(policy, l, train=True, f=None, args=None): 292 | '''Evaluate imation policy''' 293 | successes = [] 294 | l.keep_struct = False 295 | l.train = train 296 | # Eval over 100 trials 297 | for mep in range(100): 298 | obs = l.reset() 299 | imobs = np.expand_dims(l._get_obs(images=True), 0) 300 | goalim = np.expand_dims(l.goalim, 0) 301 | 302 | if f is None: 303 | graph = np.expand_dims(l.gt.flatten(), 0) 304 | else: 305 | buf = induction(args.structure,args.num, args.horizon, l, images=args.images) 306 | traj = buf.flatten() 307 | 308 | l.state = np.zeros((args.num)) 309 | l.traj = traj 310 | pred = predict(buf, f,args.structure, args.num) 311 | l.gt = pred 312 | graph = np.expand_dims(pred.flatten(), 0) 313 | 314 | for k in range(args.horizon * 2): 315 | st = np.concatenate([imobs, goalim], 3) 316 | act = policy(th.FloatTensor(st).cuda(), th.FloatTensor(graph).cuda()) 317 | action = act[0].argmax() 318 | 319 | obs, reward, done, info = l.step(action) 320 | if (mep == 50) and (train): 321 | print(action, obs[:5]) 322 | imobs = np.expand_dims(l._get_obs(images=True), 0) 323 | if done: 324 | break 325 | 326 | successes.append(l._is_success(obs)) 327 | return np.mean(successes) 328 | 329 | def eval_bclstm(policy, l, train=True, args=None): 330 | '''Evaluate imitation policy with memory''' 331 | successes = [] 332 | l.keep_struct = False 333 | l.train = train 334 | for mep in range(100): 335 | hidden = None 336 | obs = l.reset() 337 | imobs = np.expand_dims(l._get_obs(images=True), 0) 338 | goalim = np.expand_dims(l.goalim, 0) 339 | 340 | buf = induction(args.structure,args.num, args.horizon, l, images=args.images) 341 | 342 | l.state = np.zeros((args.num)) 343 | for w in range(buf.shape[0]): 344 | states = buf[w, :32*32*3].reshape(1, 32, 32, 3) 345 | sgg = np.zeros_like(states) 346 | states = np.concatenate([states, sgg], -1) 347 | actions = buf[w, 32*32*3:].reshape(1, -1) 348 | num_acts = actions.shape 349 | act, hidden = policy(th.FloatTensor(states).cuda(), th.FloatTensor(actions).cuda(), hidden) 350 | 351 | for k in range(args.horizon * 2): 352 | st = np.concatenate([imobs, goalim], 3) 353 | act, hidden = policy(th.FloatTensor(st).cuda(), th.FloatTensor(np.zeros(num_acts)).cuda(), hidden) 354 | action = act[0].argmax() 355 | 356 | obs, reward, done, info = l.step(action) 357 | if (mep == 50) and (train): 358 | print(action, obs[:5]) 359 | imobs = np.expand_dims(l._get_obs(images=True), 0) 360 | if done: 361 | break 362 | 363 | successes.append(l._is_success(obs)) 364 | return np.mean(successes) 365 | 366 | 367 | if __name__ == '__main__': 368 | parser = argparse.ArgumentParser(description='Causal Meta-RL') 369 | parser.add_argument('--fixed-goal', type=int, default=0, help='fixed goal or no') 370 | parser.add_argument('--horizon', type=int, default=10, help='Env horizon') 371 | parser.add_argument('--num', type=int, default=1, help='num lights') 372 | parser.add_argument('--structure', type=str, default="one_to_one", help='causal structure') 373 | parser.add_argument('--method', type=str, default="traj", help='Type of model') 374 | parser.add_argument('--seen', type=int, default=10, help='Num see envs') 375 | parser.add_argument('--images', type=int, default=0, help='Images or no') 376 | parser.add_argument('--data-dir', type=str, help='Model path') 377 | 378 | args = parser.parse_args() 379 | 380 | gc = 1 - args.fixed_goal 381 | fname = args.data_dir+"polattn_"+str(gc)+"_"+args.method 382 | 383 | memsize = 10000 384 | memory = {'state':[], 'graph':[], 'action':[]} 385 | if args.method == 'trajlstm': 386 | pol = BCPolicyMemory(args.num, args.structure).cuda() 387 | else: 388 | pol = BCPolicy(args.num, args.structure, True).cuda() 389 | optimizer = th.optim.Adam(pol.parameters(), lr=0.0001) 390 | 391 | ## Using ground truth graph 392 | if args.method == "gt": 393 | l = LightEnv(args.horizon*2, 394 | args.num, 395 | "gt", 396 | args.structure, 397 | gc, 398 | filename=fname, 399 | seen = args.seen) 400 | 401 | successes = [] 402 | l.keep_struct = False 403 | l.train = True 404 | ## Per episode 405 | for mep in range(100000): 406 | l.train = True 407 | obs = l.reset() 408 | 409 | curr = np.zeros((args.num)) 410 | obs = curr 411 | imobs = l._get_obs(images=True) 412 | goalim = l.goalim 413 | 414 | goal = l.goal 415 | ## Steps in episode 416 | for k in range(args.horizon*2): 417 | ## Use GT graph to plan 418 | g = np.abs(goal - obs[:args.num]) 419 | st = np.concatenate([imobs, goalim], 2) 420 | sss = 1.0*(np.dot(g, l.aj.T).T > 0.5) 421 | 422 | if args.structure == "masterswitch": 423 | sss[l.ms] = 0 424 | if sss.max() == 0: 425 | break 426 | 427 | action = np.argmax(sss) 428 | if args.structure == "masterswitch": 429 | if obs[:5].max() == 0: 430 | action = l.ms 431 | memory['state'].append(st) 432 | memory['graph'].append(l.gt.flatten()) 433 | memory['action'].append(action) 434 | 435 | ## Random noise to policy 436 | if np.random.uniform() < 0.3: 437 | action = np.random.randint(args.num) 438 | else: 439 | graph = np.expand_dims(l.gt.flatten(), 0) 440 | act = pol(th.FloatTensor(np.expand_dims(st, 0)).cuda(), th.FloatTensor(graph).cuda()) 441 | action = act[0].argmax() 442 | 443 | obs, reward, done, info = l.step(action) 444 | imobs = l._get_obs(images=True) 445 | if done: 446 | break 447 | 448 | g = np.abs(goal - obs[:args.num]) 449 | st = np.concatenate([imobs, goalim], 2) 450 | sss = 1.0*(np.dot(g, l.aj.T).T > 0.5) 451 | 452 | if args.structure == "masterswitch": 453 | if sss[l.ms]: 454 | st = np.concatenate([imobs, goalim], 2) 455 | memory['state'].append(st) 456 | memory['graph'].append(l.gt.flatten()) 457 | memory['action'].append(l.ms) 458 | obs, reward, done, info = l.step(l.ms) 459 | memory['state'] = memory['state'][-memsize:] 460 | memory['graph'] = memory['graph'][-memsize:] 461 | memory['action'] = memory['action'][-memsize:] 462 | for _ in range(1): 463 | loss = train_bc(memory, pol, optimizer) 464 | if mep % 1000 == 0: 465 | print("Episode", mep, "Loss:" , loss ) 466 | trainsc = eval_bc(pol, l, True, args=args) 467 | testsc = eval_bc(pol, l, False, args=args) 468 | with open(fname + "_S" + str(args.seen) + \ 469 | "_"+str(args.structure)+"_H"+str(args.horizon)+\ 470 | "_N"+str(args.num)+"_Ttrainsuccessrate.txt", "a") as f: 471 | f.write(str(float(trainsc)) + "\n") 472 | with open(fname + "_S" + str(args.seen) + \ 473 | "_"+str(args.structure)+"_H"+str(args.horizon)+\ 474 | "_N"+str(args.num)+"_Ttestsuccessrate.txt", "a") as f: 475 | f.write(str(float(testsc)) + "\n") 476 | 477 | print("Train Success Rate:", trainsc) 478 | print("Test Success Rate:", testsc) 479 | 480 | successes.append(l._is_success(obs)) 481 | print(np.mean(successes)) 482 | ## If using learning induction model 483 | elif (args.method == "trajF") or (args.method == "trajFi") or (args.method == "trajFia"): 484 | if args.structure == "masterswitch": 485 | st = (args.horizon*(2*args.num+1) + (args.horizon-1)*(2*args.num+1)) 486 | else: 487 | st = (args.horizon*(2*args.num+1)) 488 | tj = "gt" 489 | l = LightEnv(args.horizon*2, 490 | args.num, 491 | tj, 492 | args.structure, 493 | gc, 494 | filename=fname, 495 | seen = args.seen) 496 | 497 | if args.images: 498 | addonn = "_I1" 499 | else: 500 | addonn = "" 501 | 502 | if args.method == "trajF": 503 | FN = th.load(args.data_dir+"cnn_Redo_L2_S"+str(args.seen)+"_h"+str(args.horizon)+\ 504 | "_"+str(args.structure)+addonn).cuda() 505 | elif args.method == "trajFia": 506 | FN = th.load(args.data_dir+"iter_attn_Redo_L2_S"+str(args.seen)+"_h"+str(args.horizon)+\ 507 | "_"+str(args.structure)+addonn).cuda() 508 | else: 509 | FN = th.load(args.data_dir+"iter_Redo_L2_S"+str(args.seen)+"_h"+str(args.horizon)+\ 510 | "_"+str(args.structure)+addonn).cuda() 511 | FN = FN.eval() 512 | successes = [] 513 | l.keep_struct = False 514 | l.train = False 515 | for mep in range(100000): 516 | l.train = True 517 | obs = l.reset() 518 | goalim = l.goalim 519 | imobs = l._get_obs(images=True) 520 | 521 | ## Predict Graph 522 | buf = induction(args.structure,args.num, args.horizon, l, images=args.images) 523 | traj = buf.flatten() 524 | pred = predict(buf, FN,args.structure, args.num) 525 | l.state = np.zeros((args.num)) 526 | 527 | curr = np.zeros((args.num)) 528 | obs = curr 529 | 530 | goal = l.goal 531 | for k in range(args.horizon*2): 532 | ## Planning 533 | g = np.abs(goal - obs[:args.num]) 534 | st = np.concatenate([imobs, goalim], 2) 535 | sss = 1.0*(np.dot(g, l.aj.T).T > 0.5) 536 | 537 | if args.structure == "masterswitch": 538 | sss[l.ms] = 0 539 | if sss.max() == 0: 540 | break 541 | 542 | action = np.argmax(sss) 543 | if args.structure == "masterswitch": 544 | if obs[:5].max() == 0: 545 | action = l.ms 546 | memory['state'].append(st) 547 | memory['graph'].append(pred.flatten()) 548 | memory['action'].append(action) 549 | 550 | ## Random Noise 551 | if np.random.uniform() < 0.3: 552 | action = np.random.randint(args.num) 553 | else: 554 | pred = predict(buf, FN,args.structure, args.num) 555 | graph = np.expand_dims(pred.flatten(), 0) 556 | act = pol(th.FloatTensor(np.expand_dims(st, 0)).cuda(), th.FloatTensor(graph).cuda()) 557 | action = act[0].argmax() 558 | 559 | 560 | obs, reward, done, info = l.step(action) 561 | imobs = l._get_obs(images=True) 562 | if done: 563 | break 564 | 565 | g = np.abs(goal - obs[:args.num]) 566 | st = np.concatenate([imobs, goalim], 2) 567 | sss = 1.0*(np.dot(g, l.aj.T).T > 0.5) 568 | if args.structure == "masterswitch": 569 | if sss[l.ms]: 570 | st = np.concatenate([imobs, goalim], 2) 571 | memory['state'].append(st) 572 | memory['graph'].append(pred.flatten()) 573 | memory['action'].append(l.ms) 574 | obs, reward, done, info = l.step(l.ms) 575 | 576 | memory['state'] = memory['state'][-memsize:] 577 | memory['graph'] = memory['graph'][-memsize:] 578 | memory['action'] = memory['action'][-memsize:] 579 | for _ in range(1): 580 | loss = train_bc(memory, pol, optimizer) 581 | if mep % 1000 == 0: 582 | print("Episode", mep, "Loss:" , loss ) 583 | trainsc = eval_bc(pol, l, True, f=FN, args=args) 584 | testsc = eval_bc(pol, l, False, f=FN, args=args) 585 | with open(fname + "_S" + str(args.seen) + \ 586 | "_"+str(args.structure)+"_H"+str(args.horizon)+\ 587 | "_N"+str(args.num)+"_Ttrainsuccessrate.txt", "a") as f: 588 | f.write(str(float(trainsc)) + "\n") 589 | with open(fname + "_S" + str(args.seen) + \ 590 | "_"+str(args.structure)+"_H"+str(args.horizon)+\ 591 | "_N"+str(args.num)+"_Ttestsuccessrate.txt", "a") as f: 592 | f.write(str(float(testsc)) + "\n") 593 | 594 | print("Train Success Rate:", trainsc) 595 | print("Test Success Rate:", testsc) 596 | 597 | successes.append(l._is_success(obs)) 598 | print(np.mean(successes)) 599 | elif (args.method == "trajlstm"): 600 | if args.structure == "masterswitch": 601 | st = (args.horizon*(2*args.num+1) + (args.horizon-1)*(2*args.num+1)) 602 | else: 603 | st = (args.horizon*(2*args.num+1)) 604 | tj = "gt" 605 | l = LightEnv(args.horizon*2, 606 | args.num, 607 | tj, 608 | args.structure, 609 | gc, 610 | filename=fname, 611 | seen = args.seen) 612 | 613 | if args.images: 614 | addonn = "_I1" 615 | else: 616 | addonn = "" 617 | 618 | successes = [] 619 | l.keep_struct = False 620 | l.train = False 621 | memsize = 100 622 | trajs = [] 623 | 624 | for mep in range(100000): 625 | memory = {'state':[], 'graph':[], 'action':[]} 626 | hidden = None 627 | l.train = True 628 | obs = l.reset() 629 | goalim = l.goalim 630 | imobs = l._get_obs(images=True) 631 | 632 | ## Get interction trajectory 633 | buf = induction(args.structure,args.num, args.horizon, l, images=args.images) 634 | memory['graph'].append(buf) 635 | for w in range(buf.shape[0]): 636 | states = buf[w, :32*32*3].reshape(1, 32, 32, 3) 637 | sgg = np.zeros_like(states) 638 | states = np.concatenate([states, sgg], -1) 639 | actions = buf[w, 32*32*3:].reshape(1, -1) 640 | act, hidden = pol(th.FloatTensor(states).cuda(), th.FloatTensor(actions).cuda(), hidden) 641 | l.state = np.zeros((args.num)) 642 | 643 | curr = np.zeros((args.num)) 644 | obs = curr 645 | 646 | goal = l.goal 647 | for k in range(args.horizon*2): 648 | ## Planning 649 | g = np.abs(goal - obs[:args.num]) 650 | st = np.concatenate([imobs, goalim], 2) 651 | sss = 1.0*(np.dot(g, l.aj.T).T > 0.5) 652 | 653 | if args.structure == "masterswitch": 654 | sss[l.ms] = 0 655 | if sss.max() == 0: 656 | break 657 | 658 | action = np.argmax(sss) 659 | if args.structure == "masterswitch": 660 | if obs[:5].max() == 0: 661 | action = l.ms 662 | memory['state'].append(st) 663 | memory['action'].append(action) 664 | 665 | ## Policy Noise 666 | if np.random.uniform() < 0.3: 667 | action = np.random.randint(args.num) 668 | else: 669 | act, s_hidden = pol(th.FloatTensor(states).cuda(), th.FloatTensor(actions).cuda(), hidden) 670 | action = act[0].argmax() 671 | 672 | 673 | obs, reward, done, info = l.step(action) 674 | imobs = l._get_obs(images=True) 675 | if done: 676 | break 677 | 678 | g = np.abs(goal - obs[:args.num]) 679 | st = np.concatenate([imobs, goalim], 2) 680 | sss = 1.0*(np.dot(g, l.aj.T).T > 0.5) 681 | if args.structure == "masterswitch": 682 | if sss[l.ms]: 683 | st = np.concatenate([imobs, goalim], 2) 684 | memory['state'].append(st) 685 | memory['action'].append(l.ms) 686 | obs, reward, done, info = l.step(l.ms) 687 | 688 | 689 | 690 | if len(memory['state']) != 0: 691 | trajs.append(memory) 692 | trajs = trajs[-memsize:] 693 | for _ in range(1): 694 | loss = train_bclstm(trajs, pol, optimizer) 695 | if mep % 1000 == 0: 696 | print("Episode", mep, "Loss:" , loss ) 697 | trainsc = eval_bclstm(pol, l, True, args=args) 698 | testsc = eval_bclstm(pol, l, False, args=args) 699 | with open(fname + "_S" + str(args.seen) + \ 700 | "_"+str(args.structure)+"_H"+str(args.horizon)+\ 701 | "_N"+str(args.num)+"_Ttrainsuccessrate.txt", "a") as f: 702 | f.write(str(float(trainsc)) + "\n") 703 | with open(fname + "_S" + str(args.seen) + \ 704 | "_"+str(args.structure)+"_H"+str(args.horizon)+\ 705 | "_N"+str(args.num)+"_Ttestsuccessrate.txt", "a") as f: 706 | f.write(str(float(testsc)) + "\n") 707 | 708 | print("Train Success Rate:", trainsc) 709 | print("Test Success Rate:", testsc) 710 | 711 | successes.append(l._is_success(obs)) 712 | print(np.mean(successes)) --------------------------------------------------------------------------------