├── .gitignore
├── env
├── assets
│ ├── textures
│ │ ├── can.png
│ │ ├── bread.png
│ │ ├── cereal.png
│ │ ├── clay.png
│ │ ├── glass.png
│ │ ├── lemon.png
│ │ ├── metal.png
│ │ ├── ceramic.png
│ │ ├── dark-wood.png
│ │ └── light-wood.png
│ ├── objects
│ │ ├── meshes
│ │ │ ├── bread.stl
│ │ │ ├── can.stl
│ │ │ ├── lemon.stl
│ │ │ ├── milk.stl
│ │ │ ├── bottle.stl
│ │ │ ├── cereal.stl
│ │ │ └── handles.stl
│ │ ├── can-visual.xml
│ │ ├── milk-visual.xml
│ │ ├── cereal-visual.xml
│ │ ├── bread-visual.xml
│ │ ├── bottle.xml
│ │ ├── can.xml
│ │ ├── lemon.xml
│ │ ├── milk.xml
│ │ ├── plate-with-hole.xml
│ │ ├── bread.xml
│ │ ├── cereal.xml
│ │ ├── square-nut.xml
│ │ ├── pole.xml
│ │ ├── spinning_pole.xml
│ │ └── round-nut.xml
│ ├── arena_v2_5.xml
│ ├── arena_v2_5_vis.xml
│ ├── arena_v2_9.xml
│ └── arena_v2_7.xml
└── light_env.py
├── README.md
├── scripts
├── planner.sh
├── trainF.sh
├── evalF.sh
└── gendata.sh
├── trainF.py
├── collectdata.py
├── evalF.py
├── F_models.py
└── learn_planner.py
/.gitignore:
--------------------------------------------------------------------------------
1 | ppo_lighting_tb/
2 | .ipynb_checkpoints/
3 | *.pyc
4 | data/
5 | figs*
6 | exp*
7 |
--------------------------------------------------------------------------------
/env/assets/textures/can.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StanfordVL/causal_induction/HEAD/env/assets/textures/can.png
--------------------------------------------------------------------------------
/env/assets/textures/bread.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StanfordVL/causal_induction/HEAD/env/assets/textures/bread.png
--------------------------------------------------------------------------------
/env/assets/textures/cereal.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StanfordVL/causal_induction/HEAD/env/assets/textures/cereal.png
--------------------------------------------------------------------------------
/env/assets/textures/clay.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StanfordVL/causal_induction/HEAD/env/assets/textures/clay.png
--------------------------------------------------------------------------------
/env/assets/textures/glass.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StanfordVL/causal_induction/HEAD/env/assets/textures/glass.png
--------------------------------------------------------------------------------
/env/assets/textures/lemon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StanfordVL/causal_induction/HEAD/env/assets/textures/lemon.png
--------------------------------------------------------------------------------
/env/assets/textures/metal.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StanfordVL/causal_induction/HEAD/env/assets/textures/metal.png
--------------------------------------------------------------------------------
/env/assets/textures/ceramic.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StanfordVL/causal_induction/HEAD/env/assets/textures/ceramic.png
--------------------------------------------------------------------------------
/env/assets/objects/meshes/bread.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StanfordVL/causal_induction/HEAD/env/assets/objects/meshes/bread.stl
--------------------------------------------------------------------------------
/env/assets/objects/meshes/can.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StanfordVL/causal_induction/HEAD/env/assets/objects/meshes/can.stl
--------------------------------------------------------------------------------
/env/assets/objects/meshes/lemon.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StanfordVL/causal_induction/HEAD/env/assets/objects/meshes/lemon.stl
--------------------------------------------------------------------------------
/env/assets/objects/meshes/milk.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StanfordVL/causal_induction/HEAD/env/assets/objects/meshes/milk.stl
--------------------------------------------------------------------------------
/env/assets/textures/dark-wood.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StanfordVL/causal_induction/HEAD/env/assets/textures/dark-wood.png
--------------------------------------------------------------------------------
/env/assets/textures/light-wood.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StanfordVL/causal_induction/HEAD/env/assets/textures/light-wood.png
--------------------------------------------------------------------------------
/env/assets/objects/meshes/bottle.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StanfordVL/causal_induction/HEAD/env/assets/objects/meshes/bottle.stl
--------------------------------------------------------------------------------
/env/assets/objects/meshes/cereal.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StanfordVL/causal_induction/HEAD/env/assets/objects/meshes/cereal.stl
--------------------------------------------------------------------------------
/env/assets/objects/meshes/handles.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StanfordVL/causal_induction/HEAD/env/assets/objects/meshes/handles.stl
--------------------------------------------------------------------------------
/env/assets/objects/can-visual.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
--------------------------------------------------------------------------------
/env/assets/objects/milk-visual.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/env/assets/objects/cereal-visual.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
--------------------------------------------------------------------------------
/env/assets/objects/bread-visual.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/env/assets/objects/bottle.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/env/assets/objects/can.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/env/assets/objects/lemon.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/env/assets/objects/milk.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/env/assets/objects/plate-with-hole.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
--------------------------------------------------------------------------------
/env/assets/objects/bread.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/env/assets/objects/cereal.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Code and Environment for [Causal Induction From Visual Observations for Goal-Directed Tasks](https://arxiv.org/pdf/1910.01751.pdf).
2 |
3 | ## Environment
4 |
5 | Consists of the light switch environment for studying visual causal induction, where N switches control N lights, under various causal structures. Includes common cause, common effect, and causal chain relationships. Environment code resides under `env/light_env.py`.
6 |
7 | ## Induction Models
8 |
9 | The different induction models used are located under `F_models.py`, incuding our proposed iterative attention network, as well as baselines which do not use attention or use temporal convolutions.
10 |
11 | ## Reproducing Experiments
12 |
13 | Step 1: Generate Data
14 |
15 | `python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure one_to_one --seen 10 --images 1 --data-dir output/`
16 |
17 | Step 2: Train Induction Model
18 |
19 | `python3 trainF.py --horizon 7 --num 7 --fixed-goal 0 --structure one_to_one --type iter --images 1 --seen 10 --data-dir output/`
20 |
21 | Step 3: Eval Induction Model
22 |
23 | `python3 evalF.py --horizon 7 --num 7 --fixed-goal 0 --structure one_to_one --method trajFi --images 1 --seen 10 --data-dir output/`
24 |
25 | Step 4: Train Policy via Imitation
26 |
27 | `python3 learn_planner.py --horizon 7 --num 7 --fixed-goal 0 --structure one_to_one --method trajFi --seen 10 --images 1 --data-dir output/`
28 |
29 |
30 |
--------------------------------------------------------------------------------
/scripts/planner.sh:
--------------------------------------------------------------------------------
1 | ST=masterswitch
2 | SN=100
3 |
4 | # python3 learn_planner.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method gt --seen $SN --images 1
5 | python3 learn_planner.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method trajF --seen $SN --images 1
6 | python3 learn_planner.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method trajFi --seen $SN --images 1
7 | python3 learn_planner.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method trajFia --seen $SN --images 1
8 |
9 | # python3 learn_planner.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method gt --seen 50 --images 1
10 | # python3 learn_planner.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method trajF --seen 50 --images 1
11 | # python3 learn_planner.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method trajFi --seen 50 --images 1
12 | # python3 learn_planner.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method trajFia --seen 50 --images 1
13 |
14 | # python3 learn_planner.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method gt --seen 100 --images 1
15 | # python3 learn_planner.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method trajF --seen 100 --images 1
16 | # python3 learn_planner.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method trajFi --seen 100 --images 1
17 | # python3 learn_planner.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method trajFia --seen 100 --images 1
18 |
19 |
20 |
--------------------------------------------------------------------------------
/scripts/trainF.sh:
--------------------------------------------------------------------------------
1 | MT=iter_attn
2 | ST=masterswitch
3 |
4 | python3 trainF.py --horizon 9 --num 9 --fixed-goal 0 --structure $ST --type $MT --images 1 --seen 10
5 | python3 trainF.py --horizon 9 --num 9 --fixed-goal 0 --structure $ST --type $MT --images 1 --seen 50
6 |
7 | ST=one_to_one
8 |
9 | python3 trainF.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --type $MT --images 1 --seen 10
10 | python3 trainF.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --type $MT --images 1 --seen 50
11 |
12 |
13 |
14 | # python3 trainF.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --type $MT --images 1 --seen 100
15 | # python3 trainF.py --horizon 6 --num 6 --fixed-goal 0 --structure $ST --type $MT --images 1 --seen 10
16 | # python3 trainF.py --horizon 6 --num 6 --fixed-goal 0 --structure $ST --type $MT --images 1 --seen 50
17 | # python3 trainF.py --horizon 6 --num 6 --fixed-goal 0 --structure $ST --type $MT --images 1 --seen 100
18 | # python3 trainF.py --horizon 6 --num 6 --fixed-goal 0 --structure $ST --type $MT --images 1 --seen 500
19 | # python3 trainF.py --horizon 7 --num 7 --fixed-goal 0 --structure $ST --type $MT --images 1 --seen 10
20 | # python3 trainF.py --horizon 7 --num 7 --fixed-goal 0 --structure $ST --type $MT --images 1 --seen 50
21 | # python3 trainF.py --horizon 7 --num 7 --fixed-goal 0 --structure $ST --type $MT --images 1 --seen 100
22 | # python3 trainF.py --horizon 7 --num 7 --fixed-goal 0 --structure $ST --type $MT --images 1 --seen 500
23 |
--------------------------------------------------------------------------------
/env/assets/objects/square-nut.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/env/assets/objects/pole.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/env/assets/objects/spinning_pole.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 | 0
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/env/assets/objects/round-nut.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
--------------------------------------------------------------------------------
/scripts/evalF.sh:
--------------------------------------------------------------------------------
1 | ST=one_to_one
2 | SN=500
3 | # ST=one_to_one
4 |
5 | # python3 evalF.py --horizon 7 --num 7 --fixed-goal 0 --structure $ST --method trajF --images 1 --seen $SN
6 | # python3 evalF.py --horizon 7 --num 7 --fixed-goal 0 --structure $ST --method trajFi --images 1 --seen $SN
7 | # python3 evalF.py --horizon 7 --num 7 --fixed-goal 0 --structure $ST --method trajFia --images 1 --seen $SN
8 |
9 | # ST=one_to_many
10 |
11 | # python3 evalF.py --horizon 7 --num 7 --fixed-goal 0 --structure $ST --method trajF --images 1 --seen $SN
12 | # python3 evalF.py --horizon 7 --num 7 --fixed-goal 0 --structure $ST --method trajFi --images 1 --seen $SN
13 | # python3 evalF.py --horizon 7 --num 7 --fixed-goal 0 --structure $ST --method trajFia --images 1 --seen $SN
14 |
15 | # ST=many_to_one
16 |
17 | # python3 evalF.py --horizon 7 --num 7 --fixed-goal 0 --structure $ST --method trajF --images 1 --seen $SN
18 | # python3 evalF.py --horizon 7 --num 7 --fixed-goal 0 --structure $ST --method trajFi --images 1 --seen $SN
19 | # python3 evalF.py --horizon 7 --num 7 --fixed-goal 0 --structure $ST --method trajFia --images 1 --seen $SN
20 |
21 | ST=masterswitch
22 |
23 | python3 evalF.py --horizon 7 --num 7 --fixed-goal 0 --structure $ST --method trajF --images 1 --seen $SN
24 | python3 evalF.py --horizon 7 --num 7 --fixed-goal 0 --structure $ST --method trajFi --images 1 --seen $SN
25 | python3 evalF.py --horizon 7 --num 7 --fixed-goal 0 --structure $ST --method trajFia --images 1 --seen $SN
26 |
27 |
28 | # python3 evalF.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method $MT --images 1 --seen 500
29 |
30 | # MT=trajFi
31 |
32 | # python3 evalF.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method $MT --images 1 --seen 10
33 | # python3 evalF.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method $MT --images 1 --seen 50
34 | # python3 evalF.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method $MT --images 1 --seen 100
35 | # # python3 evalF.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method $MT --images 1 --seen 500
36 |
37 | # MT=trajF
38 |
39 | # python3 evalF.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method $MT --images 1 --seen 10
40 | # python3 evalF.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method $MT --images 1 --seen 50
41 | # python3 evalF.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method $MT --images 1 --seen 100
42 | # python3 evalF.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method $MT --images 1 --seen 500
43 |
44 |
45 | # MT=trajFia
46 | # ST=one_to_one
47 |
48 | # python3 evalF.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method $MT --images 1 --seen 10
49 | # python3 evalF.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method $MT --images 1 --seen 50
50 | # python3 evalF.py --horizon 5 --num 5 --fixed-goal 0 --structure $ST --method $MT --images 1 --seen 100
51 |
52 | # python3 evalF.py --horizon 6 --num 6 --fixed-goal 0 --structure $ST --method $MT --images 1 --seen 10
53 | # python3 evalF.py --horizon 6 --num 6 --fixed-goal 0 --structure $ST --method $MT --images 1 --seen 50
54 | # python3 evalF.py --horizon 6 --num 6 --fixed-goal 0 --structure $ST --method $MT --images 1 --seen 100
55 | # python3 evalF.py --horizon 6 --num 6 --fixed-goal 0 --structure $ST --method $MT --images 1 --seen 500
56 | # python3 evalF.py --horizon 7 --num 7 --fixed-goal 0 --structure $ST --method $MT --images 1 --seen 10
57 | # python3 evalF.py --horizon 7 --num 7 --fixed-goal 0 --structure $ST --method $MT --images 1 --seen 50
58 | # python3 evalF.py --horizon 7 --num 7 --fixed-goal 0 --structure $ST --method $MT --images 1 --seen 100
59 | # python3 evalF.py --horizon 7 --num 7 --fixed-goal 0 --structure $ST --method $MT --images 1 --seen 500
60 |
--------------------------------------------------------------------------------
/trainF.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch as th
4 | import argparse
5 |
6 | from F_models import SupervisedModelCNN, IterativeModel, IterativeModelAttention
7 |
8 |
9 | def train_supervised(F, buf, gtbuf, num, steps = 1, bs=32, images=False):
10 | buf = th.FloatTensor(buf).float()
11 | gtbuf = th.FloatTensor(gtbuf).float()
12 | optimizer = th.optim.Adam(F.parameters(), lr=0.0001)
13 | for step in range(steps):
14 | optimizer.zero_grad()
15 | perm = th.randperm(buf.size(0)-5000)
16 | testperm = th.randperm(5000) + 35000
17 |
18 | idx = perm[:bs]
19 | samples = buf[idx]
20 | gts= gtbuf[idx]
21 | testidx = testperm[:bs]
22 | testsamples = buf[testidx]
23 | testgts= gtbuf[testidx]
24 |
25 | if images:
26 | split = 32*32*3
27 | else:
28 | split = num
29 |
30 | states = samples[:, :, :split].contiguous().view(bs, -1).cuda()
31 | actions = samples[:, :, split:].contiguous().view(bs, -1).cuda()
32 | groundtruth = gts.cuda()
33 | pred = F(states, actions)
34 |
35 | teststates = testsamples[:, :, :split].contiguous().view(bs, -1).cuda()
36 | testactions = testsamples[:, :, split:].contiguous().view(bs, -1).cuda()
37 | testgroundtruth = testgts.cuda()
38 | testpred = F(teststates, testactions)
39 |
40 | loss = ((pred - groundtruth)**2).sum(1).mean()
41 | testloss = ((testpred - testgroundtruth)**2).sum(1).mean()
42 |
43 | loss.backward()
44 | if step % 1000 == 0:
45 | print((loss / num).cpu().detach().numpy())
46 | print((testloss / num).cpu().detach().numpy())
47 | print(pred[0], groundtruth[0])
48 | print(step)
49 | print("_"*50)
50 | optimizer.step()
51 |
52 | if __name__ == '__main__':
53 | parser = argparse.ArgumentParser(description='Causal Meta-RL')
54 | parser.add_argument('--fixed-goal', type=int, default=0, help='fixed goal or no')
55 | parser.add_argument('--horizon', type=int, default=10, help='Env horizon')
56 | parser.add_argument('--num', type=int, default=1, help='Num lights')
57 | parser.add_argument('--structure', type=str, default="one_to_one", help='Causal Structure')
58 | parser.add_argument('--type', type=str, default="cnn", help='Model Type')
59 | parser.add_argument('--seen', type=int, default=10, help='Num seen')
60 | parser.add_argument('--images', type=int, default=0, help='Images or no')
61 | parser.add_argument('--data-dir', type=str, help='Data Dir')
62 |
63 | args = parser.parse_args()
64 |
65 |
66 | if args.type == "cnn":
67 | if args.structure == "masterswitch":
68 | msv = True
69 | F = SupervisedModelCNN(2*args.horizon -1,args.num, ms = msv, images=args.images).cuda()
70 | else:
71 | msv = False
72 | F = SupervisedModelCNN(args.horizon,args.num, ms = msv, images=args.images).cuda()
73 | elif args.type == "iter":
74 | if args.structure == "masterswitch":
75 | msv = True
76 | F = IterativeModel(2*args.horizon -1,args.num, ms = msv, images=args.images).cuda()
77 | else:
78 | msv = False
79 | F = IterativeModel(args.horizon, args.num, ms = msv, images=args.images).cuda()
80 | elif args.type == "iter_attn":
81 | if args.structure == "masterswitch":
82 | msv = True
83 | F = IterativeModelAttention(2*args.horizon -1,args.num, ms = msv, images=args.images).cuda()
84 | else:
85 | msv = False
86 | F = IterativeModelAttention(args.horizon, args.num, ms = msv, images=args.images).cuda()
87 | else:
88 | raise NotImplementedError
89 |
90 | if args.images:
91 | addonn = "_I1"
92 | else:
93 | addonn = ""
94 |
95 | a = np.load(args.data_dir+ "buf40K_S"+str(args.seen)+\
96 | "_"+str(args.structure)+"_"+str(args.horizon) + addonn + ".npy")
97 | a2 = np.load(args.data_dir+ "gtbuf40K_S"+str(args.seen)+\
98 | "_"+str(args.structure)+"_"+str(args.horizon) + addonn + ".npy")
99 |
100 | print(a.shape, a2.shape)
101 | train_supervised(F, a, a2, args.num, steps=2000, bs=512, images=args.images)
102 | th.save(F, args.data_dir+\
103 | str(args.type)+"_Redo_L2_S"+str(args.seen)+"_h"+str(args.horizon)+"_"+str(args.structure) \
104 | + "_I"+str(args.images))
--------------------------------------------------------------------------------
/collectdata.py:
--------------------------------------------------------------------------------
1 | from env.light_env import LightEnv
2 | import numpy as np
3 | from stable_baselines.common.vec_env import DummyVecEnv, SubprocVecEnv
4 | import argparse
5 | import torch as th
6 | import cv2
7 | import os
8 |
9 | if __name__ == '__main__':
10 | parser = argparse.ArgumentParser(description='Causal Meta-RL')
11 | parser.add_argument('--fixed-goal', type=int, default=0, help='fixed goal or no')
12 | parser.add_argument('--horizon', type=int, default=5, help='Env horizon')
13 | parser.add_argument('--num', type=int, default=5, help='Num Switches')
14 | parser.add_argument('--structure', type=str, default="one_to_one", help='Graph Structure')
15 | parser.add_argument('--seen', type=int, default=10, help='Number of seen environments')
16 | parser.add_argument('--images', type=int, default=0, help='Use Images')
17 | parser.add_argument('--data-dir', type=str, help='Directory to Store Data')
18 | args = parser.parse_args()
19 |
20 | ## Init Buffer
21 | gc = 1 - args.fixed_goal
22 | buffer = []
23 | gtbuffer = []
24 | num_episodes = 40000
25 |
26 | ## Set Horizon Based On Task
27 | if args.structure == "masterswitch":
28 | st = (args.horizon*(2*args.num+1) + (args.horizon-1)*(2*args.num+1))
29 | else:
30 | st = (args.horizon*(2*args.num+1))
31 |
32 | ## Init Env
33 | l = LightEnv(args.horizon,
34 | args.num,
35 | st,
36 | args.structure,
37 | gc,
38 | filename=str(gc)+"_traj",
39 | seen = args.seen)
40 | env = DummyVecEnv(1 * [lambda: l])
41 |
42 |
43 | for q in range(num_episodes):
44 | ## Reset Causal Structure
45 | l.keep_struct = False
46 | obs = env.reset()
47 | l.keep_struct = True
48 | ##### INDUCTION #####
49 | ##### OPTIMAL POLICY 1
50 | if args.structure == "masterswitch":
51 | it = None
52 | for i in range(args.num):
53 | p = l._get_obs()
54 | if args.images:
55 | pi = l._get_obs(images=True)
56 | p = p.reshape((1, -1))
57 | a = np.zeros((1, args.num+1))
58 | a[:,i] = 1
59 |
60 | if args.images:
61 | mem = np.concatenate([np.expand_dims(pi.flatten(), 0), a], 1)
62 | else:
63 | mem = np.concatenate([p[:,:args.num], a], 1)
64 |
65 | if i == 0:
66 | epbuf = mem
67 | else:
68 | epbuf = np.concatenate([epbuf, mem], 0)
69 | l.step(i, count = False)
70 | p2 = l._get_obs()
71 | if (p != p2).any():
72 | it = i
73 | break
74 | for i in range(args.num):
75 | if i != it:
76 | p = l._get_obs()
77 | if args.images:
78 | pi = l._get_obs(images=True)
79 | p = p.reshape((1, -1))
80 | a = np.zeros((1, args.num+1))
81 | a[:,i] = 1
82 |
83 | if args.images:
84 | mem = np.concatenate([np.expand_dims(pi.flatten(), 0), a], 1)
85 | else:
86 | mem = np.concatenate([p[:,:args.num], a], 1)
87 |
88 | epbuf = np.concatenate([epbuf, mem], 0)
89 | l.step(i, count = False)
90 | ln = epbuf.shape[0]
91 | buf = np.zeros((2 * args.horizon - 1, epbuf.shape[1]))
92 | buf[:ln] = epbuf
93 | else:
94 | for i in range(args.num):
95 | p = l._get_obs()
96 | if args.images:
97 | pi = l._get_obs(images=True)
98 | p = p.reshape((1, -1))
99 | a = np.zeros((1, args.num+1))
100 | a[:,i] = 1
101 |
102 | if args.images:
103 | mem = np.concatenate([np.expand_dims(pi.flatten(), 0), a], 1)
104 | else:
105 | mem = np.concatenate([p[:,:args.num], a], 1)
106 |
107 | if i == 0:
108 | epbuf = mem
109 | else:
110 | epbuf = np.concatenate([epbuf, mem], 0)
111 | l.step(i, count = False)
112 | buf = epbuf
113 |
114 | buffer.append(buf)
115 | gtbuffer.append(l.gt)
116 | if q % 10000 == 0:
117 | print(q)
118 | buffer = np.stack(buffer, 0)
119 | gtbuffer = np.stack(gtbuffer, 0)
120 | print(buffer.shape)
121 | print(gtbuffer.shape)
122 |
123 | np.save(args.data_dir+"buf40K_S"+\
124 | str(args.seen)+"_"+str(args.structure)+"_"+str(args.horizon)+\
125 | "_I"+str(args.images), buffer)
126 | np.save(args.data_dir+"gtbuf40K_S"+\
127 | str(args.seen)+"_"+str(args.structure)+"_"+str(args.horizon)+\
128 | "_I"+str(args.images), gtbuffer)
129 |
130 |
131 |
132 |
--------------------------------------------------------------------------------
/evalF.py:
--------------------------------------------------------------------------------
1 | from env.light_env import LightEnv
2 | import numpy as np
3 | from stable_baselines.common.vec_env import DummyVecEnv, SubprocVecEnv
4 | import argparse
5 | import torch as th
6 | import cv2
7 |
8 | def induction(structure, num, horizon, l, images=False):
9 | '''Heuristic Policy to collect interaction data'''
10 | ##### INDUCTION #####
11 | ##### OPTIMAL POLICY 1
12 | if structure == "masterswitch":
13 | it = None
14 | for i in range(num):
15 | p = l._get_obs()
16 | if images:
17 | pi = l._get_obs(images=True)
18 | p = p.reshape((1, -1))
19 | a = np.zeros((1, num+1))
20 | a[:,i] = 1
21 |
22 | if images:
23 | mem = np.concatenate([np.expand_dims(pi.flatten(), 0), a], 1)
24 | else:
25 | mem = np.concatenate([p[:,:args.num], a], 1)
26 |
27 | if i == 0:
28 | epbuf = mem
29 | else:
30 | epbuf = np.concatenate([epbuf, mem], 0)
31 | l.step(i, count = False)
32 | p2 = l._get_obs()
33 | if (p != p2).any():
34 | it = i
35 | break
36 | for i in range(num):
37 | if i != it:
38 | p = l._get_obs()
39 | if images:
40 | pi = l._get_obs(images=True)
41 | p = p.reshape((1, -1))
42 | a = np.zeros((1, num+1))
43 | a[:,i] = 1
44 |
45 | if images:
46 | mem = np.concatenate([np.expand_dims(pi.flatten(), 0), a], 1)
47 | else:
48 | mem = np.concatenate([p[:,:args.num], a], 1)
49 |
50 | epbuf = np.concatenate([epbuf, mem], 0)
51 | l.step(i, count = False)
52 | ln = epbuf.shape[0]
53 | buf = np.zeros((2 * args.horizon - 1, epbuf.shape[1]))
54 | buf[:ln] = epbuf
55 | else:
56 | for i in range(num):
57 | p = l._get_obs()
58 | if images:
59 | pi = l._get_obs(images=True)
60 | # im = l.sim.render(width=480,height=480,camera_name="birdview")
61 | # cv2.imwrite('o'+str(i)+'.png', cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
62 | p = p.reshape((1, -1))
63 | a = np.zeros((1, num+1))
64 |
65 | a[:,i] = 1
66 |
67 | if images:
68 | mem = np.concatenate([np.expand_dims(pi.flatten(), 0), a], 1)
69 | else:
70 | mem = np.concatenate([p[:,:args.num], a], 1)
71 |
72 | if i == 0:
73 | epbuf = mem
74 | else:
75 | epbuf = np.concatenate([epbuf, mem], 0)
76 | l.step(i, count = False)
77 | buf = epbuf
78 | return buf
79 |
80 | def predict(buf, F, structure, num):
81 | '''Predict Graph'''
82 | s = th.FloatTensor(buf[:,:-(num+1)]).float().cuda()
83 | a = th.FloatTensor(buf[:,-(1+num):]).float().cuda()
84 | predgt = F(s, a)
85 | return predgt.cpu().detach().numpy().flatten()
86 |
87 |
88 | def f1score(pred, gt):
89 | '''Compute F1 Score'''
90 | p = 1 * (pred > 0.5)
91 |
92 | if np.sum(p) == 0:
93 | prec = 0
94 | else:
95 | prec = np.sum(gt * p) / np.sum(p)
96 | rec =np.sum(gt*p) / np.sum(gt)
97 | if (prec == 0) and (rec==0):
98 | return 0
99 | return 2 * (prec * rec) / (prec+rec)
100 |
101 |
102 |
103 |
104 | if __name__ == '__main__':
105 | parser = argparse.ArgumentParser(description='Causal Meta-RL')
106 | parser.add_argument('--fixed-goal', type=int, default=0, help='fixed goal or no')
107 | parser.add_argument('--horizon', type=int, default=10, help='Env horizon')
108 | parser.add_argument('--num', type=int, default=1, help='num lights')
109 | parser.add_argument('--structure', type=str, default="one_to_one", help='causal structure')
110 | parser.add_argument('--method', type=str, default="traj", help='method')
111 | parser.add_argument('--seen', type=int, default=10, help='num seen')
112 | parser.add_argument('--images', type=int, default=0, help='images or no')
113 | parser.add_argument('--data-dir', type=str, help='data dir')
114 | args = parser.parse_args()
115 |
116 |
117 | gc = 1 - args.fixed_goal
118 |
119 | if args.structure == "masterswitch":
120 | st = (args.horizon*(2*args.num+1) + (args.horizon-1)*(2*args.num+1))
121 | else:
122 | st = (args.horizon*(2*args.num+1))
123 | tj = "gt"
124 | l = LightEnv(args.horizon,
125 | args.num,
126 | tj,
127 | args.structure,
128 | gc,
129 | filename="exp/"+str(gc)+"_"+args.method,
130 | seen = args.seen)
131 | env = DummyVecEnv(1 * [lambda: l])
132 |
133 | if args.images:
134 | addonn = "_I1"
135 | else:
136 | addonn = ""
137 |
138 | if args.method == "trajF":
139 | F = th.load(args.data_dir+"cnn_Redo_L2_S"+str(args.seen)+"_h"+str(args.horizon)+\
140 | "_"+str(args.structure)+addonn).cuda()
141 | elif args.method == "trajFia":
142 | F = th.load(args.data_dir+"iter_attn_Redo_L2_S"+str(args.seen)+"_h"+str(args.horizon)+\
143 | "_"+str(args.structure)+addonn).cuda()
144 | else:
145 | F = th.load(args.data_dir+"iter_Redo_L2_S"+str(args.seen)+"_h"+str(args.horizon)+\
146 | "_"+str(args.structure)+addonn).cuda()
147 | F = F.eval()
148 | trloss = []
149 | tsloss = []
150 | for mep in range(100):
151 | l.keep_struct = False
152 | obs = env.reset()
153 | l.keep_struct = True
154 |
155 | buf = induction(args.structure,args.num, args.horizon, l, images=args.images)
156 | pred = predict(buf, F,args.structure, args.num)
157 | f = f1score(pred, l.gt)
158 |
159 | trloss.append(f)
160 |
161 | #### TEST ON UNSEEN CS
162 | l.keep_struct = False
163 | l.train = False
164 | for i in range(1):
165 | obs= env.reset()
166 | buf = induction(args.structure,args.num, args.horizon, l, images=args.images)
167 |
168 | pred = predict(buf, F,args.structure, args.num)
169 |
170 |
171 | f = f1score(pred, l.gt)
172 | tsloss.append(f)
173 |
174 | l.keep_struct = True
175 | l.train = True
176 |
177 | print(np.mean(trloss), np.mean(tsloss))
178 |
179 |
180 |
--------------------------------------------------------------------------------
/scripts/gendata.sh:
--------------------------------------------------------------------------------
1 | # python3 collectdata.py --horizon 6 --num 6 --fixed-goal 0 --structure one_to_one --method traj --seen 100 --images 1
2 | # python3 collectdata.py --horizon 6 --num 6 --fixed-goal 0 --structure many_to_one --method traj --seen 100 --images 1
3 | # python3 collectdata.py --horizon 6 --num 6 --fixed-goal 0 --structure one_to_many --method traj --seen 100 --images 1
4 | # python3 collectdata.py --horizon 6 --num 6 --fixed-goal 0 --structure masterswitch --method traj --seen 100 --images 1
5 |
6 | # python3 collectdata.py --horizon 6 --num 6 --fixed-goal 0 --structure one_to_one --method traj --seen 50 --images 1
7 | # python3 collectdata.py --horizon 6 --num 6 --fixed-goal 0 --structure many_to_one --method traj --seen 50 --images 1
8 | # python3 collectdata.py --horizon 6 --num 6 --fixed-goal 0 --structure one_to_many --method traj --seen 50 --images 1
9 | # python3 collectdata.py --horizon 6 --num 6 --fixed-goal 0 --structure masterswitch --method traj --seen 50 --images 1
10 |
11 | # python3 collectdata.py --horizon 6 --num 6 --fixed-goal 0 --structure one_to_one --method traj --seen 500 --images 1
12 | # python3 collectdata.py --horizon 6 --num 6 --fixed-goal 0 --structure many_to_one --method traj --seen 500 --images 1
13 | # python3 collectdata.py --horizon 6 --num 6 --fixed-goal 0 --structure one_to_many --method traj --seen 500 --images 1
14 | # python3 collectdata.py --horizon 6 --num 6 --fixed-goal 0 --structure masterswitch --method traj --seen 500 --images 1
15 |
16 | # python3 collectdata.py --horizon 5 --num 5 --fixed-goal 0 --structure one_to_one --method traj --seen 100 --images 1
17 | # python3 collectdata.py --horizon 5 --num 5 --fixed-goal 0 --structure many_to_one --method traj --seen 100 --images 1
18 | # python3 collectdata.py --horizon 5 --num 5 --fixed-goal 0 --structure one_to_many --method traj --seen 100 --images 1
19 | # python3 collectdata.py --horizon 5 --num 5 --fixed-goal 0 --structure masterswitch --method traj --seen 100 --images 1
20 |
21 | # python3 collectdata.py --horizon 5 --num 5 --fixed-goal 0 --structure one_to_one --method traj --seen 50 --images 1
22 | # python3 collectdata.py --horizon 5 --num 5 --fixed-goal 0 --structure many_to_one --method traj --seen 50 --images 1
23 | # python3 collectdata.py --horizon 5 --num 5 --fixed-goal 0 --structure one_to_many --method traj --seen 50 --images 1
24 | # python3 collectdata.py --horizon 5 --num 5 --fixed-goal 0 --structure masterswitch --method traj --seen 50 --images 1
25 |
26 | # python3 collectdata.py --horizon 5 --num 5 --fixed-goal 0 --structure one_to_one --method traj --seen 10 --images 1
27 | # python3 collectdata.py --horizon 5 --num 5 --fixed-goal 0 --structure many_to_one --method traj --seen 10 --images 1
28 | # python3 collectdata.py --horizon 5 --num 5 --fixed-goal 0 --structure one_to_many --method traj --seen 10 --images 1
29 | # python3 collectdata.py --horizon 5 --num 5 --fixed-goal 0 --structure masterswitch --method traj --seen 10 --images 1
30 |
31 | # python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure one_to_one --method traj --seen 10 --images 1
32 | # python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure many_to_one --method traj --seen 10 --images 1
33 | # python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure one_to_many --method traj --seen 10 --images 1
34 | # python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure masterswitch --method traj --seen 10 --images 1
35 |
36 | # python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure one_to_one --method traj --seen 100 --images 1
37 | # python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure many_to_one --method traj --seen 100 --images 1
38 | # python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure one_to_many --method traj --seen 100 --images 1
39 | # python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure masterswitch --method traj --seen 100 --images 1
40 |
41 | # python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure one_to_one --method traj --seen 50 --images 1
42 | # python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure many_to_one --method traj --seen 50 --images 1
43 | # python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure one_to_many --method traj --seen 50 --images 1
44 | # python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure masterswitch --method traj --seen 50 --images 1
45 |
46 | # python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure one_to_one --method traj --seen 500 --images 1
47 | # python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure many_to_one --method traj --seen 500 --images 1
48 | # python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure one_to_many --method traj --seen 500 --images 1
49 | # python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure masterswitch --method traj --seen 500 --images 1
50 |
51 |
52 | python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure one_to_one --method traj --seen 10 --images 1
53 | python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure many_to_one --method traj --seen 10 --images 1
54 | python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure one_to_many --method traj --seen 10 --images 1
55 | python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure masterswitch --method traj --seen 10 --images 1
56 |
57 | python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure one_to_one --method traj --seen 100 --images 1
58 | python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure many_to_one --method traj --seen 100 --images 1
59 | python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure one_to_many --method traj --seen 100 --images 1
60 | python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure masterswitch --method traj --seen 100 --images 1
61 |
62 | python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure one_to_one --method traj --seen 50 --images 1
63 | python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure many_to_one --method traj --seen 50 --images 1
64 | python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure one_to_many --method traj --seen 50 --images 1
65 | python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure masterswitch --method traj --seen 50 --images 1
66 |
67 | python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure one_to_one --method traj --seen 500 --images 1
68 | python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure many_to_one --method traj --seen 500 --images 1
69 | python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure one_to_many --method traj --seen 500 --images 1
70 | python3 collectdata.py --horizon 7 --num 7 --fixed-goal 0 --structure masterswitch --method traj --seen 500 --images 1
71 |
--------------------------------------------------------------------------------
/F_models.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division, absolute_import
2 |
3 | import torch as th
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import torchvision.models as models
7 | import numpy as np
8 | import math
9 | import cv2
10 |
11 |
12 | class ImageEncoder(nn.Module):
13 | """
14 | IMage Encoder
15 | """
16 | def __init__(self, num):
17 | super(ImageEncoder, self).__init__()
18 | self.encoder_conv = nn.Sequential(
19 | # 224x224xN_CHANNELS -> 112x112x64
20 | nn.Conv2d(3, 8, kernel_size=3, stride=1, padding=1),
21 | nn.MaxPool2d(kernel_size=2, stride=2), # 56x56x64
22 | nn.ReLU(inplace=True),
23 | )
24 | self.encoder_conv2 = nn.Sequential(
25 | nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1),
26 | nn.MaxPool2d(kernel_size=2, stride=2), # 27x27x64
27 | nn.ReLU(inplace=True),
28 | )
29 | self.encoder_conv3 = nn.Sequential(
30 | nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
31 | nn.MaxPool2d(kernel_size=2, stride=2), # 6x6x64
32 | nn.ReLU(inplace=True),
33 | )
34 |
35 | self.fc = nn.Linear(4 * 4 * 32, num)
36 |
37 | self.sigmoid = nn.Sigmoid()
38 |
39 | def forward(self, x):
40 | e1 = self.encoder_conv(x)
41 | e2 = self.encoder_conv2(e1)
42 | e3 = self.encoder_conv3(e2)
43 | e3 = e3.view(e3.size(0), -1)
44 | encoding = self.fc(e3)
45 | return encoding
46 |
47 |
48 | class IterativeModelAttention(nn.Module):
49 | def __init__(self, horizon,num=5, ms=False, images=False):
50 | super(IterativeModelAttention, self).__init__()
51 |
52 | self.horizon = horizon
53 | self.ms = ms
54 |
55 | if self.ms:
56 | self.attnsize = num + 1
57 | self.outsize = num
58 | else:
59 | self.attnsize = num
60 | self.outsize = num
61 |
62 |
63 | self.images = images
64 | self.num = num
65 |
66 | self.ie = ImageEncoder(num)
67 |
68 | self.fc1 = nn.Linear(2*num+1, 1024)
69 | self.fc2 = nn.Linear(1024, 512)
70 | self.fc3 = nn.Linear(512, self.attnsize + num)
71 |
72 | self.fc4 = nn.Linear(self.attnsize*self.outsize, self.attnsize + num)
73 |
74 | self.dp = nn.Dropout(0.3)
75 | self.relu = nn.ReLU()
76 | self.softmax = nn.Softmax(dim=1)
77 | self.sigmoid = nn.Sigmoid()
78 |
79 | def forward(self, s, a):
80 | if self.images:
81 | s_im = s.view(-1, 32, 32, 3).permute(0,3,1,2)
82 | senc = self.ie(s_im)
83 | sp = senc.view(-1, self.horizon, self.num)
84 | else:
85 | sp = s.view(-1, self.horizon, self.num)
86 | sp[:,:-1] = sp[:,1:] - sp[:,:-1]
87 | a = a.view(-1, self.horizon, self.num+1)
88 | e = th.cat([sp, a], 2)
89 |
90 |
91 | p = th.zeros((sp.size(0), self.attnsize, self.outsize)).cuda()
92 | for i in range(self.horizon):
93 | inn = e[:,i,:]
94 | e1 = self.relu(self.dp(self.fc1(inn)))
95 | e2 = self.relu(self.dp(self.fc2(e1)))
96 | e3 = self.fc3(e2)
97 |
98 | atn = self.softmax(e3[:, :self.attnsize]).unsqueeze(-1)
99 | e3 = self.sigmoid(e3[:, self.attnsize:].unsqueeze(1).repeat(1, self.attnsize, 1))
100 | r = atn * e3
101 | p = p + r
102 |
103 | e3 = self.fc4(p.view(-1, self.attnsize*self.num))
104 | atn = self.softmax(e3[:, :self.attnsize]).unsqueeze(-1)
105 | e3 = self.sigmoid(e3[:, self.attnsize:].unsqueeze(1).repeat(1, self.attnsize, 1))
106 | r = atn * e3
107 | p = p + r
108 | p = p.view(-1, self.attnsize*self.num)
109 |
110 |
111 | return p
112 |
113 | class IterativeModel(nn.Module):
114 | def __init__(self, horizon,num=5, ms=False, images=False):
115 | super(IterativeModel, self).__init__()
116 |
117 | self.images = images
118 | self.ie = ImageEncoder(num)
119 |
120 | self.horizon = horizon
121 | self.ms = ms
122 | if self.ms:
123 | self.outsize = num**2 + num
124 | else:
125 | self.outsize = num**2
126 |
127 | self.num = num
128 |
129 | self.fc1 = nn.Linear(2*num+1, 1024)
130 | self.fc2 = nn.Linear(1024, 512)
131 | self.fc3 = nn.Linear(512, self.outsize)
132 |
133 | self.fc4 = nn.Linear(self.outsize, self.outsize)
134 |
135 | self.cnn1 = nn.Conv1d(2*num+1, 256, kernel_size=3, padding=1)
136 | self.cnn2 = nn.Conv1d(256, 128, kernel_size=3, padding=1)
137 | self.cnn3 = nn.Conv1d(128, 128, kernel_size=3, padding=1)
138 |
139 | self.dp = nn.Dropout(0.3)
140 | self.relu = nn.ReLU()
141 | self.softmax = nn.Softmax()
142 | self.sigmoid = nn.Sigmoid()
143 |
144 | def forward(self, s, a):
145 | if self.images:
146 | s_im = s.view(-1, 32, 32, 3).permute(0,3,1,2)
147 | senc = self.ie(s_im)
148 | sp = senc.view(-1, self.horizon, self.num)
149 | else:
150 | sp = s.view(-1, self.horizon, self.num)
151 | sp[:,1:] = sp[:,1:] - sp[:,:-1]
152 | a = a.view(-1, self.horizon, self.num+1)
153 | e = th.cat([sp, a], 2)
154 |
155 | e = e.permute(0,2,1)
156 | c2 = e
157 |
158 | p = th.zeros((sp.size(0), self.outsize)).cuda()
159 | for i in range(self.horizon):
160 | e1 = self.relu(self.dp(self.fc1(c2[:,:,i])))
161 | e2 = self.relu(self.dp(self.fc2(e1)))
162 | e3 = self.sigmoid(self.fc3(e2))
163 | p = p + e3
164 | p = self.sigmoid(self.fc4(p))
165 |
166 | return p
167 |
168 | class SupervisedModelCNN(nn.Module):
169 | def __init__(self, horizon,num=5, ms=False, images=False):
170 | super(SupervisedModelCNN, self).__init__()
171 |
172 | self.images = images
173 | self.ie = ImageEncoder(num)
174 |
175 | self.horizon = horizon
176 | self.ms = ms
177 | if self.ms:
178 | self.outsize = num**2 + num
179 | else:
180 | self.outsize = num**2
181 |
182 | self.num = num
183 |
184 | self.cnn1 = nn.Conv1d(2*num+1, 256, kernel_size=3, padding=1)
185 | self.cnn2 = nn.Conv1d(256, 128, kernel_size=3, padding=1)
186 | self.cnn3 = nn.Conv1d(128, 128, kernel_size=3, padding=1)
187 |
188 | self.fc1 = nn.Linear(self.horizon*128, 1024)
189 | self.fc2 = nn.Linear(1024, 512)
190 | self.fc3 = nn.Linear(512, self.outsize)
191 |
192 | self.dp = nn.Dropout(0.3)
193 | self.relu = nn.ReLU()
194 | self.softmax = nn.Softmax()
195 | self.sigmoid = nn.Sigmoid()
196 |
197 | def forward(self, s, a):
198 | if self.images:
199 | s_im = s.view(-1, 32, 32, 3).permute(0,3,1,2)
200 | senc = self.ie(s_im)
201 | sp = senc.view(-1, self.horizon, self.num)
202 | else:
203 | sp = s.view(-1, self.horizon, self.num)
204 | a = a.view(-1, self.horizon, self.num+1)
205 | e = th.cat([sp, a], 2)
206 |
207 | e = e.permute(0,2,1)
208 | c1 = self.relu(self.cnn1(e))
209 | c2 = self.relu(self.cnn2(c1))
210 | c2 = self.relu(self.cnn3(c2))
211 |
212 | c2 = c2.view(-1, self.horizon*128)
213 | e1 = self.relu(self.dp(self.fc1(c2)))
214 | e2 = self.relu(self.dp(self.fc2(e1)))
215 |
216 | rec = self.sigmoid(self.fc3(e2))
217 | return rec
218 |
--------------------------------------------------------------------------------
/env/light_env.py:
--------------------------------------------------------------------------------
1 | from mujoco_py import load_model_from_path, MjSim, MjViewer
2 | from mujoco_py.modder import TextureModder
3 | import mujoco_py
4 | import os
5 | import numpy as np
6 | import gym
7 | from gym import error, spaces
8 | from gym.utils import seeding
9 | import copy
10 | from itertools import permutations
11 | import cv2
12 | DEFAULT_SIZE = 500
13 |
14 | class LightEnv(gym.GoalEnv):
15 | '''Light Switch Environment for Visual Causal Induction'''
16 | def __init__(self, horizon=5, num=5, cond="gt", structure="one_to_one", gc=True, filename=None, seen=10):
17 | '''
18 | Creates Light Switch Environment
19 |
20 | Args:
21 | horizon: Length of episode
22 | num: Number of switches [5,7,9]
23 | cond: Whether use default size graph "GT" or custom size (Size).
24 | structure: Type of causal structure [one_to_one, one_to_many, many_to_one, masterswitch]
25 | gc: True/False goal conditioned or now
26 | filename: Path to log episode results
27 | seen: Number of seen causal structures
28 | '''
29 |
30 | ## Load XML model
31 | fullpath = os.path.join(os.path.dirname(__file__), 'assets', "arena_v2_"+str(num)+".xml")
32 | if not os.path.exists(fullpath):
33 | raise IOError('File {} does not exist'.format(fullpath))
34 | model = mujoco_py.load_model_from_path(fullpath)
35 | self.sim = mujoco_py.MjSim(model)
36 |
37 | self.filename = filename
38 | self.horizon = horizon
39 | self.cond = cond
40 | self.gc = gc
41 | self.structure = structure
42 | self.num = num
43 | self.metadata = {'render.modes': ['human', 'rgb_array']}
44 |
45 | ## Initialize GT/traj - the underlying causal structure
46 | if self.cond is None:
47 | pass
48 | elif self.cond == "gt":
49 | if self.structure == "masterswitch":
50 | a = self.num**2 + self.num
51 | else:
52 | a = self.num**2
53 | self.gt = np.zeros((a))
54 | else:
55 | self.traj = np.zeros(self.cond)
56 | self.aj = np.zeros((self.num, self.num))
57 |
58 | ## If goal conditioned, sample goal state
59 | if self.gc:
60 | self.goal = self._sample_goal()
61 |
62 | ## Set random seed so order of causal structures is preserved
63 | np.random.seed(1)
64 |
65 |
66 | self.state = np.zeros((self.num)) # Initial State
67 | self.eps = 0 # Num Episodes
68 |
69 | # Generate causal structure
70 | if (self.structure == "one_to_many") or (self.structure == "many_to_one"):
71 | if self.num == 9:
72 | self.all_perms = self.generate_cs_set1(self.num, True)
73 | else:
74 | self.all_perms = self.generate_cs_set1(self.num)
75 | else:
76 | self.all_perms = self.generate_cs_set2(self.num)
77 |
78 | ## Params to randomize strcut and train v test
79 | self.keep_struct = True
80 | self.train = True
81 |
82 | ## Shuffled causal structures
83 | np.random.shuffle(self.all_perms)
84 | ## Number of all structs
85 | self.pmsz = self.all_perms.shape[0]
86 | self.seen = seen
87 |
88 | obs = self._get_obs()
89 | self.action_space = spaces.Discrete(self.num+1)
90 | self.observation_space = spaces.Box(0, 1, shape=obs.shape, dtype='float32')
91 |
92 |
93 | # Env methods
94 | # ----------------------------
95 |
96 | def step(self, action, count = True, log=False):
97 | '''Step in env.
98 | Args:
99 | action: which switch to toggle
100 | count: False if not counted toward episode (for collecting heuristic data)
101 | log: Log episode results
102 | '''
103 | ## If "Do Nothing" Action
104 | if action == self.num:
105 | pass
106 | else:
107 | if self.structure == "masterswitch":
108 | ## Only once masterswitch is activated can others be activated
109 | if (action == self.ms) or (self.state[self.ms] == 1):
110 | change = np.zeros(self.num)
111 | change[action] = 1
112 | self.state = np.abs(self.state - change)
113 | else:
114 | change = np.zeros(self.num)
115 | change[action] = 1
116 | self.state = np.abs(self.state - change)
117 |
118 | obs = self._get_obs()
119 |
120 | done = False
121 | info = {'is_success': self._is_success(obs)}
122 | self.correct.append((info["is_success"]))
123 | reward = self.compute_reward(obs, info)
124 |
125 | if count:
126 | self.steps += 1
127 | self.eprew += reward
128 | if reward == 0:
129 | done = True
130 | if (self.steps >= self.horizon):
131 | done = True
132 | if done and log:
133 | with open(self.filename + \
134 | "_S" + str(self.seen) + "_"+str(self.structure)+ \
135 | "_H"+str(self.horizon)+"_N"+str(self.num)+ \
136 | "_T"+str(self.current_cs)+".txt", "a") as f:
137 | f.write(str(self.eprew) + "\n")
138 | with open(self.filename + "_S" + str(self.seen) + \
139 | "_"+str(self.structure)+"_H"+str(self.horizon)+ \
140 | "_N"+str(self.num)+"_T"+str(self.current_cs)+\
141 | "successrate.txt", "a") as f:
142 | f.write(str(int(info["is_success"])) + "\n")
143 |
144 | return obs, reward, done, info
145 |
146 | def reset(self):
147 | keep_struct = self.keep_struct
148 | train = self.train
149 | if train:
150 | self.current_cs = "train"
151 | else:
152 | self.current_cs = "test"
153 |
154 | ## Either reset causal structure or not
155 | if keep_struct:
156 | pass
157 | else:
158 | ## Select from seen causal structre or unseen causal structure
159 | if train:
160 | ind = np.random.randint(0, self.seen)
161 | else:
162 | ind = np.random.randint(self.seen, self.pmsz)
163 | perm = self.all_perms[ind]
164 |
165 | ## Set graph according to causal structure
166 | if self.structure == "one_to_one":
167 | aj = np.zeros((self.num,self.num))
168 | for i in range(self.num):
169 | aj[i, perm[i]] = 1
170 | self.aj = aj
171 | self.gt = self.aj.flatten()
172 | elif self.structure == "one_to_many":
173 | aj = np.zeros((self.num,self.num))
174 | for i in range(self.num):
175 | aj[i, perm[i]] = 1
176 | self.aj = aj.T
177 | self.gt = self.aj.flatten()
178 | elif self.structure == "many_to_one":
179 | aj = np.zeros((self.num,self.num))
180 | for i in range(self.num):
181 | aj[i, perm[i]] = 1
182 | self.aj = aj
183 | self.gt = self.aj.flatten()
184 | elif self.structure == "masterswitch":
185 | aj = np.zeros((self.num,self.num))
186 | for i in range(self.num):
187 | aj[i, perm[i]] = 1
188 | self.aj = aj
189 | self.ms = np.random.randint(self.num)
190 | m = np.zeros((self.num))
191 | m[self.ms] = 1
192 | self.gt = self.aj.flatten()
193 | self.gt = np.concatenate([self.gt, m])
194 | self.eprew = 0
195 | self.steps = 0
196 | self.correct = []
197 | self.state = np.zeros((self.num))
198 | self.eps += 1
199 | self.goal = self._sample_goal()
200 |
201 | obs = self._get_obs()
202 | return obs
203 |
204 | def compute_reward(self, obs, info=None):
205 | ## Distance to goal configuration
206 | rew =-1 * np.sqrt(((obs[:self.num] - self.goal)**2).sum())
207 | return rew
208 |
209 | def _get_obs(self, images=False):
210 | """Returns the observation.
211 | """
212 | ## Compute which lights are on based on underlying state and causal graph
213 | light = np.dot(self.state.T, self.aj)
214 | light = light % 2
215 | o = light
216 |
217 | ## Set corresponding lights and render image
218 | if images:
219 | self.sim.model.light_active[:] = light
220 | im = self.sim.render(width=32,height=32,camera_name="birdview") / 255.0
221 | return im
222 |
223 | ## Concatenate goal
224 | if self.gc:
225 | o = np.concatenate([o, self.goal])
226 |
227 | ## Concatenate graph
228 | if self.cond is None:
229 | pass
230 | elif self.cond == "gt":
231 | o = np.concatenate([o, self.gt])
232 | else:
233 | o = np.concatenate([o, self.traj])
234 |
235 | return o
236 |
237 | def _is_success(self, obs):
238 | """Indicates whether or not the achieved goal successfully achieved the desired goal.
239 | """
240 | return (obs[:self.num] == self.goal).all()
241 |
242 | def _sample_goal(self):
243 | """Samples a new goal and returns it.
244 | """
245 | state = np.random.randint(0, 2, size=(self.num))
246 | light = np.dot(state.T, self.aj)
247 | light = light % 2
248 | self.sim.model.light_active[:] = light
249 | self.goalim = self.sim.render(mode='offscreen', width=32,height=32,camera_name="birdview") / 255.0
250 | return light
251 |
252 | def generate_cs_set1(self, sz, cut=False):
253 | '''Generate Causal Structures for Many to One. For every light,
254 | every possible combination of switches which could control it.
255 |
256 | Args:
257 | sz: num switches
258 | cut: Reduce generated structures for efficiency
259 | '''
260 | if sz == 1:
261 | lp = []
262 | for i in range(self.num):
263 | lp.append([i])
264 | return np.array(lp)
265 | else:
266 | gs = []
267 | tm = self.generate_cs_set1(sz-1, cut)
268 | for t in tm:
269 | if cut and (np.random.uniform() < 0.4):
270 | continue
271 | for i in range(self.num):
272 | gs.append(np.concatenate([np.array([i]), t]))
273 | return np.array(gs)
274 |
275 | def generate_cs_set2(self, sz):
276 | '''Generate Causal Structures for One to One. All Permutations.
277 |
278 | Args:
279 | sz: num switches
280 | '''
281 | t = np.arange(self.num)
282 | gs = []
283 | for perm in permutations(t):
284 | gs.append(np.array(list(perm)))
285 | return np.array(gs)
--------------------------------------------------------------------------------
/env/assets/arena_v2_5.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
--------------------------------------------------------------------------------
/env/assets/arena_v2_5_vis.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
--------------------------------------------------------------------------------
/env/assets/arena_v2_9.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
--------------------------------------------------------------------------------
/env/assets/arena_v2_7.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
--------------------------------------------------------------------------------
/learn_planner.py:
--------------------------------------------------------------------------------
1 | from env.light_env import LightEnv
2 | import numpy as np
3 | import argparse
4 | import torch as th
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import cv2
8 |
9 |
10 |
11 | class BCPolicy(nn.Module):
12 | """
13 | Imitation Policy
14 | """
15 | def __init__(self, num, structure, attention = False):
16 | super(BCPolicy, self).__init__()
17 | self.encoder_conv = nn.Sequential(
18 | # 224x224xN_CHANNELS -> 112x112x64
19 | nn.Conv2d(6, 8, kernel_size=3, stride=1, padding=1),
20 | nn.MaxPool2d(kernel_size=2, stride=2), # 56x56x64
21 | nn.ReLU(inplace=True),
22 | )
23 | self.encoder_conv2 = nn.Sequential(
24 | nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1),
25 | nn.MaxPool2d(kernel_size=2, stride=2), # 27x27x64
26 | nn.ReLU(inplace=True),
27 | )
28 | self.encoder_conv3 = nn.Sequential(
29 | nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
30 | nn.MaxPool2d(kernel_size=2, stride=2), # 6x6x64
31 | nn.ReLU(inplace=True),
32 | )
33 |
34 | self.att = attention
35 | self.num = num
36 | if structure == "masterswitch":
37 | self.ins = self.num + 1
38 | else:
39 | self.ins = self.num
40 |
41 | self.attlayer = nn.Linear(128, num)
42 | self.structure = structure
43 | self.fc1 = nn.Linear(4 * 4 * 32, 128)
44 |
45 | if not self.att:
46 | if structure == "masterswitch":
47 | self.gfc1 = nn.Linear(num*num + num, 128)
48 | else:
49 | self.gfc1 = nn.Linear(num*num, 128)
50 | else:
51 | self.gfc1 = nn.Linear(self.num, 128)
52 |
53 |
54 | if self.structure == "masterswitch":
55 | self.fc2 = nn.Linear(256+args.num, 64)
56 | else:
57 | self.fc2 = nn.Linear(256, 64)
58 | self.fc5 = nn.Linear(64, num)
59 |
60 | self.softmax = nn.Softmax(dim=-1)
61 | self.relu = nn.ReLU()
62 |
63 | def forward(self, x, gr):
64 |
65 | x = x.permute(0, 3, 1, 2).contiguous()
66 |
67 | e1 = self.encoder_conv(x)
68 | e2 = self.encoder_conv2(e1)
69 | e3 = self.encoder_conv3(e2)
70 | e3 = e3.view(e3.size(0), -1)
71 | encoding = self.relu(self.fc1(e3))
72 | if self.att:
73 | w = self.softmax(self.attlayer(encoding))
74 | if self.structure == "masterswitch":
75 | ms = gr.view((-1, self.ins, self.num))[:, -1, :]
76 | gr = gr.view((-1, self.ins, self.num))[:, :-1, :]
77 | else:
78 | gr = gr.view((-1, self.ins, self.num))
79 | gr_sel = th.bmm(gr, w.view(w.size(0), -1, 1))
80 | gr_sel = gr_sel.squeeze(-1)
81 | g1 = self.relu(self.gfc1(gr_sel))
82 | else:
83 | g1 = self.relu(self.gfc1(gr))
84 |
85 | if self.structure == "masterswitch":
86 | eout = th.cat([g1, encoding, ms], 1)
87 | else:
88 | eout = th.cat([g1, encoding], 1)
89 | a = self.relu(self.fc2(eout))
90 | a = self.fc5(a)
91 | return a
92 |
93 |
94 | class BCPolicyMemory(nn.Module):
95 | """
96 | Imitation policy with memory
97 | """
98 | def __init__(self, num, structure):
99 | super(BCPolicyMemory, self).__init__()
100 | self.encoder_conv = nn.Sequential(
101 | # 224x224xN_CHANNELS -> 112x112x64
102 | nn.Conv2d(6, 8, kernel_size=3, stride=1, padding=1),
103 | nn.MaxPool2d(kernel_size=2, stride=2), # 56x56x64
104 | nn.ReLU(inplace=True),
105 | )
106 | self.encoder_conv2 = nn.Sequential(
107 | nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1),
108 | nn.MaxPool2d(kernel_size=2, stride=2), # 27x27x64
109 | nn.ReLU(inplace=True),
110 | )
111 | self.encoder_conv3 = nn.Sequential(
112 | nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
113 | nn.MaxPool2d(kernel_size=2, stride=2), # 6x6x64
114 | nn.ReLU(inplace=True),
115 | )
116 |
117 | self.fc1 = nn.Linear(4 * 4 * 32, 128)
118 |
119 | self.aenc = nn.Linear(num+1, 128)
120 |
121 | if structure == "masterswitch":
122 | self.gfc1 = nn.Linear(num*num + num, 128)
123 | else:
124 | self.gfc1 = nn.Linear(num*num, 128)
125 |
126 | self.lstm = nn.LSTMCell(256, 256)
127 |
128 | self.fc2 = nn.Linear(256, 64)
129 | self.fc5 = nn.Linear(64, num)
130 |
131 | self.softmax = nn.Softmax()
132 | self.relu = nn.ReLU()
133 |
134 | def forward(self, x, a, hidden):
135 | x = x.permute(0, 3, 1, 2).contiguous()
136 | e1 = self.encoder_conv(x)
137 | e2 = self.encoder_conv2(e1)
138 | e3 = self.encoder_conv3(e2)
139 | e3 = e3.view(e3.size(0), -1)
140 | encoding = self.relu(self.fc1(e3))
141 |
142 | ae = self.relu(self.aenc(a))
143 | eout = th.cat([ae, encoding], 1)
144 | if hidden is None:
145 | hidden = self.lstm(eout)
146 | else:
147 | hidden = self.lstm(eout, hidden)
148 |
149 | a = self.relu(self.fc2(hidden[0]))
150 | a = self.fc5(a)
151 | return a, hidden
152 |
153 |
154 | def induction(structure, num, horizon, l, images=False):
155 | '''Roll out heurisitc interaction policy'''
156 | ##### INDUCTION #####
157 | ##### OPTIMAL POLICY 1
158 | if structure == "masterswitch":
159 | it = None
160 | for i in range(num):
161 | p = l._get_obs()
162 | if images:
163 | pi = l._get_obs(images=True)
164 | p = p.reshape((1, -1))
165 | a = np.zeros((1, num+1))
166 | a[:,i] = 1
167 |
168 | if images:
169 | mem = np.concatenate([np.expand_dims(pi.flatten(), 0), a], 1)
170 | else:
171 | mem = np.concatenate([p[:,:args.num], a], 1)
172 |
173 | if i == 0:
174 | epbuf = mem
175 | else:
176 | epbuf = np.concatenate([epbuf, mem], 0)
177 | l.step(i, count = False)
178 | p2 = l._get_obs()
179 | if (p != p2).any():
180 | it = i
181 | break
182 | for i in range(num):
183 | if i != it:
184 | p = l._get_obs()
185 | if images:
186 | pi = l._get_obs(images=True)
187 | p = p.reshape((1, -1))
188 | a = np.zeros((1, num+1))
189 | a[:,i] = 1
190 |
191 | if images:
192 | mem = np.concatenate([np.expand_dims(pi.flatten(), 0), a], 1)
193 | else:
194 | mem = np.concatenate([p[:,:args.num], a], 1)
195 |
196 | epbuf = np.concatenate([epbuf, mem], 0)
197 | l.step(i, count = False)
198 | ln = epbuf.shape[0]
199 | buf = np.zeros((2 * args.horizon - 1, epbuf.shape[1]))
200 | buf[:ln] = epbuf
201 | else:
202 | for i in range(num):
203 | p = l._get_obs()
204 | if images:
205 | pi = l._get_obs(images=True)
206 | p = p.reshape((1, -1))
207 | a = np.zeros((1, num+1))
208 | a[:,i] = 1
209 |
210 | if images:
211 | mem = np.concatenate([np.expand_dims(pi.flatten(), 0), a], 1)
212 | else:
213 | mem = np.concatenate([p[:,:args.num], a], 1)
214 |
215 | if i == 0:
216 | epbuf = mem
217 | else:
218 | epbuf = np.concatenate([epbuf, mem], 0)
219 | l.step(i, count = False)
220 | buf = epbuf
221 | return buf
222 |
223 |
224 | def predict(buf, F, structure, num):
225 | '''Predict graph'''
226 | s = th.FloatTensor(buf[:,:-(num+1)]).float().cuda()
227 | a = th.FloatTensor(buf[:,-(1+num):]).float().cuda()
228 | predgt = th.clamp(F(s, a), 0, 1)
229 | return predgt.cpu().detach().numpy().flatten()
230 |
231 | def train_bc(memory, policy, opt):
232 | '''Train Imitation policy'''
233 | if len(memory['state']) < 50:
234 | return
235 | opt.zero_grad()
236 | choices = np.random.choice(len(memory['state']), 32).astype(np.int32).tolist()
237 | states = [memory['state'][c] for c in choices]
238 | graphs = [memory['graph'][c] for c in choices]
239 | actions = [memory['action'][c] for c in choices]
240 |
241 | states = th.FloatTensor(states).cuda()
242 | graphs = th.FloatTensor(graphs).cuda()
243 | actions = th.LongTensor(actions).cuda()
244 |
245 | pred_acts = policy(states, graphs)
246 | # loss = ((pred_acts - actions)**2).sum(1).mean()
247 | celoss = nn.CrossEntropyLoss()
248 | loss = celoss(pred_acts, actions)
249 | l = loss.cpu().detach().numpy()
250 | loss.backward()
251 | opt.step()
252 | return l
253 |
254 | def train_bclstm(trajs, policy, opt):
255 | '''Train imitation policy with memory'''
256 | if len(trajs) < 10:
257 | return
258 | celoss = nn.CrossEntropyLoss()
259 | opt.zero_grad()
260 | totalloss = 0
261 | choices = np.random.choice(len(trajs), 4).astype(np.int32).tolist()
262 | for t in choices:
263 | memory = trajs[t]
264 | hidden = None
265 | ## Feed interaction trajectory through policy with memory
266 | buf = memory['graph'][0]
267 | for w in range(buf.shape[0]):
268 | states = buf[w, :32*32*3].reshape(1, 32, 32, 3)
269 | sgg = np.zeros_like(states)
270 | states = np.concatenate([states, sgg], -1)
271 | actions = buf[w, 32*32*3:].reshape(1, -1)
272 | num_acts = actions.shape
273 | act, hidden = pol(th.FloatTensor(states).cuda(), th.FloatTensor(actions).cuda(), hidden)
274 | states = np.array(memory['state'])
275 | actions = np.array(memory['action'])
276 | preds = []
277 | for w in range(states.shape[0]):
278 | a = np.zeros(num_acts)
279 |
280 | pred_acts, hidden = pol(th.FloatTensor(states[w:w+1]).cuda(), th.FloatTensor(a).cuda(), hidden)
281 | preds.append(pred_acts)
282 | preds = th.cat(preds, 0)
283 | loss = celoss(preds, th.LongTensor(actions).cuda())
284 | totalloss += loss
285 |
286 | l = totalloss.cpu().detach().numpy()
287 | totalloss.backward()
288 | opt.step()
289 | return l
290 |
291 | def eval_bc(policy, l, train=True, f=None, args=None):
292 | '''Evaluate imation policy'''
293 | successes = []
294 | l.keep_struct = False
295 | l.train = train
296 | # Eval over 100 trials
297 | for mep in range(100):
298 | obs = l.reset()
299 | imobs = np.expand_dims(l._get_obs(images=True), 0)
300 | goalim = np.expand_dims(l.goalim, 0)
301 |
302 | if f is None:
303 | graph = np.expand_dims(l.gt.flatten(), 0)
304 | else:
305 | buf = induction(args.structure,args.num, args.horizon, l, images=args.images)
306 | traj = buf.flatten()
307 |
308 | l.state = np.zeros((args.num))
309 | l.traj = traj
310 | pred = predict(buf, f,args.structure, args.num)
311 | l.gt = pred
312 | graph = np.expand_dims(pred.flatten(), 0)
313 |
314 | for k in range(args.horizon * 2):
315 | st = np.concatenate([imobs, goalim], 3)
316 | act = policy(th.FloatTensor(st).cuda(), th.FloatTensor(graph).cuda())
317 | action = act[0].argmax()
318 |
319 | obs, reward, done, info = l.step(action)
320 | if (mep == 50) and (train):
321 | print(action, obs[:5])
322 | imobs = np.expand_dims(l._get_obs(images=True), 0)
323 | if done:
324 | break
325 |
326 | successes.append(l._is_success(obs))
327 | return np.mean(successes)
328 |
329 | def eval_bclstm(policy, l, train=True, args=None):
330 | '''Evaluate imitation policy with memory'''
331 | successes = []
332 | l.keep_struct = False
333 | l.train = train
334 | for mep in range(100):
335 | hidden = None
336 | obs = l.reset()
337 | imobs = np.expand_dims(l._get_obs(images=True), 0)
338 | goalim = np.expand_dims(l.goalim, 0)
339 |
340 | buf = induction(args.structure,args.num, args.horizon, l, images=args.images)
341 |
342 | l.state = np.zeros((args.num))
343 | for w in range(buf.shape[0]):
344 | states = buf[w, :32*32*3].reshape(1, 32, 32, 3)
345 | sgg = np.zeros_like(states)
346 | states = np.concatenate([states, sgg], -1)
347 | actions = buf[w, 32*32*3:].reshape(1, -1)
348 | num_acts = actions.shape
349 | act, hidden = policy(th.FloatTensor(states).cuda(), th.FloatTensor(actions).cuda(), hidden)
350 |
351 | for k in range(args.horizon * 2):
352 | st = np.concatenate([imobs, goalim], 3)
353 | act, hidden = policy(th.FloatTensor(st).cuda(), th.FloatTensor(np.zeros(num_acts)).cuda(), hidden)
354 | action = act[0].argmax()
355 |
356 | obs, reward, done, info = l.step(action)
357 | if (mep == 50) and (train):
358 | print(action, obs[:5])
359 | imobs = np.expand_dims(l._get_obs(images=True), 0)
360 | if done:
361 | break
362 |
363 | successes.append(l._is_success(obs))
364 | return np.mean(successes)
365 |
366 |
367 | if __name__ == '__main__':
368 | parser = argparse.ArgumentParser(description='Causal Meta-RL')
369 | parser.add_argument('--fixed-goal', type=int, default=0, help='fixed goal or no')
370 | parser.add_argument('--horizon', type=int, default=10, help='Env horizon')
371 | parser.add_argument('--num', type=int, default=1, help='num lights')
372 | parser.add_argument('--structure', type=str, default="one_to_one", help='causal structure')
373 | parser.add_argument('--method', type=str, default="traj", help='Type of model')
374 | parser.add_argument('--seen', type=int, default=10, help='Num see envs')
375 | parser.add_argument('--images', type=int, default=0, help='Images or no')
376 | parser.add_argument('--data-dir', type=str, help='Model path')
377 |
378 | args = parser.parse_args()
379 |
380 | gc = 1 - args.fixed_goal
381 | fname = args.data_dir+"polattn_"+str(gc)+"_"+args.method
382 |
383 | memsize = 10000
384 | memory = {'state':[], 'graph':[], 'action':[]}
385 | if args.method == 'trajlstm':
386 | pol = BCPolicyMemory(args.num, args.structure).cuda()
387 | else:
388 | pol = BCPolicy(args.num, args.structure, True).cuda()
389 | optimizer = th.optim.Adam(pol.parameters(), lr=0.0001)
390 |
391 | ## Using ground truth graph
392 | if args.method == "gt":
393 | l = LightEnv(args.horizon*2,
394 | args.num,
395 | "gt",
396 | args.structure,
397 | gc,
398 | filename=fname,
399 | seen = args.seen)
400 |
401 | successes = []
402 | l.keep_struct = False
403 | l.train = True
404 | ## Per episode
405 | for mep in range(100000):
406 | l.train = True
407 | obs = l.reset()
408 |
409 | curr = np.zeros((args.num))
410 | obs = curr
411 | imobs = l._get_obs(images=True)
412 | goalim = l.goalim
413 |
414 | goal = l.goal
415 | ## Steps in episode
416 | for k in range(args.horizon*2):
417 | ## Use GT graph to plan
418 | g = np.abs(goal - obs[:args.num])
419 | st = np.concatenate([imobs, goalim], 2)
420 | sss = 1.0*(np.dot(g, l.aj.T).T > 0.5)
421 |
422 | if args.structure == "masterswitch":
423 | sss[l.ms] = 0
424 | if sss.max() == 0:
425 | break
426 |
427 | action = np.argmax(sss)
428 | if args.structure == "masterswitch":
429 | if obs[:5].max() == 0:
430 | action = l.ms
431 | memory['state'].append(st)
432 | memory['graph'].append(l.gt.flatten())
433 | memory['action'].append(action)
434 |
435 | ## Random noise to policy
436 | if np.random.uniform() < 0.3:
437 | action = np.random.randint(args.num)
438 | else:
439 | graph = np.expand_dims(l.gt.flatten(), 0)
440 | act = pol(th.FloatTensor(np.expand_dims(st, 0)).cuda(), th.FloatTensor(graph).cuda())
441 | action = act[0].argmax()
442 |
443 | obs, reward, done, info = l.step(action)
444 | imobs = l._get_obs(images=True)
445 | if done:
446 | break
447 |
448 | g = np.abs(goal - obs[:args.num])
449 | st = np.concatenate([imobs, goalim], 2)
450 | sss = 1.0*(np.dot(g, l.aj.T).T > 0.5)
451 |
452 | if args.structure == "masterswitch":
453 | if sss[l.ms]:
454 | st = np.concatenate([imobs, goalim], 2)
455 | memory['state'].append(st)
456 | memory['graph'].append(l.gt.flatten())
457 | memory['action'].append(l.ms)
458 | obs, reward, done, info = l.step(l.ms)
459 | memory['state'] = memory['state'][-memsize:]
460 | memory['graph'] = memory['graph'][-memsize:]
461 | memory['action'] = memory['action'][-memsize:]
462 | for _ in range(1):
463 | loss = train_bc(memory, pol, optimizer)
464 | if mep % 1000 == 0:
465 | print("Episode", mep, "Loss:" , loss )
466 | trainsc = eval_bc(pol, l, True, args=args)
467 | testsc = eval_bc(pol, l, False, args=args)
468 | with open(fname + "_S" + str(args.seen) + \
469 | "_"+str(args.structure)+"_H"+str(args.horizon)+\
470 | "_N"+str(args.num)+"_Ttrainsuccessrate.txt", "a") as f:
471 | f.write(str(float(trainsc)) + "\n")
472 | with open(fname + "_S" + str(args.seen) + \
473 | "_"+str(args.structure)+"_H"+str(args.horizon)+\
474 | "_N"+str(args.num)+"_Ttestsuccessrate.txt", "a") as f:
475 | f.write(str(float(testsc)) + "\n")
476 |
477 | print("Train Success Rate:", trainsc)
478 | print("Test Success Rate:", testsc)
479 |
480 | successes.append(l._is_success(obs))
481 | print(np.mean(successes))
482 | ## If using learning induction model
483 | elif (args.method == "trajF") or (args.method == "trajFi") or (args.method == "trajFia"):
484 | if args.structure == "masterswitch":
485 | st = (args.horizon*(2*args.num+1) + (args.horizon-1)*(2*args.num+1))
486 | else:
487 | st = (args.horizon*(2*args.num+1))
488 | tj = "gt"
489 | l = LightEnv(args.horizon*2,
490 | args.num,
491 | tj,
492 | args.structure,
493 | gc,
494 | filename=fname,
495 | seen = args.seen)
496 |
497 | if args.images:
498 | addonn = "_I1"
499 | else:
500 | addonn = ""
501 |
502 | if args.method == "trajF":
503 | FN = th.load(args.data_dir+"cnn_Redo_L2_S"+str(args.seen)+"_h"+str(args.horizon)+\
504 | "_"+str(args.structure)+addonn).cuda()
505 | elif args.method == "trajFia":
506 | FN = th.load(args.data_dir+"iter_attn_Redo_L2_S"+str(args.seen)+"_h"+str(args.horizon)+\
507 | "_"+str(args.structure)+addonn).cuda()
508 | else:
509 | FN = th.load(args.data_dir+"iter_Redo_L2_S"+str(args.seen)+"_h"+str(args.horizon)+\
510 | "_"+str(args.structure)+addonn).cuda()
511 | FN = FN.eval()
512 | successes = []
513 | l.keep_struct = False
514 | l.train = False
515 | for mep in range(100000):
516 | l.train = True
517 | obs = l.reset()
518 | goalim = l.goalim
519 | imobs = l._get_obs(images=True)
520 |
521 | ## Predict Graph
522 | buf = induction(args.structure,args.num, args.horizon, l, images=args.images)
523 | traj = buf.flatten()
524 | pred = predict(buf, FN,args.structure, args.num)
525 | l.state = np.zeros((args.num))
526 |
527 | curr = np.zeros((args.num))
528 | obs = curr
529 |
530 | goal = l.goal
531 | for k in range(args.horizon*2):
532 | ## Planning
533 | g = np.abs(goal - obs[:args.num])
534 | st = np.concatenate([imobs, goalim], 2)
535 | sss = 1.0*(np.dot(g, l.aj.T).T > 0.5)
536 |
537 | if args.structure == "masterswitch":
538 | sss[l.ms] = 0
539 | if sss.max() == 0:
540 | break
541 |
542 | action = np.argmax(sss)
543 | if args.structure == "masterswitch":
544 | if obs[:5].max() == 0:
545 | action = l.ms
546 | memory['state'].append(st)
547 | memory['graph'].append(pred.flatten())
548 | memory['action'].append(action)
549 |
550 | ## Random Noise
551 | if np.random.uniform() < 0.3:
552 | action = np.random.randint(args.num)
553 | else:
554 | pred = predict(buf, FN,args.structure, args.num)
555 | graph = np.expand_dims(pred.flatten(), 0)
556 | act = pol(th.FloatTensor(np.expand_dims(st, 0)).cuda(), th.FloatTensor(graph).cuda())
557 | action = act[0].argmax()
558 |
559 |
560 | obs, reward, done, info = l.step(action)
561 | imobs = l._get_obs(images=True)
562 | if done:
563 | break
564 |
565 | g = np.abs(goal - obs[:args.num])
566 | st = np.concatenate([imobs, goalim], 2)
567 | sss = 1.0*(np.dot(g, l.aj.T).T > 0.5)
568 | if args.structure == "masterswitch":
569 | if sss[l.ms]:
570 | st = np.concatenate([imobs, goalim], 2)
571 | memory['state'].append(st)
572 | memory['graph'].append(pred.flatten())
573 | memory['action'].append(l.ms)
574 | obs, reward, done, info = l.step(l.ms)
575 |
576 | memory['state'] = memory['state'][-memsize:]
577 | memory['graph'] = memory['graph'][-memsize:]
578 | memory['action'] = memory['action'][-memsize:]
579 | for _ in range(1):
580 | loss = train_bc(memory, pol, optimizer)
581 | if mep % 1000 == 0:
582 | print("Episode", mep, "Loss:" , loss )
583 | trainsc = eval_bc(pol, l, True, f=FN, args=args)
584 | testsc = eval_bc(pol, l, False, f=FN, args=args)
585 | with open(fname + "_S" + str(args.seen) + \
586 | "_"+str(args.structure)+"_H"+str(args.horizon)+\
587 | "_N"+str(args.num)+"_Ttrainsuccessrate.txt", "a") as f:
588 | f.write(str(float(trainsc)) + "\n")
589 | with open(fname + "_S" + str(args.seen) + \
590 | "_"+str(args.structure)+"_H"+str(args.horizon)+\
591 | "_N"+str(args.num)+"_Ttestsuccessrate.txt", "a") as f:
592 | f.write(str(float(testsc)) + "\n")
593 |
594 | print("Train Success Rate:", trainsc)
595 | print("Test Success Rate:", testsc)
596 |
597 | successes.append(l._is_success(obs))
598 | print(np.mean(successes))
599 | elif (args.method == "trajlstm"):
600 | if args.structure == "masterswitch":
601 | st = (args.horizon*(2*args.num+1) + (args.horizon-1)*(2*args.num+1))
602 | else:
603 | st = (args.horizon*(2*args.num+1))
604 | tj = "gt"
605 | l = LightEnv(args.horizon*2,
606 | args.num,
607 | tj,
608 | args.structure,
609 | gc,
610 | filename=fname,
611 | seen = args.seen)
612 |
613 | if args.images:
614 | addonn = "_I1"
615 | else:
616 | addonn = ""
617 |
618 | successes = []
619 | l.keep_struct = False
620 | l.train = False
621 | memsize = 100
622 | trajs = []
623 |
624 | for mep in range(100000):
625 | memory = {'state':[], 'graph':[], 'action':[]}
626 | hidden = None
627 | l.train = True
628 | obs = l.reset()
629 | goalim = l.goalim
630 | imobs = l._get_obs(images=True)
631 |
632 | ## Get interction trajectory
633 | buf = induction(args.structure,args.num, args.horizon, l, images=args.images)
634 | memory['graph'].append(buf)
635 | for w in range(buf.shape[0]):
636 | states = buf[w, :32*32*3].reshape(1, 32, 32, 3)
637 | sgg = np.zeros_like(states)
638 | states = np.concatenate([states, sgg], -1)
639 | actions = buf[w, 32*32*3:].reshape(1, -1)
640 | act, hidden = pol(th.FloatTensor(states).cuda(), th.FloatTensor(actions).cuda(), hidden)
641 | l.state = np.zeros((args.num))
642 |
643 | curr = np.zeros((args.num))
644 | obs = curr
645 |
646 | goal = l.goal
647 | for k in range(args.horizon*2):
648 | ## Planning
649 | g = np.abs(goal - obs[:args.num])
650 | st = np.concatenate([imobs, goalim], 2)
651 | sss = 1.0*(np.dot(g, l.aj.T).T > 0.5)
652 |
653 | if args.structure == "masterswitch":
654 | sss[l.ms] = 0
655 | if sss.max() == 0:
656 | break
657 |
658 | action = np.argmax(sss)
659 | if args.structure == "masterswitch":
660 | if obs[:5].max() == 0:
661 | action = l.ms
662 | memory['state'].append(st)
663 | memory['action'].append(action)
664 |
665 | ## Policy Noise
666 | if np.random.uniform() < 0.3:
667 | action = np.random.randint(args.num)
668 | else:
669 | act, s_hidden = pol(th.FloatTensor(states).cuda(), th.FloatTensor(actions).cuda(), hidden)
670 | action = act[0].argmax()
671 |
672 |
673 | obs, reward, done, info = l.step(action)
674 | imobs = l._get_obs(images=True)
675 | if done:
676 | break
677 |
678 | g = np.abs(goal - obs[:args.num])
679 | st = np.concatenate([imobs, goalim], 2)
680 | sss = 1.0*(np.dot(g, l.aj.T).T > 0.5)
681 | if args.structure == "masterswitch":
682 | if sss[l.ms]:
683 | st = np.concatenate([imobs, goalim], 2)
684 | memory['state'].append(st)
685 | memory['action'].append(l.ms)
686 | obs, reward, done, info = l.step(l.ms)
687 |
688 |
689 |
690 | if len(memory['state']) != 0:
691 | trajs.append(memory)
692 | trajs = trajs[-memsize:]
693 | for _ in range(1):
694 | loss = train_bclstm(trajs, pol, optimizer)
695 | if mep % 1000 == 0:
696 | print("Episode", mep, "Loss:" , loss )
697 | trainsc = eval_bclstm(pol, l, True, args=args)
698 | testsc = eval_bclstm(pol, l, False, args=args)
699 | with open(fname + "_S" + str(args.seen) + \
700 | "_"+str(args.structure)+"_H"+str(args.horizon)+\
701 | "_N"+str(args.num)+"_Ttrainsuccessrate.txt", "a") as f:
702 | f.write(str(float(trainsc)) + "\n")
703 | with open(fname + "_S" + str(args.seen) + \
704 | "_"+str(args.structure)+"_H"+str(args.horizon)+\
705 | "_N"+str(args.num)+"_Ttestsuccessrate.txt", "a") as f:
706 | f.write(str(float(testsc)) + "\n")
707 |
708 | print("Train Success Rate:", trainsc)
709 | print("Test Success Rate:", testsc)
710 |
711 | successes.append(l._is_success(obs))
712 | print(np.mean(successes))
--------------------------------------------------------------------------------