├── plan
├── log
│ ├── readme.txt
│ └── whl.txt
├── checkpoint
│ └── whl
│ │ └── readme.txt
├── utils
│ ├── __init__.py
│ ├── one_hot.py
│ ├── accuracy.py
│ ├── eval.py
│ ├── training.py
│ └── args.py
├── save_max
│ └── readme.txt
├── action_dictionary.py
├── model
│ ├── temporal.py
│ ├── helpers.py
│ └── diffusion.py
├── dataloader
│ └── data_load.py
└── main_distributed.py
├── step
├── log
│ └── readme.txt
├── checkpoint
│ └── whl
│ │ └── readme.txt
├── utils
│ ├── __init__.py
│ ├── accuracy.py
│ ├── eval.py
│ ├── training.py
│ └── args.py
├── loading_data.py
├── action_dictionary.py
├── model
│ ├── temporal.py
│ ├── helpers.py
│ └── diffusion.py
├── dataloader
│ ├── data_load.py
│ └── train_test_data_load.py
├── main_distributed.py
└── inference.py
├── requirements.txt
├── dataset
├── NIV
│ ├── download.sh
│ ├── test30.json
│ └── train70.json
├── coin
│ └── download.sh
└── crosstask
│ ├── actions_one_hot.npy
│ └── download.sh
├── PKG
├── graphs
│ ├── NIV
│ │ └── NIV_trained_graph.pkl
│ ├── COIN
│ │ └── coin_trained_graph.pkl
│ ├── CrossTaskHow
│ │ ├── weighted_graph
│ │ │ └── trained_graph.pkl
│ │ ├── min_max_N_graph
│ │ │ └── trained_graph.pkl
│ │ └── out_edge_N_graph
│ │ │ └── trained_graph.pkl
│ └── CrossTaskBase
│ │ └── trained_graph_CrossTask_base.pkl
├── graph_visualize.py
└── action_dictionary.py
└── README.md
/plan/log/readme.txt:
--------------------------------------------------------------------------------
1 | save KEPP logs here
--------------------------------------------------------------------------------
/step/log/readme.txt:
--------------------------------------------------------------------------------
1 | save KEPP logs here
--------------------------------------------------------------------------------
/plan/checkpoint/whl/readme.txt:
--------------------------------------------------------------------------------
1 | save checkpoints here
--------------------------------------------------------------------------------
/step/checkpoint/whl/readme.txt:
--------------------------------------------------------------------------------
1 | save MLP checkpoints here
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | einops
2 | tensorboard
3 | tensorboardX
4 | torch
5 |
--------------------------------------------------------------------------------
/plan/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .training import *
2 | from .eval import *
3 | from .args import *
4 |
--------------------------------------------------------------------------------
/step/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .training import *
2 | from .eval import *
3 | from .args import *
4 |
--------------------------------------------------------------------------------
/dataset/NIV/download.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 |
3 | wget https://huggingface.co/nyanko7/LLaMA-7B/resolve/main/*
--------------------------------------------------------------------------------
/dataset/coin/download.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 |
3 | wget https://vision.eecs.yorku.ca/WebShare/COIN_s3d.zip
4 |
5 | unzip COIN_s3d.zip
--------------------------------------------------------------------------------
/PKG/graphs/NIV/NIV_trained_graph.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ravindu-Yasas-Nagasinghe/KEPP/HEAD/PKG/graphs/NIV/NIV_trained_graph.pkl
--------------------------------------------------------------------------------
/PKG/graphs/COIN/coin_trained_graph.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ravindu-Yasas-Nagasinghe/KEPP/HEAD/PKG/graphs/COIN/coin_trained_graph.pkl
--------------------------------------------------------------------------------
/dataset/crosstask/actions_one_hot.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ravindu-Yasas-Nagasinghe/KEPP/HEAD/dataset/crosstask/actions_one_hot.npy
--------------------------------------------------------------------------------
/plan/save_max/readme.txt:
--------------------------------------------------------------------------------
1 | The checkpoint with maximum vaidationaccuract will be saved here. Use the path of the checkpoint as the input to the inference file.
--------------------------------------------------------------------------------
/PKG/graphs/CrossTaskHow/weighted_graph/trained_graph.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ravindu-Yasas-Nagasinghe/KEPP/HEAD/PKG/graphs/CrossTaskHow/weighted_graph/trained_graph.pkl
--------------------------------------------------------------------------------
/PKG/graphs/CrossTaskBase/trained_graph_CrossTask_base.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ravindu-Yasas-Nagasinghe/KEPP/HEAD/PKG/graphs/CrossTaskBase/trained_graph_CrossTask_base.pkl
--------------------------------------------------------------------------------
/PKG/graphs/CrossTaskHow/min_max_N_graph/trained_graph.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ravindu-Yasas-Nagasinghe/KEPP/HEAD/PKG/graphs/CrossTaskHow/min_max_N_graph/trained_graph.pkl
--------------------------------------------------------------------------------
/PKG/graphs/CrossTaskHow/out_edge_N_graph/trained_graph.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ravindu-Yasas-Nagasinghe/KEPP/HEAD/PKG/graphs/CrossTaskHow/out_edge_N_graph/trained_graph.pkl
--------------------------------------------------------------------------------
/dataset/crosstask/download.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 |
3 | wget https://www.di.ens.fr/~dzhukov/crosstask/crosstask_release.zip
4 |
5 | wget https://www.di.ens.fr/~dzhukov/crosstask/crosstask_features.zip
6 |
7 | wget https://vision.eecs.yorku.ca/WebShare/CrossTask_s3d.zip
8 |
9 | unzip '*.zip'
--------------------------------------------------------------------------------
/step/loading_data.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import random
4 | import time
5 | from collections import OrderedDict
6 |
7 | import torch.nn.parallel
8 | import torch.backends.cudnn as cudnn
9 | import torch.distributed as dist
10 | import torch.optim
11 | import torch.multiprocessing as mp
12 | import torch.utils.data
13 | import torch.utils.data.distributed
14 | from torch.distributed import ReduceOp
15 | import torch.nn.functional as F
16 | from dataloader.train_test_data_load import PlanningDataset
17 | from model.helpers import get_lr_schedule_with_warmup, Logger
18 | import torch.nn as nn
19 | from utils import *
20 | from logging import log
21 | from utils.args import get_args
22 | import numpy as np
23 |
24 | def main():
25 | args = get_args()
26 | os.environ['PYTHONHASHSEED'] = str(args.seed)
27 | if os.path.exists(args.json_path_val):
28 | pass
29 | else:
30 | train_dataset = PlanningDataset(
31 | args.root,
32 | args=args,
33 | is_val=False,
34 | model=None,
35 | )
36 | print('Train loaded')
37 | test_dataset = PlanningDataset(
38 | args.root,
39 | args=args,
40 | is_val=True,
41 | model=None,
42 | )
43 | print('test loaded')
44 |
45 | if __name__ == "__main__":
46 | main()
47 |
--------------------------------------------------------------------------------
/plan/utils/one_hot.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from utils.args import get_args
3 | args = get_args()
4 | class LLMLabelOnehot(torch.nn.Module):
5 | def __init__(self, batch_size, T, num_lists, probabilities):
6 | super(LLMLabelOnehot, self).__init__()
7 |
8 | self.batch_size = batch_size
9 | self.T = T
10 | self.num_lists = num_lists
11 | self.probabilities = probabilities
12 |
13 | self.onehot = torch.zeros((batch_size * T, args.class_dim_llama)).cuda()
14 |
15 | def forward(self, LLM_label):
16 | for batch_idx in range(self.batch_size):
17 | for t in range(self.T):
18 | list_prob = torch.zeros(args.class_dim_llama).cuda()
19 | for list_idx in range(self.num_lists):
20 | if self.probabilities==[1]:
21 | list_prob[LLM_label[batch_idx][t]] = self.probabilities[list_idx]
22 | else:
23 | list_prob[LLM_label[batch_idx][list_idx][t]] = self.probabilities[list_idx]
24 |
25 | self.onehot[batch_idx * self.T + t, :] = list_prob
26 |
27 | self.onehot = self.onehot.reshape(self.batch_size, self.T, -1).cuda()
28 | return self.onehot
29 |
30 |
31 | class PKGLabelOnehot(torch.nn.Module):
32 | def __init__(self, batch_size, T, num_lists, probabilities):
33 | super(PKGLabelOnehot, self).__init__()
34 |
35 | self.batch_size = batch_size
36 | self.T = T
37 | self.num_lists = num_lists
38 | self.probabilities = probabilities
39 |
40 | self.onehot = torch.zeros((batch_size * T, args.class_dim_graph)).cuda()
41 |
42 | def forward(self, PKG_label):
43 | for batch_idx in range(self.batch_size):
44 | for t in range(self.T):
45 | list_prob = torch.zeros(args.class_dim_graph).cuda()
46 | for list_idx in range(self.num_lists):
47 | if self.probabilities==[1]:
48 | list_prob[PKG_label[batch_idx][t]] = self.probabilities[list_idx]
49 | else:
50 | list_prob[PKG_label[batch_idx][list_idx][t]] = self.probabilities[list_idx]
51 |
52 | self.onehot[batch_idx * self.T + t, :] = list_prob
53 |
54 | self.onehot = self.onehot.reshape(self.batch_size, self.T, -1).cuda()
55 | return self.onehot
--------------------------------------------------------------------------------
/plan/utils/accuracy.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def accuracy(output, target, topk=(1,), max_traj_len=0):
5 | with torch.no_grad():
6 | maxk = max(topk)
7 | batch_size = target.size(0)
8 |
9 | _, pred = output.topk(maxk, 1, True, True)
10 | pred = pred.t() # [k, bs*T]
11 | correct = pred.eq(target.view(1, -1).expand_as(pred)) # [k, bs*T]
12 |
13 | correct_a = correct[:1].view(-1, max_traj_len) # [bs, T]
14 | correct_a0 = correct_a[:, 0].reshape(-1).float().mean().mul_(100.0)
15 | correct_aT = correct_a[:, -1].reshape(-1).float().mean().mul_(100.0)
16 |
17 | res = []
18 | for k in topk:
19 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
20 | res.append(correct_k.mul_(100.0 / batch_size))
21 |
22 | correct_1 = correct[:1] # (1, bs*T)
23 |
24 | # Success Rate
25 | trajectory_success = torch.all(correct_1.view(correct_1.shape[1] // max_traj_len, -1), dim=1)
26 | trajectory_success_rate = trajectory_success.sum() * 100.0 / trajectory_success.shape[0]
27 |
28 | # MIoU
29 | _, pred_token = output.topk(1, 1, True, True) # [bs*T, 1]
30 | pred_inst = pred_token.view(correct_1.shape[1], -1) # [bs*T, 1]
31 | pred_inst_set = set()
32 | target_inst = target.view(correct_1.shape[1], -1) # [bs*T, 1]
33 | target_inst_set = set()
34 | for i in range(pred_inst.shape[0]):
35 | # print(pred_inst[i], target_inst[i])
36 | pred_inst_set.add(tuple(pred_inst[i].tolist()))
37 | target_inst_set.add(tuple(target_inst[i].tolist()))
38 | MIoU1 = 100.0 * len(pred_inst_set.intersection(target_inst_set)) / len(pred_inst_set.union(target_inst_set))
39 |
40 | batch_size = batch_size // max_traj_len
41 | pred_inst = pred_token.view(batch_size, -1) # [bs, T]
42 | pred_inst_set = set()
43 | target_inst = target.view(batch_size, -1) # [bs, T]
44 | target_inst_set = set()
45 | MIoU_sum = 0
46 | for i in range(pred_inst.shape[0]):
47 | # print(pred_inst[i], target_inst[i])
48 | pred_inst_set.update(pred_inst[i].tolist())
49 | target_inst_set.update(target_inst[i].tolist())
50 | MIoU_current = 100.0 * len(pred_inst_set.intersection(target_inst_set)) / len(
51 | pred_inst_set.union(target_inst_set))
52 | MIoU_sum += MIoU_current
53 | pred_inst_set.clear()
54 | target_inst_set.clear()
55 |
56 | MIoU2 = MIoU_sum / batch_size
57 | return res, trajectory_success_rate, MIoU1, MIoU2, correct_a0, correct_aT
58 |
59 |
--------------------------------------------------------------------------------
/step/utils/accuracy.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def accuracy(output, target, topk=(1,), max_traj_len=0):
5 | with torch.no_grad():
6 | maxk = max(topk)
7 | batch_size = target.size(0)
8 |
9 | _, pred = output.topk(maxk, 1, True, True)
10 | pred = pred.t() # [k, bs*T]
11 | correct = pred.eq(target.view(1, -1).expand_as(pred)) # [k, bs*T]
12 |
13 | correct_a = correct[:1].view(-1, max_traj_len) # [bs, T]
14 | correct_a0 = correct_a[:, 0].reshape(-1).float().mean().mul_(100.0)
15 | correct_aT = correct_a[:, -1].reshape(-1).float().mean().mul_(100.0)
16 |
17 | res = []
18 | for k in topk:
19 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
20 | res.append(correct_k.mul_(100.0 / batch_size))
21 |
22 | correct_1 = correct[:1] # (1, bs*T)
23 |
24 | # Success Rate
25 | trajectory_success = torch.all(correct_1.view(correct_1.shape[1] // max_traj_len, -1), dim=1)
26 | trajectory_success_rate = trajectory_success.sum() * 100.0 / trajectory_success.shape[0]
27 |
28 | # MIoU
29 | _, pred_token = output.topk(1, 1, True, True) # [bs*T, 1]
30 | pred_inst = pred_token.view(correct_1.shape[1], -1) # [bs*T, 1]
31 | pred_inst_set = set()
32 | target_inst = target.view(correct_1.shape[1], -1) # [bs*T, 1]
33 | target_inst_set = set()
34 | for i in range(pred_inst.shape[0]):
35 | # print(pred_inst[i], target_inst[i])
36 | pred_inst_set.add(tuple(pred_inst[i].tolist()))
37 | target_inst_set.add(tuple(target_inst[i].tolist()))
38 | MIoU1 = 100.0 * len(pred_inst_set.intersection(target_inst_set)) / len(pred_inst_set.union(target_inst_set))
39 |
40 | batch_size = batch_size // max_traj_len
41 | pred_inst = pred_token.view(batch_size, -1) # [bs, T]
42 | pred_inst_set = set()
43 | target_inst = target.view(batch_size, -1) # [bs, T]
44 | target_inst_set = set()
45 | MIoU_sum = 0
46 | for i in range(pred_inst.shape[0]):
47 | # print(pred_inst[i], target_inst[i])
48 | pred_inst_set.update(pred_inst[i].tolist())
49 | target_inst_set.update(target_inst[i].tolist())
50 | MIoU_current = 100.0 * len(pred_inst_set.intersection(target_inst_set)) / len(
51 | pred_inst_set.union(target_inst_set))
52 | MIoU_sum += MIoU_current
53 | pred_inst_set.clear()
54 | target_inst_set.clear()
55 |
56 | MIoU2 = MIoU_sum / batch_size
57 | return res, trajectory_success_rate, MIoU1, MIoU2, correct_a0, correct_aT
58 |
59 |
--------------------------------------------------------------------------------
/PKG/graph_visualize.py:
--------------------------------------------------------------------------------
1 | import networkx as nx
2 | import cv2
3 | import numpy as np
4 |
5 | import os
6 | import pickle
7 | import networkx as nx
8 | import logging
9 | import json
10 | import torch
11 | from tabulate import tabulate
12 | from action_dictionary import action_dictionary
13 |
14 | # Function to find all paths from an input node within a fixed length
15 | def find_paths_within_length(graph, start_node, max_length):
16 | def dfs_paths(node, path, length):
17 | if length <= max_length:
18 | path.append(node)
19 | if node == start_node:
20 | yield list(path)
21 | else:
22 | for neighbor in graph.neighbors(node):
23 | if neighbor not in path:
24 | yield from dfs_paths(neighbor, path, length + 1)
25 | path.pop()
26 |
27 | paths = list(dfs_paths(start_node, [], 0))
28 | return paths
29 |
30 |
31 | graph_save_path = '/home/ravindu.nagasinghe/GithubCodes/RaviPP/trained_graph.pkl'
32 |
33 | with open(graph_save_path, 'rb') as graph_file:
34 | graph = pickle.load(graph_file)
35 |
36 | # Input parameters
37 | input_node = 1 # Replace with the desired input node
38 | max_length = 4 # Replace with the desired fixed length
39 |
40 | # Find all paths within the fixed length from the input node
41 | all_paths = find_paths_within_length(graph, input_node, max_length)
42 |
43 | # Create an image for visualization
44 | node_positions = nx.circular_layout(graph)
45 | image_width = len(all_paths[0]) * 100 # Width of the image based on the number of steps
46 | image_height = len(graph.nodes()) * 100 # Height of the image based on the number of nodes
47 |
48 | image = np.zeros((image_height, image_width, 3), dtype=np.uint8) # Initialize an empty image
49 | image.fill(255) # Set the background to white
50 |
51 | # Draw nodes
52 | for node, pos in node_positions.items():
53 | x, y = int(pos[0] * image_width), int(pos[1] * image_height)
54 | cv2.circle(image, (x, y), 10, (0, 0, 0), -1) # Draw black nodes
55 |
56 | # Draw paths with thickness based on weight
57 | for path in all_paths:
58 | for i in range(len(path) - 1):
59 | u, v = path[i], path[i + 1]
60 | weight = graph[u][v]['weight']
61 | thickness = int((weight / 5) * 3) # Adjust the scaling factor as needed
62 | u_pos, v_pos = node_positions[u], node_positions[v]
63 | u_x, u_y = int(u_pos[0] * image_width), int(u_pos[1] * image_height)
64 | v_x, v_y = int(v_pos[0] * image_width), int(v_pos[1] * image_height)
65 | cv2.line(image, (u_x, u_y), (v_x, v_y), (0, 0, 255), thickness)
66 |
67 | # Display the image (you can save it using cv2.imwrite)
68 | file_path = "graph_visualization.png"
69 | cv2.imwrite(file_path, image)
70 |
71 |
--------------------------------------------------------------------------------
/PKG/action_dictionary.py:
--------------------------------------------------------------------------------
1 | action_dictionary = {1: 'pour water', 2: 'pour juice', 3: 'pour jello powder', 4: 'pour alcohol',
2 | 5: 'stir mixture', 6: 'pour mixture into cup', 7: 'cut shelve', 8: 'assemble shelve',
3 | 9: 'sand shelve', 10: 'paint shelve', 11: 'attach shelve', 12: 'add onion', 13: 'add taco',
4 | 14: 'add lettuce', 15: 'add meat', 16: 'add tomato', 17: 'add cheese', 18: 'stir',
5 | 19: 'add tortilla', 20: 'season steak', 21: 'put steak on grill', 22: 'close lid',
6 | 23: 'open lid', 24: 'move steak on grill', 25: 'flip steak', 26: 'check temperature',
7 | 27: 'take steak from grill', 28: 'top steak', 29: 'cut steak', 30: 'taste steak', 31: 'add rice',
8 | 32: 'add ham', 33: 'add kimchi', 34: 'pour sesame oil', 35: 'pour egg', 36: 'add sugar',
9 | 37: 'whisk mixture', 38: 'put mixture into bag', 39: 'spread mixture', 40: 'put meringue into oven',
10 | 41: 'add coffee', 42: 'press coffee', 43: 'pour espresso', 44: 'steam milk', 45: 'pour milk',
11 | 46: 'cut cucumber', 47: 'cut onion', 48: 'add salt', 49: 'pour vinegar', 50: 'add spices',
12 | 51: 'put vegetables in water', 52: 'pack cucumbers in jar', 53: 'seal jar',
13 | 54: 'put jar in water', 55: 'cut lemon', 56: 'squeeze lemon', 57: 'pour lemon juice',
14 | 58: 'add ice', 59: 'pour lemonade into glass', 60: 'dip bread in mixture', 61: 'melt butter',
15 | 62: 'put bread in pan', 63: 'add vanilla extract', 64: 'flip bread', 65: 'remove bread from pan',
16 | 66: 'top toast', 67: 'brake on', 68: 'raise jack', 69: 'lower jack', 70: 'add chili powder',
17 | 71: 'add mustard seeds', 72: 'add curry leaves', 73: 'add fish', 74: 'peel banana',
18 | 75: 'cut banana', 76: 'put bananas into blender', 77: 'mix ingredients', 78: 'remove cap',
19 | 79: 'put funnel', 80: 'pour oil', 81: 'remove funnel', 82: 'close cap', 83: 'pull out dipstick',
20 | 84: 'wipe off dipstick', 85: 'insert dipstick', 86: 'get things out', 87: 'start loose',
21 | 88: 'jack up', 89: 'unscrew wheel', 90: 'withdraw wheel', 91: 'put wheel', 92: 'screw wheel',
22 | 93: 'jack down', 94: 'tight wheel', 95: 'put things back', 96: 'add whipped cream',
23 | 97: 'add flour', 98: 'add butter', 99: 'put dough into form', 100: 'spread creme upon cake',
24 | 101: 'cut strawberries', 102: 'add strawberries to cake', 103: 'pour mixture into pan',
25 | 104: 'flip pancake', 105: 'take pancake from pan'}
--------------------------------------------------------------------------------
/plan/action_dictionary.py:
--------------------------------------------------------------------------------
1 | action_dictionary = {1: 'pour water', 2: 'pour juice', 3: 'pour jello powder', 4: 'pour alcohol',
2 | 5: 'stir mixture', 6: 'pour mixture into cup', 7: 'cut shelve', 8: 'assemble shelve',
3 | 9: 'sand shelve', 10: 'paint shelve', 11: 'attach shelve', 12: 'add onion', 13: 'add taco',
4 | 14: 'add lettuce', 15: 'add meat', 16: 'add tomato', 17: 'add cheese', 18: 'stir',
5 | 19: 'add tortilla', 20: 'season steak', 21: 'put steak on grill', 22: 'close lid',
6 | 23: 'open lid', 24: 'move steak on grill', 25: 'flip steak', 26: 'check temperature',
7 | 27: 'take steak from grill', 28: 'top steak', 29: 'cut steak', 30: 'taste steak', 31: 'add rice',
8 | 32: 'add ham', 33: 'add kimchi', 34: 'pour sesame oil', 35: 'pour egg', 36: 'add sugar',
9 | 37: 'whisk mixture', 38: 'put mixture into bag', 39: 'spread mixture', 40: 'put meringue into oven',
10 | 41: 'add coffee', 42: 'press coffee', 43: 'pour espresso', 44: 'steam milk', 45: 'pour milk',
11 | 46: 'cut cucumber', 47: 'cut onion', 48: 'add salt', 49: 'pour vinegar', 50: 'add spices',
12 | 51: 'put vegetables in water', 52: 'pack cucumbers in jar', 53: 'seal jar',
13 | 54: 'put jar in water', 55: 'cut lemon', 56: 'squeeze lemon', 57: 'pour lemon juice',
14 | 58: 'add ice', 59: 'pour lemonade into glass', 60: 'dip bread in mixture', 61: 'melt butter',
15 | 62: 'put bread in pan', 63: 'add vanilla extract', 64: 'flip bread', 65: 'remove bread from pan',
16 | 66: 'top toast', 67: 'brake on', 68: 'raise jack', 69: 'lower jack', 70: 'add chili powder',
17 | 71: 'add mustard seeds', 72: 'add curry leaves', 73: 'add fish', 74: 'peel banana',
18 | 75: 'cut banana', 76: 'put bananas into blender', 77: 'mix ingredients', 78: 'remove cap',
19 | 79: 'put funnel', 80: 'pour oil', 81: 'remove funnel', 82: 'close cap', 83: 'pull out dipstick',
20 | 84: 'wipe off dipstick', 85: 'insert dipstick', 86: 'get things out', 87: 'start loose',
21 | 88: 'jack up', 89: 'unscrew wheel', 90: 'withdraw wheel', 91: 'put wheel', 92: 'screw wheel',
22 | 93: 'jack down', 94: 'tight wheel', 95: 'put things back', 96: 'add whipped cream',
23 | 97: 'add flour', 98: 'add butter', 99: 'put dough into form', 100: 'spread creme upon cake',
24 | 101: 'cut strawberries', 102: 'add strawberries to cake', 103: 'pour mixture into pan',
25 | 104: 'flip pancake', 105: 'take pancake from pan'}
--------------------------------------------------------------------------------
/step/action_dictionary.py:
--------------------------------------------------------------------------------
1 | action_dictionary = {1: 'pour water', 2: 'pour juice', 3: 'pour jello powder', 4: 'pour alcohol',
2 | 5: 'stir mixture', 6: 'pour mixture into cup', 7: 'cut shelve', 8: 'assemble shelve',
3 | 9: 'sand shelve', 10: 'paint shelve', 11: 'attach shelve', 12: 'add onion', 13: 'add taco',
4 | 14: 'add lettuce', 15: 'add meat', 16: 'add tomato', 17: 'add cheese', 18: 'stir',
5 | 19: 'add tortilla', 20: 'season steak', 21: 'put steak on grill', 22: 'close lid',
6 | 23: 'open lid', 24: 'move steak on grill', 25: 'flip steak', 26: 'check temperature',
7 | 27: 'take steak from grill', 28: 'top steak', 29: 'cut steak', 30: 'taste steak', 31: 'add rice',
8 | 32: 'add ham', 33: 'add kimchi', 34: 'pour sesame oil', 35: 'pour egg', 36: 'add sugar',
9 | 37: 'whisk mixture', 38: 'put mixture into bag', 39: 'spread mixture', 40: 'put meringue into oven',
10 | 41: 'add coffee', 42: 'press coffee', 43: 'pour espresso', 44: 'steam milk', 45: 'pour milk',
11 | 46: 'cut cucumber', 47: 'cut onion', 48: 'add salt', 49: 'pour vinegar', 50: 'add spices',
12 | 51: 'put vegetables in water', 52: 'pack cucumbers in jar', 53: 'seal jar',
13 | 54: 'put jar in water', 55: 'cut lemon', 56: 'squeeze lemon', 57: 'pour lemon juice',
14 | 58: 'add ice', 59: 'pour lemonade into glass', 60: 'dip bread in mixture', 61: 'melt butter',
15 | 62: 'put bread in pan', 63: 'add vanilla extract', 64: 'flip bread', 65: 'remove bread from pan',
16 | 66: 'top toast', 67: 'brake on', 68: 'raise jack', 69: 'lower jack', 70: 'add chili powder',
17 | 71: 'add mustard seeds', 72: 'add curry leaves', 73: 'add fish', 74: 'peel banana',
18 | 75: 'cut banana', 76: 'put bananas into blender', 77: 'mix ingredients', 78: 'remove cap',
19 | 79: 'put funnel', 80: 'pour oil', 81: 'remove funnel', 82: 'close cap', 83: 'pull out dipstick',
20 | 84: 'wipe off dipstick', 85: 'insert dipstick', 86: 'get things out', 87: 'start loose',
21 | 88: 'jack up', 89: 'unscrew wheel', 90: 'withdraw wheel', 91: 'put wheel', 92: 'screw wheel',
22 | 93: 'jack down', 94: 'tight wheel', 95: 'put things back', 96: 'add whipped cream',
23 | 97: 'add flour', 98: 'add butter', 99: 'put dough into form', 100: 'spread creme upon cake',
24 | 101: 'cut strawberries', 102: 'add strawberries to cake', 103: 'pour mixture into pan',
25 | 104: 'flip pancake', 105: 'take pancake from pan'}
--------------------------------------------------------------------------------
/step/utils/eval.py:
--------------------------------------------------------------------------------
1 | from .accuracy import *
2 | from model.helpers import AverageMeter
3 |
4 |
5 | def validate(val_loader, model, args):
6 | model.eval()
7 | losses = AverageMeter()
8 | acc_top1 = AverageMeter()
9 | acc_top5 = AverageMeter()
10 | trajectory_success_rate_meter = AverageMeter()
11 | MIoU1_meter = AverageMeter()
12 | MIoU2_meter = AverageMeter()
13 |
14 | A0_acc = AverageMeter()
15 | AT_acc = AverageMeter()
16 |
17 | for i_batch, sample_batch in enumerate(val_loader):
18 | # compute output
19 | global_img_tensors = sample_batch[0].cuda().contiguous().float()
20 | video_label = sample_batch[1].cuda()
21 | batch_size_current, T = video_label.size()
22 | #task_class = sample_batch[2].view(-1).cuda()
23 | cond = {}
24 |
25 | with torch.no_grad():
26 | cond[0] = global_img_tensors[:, 0, :]
27 | cond[T - 1] = global_img_tensors[:, -1, :]
28 | '''
29 | task_onehot = torch.zeros((task_class.size(0), args.class_dim)) # [bs*T, ac_dim]
30 | ind = torch.arange(0, len(task_class))
31 | task_onehot[ind, task_class] = 1.
32 | task_onehot = task_onehot.cuda()
33 | temp = task_onehot.unsqueeze(1)
34 | task_class_ = temp.repeat(1, T, 1) # [bs, T, args.class_dim]
35 | cond['task'] = task_class_
36 | '''
37 | video_label_reshaped = video_label.view(-1)
38 |
39 | action_label_onehot = torch.zeros((video_label_reshaped.size(0), args.action_dim))
40 | ind = torch.arange(0, len(video_label_reshaped))
41 | action_label_onehot[ind, video_label_reshaped] = 1.
42 | action_label_onehot = action_label_onehot.reshape(batch_size_current, T, -1).cuda()
43 |
44 | x_start = torch.zeros((batch_size_current, T, args.action_dim + args.observation_dim))
45 | x_start[:, 0, args.action_dim:] = global_img_tensors[:, 0, :]
46 | x_start[:, -1, args.action_dim:] = global_img_tensors[:, -1, :]
47 | action_label_onehot[:,1:-1,:] = 0.
48 | x_start[:, :, :args.action_dim] = action_label_onehot
49 | #x_start[:, :, :args.class_dim] = task_class_
50 | output = model(cond)
51 | actions_pred = output.contiguous()
52 | loss = model.module.loss_fn(actions_pred, x_start.cuda())
53 |
54 | actions_pred = actions_pred[:, :, :args.action_dim].contiguous()
55 | actions_pred = actions_pred.view(-1, args.action_dim) # [bs*T, action_dim]
56 |
57 | (acc1, acc5), trajectory_success_rate, MIoU1, MIoU2, a0_acc, aT_acc = \
58 | accuracy(actions_pred.cpu(), video_label_reshaped.cpu(), topk=(1, 5), max_traj_len=args.horizon)
59 |
60 | losses.update(loss.item(), batch_size_current)
61 | acc_top1.update(acc1.item(), batch_size_current)
62 | acc_top5.update(acc5.item(), batch_size_current)
63 | trajectory_success_rate_meter.update(trajectory_success_rate.item(), batch_size_current)
64 | MIoU1_meter.update(MIoU1, batch_size_current)
65 | MIoU2_meter.update(MIoU2, batch_size_current)
66 | A0_acc.update(a0_acc, batch_size_current)
67 | AT_acc.update(aT_acc, batch_size_current)
68 |
69 | return torch.tensor(losses.avg), torch.tensor(acc_top1.avg), torch.tensor(acc_top5.avg), \
70 | torch.tensor(trajectory_success_rate_meter.avg), \
71 | torch.tensor(MIoU1_meter.avg), torch.tensor(MIoU2_meter.avg), \
72 | torch.tensor(A0_acc.avg), torch.tensor(AT_acc.avg)
73 |
--------------------------------------------------------------------------------
/plan/utils/eval.py:
--------------------------------------------------------------------------------
1 | from .accuracy import *
2 | from model.helpers import AverageMeter
3 | from .one_hot import PKGLabelOnehot
4 | from .one_hot import LLMLabelOnehot
5 |
6 |
7 | def validate(val_loader, model, args):
8 | model.eval()
9 | losses = AverageMeter()
10 | acc_top1 = AverageMeter()
11 | acc_top5 = AverageMeter()
12 | trajectory_success_rate_meter = AverageMeter()
13 | MIoU1_meter = AverageMeter()
14 | MIoU2_meter = AverageMeter()
15 |
16 | A0_acc = AverageMeter()
17 | AT_acc = AverageMeter()
18 | for i_batch, sample_batch in enumerate(val_loader):
19 | # compute output
20 | global_img_tensors = sample_batch[0].cuda().contiguous().float()
21 | video_label = sample_batch[1].cuda()
22 | batch_size_current, T = video_label.size()
23 | LLM_label = sample_batch[2].cuda()
24 |
25 | PKG_label = sample_batch[3].cuda()
26 | cond = {}
27 | llm_label_onehot = LLMLabelOnehot(batch_size_current, T, args.num_seq_LLM,[2/3, 1/3])
28 | pkg_label_onehot = PKGLabelOnehot(batch_size_current, T, args.num_seq_PKG,[2/3, 1/3])
29 |
30 |
31 | with torch.no_grad():
32 | cond[0] = global_img_tensors[:, 0, :]
33 | cond[T - 1] = global_img_tensors[:, -1, :]
34 |
35 | LLM_label_onehot = llm_label_onehot(LLM_label)
36 |
37 | PKG_label_onehot = pkg_label_onehot(PKG_label)
38 |
39 | cond['LLM_action_path'] = LLM_label_onehot
40 | cond['PKG_action_path'] = PKG_label_onehot
41 |
42 | video_label_reshaped = video_label.view(-1)
43 |
44 | action_label_onehot = torch.zeros((video_label_reshaped.size(0), args.action_dim))
45 | ind = torch.arange(0, len(video_label_reshaped))
46 | action_label_onehot[ind, video_label_reshaped] = 1.
47 | action_label_onehot = action_label_onehot.reshape(batch_size_current, T, -1).cuda()
48 |
49 | x_start = torch.zeros((batch_size_current, T, args.class_dim_graph + args.class_dim_llama + args.action_dim + args.observation_dim))
50 | x_start[:, 0, args.class_dim_graph + args.class_dim_llama + args.action_dim:] = global_img_tensors[:, 0, :]
51 | x_start[:, -1, args.class_dim_graph + args.class_dim_llama + args.action_dim:] = global_img_tensors[:, -1, :]
52 |
53 | x_start[:, :,args.class_dim_graph + args.class_dim_llama:args.class_dim_graph + args.class_dim_llama + args.action_dim] = action_label_onehot
54 | x_start[:, :, :args.class_dim_graph] = PKG_label_onehot
55 | x_start[:,:,args.class_dim_graph:args.class_dim_graph + args.class_dim_llama] = LLM_label_onehot
56 | output = model(cond)
57 | actions_pred = output.contiguous()
58 | loss = model.module.loss_fn(actions_pred, x_start.cuda())
59 |
60 | actions_pred = actions_pred[:, :, args.class_dim_graph + args.class_dim_llama:args.class_dim_graph + args.class_dim_llama + args.action_dim].contiguous()
61 | actions_pred = actions_pred.view(-1, args.action_dim) # [bs*T, action_dim]
62 |
63 | (acc1, acc5), trajectory_success_rate, MIoU1, MIoU2, a0_acc, aT_acc = \
64 | accuracy(actions_pred.cpu(), video_label_reshaped.cpu(), topk=(1, 5), max_traj_len=args.horizon)
65 |
66 | losses.update(loss.item(), batch_size_current)
67 | acc_top1.update(acc1.item(), batch_size_current)
68 | acc_top5.update(acc5.item(), batch_size_current)
69 | trajectory_success_rate_meter.update(trajectory_success_rate.item(), batch_size_current)
70 | MIoU1_meter.update(MIoU1, batch_size_current)
71 | MIoU2_meter.update(MIoU2, batch_size_current)
72 | A0_acc.update(a0_acc, batch_size_current)
73 | AT_acc.update(aT_acc, batch_size_current)
74 |
75 | return torch.tensor(losses.avg), torch.tensor(acc_top1.avg), torch.tensor(acc_top5.avg), \
76 | torch.tensor(trajectory_success_rate_meter.avg), \
77 | torch.tensor(MIoU1_meter.avg), torch.tensor(MIoU2_meter.avg), \
78 | torch.tensor(A0_acc.avg), torch.tensor(AT_acc.avg)
79 |
--------------------------------------------------------------------------------
/step/model/temporal.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import einops
4 | from einops.layers.torch import Rearrange
5 |
6 | from .helpers import (
7 | SinusoidalPosEmb,
8 | Downsample1d,
9 | Upsample1d,
10 | Conv1dBlock,
11 | )
12 |
13 |
14 | class ResidualTemporalBlock(nn.Module):
15 |
16 | def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=3):
17 | super().__init__()
18 |
19 | self.blocks = nn.ModuleList([
20 | Conv1dBlock(inp_channels, out_channels, kernel_size),
21 | Conv1dBlock(out_channels, out_channels, kernel_size, if_zero=True)
22 | ])
23 | self.time_mlp = nn.Sequential( # should be removed for Noise and Deterministic Baselines
24 | nn.Mish(),
25 | nn.Linear(embed_dim, out_channels),
26 | Rearrange('batch t -> batch t 1'),
27 | )
28 | self.residual_conv = nn.Conv1d(inp_channels, out_channels, 1) \
29 | if inp_channels != out_channels else nn.Identity()
30 |
31 | def forward(self, x, t):
32 | out = self.blocks[0](x) + self.time_mlp(t) # for diffusion
33 | # out = self.blocks[0](x) # for Noise and Deterministic Baselines
34 | out = self.blocks[1](out)
35 | return out + self.residual_conv(x)
36 |
37 |
38 |
39 | class TemporalUnet(nn.Module):
40 | def __init__(
41 | self,
42 | transition_dim,
43 | dim=32,
44 | dim_mults=(1, 2, 4, 8),
45 | ):
46 | super().__init__()
47 |
48 | dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
49 | in_out = list(zip(dims[:-1], dims[1:]))
50 |
51 | time_dim = dim
52 | self.time_mlp = nn.Sequential( # should be removed for Noise and Deterministic Baselines
53 | SinusoidalPosEmb(dim),
54 | nn.Linear(dim, dim * 4),
55 | nn.Mish(),
56 | nn.Linear(dim * 4, dim),
57 | )
58 |
59 | self.downs = nn.ModuleList([])
60 | self.ups = nn.ModuleList([])
61 | num_resolutions = len(in_out)
62 |
63 | # print(in_out)
64 | for ind, (dim_in, dim_out) in enumerate(in_out):
65 | is_last = ind >= (num_resolutions - 1)
66 |
67 | self.downs.append(nn.ModuleList([
68 | ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim),
69 | ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim),
70 | Downsample1d(dim_out) if not is_last else nn.Identity()
71 | ]))
72 |
73 | mid_dim = dims[-1]
74 | self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim)
75 | self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim)
76 |
77 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
78 | is_last = ind >= (num_resolutions - 1)
79 |
80 | self.ups.append(nn.ModuleList([
81 | ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim),
82 | ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim),
83 | Upsample1d(dim_in) if not is_last else nn.Identity()
84 | ]))
85 |
86 | self.final_conv = nn.Sequential(
87 | Conv1dBlock(dim, dim, kernel_size=3, if_zero=True),
88 | nn.Conv1d(dim, transition_dim, 1),
89 | )
90 |
91 | def forward(self, x, time):
92 | x = einops.rearrange(x, 'b h t -> b t h')
93 |
94 | # t = None # for Noise and Deterministic Baselines
95 | t = self.time_mlp(time) # for diffusion
96 | h = []
97 |
98 | for resnet, resnet2, downsample in self.downs:
99 | x = resnet(x, t)
100 | x = resnet2(x, t)
101 | h.append(x)
102 | x = downsample(x)
103 |
104 |
105 | x = self.mid_block1(x, t)
106 | x = self.mid_block2(x, t)
107 |
108 | for resnet, resnet2, upsample in self.ups:
109 | x = torch.cat((x, h.pop()), dim=1)
110 | x = resnet(x, t)
111 | x = resnet2(x, t)
112 | x = upsample(x)
113 |
114 | x = self.final_conv(x)
115 | x = einops.rearrange(x, 'b t h -> b h t')
116 | return x
117 |
--------------------------------------------------------------------------------
/plan/model/temporal.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import einops
4 | from einops.layers.torch import Rearrange
5 |
6 | from .helpers import (
7 | SinusoidalPosEmb,
8 | Downsample1d,
9 | Upsample1d,
10 | Conv1dBlock,
11 | )
12 |
13 |
14 | class ResidualTemporalBlock(nn.Module):
15 |
16 | def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=3):
17 | super().__init__()
18 |
19 | self.blocks = nn.ModuleList([
20 | Conv1dBlock(inp_channels, out_channels, kernel_size),
21 | Conv1dBlock(out_channels, out_channels, kernel_size, if_zero=True)
22 | ])
23 | self.time_mlp = nn.Sequential( # should be removed for Noise and Deterministic Baselines
24 | nn.Mish(),
25 | nn.Linear(embed_dim, out_channels),
26 | Rearrange('batch t -> batch t 1'),
27 | )
28 | self.residual_conv = nn.Conv1d(inp_channels, out_channels, 1) \
29 | if inp_channels != out_channels else nn.Identity()
30 |
31 | def forward(self, x, t):
32 | out = self.blocks[0](x) + self.time_mlp(t) # for diffusion
33 | # out = self.blocks[0](x) # for Noise and Deterministic Baselines
34 | out = self.blocks[1](out)
35 | return out + self.residual_conv(x)
36 |
37 |
38 | class TemporalUnet(nn.Module):
39 | def __init__(
40 | self,
41 | transition_dim,
42 | dim=32,
43 | dim_mults=(1, 2, 4, 8),
44 | ):
45 | super().__init__()
46 |
47 | dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
48 | in_out = list(zip(dims[:-1], dims[1:]))
49 |
50 | time_dim = dim
51 | self.time_mlp = nn.Sequential( # should be removed for Noise and Deterministic Baselines
52 | SinusoidalPosEmb(dim),
53 | nn.Linear(dim, dim * 4),
54 | nn.Mish(),
55 | nn.Linear(dim * 4, dim),
56 | )
57 |
58 | self.downs = nn.ModuleList([])
59 | self.ups = nn.ModuleList([])
60 | num_resolutions = len(in_out)
61 |
62 | # print(in_out)
63 | for ind, (dim_in, dim_out) in enumerate(in_out):
64 | is_last = ind >= (num_resolutions - 1)
65 |
66 | self.downs.append(nn.ModuleList([
67 | ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim),
68 | ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim),
69 | Downsample1d(dim_out) if not is_last else nn.Identity()
70 | ]))
71 |
72 | mid_dim = dims[-1]
73 | self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim)
74 | self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim)
75 |
76 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
77 | is_last = ind >= (num_resolutions - 1)
78 |
79 | self.ups.append(nn.ModuleList([
80 | ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim),
81 | ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim),
82 | Upsample1d(dim_in) if not is_last else nn.Identity()
83 | ]))
84 |
85 | self.final_conv = nn.Sequential(
86 | Conv1dBlock(dim, dim, kernel_size=3, if_zero=True),
87 | nn.Conv1d(dim, transition_dim, 1),
88 | )
89 |
90 | def forward(self, x, time):
91 | x = einops.rearrange(x, 'b h t -> b t h')
92 |
93 | # t = None # for Noise and Deterministic Baselines
94 | t = self.time_mlp(time) # for diffusion
95 | h = []
96 |
97 | for resnet, resnet2, downsample in self.downs:
98 | x = resnet(x, t)
99 |
100 | x = resnet2(x, t)
101 |
102 | h.append(x)
103 |
104 | x = downsample(x)
105 |
106 |
107 | x = self.mid_block1(x, t)
108 | x = self.mid_block2(x, t)
109 |
110 | for resnet, resnet2, upsample in self.ups:
111 | x = torch.cat((x, h.pop()), dim=1)
112 | x = resnet(x, t)
113 | x = resnet2(x, t)
114 | x = upsample(x)
115 |
116 | x = self.final_conv(x)
117 | x = einops.rearrange(x, 'b t h -> b h t')
118 | return x
119 |
--------------------------------------------------------------------------------
/dataset/NIV/test30.json:
--------------------------------------------------------------------------------
1 | [{"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0001.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0004.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0005.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0007.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0012.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0020.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0021.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0022.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0025.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0030.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0001.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0007.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0014.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0018.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0019.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0023.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0025.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0027.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0002.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0006.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0012.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0013.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0016.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0017.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0019.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0021.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0023.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0026.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0001.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0009.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0011.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0017.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0018.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0020.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0025.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0028.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0029.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/repot_0004.npy", "task_id": 4}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/repot_0015.npy", "task_id": 4}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/repot_0018.npy", "task_id": 4}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/repot_0025.npy", "task_id": 4}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/repot_0029.npy", "task_id": 4}]
--------------------------------------------------------------------------------
/step/utils/training.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from model.helpers import AverageMeter
3 | from .accuracy import *
4 |
5 |
6 | def cycle(dl):
7 | while True:
8 | for data in dl:
9 | yield data
10 |
11 |
12 | class EMA():
13 | """
14 | empirical moving average
15 | """
16 |
17 | def __init__(self, beta):
18 | super().__init__()
19 | self.beta = beta
20 |
21 | def update_model_average(self, ma_model, current_model):
22 | for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
23 | old_weight, up_weight = ma_params.data, current_params.data
24 | ma_params.data = self.update_average(old_weight, up_weight)
25 |
26 | def update_average(self, old, new):
27 | if old is None:
28 | return new
29 | return old * self.beta + (1 - self.beta) * new
30 |
31 |
32 | class Trainer(object):
33 | def __init__(
34 | self,
35 | diffusion_model,
36 | datasetloader,
37 | ema_decay=0.995,
38 | train_lr=1e-5,
39 | gradient_accumulate_every=1,
40 | step_start_ema=400,
41 | update_ema_every=10,
42 | log_freq=100,
43 | ):
44 | super().__init__()
45 | self.model = diffusion_model
46 | self.ema = EMA(ema_decay)
47 | self.ema_model = copy.deepcopy(self.model)
48 | self.update_ema_every = update_ema_every
49 |
50 | self.step_start_ema = step_start_ema
51 | self.log_freq = log_freq
52 | self.gradient_accumulate_every = gradient_accumulate_every
53 |
54 | self.dataloader = cycle(datasetloader)
55 | self.optimizer = torch.optim.AdamW(diffusion_model.parameters(), lr=train_lr, weight_decay=0.0)
56 | # self.optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, diffusion_model.parameters()), lr=train_lr, weight_decay=0.0)
57 |
58 | self.reset_parameters()
59 | self.step = 0
60 |
61 | def reset_parameters(self):
62 | self.ema_model.load_state_dict(self.model.state_dict())
63 |
64 | def step_ema(self):
65 | if self.step < self.step_start_ema:
66 | self.reset_parameters()
67 | return
68 | self.ema.update_model_average(self.ema_model, self.model)
69 |
70 | # -----------------------------------------------------------------------------#
71 | # ------------------------------------ api ------------------------------------#
72 | # -----------------------------------------------------------------------------#
73 |
74 | def train(self, n_train_steps, if_calculate_acc, args, scheduler):
75 | self.model.train()
76 | self.ema_model.train()
77 | losses = AverageMeter()
78 | self.optimizer.zero_grad()
79 |
80 | for step in range(n_train_steps):
81 | for i in range(self.gradient_accumulate_every):
82 | batch = next(self.dataloader)
83 | bs, T = batch[1].shape # [bs, (T+1), ob_dim]
84 | #print('shape: ', batch[1].shape , batch[0].shape)
85 | global_img_tensors = batch[0].cuda().contiguous().float()
86 | img_tensors = torch.zeros((bs, T, args.action_dim + args.observation_dim))
87 | img_tensors[:, 0, args.action_dim:] = global_img_tensors[:, 0, :]
88 | img_tensors[:, -1, args.action_dim:] = global_img_tensors[:, -1, :]
89 | img_tensors = img_tensors.cuda()
90 |
91 | video_label = batch[1].view(-1).cuda() # [bs*T]
92 | #task_class = batch[2].view(-1).cuda() # [bs]
93 |
94 | action_label_onehot = torch.zeros((video_label.size(0), self.model.module.action_dim))
95 | # [bs*T, ac_dim]
96 | ind = torch.arange(0, len(video_label))
97 | action_label_onehot[ind, video_label] = 1.
98 | action_label_onehot = action_label_onehot.reshape(bs, T, -1).cuda()
99 | action_label_onehot[:,1:-1,:] = 0.
100 | img_tensors[:, :, :args.action_dim] = action_label_onehot
101 |
102 | '''
103 | task_onehot = torch.zeros((task_class.size(0), args.class_dim))
104 | # [bs*T, ac_dim]
105 | ind = torch.arange(0, len(task_class))
106 | task_onehot[ind, task_class] = 1.
107 | task_onehot = task_onehot.cuda()
108 | temp = task_onehot.unsqueeze(1)
109 | task_class_ = temp.repeat(1, T, 1) # [bs, T, args.class_dim]
110 | img_tensors[:, :, :args.class_dim] = task_class_
111 | '''
112 | cond = {0: global_img_tensors[:, 0, :].float(), T - 1: global_img_tensors[:, -1, :].float()}
113 | x = img_tensors.float()
114 | loss = self.model.module.loss(x, cond)
115 | loss = loss / self.gradient_accumulate_every
116 | loss.backward()
117 | losses.update(loss.item(), bs)
118 |
119 | self.optimizer.step()
120 | self.optimizer.zero_grad()
121 | scheduler.step()
122 |
123 | if self.step % self.update_ema_every == 0:
124 | self.step_ema()
125 | self.step += 1
126 |
127 | if if_calculate_acc:
128 | with torch.no_grad():
129 | output = self.ema_model(cond)
130 | actions_pred = output[:, :, :self.model.module.action_dim]\
131 | .contiguous().view(-1, self.model.module.action_dim) # [bs*T, action_dim]
132 |
133 | (acc1, acc5), trajectory_success_rate, MIoU1, MIoU2, a0_acc, aT_acc = \
134 | accuracy(actions_pred.cpu(), video_label.cpu(), topk=(1, 5),
135 | max_traj_len=self.model.module.horizon)
136 |
137 | return torch.tensor(losses.avg), acc1, acc5, torch.tensor(trajectory_success_rate), \
138 | torch.tensor(MIoU1), torch.tensor(MIoU2), a0_acc, aT_acc
139 |
140 | else:
141 | return torch.tensor(losses.avg)
142 |
--------------------------------------------------------------------------------
/plan/utils/training.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from model.helpers import AverageMeter
3 | from .accuracy import *
4 | from .one_hot import PKGLabelOnehot
5 | from .one_hot import LLMLabelOnehot
6 |
7 |
8 | def cycle(dl):
9 | while True:
10 | for data in dl:
11 | yield data
12 |
13 |
14 | class EMA():
15 | """
16 | empirical moving average
17 | """
18 |
19 | def __init__(self, beta):
20 | super().__init__()
21 | self.beta = beta
22 |
23 | def update_model_average(self, ma_model, current_model):
24 | for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
25 | old_weight, up_weight = ma_params.data, current_params.data
26 | ma_params.data = self.update_average(old_weight, up_weight)
27 |
28 | def update_average(self, old, new):
29 | if old is None:
30 | return new
31 | return old * self.beta + (1 - self.beta) * new
32 |
33 |
34 | class Trainer(object):
35 | def __init__(
36 | self,
37 | diffusion_model,
38 | datasetloader,
39 | ema_decay=0.995,
40 | train_lr=1e-5,
41 | gradient_accumulate_every=1,
42 | step_start_ema=400,
43 | update_ema_every=10,
44 | log_freq=100,
45 | ):
46 | super().__init__()
47 | self.model = diffusion_model
48 | self.ema = EMA(ema_decay)
49 | self.ema_model = copy.deepcopy(self.model)
50 | self.update_ema_every = update_ema_every
51 |
52 | self.step_start_ema = step_start_ema
53 | self.log_freq = log_freq
54 | self.gradient_accumulate_every = gradient_accumulate_every
55 |
56 | self.dataloader = cycle(datasetloader)
57 | self.optimizer = torch.optim.AdamW(diffusion_model.parameters(), lr=train_lr, weight_decay=0.0)
58 | # self.optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, diffusion_model.parameters()), lr=train_lr, weight_decay=0.0)
59 |
60 | self.reset_parameters()
61 | self.step = 0
62 |
63 | def reset_parameters(self):
64 | self.ema_model.load_state_dict(self.model.state_dict())
65 |
66 | def step_ema(self):
67 | if self.step < self.step_start_ema:
68 | self.reset_parameters()
69 | return
70 | self.ema.update_model_average(self.ema_model, self.model)
71 |
72 | # -----------------------------------------------------------------------------#
73 | # ------------------------------------ api ------------------------------------#
74 | # -----------------------------------------------------------------------------#
75 |
76 | def train(self, n_train_steps, if_calculate_acc, args, scheduler):
77 | self.model.train()
78 | self.ema_model.train()
79 | losses = AverageMeter()
80 | self.optimizer.zero_grad()
81 |
82 | for step in range(n_train_steps):
83 | for i in range(self.gradient_accumulate_every):
84 | batch = next(self.dataloader)
85 | bs, T = batch[1].shape # [bs, (T+1), ob_dim]
86 | global_img_tensors = batch[0].cuda().contiguous().float()
87 | img_tensors = torch.zeros((bs, T, args.class_dim_graph + args.class_dim_llama + args.action_dim + args.observation_dim))
88 | img_tensors[:, 0, args.class_dim_graph + args.class_dim_llama+args.action_dim:] = global_img_tensors[:, 0, :]
89 | img_tensors[:, -1, args.class_dim_graph + args.class_dim_llama+args.action_dim:] = global_img_tensors[:, -1, :]
90 | img_tensors = img_tensors.cuda()
91 | video_label = batch[1].view(-1).cuda() # [bs*T]
92 | LLM_label = batch[2].cuda() # [bs]
93 | PKG_label = batch[3].cuda()
94 | llm_label_onehot = LLMLabelOnehot(bs, T, args.num_seq_LLM,[2/3, 1/3])
95 | pkg_label_onehot = PKGLabelOnehot(bs, T, args.num_seq_PKG,[2/3, 1/3])
96 |
97 |
98 | action_label_onehot = torch.zeros((video_label.size(0), self.model.module.action_dim))
99 | # [bs*T, ac_dim]
100 | ind = torch.arange(0, len(video_label))
101 | action_label_onehot[ind, video_label] = 1.
102 | action_label_onehot = action_label_onehot.reshape(bs, T, -1).cuda()
103 | img_tensors[:, :, args.class_dim_graph + args.class_dim_llama:args.class_dim_graph + args.class_dim_llama+args.action_dim] = action_label_onehot
104 |
105 |
106 | PKG_label_onehot = pkg_label_onehot(PKG_label)
107 | LLM_label_onehot = llm_label_onehot(LLM_label)
108 |
109 |
110 | img_tensors[:, :, args.class_dim_graph : args.class_dim_graph + args.class_dim_llama] = LLM_label_onehot
111 | img_tensors[:, :, : args.class_dim_graph] = PKG_label_onehot
112 |
113 | cond = {0: global_img_tensors[:, 0, :].float(), T - 1: global_img_tensors[:, -1, :].float(),'LLM_action_path': LLM_label_onehot , 'PKG_action_path': PKG_label_onehot}
114 | x = img_tensors.float()
115 |
116 | loss = self.model.module.loss(x, cond)
117 | loss = loss / self.gradient_accumulate_every
118 | loss.backward()
119 | losses.update(loss.item(), bs)
120 |
121 | self.optimizer.step()
122 | self.optimizer.zero_grad()
123 | scheduler.step()
124 |
125 | if self.step % self.update_ema_every == 0:
126 | self.step_ema()
127 | self.step += 1
128 |
129 | if if_calculate_acc:
130 | with torch.no_grad():
131 | output = self.ema_model(cond)
132 | actions_pred = output[:, :, args.class_dim_graph + args.class_dim_llama:args.class_dim_graph + args.class_dim_llama+self.model.module.action_dim]\
133 | .contiguous().view(-1, self.model.module.action_dim) # [bs*T, action_dim]
134 |
135 | (acc1, acc5), trajectory_success_rate, MIoU1, MIoU2, a0_acc, aT_acc = \
136 | accuracy(actions_pred.cpu(), video_label.cpu(), topk=(1, 5),
137 | max_traj_len=self.model.module.horizon)
138 |
139 | return torch.tensor(losses.avg), acc1, acc5, torch.tensor(trajectory_success_rate), \
140 | torch.tensor(MIoU1), torch.tensor(MIoU2), a0_acc, aT_acc
141 |
142 | else:
143 | return torch.tensor(losses.avg)
144 |
--------------------------------------------------------------------------------
/step/dataloader/data_load.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | from torch.utils.data import Dataset
5 | import json
6 | from collections import namedtuple
7 |
8 |
9 | Batch = namedtuple('Batch', 'Observations Actions')
10 |
11 |
12 | class PlanningDataset(Dataset):
13 | """
14 | load video and action features from dataset
15 | """
16 |
17 | def __init__(self,
18 | root,
19 | args=None,
20 | is_val=False,
21 | model=None,
22 | ):
23 | self.is_val = is_val
24 | self.data_root = root
25 | self.args = args
26 | self.max_traj_len = args.horizon
27 | self.vid_names = None
28 | self.frame_cnts = None
29 | self.images = None
30 | self.last_vid = ''
31 |
32 | if args.dataset == 'crosstask':
33 | if is_val:
34 | cross_task_data_name = args.json_path_val
35 | print('cross task data name val', cross_task_data_name)
36 |
37 | else:
38 | cross_task_data_name = args.json_path_train
39 | print('cross task data name train', cross_task_data_name)
40 |
41 |
42 | if os.path.exists(cross_task_data_name):
43 | with open(cross_task_data_name, 'r') as f:
44 | self.json_data = json.load(f)
45 | print('Loaded {}'.format(cross_task_data_name))
46 | else:
47 | assert 0
48 | elif args.dataset == 'coin':
49 | if is_val:
50 | coin_data_name = args.json_path_val
51 |
52 | else:
53 | coin_data_name = args.json_path_train
54 |
55 | if os.path.exists(coin_data_name):
56 | with open(coin_data_name, 'r') as f:
57 | self.json_data = json.load(f)
58 | print('Loaded {}'.format(coin_data_name))
59 | else:
60 | assert 0
61 | elif args.dataset == 'NIV':
62 | if is_val:
63 | niv_data_name = args.json_path_val
64 |
65 | else:
66 | niv_data_name = args.json_path_train
67 |
68 | if os.path.exists(niv_data_name):
69 | with open(niv_data_name, 'r') as f:
70 | self.json_data = json.load(f)
71 | print('Loaded {}'.format(niv_data_name))
72 | else:
73 | assert 0
74 | else:
75 | raise NotImplementedError(
76 | 'Dataset {} is not implemented'.format(args.dataset))
77 |
78 | self.model = model
79 | self.prepare_data()
80 | self.M = 3
81 |
82 | def prepare_data(self):
83 | vid_names = []
84 | frame_cnts = []
85 | for listdata in self.json_data:
86 | vid_names.append(listdata['id'])
87 | frame_cnts.append(listdata['instruction_len'])
88 | self.vid_names = vid_names
89 | self.frame_cnts = frame_cnts
90 | print('vid name list length', len(vid_names))
91 |
92 |
93 | def curate_dataset(self, images, legal_range, M=2):
94 | images_list = []
95 | labels_onehot_list = []
96 | idx_list = []
97 | for start_idx, end_idx, action_label in legal_range:
98 | idx = start_idx
99 | idx_list.append(idx)
100 | image_start_idx = max(0, idx)
101 |
102 | if image_start_idx + M <= len(images):
103 | #image_start = images[image_start_idx: image_start_idx + M]
104 |
105 | if image_start_idx == 0:
106 | image_start = images[image_start_idx: image_start_idx + M]
107 | else:
108 | image_start = images[image_start_idx - 1: image_start_idx + M - 1] ############################### Modified to load data similar to other papers
109 |
110 | else:
111 | image_start = images[len(images) - M: len(images)]
112 | image_start_cat = image_start[0]
113 | for w in range(len(image_start) - 1):
114 | image_start_cat = np.concatenate((image_start_cat, image_start[w + 1]), axis=0)
115 | images_list.append(image_start_cat)
116 | labels_onehot_list.append(action_label)
117 |
118 | end_idx = max(2, end_idx)
119 | #image_end = images[end_idx - 2:end_idx + M - 2]
120 |
121 | if end_idx >= len(images)-1:
122 |
123 | image_end = images[end_idx - 2:end_idx + M - 2]
124 | else:
125 |
126 | image_end = images[end_idx - 1:end_idx + M - 1] #########################Modified to load data similar to other papers #####################################
127 |
128 | image_end_cat = image_end[0]
129 | for w in range(len(image_end) - 1):
130 | image_end_cat = np.concatenate((image_end_cat, image_end[w + 1]), axis=0)
131 | images_list.append(image_end_cat)
132 |
133 | return images_list, labels_onehot_list, idx_list
134 |
135 | def sample_single(self, index):
136 | folder_id = self.vid_names[index]
137 | if self.args.dataset == 'crosstask':
138 | if folder_id['vid'] != self.last_vid:
139 | images_ = np.load(folder_id['feature'], allow_pickle=True)
140 | self.images = images_['frames_features']
141 | self.last_vid = folder_id['vid']
142 | else:
143 | images_ = np.load(folder_id['feature'], allow_pickle=True)
144 | self.images = images_['frames_features']
145 | images, labels_matrix, idx_list = self.curate_dataset(
146 | self.images, folder_id['legal_range'], M=self.M)
147 |
148 | shapes = [arr.shape for arr in images]
149 |
150 | frames = torch.tensor(np.array(images))
151 | labels_tensor = torch.tensor(labels_matrix, dtype=torch.long)
152 |
153 | return frames, labels_tensor
154 |
155 | def __getitem__(self, index):
156 | if self.is_val:
157 | frames, labels = self.sample_single(index)
158 |
159 | else:
160 | frames, labels = self.sample_single(index)
161 | if self.is_val:
162 | batch = Batch(frames, labels)
163 | else:
164 | batch = Batch(frames, labels)
165 | return batch
166 |
167 | def __len__(self):
168 | return len(self.json_data)
--------------------------------------------------------------------------------
/plan/dataloader/data_load.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | from torch.utils.data import Dataset
5 | import json
6 | from collections import namedtuple
7 |
8 |
9 | Batch = namedtuple('Batch', 'Observations Actions LLM Knowledge')
10 |
11 |
12 | class PlanningDataset(Dataset):
13 | """
14 | load video and action features from dataset
15 | """
16 |
17 | def __init__(self,
18 | root,
19 | args=None,
20 | is_val=False,
21 | model=None,
22 | ):
23 | self.is_val = is_val
24 | self.data_root = root
25 | self.args = args
26 | self.max_traj_len = args.horizon
27 | self.vid_names = None
28 | self.frame_cnts = None
29 | self.images = None
30 | self.last_vid = ''
31 |
32 | if args.dataset == 'crosstask':
33 | if is_val:
34 | cross_task_data_name = args.json_path_val
35 | print('cross task data name val', cross_task_data_name)
36 |
37 | else:
38 | cross_task_data_name = args.json_path_train
39 | print('cross task data name train', cross_task_data_name)
40 |
41 | if os.path.exists(cross_task_data_name):
42 | with open(cross_task_data_name, 'r') as f:
43 | self.json_data = json.load(f)
44 | print('Loaded {}'.format(cross_task_data_name))
45 | else:
46 | assert 0
47 | elif args.dataset == 'coin':
48 | if is_val:
49 | coin_data_name = args.json_path_val
50 |
51 | else:
52 | coin_data_name = args.json_path_train
53 |
54 | if os.path.exists(coin_data_name):
55 | with open(coin_data_name, 'r') as f:
56 | self.json_data = json.load(f)
57 | print('Loaded {}'.format(coin_data_name))
58 | else:
59 | assert 0
60 | elif args.dataset == 'NIV':
61 | if is_val:
62 | niv_data_name = args.json_path_val
63 |
64 | else:
65 | niv_data_name = args.json_path_train
66 |
67 | if os.path.exists(niv_data_name):
68 | with open(niv_data_name, 'r') as f:
69 | self.json_data = json.load(f)
70 | print('Loaded {}'.format(niv_data_name))
71 | else:
72 | assert 0
73 | else:
74 | raise NotImplementedError(
75 | 'Dataset {} is not implemented'.format(args.dataset))
76 |
77 | self.model = model
78 | self.prepare_data()
79 | self.M = 3
80 |
81 | def prepare_data(self):
82 | vid_names = []
83 | frame_cnts = []
84 | for listdata in self.json_data:
85 | vid_names.append(listdata['id'])
86 | frame_cnts.append(listdata['instruction_len'])
87 | self.vid_names = vid_names
88 | self.frame_cnts = frame_cnts
89 | print('vid name list length', len(vid_names))
90 |
91 |
92 | def curate_dataset(self, images, legal_range, M=2):
93 | images_list = []
94 | labels_onehot_list = []
95 | idx_list = []
96 | for start_idx, end_idx, action_label in legal_range:
97 | idx = start_idx
98 | idx_list.append(idx)
99 | image_start_idx = max(0, idx)
100 | if image_start_idx + M <= len(images):
101 | #image_start = images[image_start_idx: image_start_idx + M]
102 |
103 | if image_start_idx == 0:
104 | image_start = images[image_start_idx: image_start_idx + M]
105 | else:
106 | image_start = images[image_start_idx - 1: image_start_idx + M - 1] ############################### Modified to load data similar to other papers
107 |
108 | else:
109 | image_start = images[len(images) - M: len(images)]
110 | image_start_cat = image_start[0]
111 | for w in range(len(image_start) - 1):
112 | image_start_cat = np.concatenate((image_start_cat, image_start[w + 1]), axis=0)
113 | images_list.append(image_start_cat)
114 | labels_onehot_list.append(action_label)
115 |
116 | end_idx = max(2, end_idx)
117 | #image_end = images[end_idx - 2:end_idx + M - 2]
118 |
119 | if end_idx >= len(images)-1:
120 |
121 | image_end = images[end_idx - 2:end_idx + M - 2]
122 | else:
123 |
124 | image_end = images[end_idx - 1:end_idx + M - 1] #########################Modified to load data similar to other papers #####################################
125 |
126 | image_end_cat = image_end[0]
127 | for w in range(len(image_end) - 1):
128 | image_end_cat = np.concatenate((image_end_cat, image_end[w + 1]), axis=0)
129 | images_list.append(image_end_cat)
130 | return images_list, labels_onehot_list, idx_list
131 |
132 | def sample_single(self, index):
133 | folder_id = self.vid_names[index]
134 |
135 | graph_path = folder_id['graph_action_path']
136 | LLM_path = folder_id['graph_action_path'] ############################################################## Modify this into LLM_action_path if LLM_actions_are_available#############################
137 |
138 |
139 | if self.args.dataset == 'crosstask':
140 | if folder_id['vid'] != self.last_vid:
141 | images_ = np.load(folder_id['feature'], allow_pickle=True)
142 | self.images = images_['frames_features']
143 | self.last_vid = folder_id['vid']
144 | else:
145 | images_ = np.load(folder_id['feature'], allow_pickle=True)
146 | self.images = images_['frames_features']
147 | images, labels_matrix, idx_list = self.curate_dataset(
148 | self.images, folder_id['legal_range'], M=self.M)
149 |
150 |
151 | # Get the shape of each array
152 | shapes = [arr.shape for arr in images]
153 |
154 | frames = torch.tensor(np.array(images))
155 | labels_tensor = torch.tensor(labels_matrix, dtype=torch.long)
156 |
157 | graph_path = torch.tensor(graph_path, dtype=torch.long)
158 |
159 | LLM_path = torch.tensor(LLM_path, dtype=torch.long)
160 | return frames, labels_tensor, LLM_path, graph_path
161 |
162 | def __getitem__(self, index):
163 | frames, labels, LLM_path, graph_path = self.sample_single(index)
164 | batch = Batch(frames, labels, LLM_path, graph_path)
165 | return batch
166 |
167 | def __len__(self):
168 | return len(self.json_data)
--------------------------------------------------------------------------------
/plan/utils/args.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def get_args(description='whl'):
5 | parser = argparse.ArgumentParser(description=description)
6 | parser.add_argument('--checkpoint_root',
7 | type=str,
8 | default='checkpoint',
9 | help='checkpoint dir root')
10 | parser.add_argument('--log_root',
11 | type=str,
12 | default='log',
13 | help='log dir root')
14 | parser.add_argument('--checkpoint_dir',
15 | type=str,
16 | default='',
17 | help='checkpoint model folder')
18 | parser.add_argument('--optimizer',
19 | type=str,
20 | default='adam',
21 | help='opt algorithm')
22 | parser.add_argument('--num_thread_reader',
23 | type=int,
24 | default=40,
25 | help='')
26 | parser.add_argument('--batch_size',
27 | type=int,
28 | default=256,
29 | help='batch size')
30 | parser.add_argument('--batch_size_val',
31 | type=int,
32 | default=1024,
33 | help='batch size eval')
34 | parser.add_argument('--pretrain_cnn_path',
35 | type=str,
36 | default='',
37 | help='')
38 | parser.add_argument('--momemtum',
39 | type=float,
40 | default=0.9,
41 | help='SGD momemtum')
42 | parser.add_argument('--log_freq',
43 | type=int,
44 | default=500,
45 | help='how many steps do we log once')
46 | parser.add_argument('--save_freq',
47 | type=int,
48 | default=1,
49 | help='how many epochs do we save once')
50 | parser.add_argument('--gradient_accumulate_every',
51 | type=int,
52 | default=1,
53 | help='accumulation_steps')
54 | parser.add_argument('--ema_decay',
55 | type=float,
56 | default=0.995,
57 | help='')
58 | parser.add_argument('--step_start_ema',
59 | type=int,
60 | default=400,
61 | help='')
62 | parser.add_argument('--update_ema_every',
63 | type=int,
64 | default=10,
65 | help='')
66 | parser.add_argument('--crop_only',
67 | type=int,
68 | default=1,
69 | help='random seed')
70 | parser.add_argument('--centercrop',
71 | type=int,
72 | default=0,
73 | help='random seed')
74 | parser.add_argument('--random_flip',
75 | type=int,
76 | default=1,
77 | help='random seed')
78 | parser.add_argument('--verbose',
79 | type=int,
80 | default=1,
81 | help='')
82 | parser.add_argument('--fps',
83 | type=int,
84 | default=1,
85 | help='')
86 | parser.add_argument('--cudnn_benchmark',
87 | type=int,
88 | default=0,
89 | help='')
90 | parser.add_argument('--horizon',
91 | type=int,
92 | default=4,
93 | help='')
94 | parser.add_argument('--dataset',
95 | type=str,
96 | default='crosstask',
97 | #default='coin',
98 | help='dataset')
99 | parser.add_argument('--action_dim',
100 | type=int,
101 | default=105,
102 | help='')
103 | parser.add_argument('--observation_dim',
104 | type=int,
105 | default=1536,
106 | help='')
107 | parser.add_argument('--class_dim_graph',
108 | type=int,
109 | default=105, ####### modify according to the dataset.
110 | help='')
111 | parser.add_argument('--class_dim_llama',
112 | type=int,
113 | default=105, ####### modify according to the dataset.
114 | help='')
115 | parser.add_argument('--num_seq_PKG',
116 | type=int,
117 | default=2, ####### modify according to the dataset.
118 | help='')
119 | parser.add_argument('--num_seq_LLM',
120 | type=int,
121 | default=2, ####### modify according to the dataset.
122 | help='')
123 | parser.add_argument('--n_diffusion_steps',
124 | type=int,
125 | default=200,
126 | help='')
127 | parser.add_argument('--n_train_steps',
128 | type=int,
129 | default=200,
130 | help='training_steps_per_epoch')
131 | parser.add_argument('--root',
132 | type=str,
133 | default='/l/users/ravindu.nagasinghe/KEPP/datasets/PDPP/CrossTask_assets',
134 | help='root path of dataset crosstask')
135 | parser.add_argument('--json_path_train',
136 | type=str,
137 | default='/l/users/ravindu.nagasinghe/KEPP/plan/data_lists/train_list.json',
138 | help='path of the generated json file for train')
139 | parser.add_argument('--json_path_val',
140 | type=str,
141 | default='/l/users/ravindu.nagasinghe/KEPP/plan/data_lists/test_list.json',
142 | help='path of the generated json file for val')
143 |
144 | parser.add_argument('--epochs', default=120, type=int, metavar='N',
145 | help='number of total epochs to run')
146 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
147 | help='manual epoch number (useful on restarts)')
148 | parser.add_argument('--lr', '--learning-rate', default=5e-4, type=float,
149 | metavar='LR', help='initial learning rate', dest='lr')
150 | parser.add_argument('--resume', dest='resume', action='store_true',
151 | help='resume training from last checkpoint')
152 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
153 | help='evaluate model on validation set')
154 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
155 | help='use pre-trained model')
156 | parser.add_argument('--pin_memory', dest='pin_memory', action='store_true',
157 | help='use pin_memory')
158 | parser.add_argument('--world-size', default=1, type=int,
159 | help='number of nodes for distributed training')
160 | parser.add_argument('--rank', default=0, type=int,
161 | help='node rank for distributed training')
162 | parser.add_argument('--dist-file', default='dist-file', type=str,
163 | help='url used to set up distributed training')
164 | parser.add_argument('--dist-url', default='tcp://localhost:21712', type=str,
165 | help='url used to set up distributed training')
166 | parser.add_argument('--dist-backend', default='nccl', type=str,
167 | help='distributed backend')
168 | parser.add_argument('--seed', default=217, type=int,
169 | help='seed for initializing training. ')
170 | parser.add_argument('--gpu', default=None, type=int,
171 | help='GPU id to use.')
172 | parser.add_argument('--multiprocessing-distributed', action='store_true',
173 | help='Use multi-processing distributed training to launch '
174 | 'N processes per node, which has N GPUs. This is the '
175 | 'fastest way to use PyTorch for either single node or '
176 | 'multi node data parallel training')
177 | args = parser.parse_args()
178 | return args
--------------------------------------------------------------------------------
/step/utils/args.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def get_args(description='whl'):
5 | parser = argparse.ArgumentParser(description=description)
6 | parser.add_argument('--checkpoint_root',
7 | type=str,
8 | default='checkpoint',
9 | help='checkpoint dir root')
10 | parser.add_argument('--log_root',
11 | type=str,
12 | default='log',
13 | help='log dir root')
14 | parser.add_argument('--checkpoint_dir',
15 | type=str,
16 | default='',
17 | help='checkpoint model folder')
18 | parser.add_argument('--optimizer',
19 | type=str,
20 | default='adam',
21 | help='opt algorithm')
22 | parser.add_argument('--num_thread_reader',
23 | type=int,
24 | default=40,
25 | help='')
26 | parser.add_argument('--batch_size',
27 | type=int,
28 | default=256,
29 | help='batch size')
30 | parser.add_argument('--batch_size_val',
31 | type=int,
32 | default=1024,
33 | help='batch size eval')
34 | parser.add_argument('--pretrain_cnn_path',
35 | type=str,
36 | default='',
37 | help='')
38 | parser.add_argument('--momemtum',
39 | type=float,
40 | default=0.9,
41 | help='SGD momemtum')
42 | parser.add_argument('--log_freq',
43 | type=int,
44 | default=500,
45 | help='how many steps do we log once')
46 | parser.add_argument('--save_freq',
47 | type=int,
48 | default=1,
49 | help='how many epochs do we save once')
50 | parser.add_argument('--gradient_accumulate_every',
51 | type=int,
52 | default=1,
53 | help='accumulation_steps')
54 | parser.add_argument('--ema_decay',
55 | type=float,
56 | default=0.995,
57 | help='')
58 | parser.add_argument('--step_start_ema',
59 | type=int,
60 | default=400,
61 | help='')
62 | parser.add_argument('--update_ema_every',
63 | type=int,
64 | default=10,
65 | help='')
66 | parser.add_argument('--crop_only',
67 | type=int,
68 | default=1,
69 | help='random seed')
70 | parser.add_argument('--centercrop',
71 | type=int,
72 | default=0,
73 | help='random seed')
74 | parser.add_argument('--random_flip',
75 | type=int,
76 | default=1,
77 | help='random seed')
78 | parser.add_argument('--verbose',
79 | type=int,
80 | default=1,
81 | help='')
82 | parser.add_argument('--fps',
83 | type=int,
84 | default=1,
85 | help='')
86 | parser.add_argument('--cudnn_benchmark',
87 | type=int,
88 | default=0,
89 | help='')
90 | parser.add_argument('--horizon',
91 | type=int,
92 | default=4,
93 | help='')
94 | parser.add_argument('--dataset',
95 | type=str,
96 | default='crosstask',
97 | #default='coin',
98 | help='dataset')
99 | parser.add_argument('--action_dim',
100 | type=int,
101 | default=105,
102 | #default=778,
103 | help='')
104 | parser.add_argument('--observation_dim',
105 | type=int,
106 | default=1536,
107 | help='')
108 | parser.add_argument('--n_diffusion_steps',
109 | type=int,
110 | default=200,
111 | help='')
112 | parser.add_argument('--n_train_steps',
113 | type=int,
114 | default=200,
115 | help='training_steps_per_epoch')
116 | parser.add_argument('--root',
117 | type=str,
118 | default='/l/users/ravindu.nagasinghe/KEPP/datasets/CrossTask_assets',
119 | help='root path of dataset crosstask')
120 | parser.add_argument('--json_path_train',
121 | type=str,
122 | default='/l/users/ravindu.nagasinghe/KEPP/step/data_lists/train_data_list.json',
123 | help='path of the generated json file for train')
124 | parser.add_argument('--json_path_val',
125 | type=str,
126 | default='/l/users/ravindu.nagasinghe/KEPP/step/data_lists/test_data_list.json', #modify for train set /l/users/ravindu.nagasinghe/KEPP/step/data_lists/train_data_list.json
127 | help='path of the generated json file for val')
128 | parser.add_argument('--steps_path',
129 | type=str,
130 | default='/l/users/ravindu.nagasinghe/KEPP/step/test_list_steps.json', # modify for train set '/l/users/ravindu.nagasinghe/KEPP/step/train_list_steps.json'
131 | help='the path for predicted steps only')
132 | parser.add_argument('--step_model_output',
133 | type=str,
134 | default='/l/users/ravindu.nagasinghe/KEPP/step/test_data_step_model.json', # modify for train set '/l/users/ravindu.nagasinghe/KEPP/step/train_data_step_model.json'
135 | help='the path for predicted steps final output')
136 |
137 | parser.add_argument('--epochs', default=120, type=int, metavar='N',
138 | help='number of total epochs to run')
139 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
140 | help='manual epoch number (useful on restarts)')
141 | parser.add_argument('--lr', '--learning-rate', default=5e-4, type=float,
142 | metavar='LR', help='initial learning rate', dest='lr')
143 | parser.add_argument('--resume', dest='resume', action='store_true',
144 | help='resume training from last checkpoint')
145 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
146 | help='evaluate model on validation set')
147 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
148 | help='use pre-trained model')
149 | parser.add_argument('--pin_memory', dest='pin_memory', action='store_true',
150 | help='use pin_memory')
151 | parser.add_argument('--world-size', default=1, type=int,
152 | help='number of nodes for distributed training')
153 | parser.add_argument('--rank', default=0, type=int,
154 | help='node rank for distributed training')
155 | parser.add_argument('--dist-file', default='dist-file', type=str,
156 | help='url used to set up distributed training')
157 | parser.add_argument('--dist-url', default='tcp://localhost:21712', type=str,
158 | help='url used to set up distributed training')
159 | parser.add_argument('--dist-backend', default='nccl', type=str,
160 | help='distributed backend')
161 | parser.add_argument('--seed', default=217, type=int,
162 | help='seed for initializing training. ')
163 | parser.add_argument('--gpu', default=None, type=int,
164 | help='GPU id to use.')
165 | parser.add_argument('--multiprocessing-distributed', action='store_true',
166 | help='Use multi-processing distributed training to launch '
167 | 'N processes per node, which has N GPUs. This is the '
168 | 'fastest way to use PyTorch for either single node or '
169 | 'multi node data parallel training')
170 | args = parser.parse_args()
171 | return args
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # KEPP: Why Not Use Your Textbook? Knowledge-Enhanced Procedure Planning of Instructional Videos-CVPR 2024
2 |
3 | [](https://arxiv.org/abs/2403.02782)
4 | [](https://ravindu-yasas-nagasinghe.github.io/KEPP-Project_Page/)
5 |
6 | This repository gives the official implementation of KEPP:Why Not Use Your Textbook? Knowledge-Enhanced Procedure Planning of Instructional Videos (CVPR 2024)
7 |
8 | In our project, we explore the capability of an agent to construct a logical sequence of action steps, thereby assembling a strategic procedural plan. This plan is crucial for navigating from an initial visual observation to a target visual outcome, as depicted in real-life instructional videos. Existing works have attained partial success by extensively leveraging various sources of information available in the datasets, such as heavy intermediate visual observations, procedural names, or natural language step-by-step instructions, for features or supervision signals. However, the task remains formidable due to the implicit causal constraints in the sequencing of steps and the variability inherent in multiple feasible plans. To tackle these intricacies that previous efforts have overlooked, we propose to enhance the agent's capabilities by infusing it with procedural knowledge. This knowledge, sourced from training procedure plans and structured as a directed weighted graph, equips the agent to better navigate the complexities of step sequencing and its potential variations. We coin our approach KEPP, a novel Knowledge-Enhanced Procedure Planning system, which harnesses a probabilistic procedural knowledge graph extracted from training data, effectively acting as a comprehensive textbook for the training domain. Experimental evaluations across three widely-used datasets under settings of varying complexity reveal that KEPP attains superior, state-of-the-art results while requiring only minimal supervision. The main architecture of our model is as follows.
9 |
10 |
21 |
22 | 
23 |
24 | ### Contents
25 | 1) [Setup](#Setup)
26 | 2) [Data Preparation](#Data-Preparation)
27 | 3) [Train Step model](#Train-Step-model)
28 | 4) [Generate paths from procedure knowlege graph](#Generate-paths-from-procedure-knowlege-graph)
29 | 5) [Inference](#Inference)
30 | ## Setup
31 | In a conda env with cuda available, run:
32 | ```shell
33 | pip install -r requirements.txt
34 | ```
35 | ## Data Preparation
36 | ### CrossTask
37 | 1. Download datasets&features
38 | ```shell
39 | cd {root}/dataset/crosstask
40 | bash download.sh
41 | ```
42 | 2. move your datasplit files and action one-hot coding file to `{root}/dataset/crosstask/crosstask_release/`
43 | ```shell
44 | mv *.json crosstask_release
45 | mv actions_one_hot.npy crosstask_release
46 | ```
47 | ### COIN
48 | 1. Download datasets&features
49 | ```shell
50 | cd {root}/dataset/coin
51 | bash download.sh
52 | ```
53 | ### NIV
54 | 1. Download datasets&features
55 | ```shell
56 | cd {root}/dataset/NIV
57 | bash download.sh
58 | ```
59 | ## Train Step model
60 | 1. First generate the training and testing dataset json files. You can modify the dataset, train steps, horizon(prediction length), json files savepath etc. in `args.py`. Set the `--json_path_train`, and `--json_path_val` in `args.py` as the dataset json file paths.
61 | ```shell
62 | cd {root}/step
63 | python loading_data.py
64 | ```
65 | Dimensions for different datasets are listed below:
66 |
67 | | Dataset | observation_dim | action_dim | class_dim |
68 | |----| ----| ----| ----|
69 | | CrossTask | 1536(how) 9600(base) | 105 | 18 |
70 | | COIN | 1536 | 778 | 180 |
71 | | NIV | 1536 | 48 | 5 |
72 |
73 | 2. Train the step model
74 | ```shell
75 | python main_distributed.py --multiprocessing-distributed --num_thread_reader=8 --cudnn_benchmark=1 --pin_memory --checkpoint_dir=whl --resume --batch_size=256 --batch_size_val=256 --evaluate
76 | ```
77 | The trained models will be saved in {root}/step/save_max.
78 |
79 | 3. Generate first and last action predictions for train and test dataset.
80 | * Modify the checkpoint path(L329) as the evaluated model(in save_max) in inference.py.
81 | * Modify the `--json_path_val` , `--steps_path` , and `--step_model_output` arguments in `args.py` to generate step predicted dataset json file paths for train and test datasets seperately. Run following command for train and test datasets seperately by modifying as afore mentioned.
82 |
83 | ```shell
84 | python inference.py --multiprocessing-distributed --num_thread_reader=8 --cudnn_benchmark=1 --pin_memory --checkpoint_dir=whl --resume --batch_size=256 --batch_size_val=256 --evaluate > output.txt
85 | ```
86 | ## Generate paths from procedure knowlege graph
87 | 1. Train the graph for the relavent dataset (Not compulsory)
88 | ```shell
89 | cd {root}/PKG
90 | python graph_creation.py
91 | Select mode "train_out_n"
92 | ```
93 | Trained graphs for CrossTask, COIN, NIV datasets are available on `cd {root}/PKG/graphs`. Change (L13) `graph_save_path` of `graph_creation.py` to load procedure knowledge graphs trained on different datasets.
94 |
95 | 2. Obtain PKG conditions for train and test datasets.
96 | * Modify line 540 of `graph_creation.py` as the output of step model (`--step_model_output`).
97 | * Modify line 568 of `graph_creation.py` to set the output path for the generated procedure knowlwdge graph conditioned train and test dataset json files.
98 | * run the following for both train and test dataset files generated from the step model by modifying `graph_creation.py` file as afore mentioned.
99 | ```shell
100 | python graph_creation.py
101 | Select mode "validate"
102 | ```
103 | ## Train plan model
104 | 1. Modify the `json_path_train` and `json_path_val` arguments of `args.py` in plan model as the outputs generated from procedure knowlwdge graph for train and test data respectively.
105 |
106 | Modify the parameter `--num_seq_PKG` in `args.py` to match the generated amount of PKG conditions. (Modify `--num_seq_LLM` to the same number as well if LLM conditions are not used seperately.)
107 | ```shell
108 | cd {root}/plan
109 | python main_distributed.py --multiprocessing-distributed --num_thread_reader=8 --cudnn_benchmark=1 --pin_memory --checkpoint_dir=whl --resume --batch_size=256 --batch_size_val=256 --evaluate
110 | ```
111 | ## Inference
112 |
113 | For Metrics
114 | Modify the max checkpoint path(L339) as the evaluated model in inference.py and run:
115 | ```shell
116 | python inference.py --multiprocessing-distributed --num_thread_reader=8 --cudnn_benchmark=1 --pin_memory --checkpoint_dir=whl --resume --batch_size=256 --batch_size_val=256 --evaluate > output.txt
117 | ```
118 | Results of given checkpoints:
119 |
120 | | dataset | SR | mAcc | MIoU |
121 | | ---- | -- | -- | -- |
122 | | Crosstask_T=4 | 21.02 | 56.08 | 64.15 |
123 | | COIN_T=4 | 15.63 | 39.53 | 53.27 |
124 | | NIV_T=4 | 22.71 | 41.59 | 91.49 |
125 |
126 | Here we present the qualitative examples of our proposed method. Intermediate steps are padded in the step model because it only predicts the start and end actions.
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 | Checkpoint links will be uploaded soon
135 |
136 |
137 | ### Citation
138 | ```shell
139 | @InProceedings{Nagasinghe_2024_CVPR,
140 | author = {Nagasinghe, Kumaranage Ravindu Yasas and Zhou, Honglu and Gunawardhana, Malitha and Min, Martin Renqiang and Harari, Daniel and Khan, Muhammad Haris},
141 | title = {Why Not Use Your Textbook? Knowledge-Enhanced Procedure Planning of Instructional Videos},
142 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
143 | month = {June},
144 | year = {2024},
145 | pages = {18816-18826}
146 | }
147 | ```
148 | ### Contact
149 | In case of any query, create issue or contact ravindunagasinghe1998@gmail.com
150 |
151 | ### Acknowledgement
152 | * This work was supported by joint MBZUAI-WIS grant P007. The authors are grateful for their generous support, which made this research possible.
153 | * This codebase is built on PDPP
154 |
155 |
--------------------------------------------------------------------------------
/step/model/helpers.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from einops.layers.torch import Rearrange
6 | from torch.optim.lr_scheduler import LambdaLR
7 | import os
8 | import numpy as np
9 | import logging
10 | from tensorboardX import SummaryWriter
11 |
12 |
13 | # -----------------------------------------------------------------------------#
14 | # ---------------------------------- modules ----------------------------------#
15 | # -----------------------------------------------------------------------------#
16 |
17 | def zero_module(module):
18 | """
19 | Zero out the parameters of a module and return it.
20 | """
21 | for p in module.parameters():
22 | p.detach().zero_()
23 | return module
24 |
25 |
26 | class SinusoidalPosEmb(nn.Module):
27 | def __init__(self, dim):
28 | super().__init__()
29 | self.dim = dim
30 |
31 | def forward(self, x):
32 | device = x.device
33 | half_dim = self.dim // 2
34 | emb = math.log(10000) / (half_dim - 1)
35 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
36 | emb = x[:, None] * emb[None, :]
37 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
38 | return emb
39 |
40 |
41 | class Downsample1d(nn.Module):
42 | def __init__(self, dim):
43 | super().__init__()
44 | print(dim, 'down!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
45 | self.conv = nn.Conv1d(dim, dim, 2, 1, 0) #####Edited kernal size previously had (dim, dim, 2, 1, 0) . make 2 to 1 when horizon is 2.
46 |
47 | def forward(self, x):
48 | return self.conv(x)
49 |
50 |
51 | class Upsample1d(nn.Module):
52 | def __init__(self, dim):
53 | super().__init__()
54 | self.conv = nn.ConvTranspose1d(dim, dim, 2, 1, 0) #####Edited kernal size previously had (dim, dim, 2, 1, 0) . make 2 to 1 when horizon is 2.
55 |
56 | def forward(self, x):
57 | return self.conv(x)
58 |
59 |
60 | class Conv1dBlock(nn.Module):
61 | """
62 | Conv1d --> GroupNorm --> Mish
63 | """
64 |
65 | def __init__(self, inp_channels, out_channels, kernel_size, n_groups=32, drop_out=0.0, if_zero=False):
66 | super().__init__()
67 | if drop_out > 0.0:
68 | self.block = nn.Sequential(
69 | zero_module(
70 | nn.Conv1d(inp_channels, out_channels, kernel_size, padding=1),
71 | ),
72 | Rearrange('batch channels horizon -> batch channels 1 horizon'),
73 | nn.GroupNorm(n_groups, out_channels),
74 | Rearrange('batch channels 1 horizon -> batch channels horizon'),
75 | nn.Mish(),
76 | nn.Dropout(p=drop_out),
77 | )
78 | elif if_zero:
79 | self.block = nn.Sequential(
80 | zero_module(
81 | nn.Conv1d(inp_channels, out_channels, kernel_size, padding=1),
82 | ),
83 | Rearrange('batch channels horizon -> batch channels 1 horizon'),
84 | nn.GroupNorm(n_groups, out_channels),
85 | Rearrange('batch channels 1 horizon -> batch channels horizon'),
86 | nn.Mish(),
87 |
88 | )
89 | else:
90 | self.block = nn.Sequential(
91 | nn.Conv1d(inp_channels, out_channels, kernel_size, padding=1),
92 | Rearrange('batch channels horizon -> batch channels 1 horizon'),
93 | nn.GroupNorm(n_groups, out_channels),
94 | Rearrange('batch channels 1 horizon -> batch channels horizon'),
95 | nn.Mish(),
96 | )
97 |
98 | def forward(self, x):
99 | return self.block(x)
100 |
101 |
102 | # -----------------------------------------------------------------------------#
103 | # ---------------------------------- sampling ---------------------------------#
104 | # -----------------------------------------------------------------------------#
105 |
106 | def extract(a, t, x_shape):
107 | b, *_ = t.shape
108 | out = a.gather(-1, t)
109 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
110 |
111 |
112 | def cosine_beta_schedule(timesteps, s=0.008, dtype=torch.float32):
113 | """
114 | cosine schedule
115 | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
116 | """
117 | steps = timesteps + 1
118 | x = np.linspace(0, steps, steps)
119 | alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
120 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
121 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
122 | betas_clipped = np.clip(betas, a_min=0, a_max=0.999)
123 | return torch.tensor(betas_clipped, dtype=dtype)
124 |
125 |
126 | def condition_projection(x, conditions, action_dim):
127 | for t, val in conditions.items():
128 | if t != 'task':
129 | x[:, t, action_dim:] = val.clone()
130 |
131 | x[:, 1:-1, :] = 0. ########################condition padding
132 |
133 | return x
134 |
135 |
136 | # -----------------------------------------------------------------------------#
137 | # ---------------------------------- Loss -------------------------------------#
138 | # -----------------------------------------------------------------------------#
139 |
140 | class Weighted_MSE(nn.Module):
141 |
142 | def __init__(self, weights, action_dim):
143 | super().__init__()
144 | # self.register_buffer('weights', weights)
145 | self.action_dim = action_dim
146 |
147 | def forward(self, pred, targ):
148 | """
149 | :param pred: [B, T, task_dim+action_dim+observation_dim]
150 | :param targ: [B, T, task_dim+action_dim+observation_dim]
151 | :return:
152 | """
153 |
154 | loss_action = F.mse_loss(pred, targ, reduction='none')
155 | loss_action[:, 0, :self.action_dim] *= 10.
156 | loss_action[:, -1, :self.action_dim] *= 10.
157 | loss_action[:, 1:-1, :self.action_dim] *= 0.
158 | loss_action = loss_action.sum()
159 | return loss_action
160 |
161 |
162 | Losses = {
163 | 'Weighted_MSE': Weighted_MSE,
164 | }
165 |
166 | # -----------------------------------------------------------------------------#
167 | # -------------------------------- lr_schedule --------------------------------#
168 | # -----------------------------------------------------------------------------#
169 |
170 | def get_lr_schedule_with_warmup(optimizer, num_training_steps, last_epoch=-1):
171 | num_warmup_steps = num_training_steps * 20 / 120
172 | decay_steps = num_training_steps * 30 / 120
173 |
174 | def lr_lambda(current_step):
175 | if current_step <= num_warmup_steps:
176 | return max(0., float(current_step) / float(max(1, num_warmup_steps)))
177 | else:
178 | return max(0.5 ** ((current_step - num_warmup_steps) // decay_steps), 0.)
179 |
180 | return LambdaLR(optimizer, lr_lambda, last_epoch)
181 |
182 | # -----------------------------------------------------------------------------#
183 | # ---------------------------------- logging ----------------------------------#
184 | # -----------------------------------------------------------------------------#
185 |
186 | # Taken from PyTorch's examples.imagenet.main
187 | class AverageMeter(object):
188 | """Computes and stores the average and current value"""
189 |
190 | def __init__(self):
191 | self.reset()
192 |
193 | def reset(self):
194 | self.val = 0
195 | self.avg = 0
196 | self.sum = 0
197 | self.count = 0
198 |
199 | def update(self, val, n=1):
200 | self.val = val
201 | self.sum += val * n
202 | self.count += n
203 | self.avg = self.sum / self.count
204 |
205 |
206 | class Logger:
207 | def __init__(self, log_dir, n_logged_samples=10, summary_writer=SummaryWriter, if_exist=False):
208 | self._log_dir = log_dir
209 | print('logging outputs to ', log_dir)
210 | self._n_logged_samples = n_logged_samples
211 | self._summ_writer = summary_writer(log_dir, flush_secs=120, max_queue=10)
212 | if not if_exist:
213 | log = logging.getLogger(log_dir)
214 | if not log.handlers:
215 | log.setLevel(logging.DEBUG)
216 | if not os.path.exists(log_dir):
217 | os.mkdir(log_dir)
218 | fh = logging.FileHandler(os.path.join(log_dir, 'log.txt'))
219 | fh.setLevel(logging.INFO)
220 | formatter = logging.Formatter(fmt='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S')
221 | fh.setFormatter(formatter)
222 | log.addHandler(fh)
223 | self.log = log
224 |
225 | def log_scalar(self, scalar, name, step_):
226 | self._summ_writer.add_scalar('{}'.format(name), scalar, step_)
227 |
228 | def log_scalars(self, scalar_dict, group_name, step, phase):
229 | """Will log all scalars in the same plot."""
230 | self._summ_writer.add_scalars('{}_{}'.format(group_name, phase), scalar_dict, step)
231 |
232 | def flush(self):
233 | self._summ_writer.flush()
234 |
235 | def log_info(self, info):
236 | self.log.info("{}".format(info))
237 |
--------------------------------------------------------------------------------
/plan/model/helpers.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from einops.layers.torch import Rearrange
6 | from torch.optim.lr_scheduler import LambdaLR
7 | import os
8 | import numpy as np
9 | import logging
10 | from tensorboardX import SummaryWriter
11 |
12 |
13 | # -----------------------------------------------------------------------------#
14 | # ---------------------------------- modules ----------------------------------#
15 | # -----------------------------------------------------------------------------#
16 |
17 | def zero_module(module):
18 | """
19 | Zero out the parameters of a module and return it.
20 | """
21 | for p in module.parameters():
22 | p.detach().zero_()
23 | return module
24 |
25 |
26 | class SinusoidalPosEmb(nn.Module):
27 | def __init__(self, dim):
28 | super().__init__()
29 | self.dim = dim
30 |
31 | def forward(self, x):
32 | device = x.device
33 | half_dim = self.dim // 2
34 | emb = math.log(10000) / (half_dim - 1)
35 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
36 | emb = x[:, None] * emb[None, :]
37 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
38 | return emb
39 |
40 |
41 | class Downsample1d(nn.Module):
42 | def __init__(self, dim):
43 | super().__init__()
44 | self.conv = nn.Conv1d(dim, dim, 2, 1, 0) #####Edited kernal size previously had (dim, dim, 2, 1, 0)
45 |
46 | def forward(self, x):
47 | return self.conv(x)
48 |
49 |
50 | class Upsample1d(nn.Module):
51 | def __init__(self, dim):
52 | super().__init__()
53 | print(dim, 'down!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
54 | self.conv = nn.ConvTranspose1d(dim, dim, 2, 1, 0) #####Edited kernal size previously had (dim, dim, 2, 1, 0)
55 |
56 | def forward(self, x):
57 | return self.conv(x)
58 |
59 |
60 | class Conv1dBlock(nn.Module):
61 | """
62 | Conv1d --> GroupNorm --> Mish
63 | """
64 |
65 | def __init__(self, inp_channels, out_channels, kernel_size, n_groups=32, drop_out=0.0, if_zero=False):
66 | super().__init__()
67 | if drop_out > 0.0:
68 | self.block = nn.Sequential(
69 | zero_module(
70 | nn.Conv1d(inp_channels, out_channels, kernel_size, padding=1),
71 | ),
72 | Rearrange('batch channels horizon -> batch channels 1 horizon'),
73 | nn.GroupNorm(n_groups, out_channels),
74 | Rearrange('batch channels 1 horizon -> batch channels horizon'),
75 | nn.Mish(),
76 | nn.Dropout(p=drop_out),
77 | )
78 | elif if_zero:
79 | self.block = nn.Sequential(
80 | zero_module(
81 | nn.Conv1d(inp_channels, out_channels, kernel_size, padding=1),
82 | ),
83 | Rearrange('batch channels horizon -> batch channels 1 horizon'),
84 | nn.GroupNorm(n_groups, out_channels),
85 | Rearrange('batch channels 1 horizon -> batch channels horizon'),
86 | nn.Mish(),
87 |
88 | )
89 | else:
90 | self.block = nn.Sequential(
91 | nn.Conv1d(inp_channels, out_channels, kernel_size, padding=1),
92 | Rearrange('batch channels horizon -> batch channels 1 horizon'),
93 | nn.GroupNorm(n_groups, out_channels),
94 | Rearrange('batch channels 1 horizon -> batch channels horizon'),
95 | nn.Mish(),
96 | )
97 |
98 | def forward(self, x):
99 | return self.block(x)
100 |
101 |
102 | # -----------------------------------------------------------------------------#
103 | # ---------------------------------- sampling ---------------------------------#
104 | # -----------------------------------------------------------------------------#
105 |
106 | def extract(a, t, x_shape):
107 | b, *_ = t.shape
108 | out = a.gather(-1, t)
109 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
110 |
111 |
112 | def cosine_beta_schedule(timesteps, s=0.008, dtype=torch.float32):
113 | """
114 | cosine schedule
115 | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
116 | """
117 | steps = timesteps + 1
118 | x = np.linspace(0, steps, steps)
119 | alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
120 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
121 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
122 | betas_clipped = np.clip(betas, a_min=0, a_max=0.999)
123 | return torch.tensor(betas_clipped, dtype=dtype)
124 |
125 |
126 | def condition_projection(x, conditions, action_dim, LLM_dim , PKG_dim):
127 | for t, val in conditions.items():
128 | #print('t---------------',t)
129 | if t != 'PKG_action_path' and t != 'LLM_action_path':
130 | x[:, t, LLM_dim + PKG_dim + action_dim:] = val.clone()
131 |
132 | x[:, 1:-1, LLM_dim + PKG_dim + action_dim:] = 0.
133 | x[:, :, :PKG_dim] = conditions['PKG_action_path']
134 | x[:, :, PKG_dim:PKG_dim + LLM_dim] = conditions['LLM_action_path']
135 |
136 | return x
137 |
138 |
139 | # -----------------------------------------------------------------------------#
140 | # ---------------------------------- Loss -------------------------------------#
141 | # -----------------------------------------------------------------------------#
142 |
143 | class Weighted_MSE(nn.Module):
144 |
145 | def __init__(self, weights, action_dim, LLM_dim, PKG_dim):
146 | super().__init__()
147 | # self.register_buffer('weights', weights)
148 | self.action_dim = action_dim
149 | self.LLM_dim = LLM_dim
150 | self.PKG_dim = PKG_dim
151 |
152 | def forward(self, pred, targ):
153 | """
154 | :param pred: [B, T, task_dim+action_dim+observation_dim]
155 | :param targ: [B, T, task_dim+action_dim+observation_dim]
156 | :return:
157 | """
158 |
159 | loss_action = F.mse_loss(pred, targ, reduction='none')
160 | loss_action[:, 0, self.LLM_dim + self.PKG_dim: self.LLM_dim + self.PKG_dim + self.action_dim] *= 5.
161 | loss_action[:, -1, self.LLM_dim + self.PKG_dim: self.LLM_dim + self.PKG_dim + self.action_dim] *= 5.
162 |
163 | #loss_action[:, 1:-1, self.class_dim:self.class_dim + self.action_dim] *= 0.
164 | loss_action = loss_action.sum()
165 | return loss_action
166 |
167 |
168 | Losses = {
169 | 'Weighted_MSE': Weighted_MSE,
170 | }
171 |
172 | # -----------------------------------------------------------------------------#
173 | # -------------------------------- lr_schedule --------------------------------#
174 | # -----------------------------------------------------------------------------#
175 |
176 | def get_lr_schedule_with_warmup(optimizer, num_training_steps, last_epoch=-1):
177 | num_warmup_steps = num_training_steps * 20 / 120
178 | decay_steps = num_training_steps * 30 / 120
179 |
180 | def lr_lambda(current_step):
181 | if current_step <= num_warmup_steps:
182 | return max(0., float(current_step) / float(max(1, num_warmup_steps)))
183 | else:
184 | return max(0.5 ** ((current_step - num_warmup_steps) // decay_steps), 0.)
185 |
186 | return LambdaLR(optimizer, lr_lambda, last_epoch)
187 |
188 | # -----------------------------------------------------------------------------#
189 | # ---------------------------------- logging ----------------------------------#
190 | # -----------------------------------------------------------------------------#
191 |
192 | # Taken from PyTorch's examples.imagenet.main
193 | class AverageMeter(object):
194 | """Computes and stores the average and current value"""
195 |
196 | def __init__(self):
197 | self.reset()
198 |
199 | def reset(self):
200 | self.val = 0
201 | self.avg = 0
202 | self.sum = 0
203 | self.count = 0
204 |
205 | def update(self, val, n=1):
206 | self.val = val
207 | self.sum += val * n
208 | self.count += n
209 | self.avg = self.sum / self.count
210 |
211 |
212 | class Logger:
213 | def __init__(self, log_dir, n_logged_samples=10, summary_writer=SummaryWriter, if_exist=False):
214 | self._log_dir = log_dir
215 | print('logging outputs to ', log_dir)
216 | self._n_logged_samples = n_logged_samples
217 | self._summ_writer = summary_writer(log_dir, flush_secs=120, max_queue=10)
218 | if not if_exist:
219 | log = logging.getLogger(log_dir)
220 | if not log.handlers:
221 | log.setLevel(logging.DEBUG)
222 | if not os.path.exists(log_dir):
223 | os.mkdir(log_dir)
224 | fh = logging.FileHandler(os.path.join(log_dir, 'log.txt'))
225 | fh.setLevel(logging.INFO)
226 | formatter = logging.Formatter(fmt='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S')
227 | fh.setFormatter(formatter)
228 | log.addHandler(fh)
229 | self.log = log
230 |
231 | def log_scalar(self, scalar, name, step_):
232 | self._summ_writer.add_scalar('{}'.format(name), scalar, step_)
233 |
234 | def log_scalars(self, scalar_dict, group_name, step, phase):
235 | """Will log all scalars in the same plot."""
236 | self._summ_writer.add_scalars('{}_{}'.format(group_name, phase), scalar_dict, step)
237 |
238 | def flush(self):
239 | self._summ_writer.flush()
240 |
241 | def log_info(self, info):
242 | self.log.info("{}".format(info))
243 |
--------------------------------------------------------------------------------
/step/model/diffusion.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import torch
4 | from torch import nn
5 |
6 | from .helpers import (
7 | cosine_beta_schedule,
8 | extract,
9 | condition_projection,
10 | Losses,
11 | )
12 |
13 |
14 | class GaussianDiffusion(nn.Module):
15 | def __init__(self, model, horizon, observation_dim, action_dim, n_timesteps=200,
16 | loss_type='Weighted_MSE', clip_denoised=False, ddim_discr_method='uniform',
17 | ):
18 | super().__init__()
19 | self.horizon = horizon
20 | self.observation_dim = observation_dim
21 | self.action_dim = action_dim
22 | self.model = model
23 |
24 | betas = cosine_beta_schedule(n_timesteps)
25 | alphas = 1. - betas
26 | alphas_cumprod = torch.cumprod(alphas, dim=0)
27 | alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]])
28 |
29 | self.n_timesteps = n_timesteps
30 | self.clip_denoised = clip_denoised
31 | self.eta = 0.0
32 | self.random_ratio = 1.0
33 |
34 | # ---------------------------ddim--------------------------------
35 | ddim_timesteps = 10
36 |
37 | if ddim_discr_method == 'uniform':
38 | c = n_timesteps // ddim_timesteps
39 | ddim_timestep_seq = np.asarray(list(range(0, n_timesteps, c)))
40 | elif ddim_discr_method == 'quad':
41 | ddim_timestep_seq = (
42 | (np.linspace(0, np.sqrt(n_timesteps), ddim_timesteps)) ** 2
43 | ).astype(int)
44 | else:
45 | assert RuntimeError()
46 |
47 | self.ddim_timesteps = ddim_timesteps
48 | self.ddim_timestep_seq = ddim_timestep_seq
49 | # ----------------------------------------------------------------
50 |
51 | self.register_buffer('betas', betas)
52 | self.register_buffer('alphas_cumprod', alphas_cumprod)
53 | self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
54 |
55 | # calculations for diffusion q(x_t | x_{t-1}) and others
56 | self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
57 | self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
58 | self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
59 | self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
60 | self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
61 |
62 | # calculations for posterior q(x_{t-1} | x_t, x_0)
63 | posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
64 | self.register_buffer('posterior_variance', posterior_variance)
65 |
66 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
67 | self.register_buffer('posterior_log_variance_clipped',
68 | torch.log(torch.clamp(posterior_variance, min=1e-20)))
69 | self.register_buffer('posterior_mean_coef1',
70 | betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
71 | self.register_buffer('posterior_mean_coef2',
72 | (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))
73 | self.loss_type = loss_type
74 | self.loss_fn = Losses[loss_type](None, self.action_dim)
75 |
76 | # ------------------------------------------ sampling ------------------------------------------#
77 |
78 | def q_posterior(self, x_start, x_t, t):
79 | posterior_mean = (
80 | extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
81 | extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
82 | )
83 | posterior_variance = extract(self.posterior_variance, t, x_t.shape)
84 | posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
85 | return posterior_mean, posterior_variance, posterior_log_variance_clipped
86 |
87 | def p_mean_variance(self, x, cond, t):
88 | x_recon = self.model(x, t)
89 |
90 | if self.clip_denoised:
91 | x_recon.clamp(-1., 1.)
92 | else:
93 | assert RuntimeError()
94 |
95 | model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
96 | x_start=x_recon, x_t=x, t=t)
97 | return model_mean, posterior_variance, posterior_log_variance
98 |
99 | @torch.no_grad()
100 | def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
101 | return \
102 | (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) \
103 | / extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
104 |
105 | @torch.no_grad()
106 | def p_sample_ddim(self, x, cond, t, t_prev, if_prev=False):
107 | b, *_, device = *x.shape, x.device
108 | x_recon = self.model(x, t)
109 |
110 | if self.clip_denoised:
111 | x_recon.clamp(-1., 1.)
112 | else:
113 | assert RuntimeError()
114 |
115 | eps = self._predict_eps_from_xstart(x, t, x_recon)
116 | alpha_bar = extract(self.alphas_cumprod, t, x.shape)
117 | if if_prev:
118 | alpha_bar_prev = extract(self.alphas_cumprod_prev, t_prev, x.shape)
119 | else:
120 | alpha_bar_prev = extract(self.alphas_cumprod, t_prev, x.shape)
121 | sigma = (
122 | self.eta
123 | * torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
124 | * torch.sqrt(1 - alpha_bar / alpha_bar_prev)
125 | )
126 |
127 | noise = torch.randn_like(x) * self.random_ratio
128 | mean_pred = (
129 | x_recon * torch.sqrt(alpha_bar_prev)
130 | + torch.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
131 | )
132 |
133 | nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
134 | return mean_pred + nonzero_mask * sigma * noise
135 |
136 | @torch.no_grad()
137 | def p_sample(self, x, cond, t):
138 | b, *_, device = *x.shape, x.device
139 | model_mean, _, model_log_variance = self.p_mean_variance(x=x, cond=cond, t=t)
140 | noise = torch.randn_like(x) * self.random_ratio
141 | nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
142 | return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
143 |
144 | @torch.no_grad()
145 | def p_sample_loop(self, cond, if_jump):
146 | device = self.betas.device
147 | batch_size = len(cond[0])
148 | horizon = self.horizon
149 | shape = (batch_size, horizon, self.action_dim + self.observation_dim)
150 |
151 | x = torch.randn(shape, device=device) * self.random_ratio # xt for Noise and diffusion
152 | # x = torch.zeros(shape, device=device) # for Deterministic
153 | x = condition_projection(x, cond, self.action_dim)
154 |
155 | '''
156 | The if-else below is for diffusion, should be removed for Noise and Deterministic
157 | '''
158 | if not if_jump:
159 | for i in reversed(range(0, self.n_timesteps)):
160 | timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long)
161 | x = self.p_sample(x, cond, timesteps)
162 | x = condition_projection(x, cond, self.action_dim)
163 |
164 | else:
165 | for i in reversed(range(0, self.ddim_timesteps)):
166 | timesteps = torch.full((batch_size,), self.ddim_timestep_seq[i], device=device, dtype=torch.long)
167 | if i == 0:
168 | timesteps_prev = torch.full((batch_size,), 0, device=device, dtype=torch.long)
169 | x = self.p_sample_ddim(x, cond, timesteps, timesteps_prev, True)
170 | else:
171 | timesteps_prev = torch.full((batch_size,), self.ddim_timestep_seq[i-1], device=device, dtype=torch.long)
172 | x = self.p_sample_ddim(x, cond, timesteps, timesteps_prev)
173 | x = condition_projection(x, cond, self.action_dim)
174 |
175 | '''
176 | The two lines below is for Noise and Deterministic
177 | '''
178 | # x = self.model(x, None)
179 | # x = condition_projection(x, cond, self.action_dim, self.class_dim)
180 |
181 | return x
182 |
183 | # ------------------------------------------ training ------------------------------------------#
184 |
185 | def q_sample(self, x_start, t, noise=None):
186 | if noise is None:
187 | noise = torch.randn_like(x_start) * self.random_ratio
188 |
189 | sample = (
190 | extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
191 | extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
192 | )
193 |
194 | return sample
195 |
196 | def p_losses(self, x_start, cond, t):
197 | noise = torch.randn_like(x_start) * self.random_ratio # for Noise and diffusion
198 | # noise = torch.zeros_like(x_start) # for Deterministic
199 | # x_noisy = noise # for Noise and Deterministic
200 |
201 | x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) # for diffusion, should be removed for Noise and Deterministic
202 | x_noisy = condition_projection(x_noisy, cond, self.action_dim)
203 | x_recon = self.model(x_noisy, t)
204 | x_recon = condition_projection(x_recon, cond, self.action_dim)
205 |
206 | loss = self.loss_fn(x_recon, x_start)
207 | return loss
208 |
209 | def loss(self, x, cond):
210 | batch_size = len(x) # for diffusion
211 | t = torch.randint(0, self.n_timesteps, (batch_size,), device=x.device).long() # for diffusion
212 | # t = None # for Noise and Deterministic
213 | return self.p_losses(x, cond, t)
214 |
215 | def forward(self, cond, if_jump=False):
216 | return self.p_sample_loop(cond, if_jump)
217 |
--------------------------------------------------------------------------------
/plan/model/diffusion.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import torch
4 | from torch import nn
5 |
6 | from .helpers import (
7 | cosine_beta_schedule,
8 | extract,
9 | condition_projection,
10 | Losses,
11 | )
12 |
13 |
14 | class GaussianDiffusion(nn.Module):
15 | def __init__(self, model, horizon, observation_dim, action_dim, LLM_dim, PKG_dim, n_timesteps=200,
16 | loss_type='Weighted_MSE', clip_denoised=False, ddim_discr_method='uniform',
17 | ):
18 | super().__init__()
19 | self.horizon = horizon
20 | self.observation_dim = observation_dim
21 | self.action_dim = action_dim
22 | self.LLM_dim = LLM_dim
23 | self.PKG_dim = PKG_dim
24 | self.model = model
25 |
26 | betas = cosine_beta_schedule(n_timesteps)
27 | alphas = 1. - betas
28 | alphas_cumprod = torch.cumprod(alphas, dim=0)
29 | alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]])
30 |
31 | self.n_timesteps = n_timesteps
32 | self.clip_denoised = clip_denoised
33 | self.eta = 0.0
34 | self.random_ratio = 1.0
35 |
36 | # ---------------------------ddim--------------------------------
37 | ddim_timesteps = 10
38 |
39 | if ddim_discr_method == 'uniform':
40 | c = n_timesteps // ddim_timesteps
41 | ddim_timestep_seq = np.asarray(list(range(0, n_timesteps, c)))
42 | elif ddim_discr_method == 'quad':
43 | ddim_timestep_seq = (
44 | (np.linspace(0, np.sqrt(n_timesteps), ddim_timesteps)) ** 2
45 | ).astype(int)
46 | else:
47 | assert RuntimeError()
48 |
49 | self.ddim_timesteps = ddim_timesteps
50 | self.ddim_timestep_seq = ddim_timestep_seq
51 | # ----------------------------------------------------------------
52 |
53 | self.register_buffer('betas', betas)
54 | self.register_buffer('alphas_cumprod', alphas_cumprod)
55 | self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
56 |
57 | # calculations for diffusion q(x_t | x_{t-1}) and others
58 | self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
59 | self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
60 | self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
61 | self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
62 | self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
63 |
64 | # calculations for posterior q(x_{t-1} | x_t, x_0)
65 | posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
66 | self.register_buffer('posterior_variance', posterior_variance)
67 |
68 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
69 | self.register_buffer('posterior_log_variance_clipped',
70 | torch.log(torch.clamp(posterior_variance, min=1e-20)))
71 | self.register_buffer('posterior_mean_coef1',
72 | betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
73 | self.register_buffer('posterior_mean_coef2',
74 | (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))
75 | self.loss_type = loss_type
76 | self.loss_fn = Losses[loss_type](None, self.action_dim, self.LLM_dim, self.PKG_dim)
77 |
78 | # ------------------------------------------ sampling ------------------------------------------#
79 |
80 | def q_posterior(self, x_start, x_t, t):
81 | posterior_mean = (
82 | extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
83 | extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
84 | )
85 | posterior_variance = extract(self.posterior_variance, t, x_t.shape)
86 | posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
87 | return posterior_mean, posterior_variance, posterior_log_variance_clipped
88 |
89 | def p_mean_variance(self, x, cond, t):
90 | x_recon = self.model(x, t)
91 |
92 | if self.clip_denoised:
93 | x_recon.clamp(-1., 1.)
94 | else:
95 | assert RuntimeError()
96 |
97 | model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
98 | x_start=x_recon, x_t=x, t=t)
99 | return model_mean, posterior_variance, posterior_log_variance
100 |
101 | @torch.no_grad()
102 | def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
103 | return \
104 | (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) \
105 | / extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
106 |
107 | @torch.no_grad()
108 | def p_sample_ddim(self, x, cond, t, t_prev, if_prev=False):
109 | b, *_, device = *x.shape, x.device
110 | x_recon = self.model(x, t)
111 |
112 | if self.clip_denoised:
113 | x_recon.clamp(-1., 1.)
114 | else:
115 | assert RuntimeError()
116 |
117 | eps = self._predict_eps_from_xstart(x, t, x_recon)
118 | alpha_bar = extract(self.alphas_cumprod, t, x.shape)
119 | if if_prev:
120 | alpha_bar_prev = extract(self.alphas_cumprod_prev, t_prev, x.shape)
121 | else:
122 | alpha_bar_prev = extract(self.alphas_cumprod, t_prev, x.shape)
123 | sigma = (
124 | self.eta
125 | * torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
126 | * torch.sqrt(1 - alpha_bar / alpha_bar_prev)
127 | )
128 |
129 | noise = torch.randn_like(x) * self.random_ratio
130 | mean_pred = (
131 | x_recon * torch.sqrt(alpha_bar_prev)
132 | + torch.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
133 | )
134 |
135 | nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
136 | return mean_pred + nonzero_mask * sigma * noise
137 |
138 | @torch.no_grad()
139 | def p_sample(self, x, cond, t):
140 | b, *_, device = *x.shape, x.device
141 | model_mean, _, model_log_variance = self.p_mean_variance(x=x, cond=cond, t=t)
142 | noise = torch.randn_like(x) * self.random_ratio
143 | nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
144 | return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
145 |
146 | @torch.no_grad()
147 | def p_sample_loop(self, cond, if_jump):
148 | device = self.betas.device
149 | batch_size = len(cond[0])
150 | horizon = self.horizon
151 | shape = (batch_size, horizon, self.PKG_dim + self.LLM_dim + self.action_dim + self.observation_dim)
152 |
153 | x = torch.randn(shape, device=device) * self.random_ratio # xt for Noise and diffusion
154 | # x = torch.zeros(shape, device=device) # for Deterministic
155 | x = condition_projection(x, cond, self.action_dim, self.LLM_dim, self.PKG_dim)
156 |
157 | '''
158 | The if-else below is for diffusion, should be removed for Noise and Deterministic
159 | '''
160 | if not if_jump:
161 | for i in reversed(range(0, self.n_timesteps)):
162 | timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long)
163 | x = self.p_sample(x, cond, timesteps)
164 | x = condition_projection(x, cond, self.action_dim, self.LLM_dim, self.PKG_dim)
165 |
166 | else:
167 | for i in reversed(range(0, self.ddim_timesteps)):
168 | timesteps = torch.full((batch_size,), self.ddim_timestep_seq[i], device=device, dtype=torch.long)
169 | if i == 0:
170 | timesteps_prev = torch.full((batch_size,), 0, device=device, dtype=torch.long)
171 | x = self.p_sample_ddim(x, cond, timesteps, timesteps_prev, True)
172 | else:
173 | timesteps_prev = torch.full((batch_size,), self.ddim_timestep_seq[i-1], device=device, dtype=torch.long)
174 | x = self.p_sample_ddim(x, cond, timesteps, timesteps_prev)
175 | x = condition_projection(x, cond, self.action_dim, self.LLM_dim, self.PKG_dim)
176 |
177 | '''
178 | The two lines below is for Noise and Deterministic
179 | '''
180 | # x = self.model(x, None)
181 | # x = condition_projection(x, cond, self.action_dim, self.class_dim)
182 |
183 | return x
184 |
185 | # ------------------------------------------ training ------------------------------------------#
186 |
187 | def q_sample(self, x_start, t, noise=None):
188 | if noise is None:
189 | noise = torch.randn_like(x_start) * self.random_ratio
190 |
191 | sample = (
192 | extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
193 | extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
194 | )
195 |
196 | return sample
197 |
198 | def p_losses(self, x_start, cond, t):
199 | noise = torch.randn_like(x_start) * self.random_ratio # for Noise and diffusion
200 | # noise = torch.zeros_like(x_start) # for Deterministic
201 | # x_noisy = noise # for Noise and Deterministic
202 |
203 | x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) # for diffusion, should be removed for Noise and Deterministic
204 | x_noisy = condition_projection(x_noisy, cond, self.action_dim, self.LLM_dim, self.PKG_dim)
205 |
206 | x_recon = self.model(x_noisy, t)
207 | x_recon = condition_projection(x_recon, cond, self.action_dim, self.LLM_dim, self.PKG_dim)
208 |
209 | loss = self.loss_fn(x_recon, x_start)
210 | return loss
211 |
212 | def loss(self, x, cond):
213 | batch_size = len(x) # for diffusion
214 | t = torch.randint(0, self.n_timesteps, (batch_size,), device=x.device).long() # for diffusion
215 | # t = None # for Noise and Deterministic
216 | return self.p_losses(x, cond, t)
217 |
218 | def forward(self, cond, if_jump=False):
219 | return self.p_sample_loop(cond, if_jump)
220 |
--------------------------------------------------------------------------------
/dataset/NIV/train70.json:
--------------------------------------------------------------------------------
1 | [{"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0002.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0003.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0006.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0008.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0009.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0010.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0011.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0013.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0014.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0015.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0016.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0017.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0018.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0019.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0023.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0024.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0026.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0027.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0028.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0029.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0002.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0003.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0004.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0005.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0006.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0008.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0009.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0010.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0011.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0012.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0013.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0015.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0016.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0017.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0020.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0021.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0022.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0024.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0026.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0028.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0029.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0030.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0001.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0003.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0004.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0005.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0007.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0008.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0009.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0010.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0011.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0014.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0015.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0018.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0020.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0022.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0024.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0025.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0027.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0028.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0029.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0030.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0002.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0003.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0004.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0005.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0006.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0007.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0008.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0010.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0012.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0013.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0014.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0015.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0016.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0019.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0021.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0022.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0023.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0024.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0026.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0027.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0030.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/repot_0002.npy", "task_id": 4}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/repot_0003.npy", "task_id": 4}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/repot_0007.npy", "task_id": 4}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/repot_0008.npy", "task_id": 4}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/repot_0010.npy", "task_id": 4}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/repot_0011.npy", "task_id": 4}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/repot_0013.npy", "task_id": 4}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/repot_0014.npy", "task_id": 4}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/repot_0016.npy", "task_id": 4}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/repot_0017.npy", "task_id": 4}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/repot_0019.npy", "task_id": 4}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/repot_0021.npy", "task_id": 4}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/repot_0023.npy", "task_id": 4}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/repot_0024.npy", "task_id": 4}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/repot_0027.npy", "task_id": 4}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/repot_0028.npy", "task_id": 4}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/repot_0030.npy", "task_id": 4}]
--------------------------------------------------------------------------------
/plan/log/whl.txt:
--------------------------------------------------------------------------------
1 | => no checkpoint found at 'True'
2 | Starting training loop for rank: 0, total batch size: 256
3 | => no checkpoint found at 'True'
4 | Starting training loop for rank: 0, total batch size: 256
5 | => no checkpoint found at 'True'
6 | Starting training loop for rank: 0, total batch size: 256
7 | => no checkpoint found at 'True'
8 | Starting training loop for rank: 0, total batch size: 256
9 | => no checkpoint found at 'True'
10 | Starting training loop for rank: 0, total batch size: 256
11 | => no checkpoint found at 'True'
12 | Starting training loop for rank: 0, total batch size: 256
13 | => no checkpoint found at 'True'
14 | Starting training loop for rank: 0, total batch size: 256
15 | => no checkpoint found at 'True'
16 | Starting training loop for rank: 0, total batch size: 256
17 | => no checkpoint found at 'True'
18 | Starting training loop for rank: 0, total batch size: 256
19 | => no checkpoint found at 'True'
20 | Starting training loop for rank: 0, total batch size: 256
21 | => no checkpoint found at 'True'
22 | Starting training loop for rank: 0, total batch size: 256
23 | => no checkpoint found at 'True'
24 | Starting training loop for rank: 1, total batch size: 256
25 | => no checkpoint found at 'True'
26 | Starting training loop for rank: 0, total batch size: 256
27 | => no checkpoint found at 'True'
28 | => no checkpoint found at 'True'
29 | Starting training loop for rank: 1, total batch size: 256
30 | Starting training loop for rank: 0, total batch size: 256
31 | => loading checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/final/PDPP/checkpoint/whl/epoch0004.pth.tar'
32 | => loading checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/final/PDPP/checkpoint/whl/epoch0004.pth.tar'
33 | Starting training loop for rank: 1, total batch size: 256
34 | => loaded checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/final/PDPP/checkpoint/whl/epoch0004.pth.tar' (epoch 4)0
35 | Starting training loop for rank: 0, total batch size: 256
36 | => loading checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/final/PDPP/checkpoint/whl/epoch0004.pth.tar'
37 | => loading checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/final/PDPP/checkpoint/whl/epoch0004.pth.tar'
38 | Starting training loop for rank: 1, total batch size: 256
39 | => loaded checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/final/PDPP/checkpoint/whl/epoch0004.pth.tar' (epoch 4)0
40 | Starting training loop for rank: 0, total batch size: 256
41 | => loading checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/final/PDPP/checkpoint/whl/epoch0004.pth.tar'
42 | => loading checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/final/PDPP/checkpoint/whl/epoch0004.pth.tar'
43 | Starting training loop for rank: 1, total batch size: 256
44 | => loaded checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/final/PDPP/checkpoint/whl/epoch0004.pth.tar' (epoch 4)0
45 | Starting training loop for rank: 0, total batch size: 256
46 | => loading checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/final/PDPP/checkpoint/whl/epoch0005.pth.tar'
47 | => loading checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/final/PDPP/checkpoint/whl/epoch0005.pth.tar'
48 | Starting training loop for rank: 1, total batch size: 256
49 | => loaded checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/final/PDPP/checkpoint/whl/epoch0005.pth.tar' (epoch 5)0
50 | Starting training loop for rank: 0, total batch size: 256
51 | => no checkpoint found at 'True'
52 | Starting training loop for rank: 0, total batch size: 256
53 | => no checkpoint found at 'True'
54 | Starting training loop for rank: 0, total batch size: 256
55 | => no checkpoint found at 'True'
56 | Starting training loop for rank: 1, total batch size: 256
57 | => no checkpoint found at 'True'
58 | Starting training loop for rank: 0, total batch size: 256
59 | => no checkpoint found at 'True'
60 | => no checkpoint found at 'True'
61 | => no checkpoint found at 'True'
62 | Starting training loop for rank: 2, total batch size: 256
63 | Starting training loop for rank: 1, total batch size: 256
64 | Starting training loop for rank: 3, total batch size: 256
65 | => no checkpoint found at 'True'
66 | Starting training loop for rank: 0, total batch size: 256
67 | => no checkpoint found at 'True'
68 | Starting training loop for rank: 1, total batch size: 256
69 | => no checkpoint found at 'True'
70 | Starting training loop for rank: 0, total batch size: 256
71 | => no checkpoint found at 'True'
72 | Starting training loop for rank: 1, total batch size: 256
73 | => no checkpoint found at 'True'
74 | Starting training loop for rank: 0, total batch size: 256
75 | => no checkpoint found at 'True'
76 | Starting training loop for rank: 1, total batch size: 256
77 | => no checkpoint found at 'True'
78 | Starting training loop for rank: 0, total batch size: 256
79 | => no checkpoint found at 'True'
80 | Starting training loop for rank: 1, total batch size: 256
81 | => no checkpoint found at 'True'
82 | Starting training loop for rank: 0, total batch size: 256
83 | => no checkpoint found at 'True'
84 | Starting training loop for rank: 1, total batch size: 256
85 | => no checkpoint found at 'True'
86 | Starting training loop for rank: 0, total batch size: 256
87 | => no checkpoint found at 'True'
88 | Starting training loop for rank: 1, total batch size: 256
89 | => no checkpoint found at 'True'
90 | Starting training loop for rank: 0, total batch size: 256
91 | => no checkpoint found at 'True'
92 | Starting training loop for rank: 1, total batch size: 256
93 | => no checkpoint found at 'True'
94 | Starting training loop for rank: 0, total batch size: 256
95 | => no checkpoint found at 'True'
96 | Starting training loop for rank: 1, total batch size: 256
97 | => no checkpoint found at 'True'
98 | Starting training loop for rank: 0, total batch size: 256
99 | => no checkpoint found at 'True'
100 | Starting training loop for rank: 1, total batch size: 256
101 | => no checkpoint found at 'True'
102 | Starting training loop for rank: 0, total batch size: 256
103 | => no checkpoint found at 'True'
104 | Starting training loop for rank: 1, total batch size: 256
105 | => no checkpoint found at 'True'
106 | Starting training loop for rank: 0, total batch size: 256
107 | => no checkpoint found at 'True'
108 | Starting training loop for rank: 1, total batch size: 256
109 | => no checkpoint found at 'True'
110 | Starting training loop for rank: 0, total batch size: 256
111 | => no checkpoint found at 'True'
112 | Starting training loop for rank: 1, total batch size: 256
113 | => no checkpoint found at 'True'
114 | Starting training loop for rank: 0, total batch size: 256
115 | => no checkpoint found at 'True'
116 | Starting training loop for rank: 1, total batch size: 256
117 | => no checkpoint found at 'True'
118 | Starting training loop for rank: 0, total batch size: 256
119 | => no checkpoint found at 'True'
120 | Starting training loop for rank: 1, total batch size: 256
121 | => no checkpoint found at 'True'
122 | Starting training loop for rank: 0, total batch size: 256
123 | => no checkpoint found at 'True'
124 | Starting training loop for rank: 1, total batch size: 256
125 | => no checkpoint found at 'True'
126 | Starting training loop for rank: 0, total batch size: 256
127 | => no checkpoint found at 'True'
128 | Starting training loop for rank: 1, total batch size: 256
129 | => no checkpoint found at 'True'
130 | Starting training loop for rank: 0, total batch size: 256
131 | => no checkpoint found at 'True'
132 | Starting training loop for rank: 1, total batch size: 256
133 | => no checkpoint found at 'True'
134 | Starting training loop for rank: 0, total batch size: 256
135 | => no checkpoint found at 'True'
136 | Starting training loop for rank: 1, total batch size: 256
137 | => no checkpoint found at 'True'
138 | Starting training loop for rank: 0, total batch size: 256
139 | => no checkpoint found at 'True'
140 | Starting training loop for rank: 1, total batch size: 256
141 | => no checkpoint found at 'True'
142 | Starting training loop for rank: 0, total batch size: 256
143 | => loading checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar'
144 | => loading checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar'
145 | Starting training loop for rank: 1, total batch size: 256
146 | => loaded checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar' (epoch 4)0
147 | Starting training loop for rank: 0, total batch size: 256
148 | => no checkpoint found at 'True'
149 | Starting training loop for rank: 1, total batch size: 256
150 | => no checkpoint found at 'True'
151 | Starting training loop for rank: 0, total batch size: 256
152 | => loading checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar'
153 | => loading checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar'
154 | => loaded checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar' (epoch 4)0
155 | Starting training loop for rank: 0, total batch size: 256
156 | Starting training loop for rank: 1, total batch size: 256
157 | => loading checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar'
158 | => loading checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar'
159 | => loaded checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar' (epoch 4)0
160 | Starting training loop for rank: 0, total batch size: 256
161 | Starting training loop for rank: 1, total batch size: 256
162 | => loading checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar'
163 | => loading checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar'
164 | => loaded checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar' (epoch 4)0
165 | Starting training loop for rank: 0, total batch size: 256
166 | Starting training loop for rank: 1, total batch size: 256
167 | => loading checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar'
168 | => loading checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar'
169 | Starting training loop for rank: 1, total batch size: 256
170 | => loaded checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar' (epoch 4)0
171 | Starting training loop for rank: 0, total batch size: 256
172 | => loading checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar'
173 | => loading checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar'
174 | => loaded checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar' (epoch 4)0
175 | Starting training loop for rank: 0, total batch size: 256
176 | Starting training loop for rank: 1, total batch size: 256
177 | => loading checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar'
178 | => loading checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar'
179 | Starting training loop for rank: 1, total batch size: 256
180 | => loaded checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar' (epoch 4)0
181 | Starting training loop for rank: 0, total batch size: 256
182 | => no checkpoint found at 'True'
183 | Starting training loop for rank: 1, total batch size: 256
184 | => no checkpoint found at 'True'
185 | Starting training loop for rank: 0, total batch size: 256
186 | => no checkpoint found at 'True'
187 | Starting training loop for rank: 1, total batch size: 256
188 | => no checkpoint found at 'True'
189 | Starting training loop for rank: 0, total batch size: 256
190 | => no checkpoint found at 'True'
191 | Starting training loop for rank: 1, total batch size: 256
192 | => no checkpoint found at 'True'
193 | Starting training loop for rank: 0, total batch size: 256
194 | => loading checkpoint '/l/users/ravindu.nagasinghe/T4How/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar'
195 | => loading checkpoint '/l/users/ravindu.nagasinghe/T4How/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar'
196 | Starting training loop for rank: 1, total batch size: 256
197 | => loaded checkpoint '/l/users/ravindu.nagasinghe/T4How/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar' (epoch 4)0
198 | Starting training loop for rank: 0, total batch size: 256
199 | => loading checkpoint '/l/users/ravindu.nagasinghe/T4How/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar'
200 | => loading checkpoint '/l/users/ravindu.nagasinghe/T4How/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar'
201 | Starting training loop for rank: 1, total batch size: 256
202 | => loaded checkpoint '/l/users/ravindu.nagasinghe/T4How/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar' (epoch 4)0
203 | Starting training loop for rank: 0, total batch size: 256
204 | => loading checkpoint '/l/users/ravindu.nagasinghe/T4How/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar'
205 | => loading checkpoint '/l/users/ravindu.nagasinghe/T4How/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar'
206 | => loaded checkpoint '/l/users/ravindu.nagasinghe/T4How/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar' (epoch 4)0
207 | Starting training loop for rank: 0, total batch size: 256
208 | Starting training loop for rank: 1, total batch size: 256
209 | => loading checkpoint '/l/users/ravindu.nagasinghe/T4How/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar'
210 | => loading checkpoint '/l/users/ravindu.nagasinghe/T4How/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar'
211 | => loaded checkpoint '/l/users/ravindu.nagasinghe/T4How/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar' (epoch 4)0
212 | Starting training loop for rank: 0, total batch size: 256
213 | Starting training loop for rank: 1, total batch size: 256
214 | => no checkpoint found at 'True'
215 | Starting training loop for rank: 1, total batch size: 256
216 | => no checkpoint found at 'True'
217 | Starting training loop for rank: 0, total batch size: 256
218 |
--------------------------------------------------------------------------------
/plan/main_distributed.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import random
4 | import time
5 | from collections import OrderedDict
6 |
7 | import torch.nn.parallel
8 | import torch.backends.cudnn as cudnn
9 | import torch.distributed as dist
10 | import torch.optim
11 | import torch.multiprocessing as mp
12 | import torch.utils.data
13 | import torch.utils.data.distributed
14 | from torch.distributed import ReduceOp
15 |
16 | import utils
17 | from dataloader.data_load import PlanningDataset
18 | from model import diffusion, temporal
19 | from model.helpers import get_lr_schedule_with_warmup
20 |
21 | from utils import *
22 | from logging import log
23 | from utils.args import get_args
24 | import numpy as np
25 | from model.helpers import Logger
26 |
27 |
28 | def reduce_tensor(tensor):
29 | rt = tensor.clone()
30 | torch.distributed.all_reduce(rt, op=ReduceOp.SUM)
31 | rt /= dist.get_world_size()
32 | return rt
33 |
34 |
35 | def main():
36 | # os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
37 | args = get_args()
38 | os.environ['PYTHONHASHSEED'] = str(args.seed)
39 |
40 | if args.verbose:
41 | print(args)
42 | if args.seed is not None:
43 | random.seed(args.seed)
44 | np.random.seed(args.seed)
45 | torch.manual_seed(args.seed)
46 | torch.cuda.manual_seed_all(args.seed)
47 |
48 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed
49 | ngpus_per_node = torch.cuda.device_count()
50 |
51 | if args.multiprocessing_distributed:
52 | args.world_size = ngpus_per_node * args.world_size
53 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
54 | else:
55 | main_worker(args.gpu, ngpus_per_node, args)
56 |
57 |
58 | def main_worker(gpu, ngpus_per_node, args):
59 | args.gpu = gpu
60 | # print('gpuid:', args.gpu)
61 |
62 | if args.distributed:
63 | if args.multiprocessing_distributed:
64 | args.rank = args.rank * ngpus_per_node + gpu
65 | dist.init_process_group(
66 | backend=args.dist_backend,
67 | init_method=args.dist_url,
68 | world_size=args.world_size,
69 | rank=args.rank,
70 | )
71 | if args.gpu is not None:
72 | torch.cuda.set_device(args.gpu)
73 | args.batch_size = int(args.batch_size / ngpus_per_node)
74 | args.batch_size_val = int(args.batch_size_val / ngpus_per_node)
75 | args.num_thread_reader = int(args.num_thread_reader / ngpus_per_node)
76 | elif args.gpu is not None:
77 | torch.cuda.set_device(args.gpu)
78 |
79 | # Data loading code
80 | train_dataset = PlanningDataset(
81 | args.root,
82 | args=args,
83 | is_val=False,
84 | model=None,
85 | )
86 | # Test data loading code
87 | test_dataset = PlanningDataset(
88 | args.root,
89 | args=args,
90 | is_val=True,
91 | model=None,
92 | )
93 | print(train_dataset , test_dataset)
94 | if args.distributed:
95 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
96 | test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset)
97 | else:
98 | train_sampler = None
99 | test_sampler = None
100 |
101 | train_loader = torch.utils.data.DataLoader(
102 | train_dataset,
103 | batch_size=args.batch_size,
104 | shuffle=(train_sampler is None),
105 | drop_last=True,
106 | num_workers=args.num_thread_reader,
107 | pin_memory=args.pin_memory,
108 | sampler=train_sampler,
109 | )
110 | test_loader = torch.utils.data.DataLoader(
111 | test_dataset,
112 | batch_size=args.batch_size_val,
113 | shuffle=False,
114 | drop_last=False,
115 | num_workers=args.num_thread_reader,
116 | sampler=test_sampler,
117 | )
118 |
119 | # create model
120 | temporal_model = temporal.TemporalUnet( # temporal
121 | args.action_dim + args.observation_dim + args.class_dim_llama + args.class_dim_graph,
122 | dim=256,
123 | dim_mults=(1, 2, 4), )
124 |
125 | diffusion_model = diffusion.GaussianDiffusion(
126 | temporal_model, args.horizon, args.observation_dim, args.action_dim, args.class_dim_llama, args.class_dim_graph, args.n_diffusion_steps,
127 | loss_type='Weighted_MSE', clip_denoised=True, )
128 |
129 |
130 | model = utils.Trainer(diffusion_model, train_loader, args.ema_decay, args.lr, args.gradient_accumulate_every,
131 | args.step_start_ema, args.update_ema_every, args.log_freq)
132 |
133 | if args.pretrain_cnn_path:
134 | net_data = torch.load(args.pretrain_cnn_path)
135 | model.model.load_state_dict(net_data)
136 | model.ema_model.load_state_dict(net_data)
137 |
138 | if args.distributed:
139 | if args.gpu is not None:
140 | model.model.cuda(args.gpu)
141 | model.ema_model.cuda(args.gpu)
142 | model.model = torch.nn.parallel.DistributedDataParallel(
143 | model.model, device_ids=[args.gpu], find_unused_parameters=True)
144 | model.ema_model = torch.nn.parallel.DistributedDataParallel(
145 | model.ema_model, device_ids=[args.gpu], find_unused_parameters=True)
146 | else:
147 | model.model.cuda()
148 | model.ema_model.cuda()
149 | model.model = torch.nn.parallel.DistributedDataParallel(model.model, find_unused_parameters=True)
150 | model.ema_model = torch.nn.parallel.DistributedDataParallel(model.ema_model,
151 | find_unused_parameters=True)
152 |
153 | elif args.gpu is not None:
154 | model.model = model.model.cuda(args.gpu)
155 | model.ema_model = model.ema_model.cuda(args.gpu)
156 | else:
157 | model.model = torch.nn.DataParallel(model.model).cuda()
158 | model.ema_model = torch.nn.DataParallel(model.ema_model).cuda()
159 |
160 | scheduler = get_lr_schedule_with_warmup(model.optimizer, int(args.n_train_steps * args.epochs))
161 |
162 | checkpoint_dir = os.path.join(os.path.dirname(__file__), 'checkpoint', args.checkpoint_dir)
163 | if args.checkpoint_dir != '' and not (os.path.isdir(checkpoint_dir)) and args.rank == 0:
164 | os.mkdir(checkpoint_dir)
165 |
166 | if args.resume:
167 | checkpoint_path = get_last_checkpoint(checkpoint_dir)
168 | if checkpoint_path:
169 | log("=> loading checkpoint '{}'".format(checkpoint_path), args)
170 | checkpoint = torch.load(checkpoint_path, map_location='cuda:{}'.format(args.rank))
171 | args.start_epoch = checkpoint["epoch"]
172 | model.model.load_state_dict(checkpoint["model"])
173 | model.ema_model.load_state_dict(checkpoint["ema"])
174 | model.optimizer.load_state_dict(checkpoint["optimizer"])
175 | model.step = checkpoint["step"]
176 | # for p in model.optimizer.param_groups:
177 | # p['lr'] = 1e-5
178 | scheduler.load_state_dict(checkpoint["scheduler"])
179 | tb_logdir = checkpoint["tb_logdir"]
180 | if args.rank == 0:
181 | # creat logger
182 | tb_logger = Logger(tb_logdir)
183 | log("=> loaded checkpoint '{}' (epoch {}){}".format(checkpoint_path, checkpoint["epoch"], args.gpu), args)
184 |
185 |
186 | else:
187 |
188 | time_pre = time.strftime("%Y%m%d%H%M%S", time.localtime())
189 | logname = args.log_root + '_' + time_pre + '_' + args.dataset
190 | tb_logdir = os.path.join(args.log_root, logname)
191 | if args.rank == 0:
192 | # creat logger
193 | if not (os.path.exists(tb_logdir)):
194 | os.makedirs(tb_logdir)
195 | tb_logger = Logger(tb_logdir)
196 | tb_logger.log_info(args)
197 | log("=> no checkpoint found at '{}'".format(args.resume), args)
198 |
199 | if args.cudnn_benchmark:
200 | cudnn.benchmark = True
201 | total_batch_size = args.world_size * args.batch_size
202 |
203 | log(
204 | "Starting training loop for rank: {}, total batch size: {}".format(
205 | args.rank, total_batch_size
206 | ), args
207 | )
208 |
209 | max_eva = 0
210 | max_acc = 0
211 | old_max_epoch = 0
212 | save_max = os.path.join(os.path.dirname(__file__), 'save_max')
213 |
214 | for epoch in range(args.start_epoch, args.epochs):
215 | print('epoch : ', epoch)
216 | if args.distributed:
217 | train_sampler.set_epoch(epoch)
218 |
219 | # train for one epoch
220 | if (epoch + 1) % 10 == 0: # calculate on training set
221 | losses, acc_top1, acc_top5, trajectory_success_rate_meter, MIoU1_meter, MIoU2_meter, \
222 | acc_a0, acc_aT = model.train(args.n_train_steps, True, args, scheduler)
223 | losses_reduced = reduce_tensor(losses.cuda()).item()
224 | acc_top1_reduced = reduce_tensor(acc_top1.cuda()).item()
225 | acc_top5_reduced = reduce_tensor(acc_top5.cuda()).item()
226 | trajectory_success_rate_meter_reduced = reduce_tensor(trajectory_success_rate_meter.cuda()).item()
227 | MIoU1_meter_reduced = reduce_tensor(MIoU1_meter.cuda()).item()
228 | MIoU2_meter_reduced = reduce_tensor(MIoU2_meter.cuda()).item()
229 | acc_a0_reduced = reduce_tensor(acc_a0.cuda()).item()
230 | acc_aT_reduced = reduce_tensor(acc_aT.cuda()).item()
231 |
232 | if args.rank == 0:
233 | logs = OrderedDict()
234 | logs['Train/EpochLoss'] = losses_reduced
235 | logs['Train/EpochAcc@1'] = acc_top1_reduced
236 | logs['Train/EpochAcc@5'] = acc_top5_reduced
237 | logs['Train/Traj_Success_Rate'] = trajectory_success_rate_meter_reduced
238 | logs['Train/MIoU1'] = MIoU1_meter_reduced
239 | logs['Train/MIoU2'] = MIoU2_meter_reduced
240 | logs['Train/acc_a0'] = acc_a0_reduced
241 | logs['Train/acc_aT'] = acc_aT_reduced
242 | for key, value in logs.items():
243 | tb_logger.log_scalar(value, key, epoch + 1)
244 |
245 | tb_logger.flush()
246 | else:
247 | losses = model.train(args.n_train_steps, False, args, scheduler).cuda()
248 | losses_reduced = reduce_tensor(losses).item()
249 | if args.rank == 0:
250 | print('lrs:')
251 | for p in model.optimizer.param_groups:
252 | print(p['lr'])
253 | print('---------------------------------')
254 |
255 | logs = OrderedDict()
256 | logs['Train/EpochLoss'] = losses_reduced
257 | for key, value in logs.items():
258 | tb_logger.log_scalar(value, key, epoch + 1)
259 |
260 | tb_logger.flush()
261 |
262 | if ((epoch + 1) % 5 == 0) and args.evaluate: # or epoch > 18
263 | losses, acc_top1, acc_top5, \
264 | trajectory_success_rate_meter, MIoU1_meter, MIoU2_meter, \
265 | acc_a0, acc_aT = validate(test_loader, model.ema_model, args)
266 |
267 | losses_reduced = reduce_tensor(losses.cuda()).item()
268 | acc_top1_reduced = reduce_tensor(acc_top1.cuda()).item()
269 | acc_top5_reduced = reduce_tensor(acc_top5.cuda()).item()
270 | trajectory_success_rate_meter_reduced = reduce_tensor(trajectory_success_rate_meter.cuda()).item()
271 | MIoU1_meter_reduced = reduce_tensor(MIoU1_meter.cuda()).item()
272 | MIoU2_meter_reduced = reduce_tensor(MIoU2_meter.cuda()).item()
273 | acc_a0_reduced = reduce_tensor(acc_a0.cuda()).item()
274 | acc_aT_reduced = reduce_tensor(acc_aT.cuda()).item()
275 |
276 | if args.rank == 0:
277 | logs = OrderedDict()
278 | logs['Val/EpochLoss'] = losses_reduced
279 | logs['Val/EpochAcc@1'] = acc_top1_reduced
280 | logs['Val/EpochAcc@5'] = acc_top5_reduced
281 | logs['Val/Traj_Success_Rate'] = trajectory_success_rate_meter_reduced
282 | logs['Val/MIoU1'] = MIoU1_meter_reduced
283 | logs['Val/MIoU2'] = MIoU2_meter_reduced
284 | logs['Val/acc_a0'] = acc_a0_reduced
285 | logs['Val/acc_aT'] = acc_aT_reduced
286 | for key, value in logs.items():
287 | tb_logger.log_scalar(value, key, epoch + 1)
288 |
289 | tb_logger.flush()
290 | print(trajectory_success_rate_meter_reduced, max_eva)
291 | if trajectory_success_rate_meter_reduced >= max_eva:
292 | if not (trajectory_success_rate_meter_reduced == max_eva and acc_top1_reduced < max_acc):
293 | save_checkpoint2(
294 | {
295 | "epoch": epoch + 1,
296 | "model": model.model.state_dict(),
297 | "ema": model.ema_model.state_dict(),
298 | "optimizer": model.optimizer.state_dict(),
299 | "step": model.step,
300 | "tb_logdir": tb_logdir,
301 | "scheduler": scheduler.state_dict(),
302 | }, save_max, old_max_epoch, epoch + 1, args.rank
303 | )
304 | max_eva = trajectory_success_rate_meter_reduced
305 | max_acc = acc_top1_reduced
306 | old_max_epoch = epoch + 1
307 |
308 | if (epoch + 1) % args.save_freq == 0:
309 | if args.rank == 0:
310 | save_checkpoint(
311 | {
312 | "epoch": epoch + 1,
313 | "model": model.model.state_dict(),
314 | "ema": model.ema_model.state_dict(),
315 | "optimizer": model.optimizer.state_dict(),
316 | "step": model.step,
317 | "tb_logdir": tb_logdir,
318 | "scheduler": scheduler.state_dict(),
319 | }, checkpoint_dir, epoch + 1
320 | )
321 |
322 |
323 | def log(output, args):
324 | with open(os.path.join(os.path.dirname(__file__), 'log', args.checkpoint_dir + '.txt'), "a") as f:
325 | f.write(output + '\n')
326 |
327 |
328 | def save_checkpoint(state, checkpoint_dir, epoch, n_ckpt=3):
329 | torch.save(state, os.path.join(checkpoint_dir, "epoch{:0>4d}.pth.tar".format(epoch)))
330 | if epoch - n_ckpt >= 0:
331 | oldest_ckpt = os.path.join(checkpoint_dir, "epoch{:0>4d}.pth.tar".format(epoch - n_ckpt))
332 | if os.path.isfile(oldest_ckpt):
333 | os.remove(oldest_ckpt)
334 |
335 |
336 | def save_checkpoint2(state, checkpoint_dir, old_epoch, epoch, rank):
337 | torch.save(state, os.path.join(checkpoint_dir, "epoch{:0>4d}_{}.pth.tar".format(epoch, rank)))
338 | if old_epoch > 0:
339 | oldest_ckpt = os.path.join(checkpoint_dir, "epoch{:0>4d}_{}.pth.tar".format(old_epoch, rank))
340 | if os.path.isfile(oldest_ckpt):
341 | os.remove(oldest_ckpt)
342 |
343 |
344 | def get_last_checkpoint(checkpoint_dir):
345 | all_ckpt = glob.glob(os.path.join(checkpoint_dir, 'epoch*.pth.tar'))
346 | if all_ckpt:
347 | all_ckpt = sorted(all_ckpt)
348 | return all_ckpt[-1]
349 | else:
350 | return ''
351 |
352 |
353 | if __name__ == "__main__":
354 | main()
355 |
--------------------------------------------------------------------------------
/step/dataloader/train_test_data_load.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | from torch.utils.data import Dataset
5 | import json
6 | import math
7 | from collections import namedtuple
8 |
9 | def get_vids_from_json(path):
10 | task_vids = {}
11 | with open(path, 'r') as f:
12 | json_data = json.load(f)
13 |
14 | for i in json_data:
15 | task = i['task']
16 | vid = i['vid']
17 | if task not in task_vids:
18 | task_vids[task] = []
19 | task_vids[task].append(vid)
20 | return task_vids
21 |
22 |
23 | def get_vids(path):
24 | task_vids = {}
25 | with open(path, 'r') as f:
26 | for line in f:
27 | task, vid, url = line.strip().split(',')
28 | if task not in task_vids:
29 | task_vids[task] = []
30 | task_vids[task].append(vid)
31 | return task_vids
32 |
33 |
34 | def read_task_info(path):
35 | titles = {}
36 | urls = {}
37 | n_steps = {}
38 | steps = {}
39 | with open(path, 'r') as f:
40 | idx = f.readline()
41 | while idx != '':
42 | idx = idx.strip()
43 | titles[idx] = f.readline().strip()
44 | urls[idx] = f.readline().strip()
45 | n_steps[idx] = int(f.readline().strip())
46 | steps[idx] = f.readline().strip().split(',')
47 | next(f)
48 | idx = f.readline()
49 | return {'title': titles, 'url': urls, 'n_steps': n_steps, 'steps': steps}
50 |
51 |
52 | class PlanningDataset(Dataset):
53 | def __init__(self,
54 | root,
55 | args=None,
56 | is_val=False,
57 | model=None,
58 | crosstask_use_feature_how=True,
59 | ):
60 | self.is_val = is_val
61 | self.data_root = root
62 | self.args = args
63 | self.max_traj_len = args.horizon
64 | self.vid_names = None
65 | self.frame_cnts = None
66 | self.images = None
67 | self.last_vid = ''
68 | self.crosstask_use_feature_how = crosstask_use_feature_how
69 | if args.dataset == 'crosstask':
70 | """
71 | .
72 | └── crosstask
73 | ├── crosstask_features
74 | └── crosstask_release
75 | ├── tasks_primary.txt
76 | ├── videos.csv or json
77 | └── videos_val.csv or json
78 | """
79 | val_csv_path = os.path.join(
80 | root, 'crosstask_release', 'test_list.json')
81 | video_csv_path = os.path.join(
82 | root, 'crosstask_release', 'train_list.json')
83 |
84 | if crosstask_use_feature_how:
85 | self.features_path = os.path.join(root, 'processed_data')
86 | else:
87 | self.features_path = os.path.join(root, 'crosstask_features')
88 |
89 | self.constraints_path = os.path.join(
90 | root, 'crosstask_release', 'annotations')
91 |
92 | self.action_one_hot = np.load(
93 | os.path.join(root, 'crosstask_release', 'actions_one_hot.npy'),
94 | allow_pickle=True).item()
95 |
96 | if is_val:
97 | cross_task_data_name = args.json_path_val
98 | else:
99 | cross_task_data_name = args.json_path_train
100 |
101 | if os.path.exists(cross_task_data_name):
102 | with open(cross_task_data_name, 'r') as f:
103 | self.json_data = json.load(f)
104 | print('Loaded {}'.format(cross_task_data_name))
105 | else:
106 | file_type = val_csv_path.split('.')[-1]
107 | if file_type == 'json':
108 | all_task_vids = get_vids_from_json(video_csv_path)
109 | val_vids = get_vids_from_json(val_csv_path)
110 | else:
111 | all_task_vids = get_vids(video_csv_path)
112 | val_vids = get_vids(val_csv_path)
113 |
114 | if is_val:
115 | task_vids = val_vids
116 | else:
117 | task_vids = {task: [vid for vid in vids if task not in val_vids or vid not in val_vids[task]] for
118 | task, vids in
119 | all_task_vids.items()}
120 |
121 | primary_info = read_task_info(os.path.join(
122 | root, 'crosstask_release', 'tasks_primary.txt'))
123 |
124 | self.n_steps = primary_info['n_steps']
125 | all_tasks = set(self.n_steps.keys())
126 |
127 | task_vids = {task: vids for task,
128 | vids in task_vids.items() if task in all_tasks}
129 |
130 | all_vids = []
131 | for task, vids in task_vids.items():
132 | all_vids.extend([(task, vid) for vid in vids])
133 | json_data = []
134 | for idx in range(len(all_vids)):
135 | task, vid = all_vids[idx]
136 | if self.crosstask_use_feature_how:
137 | video_path = os.path.join(
138 | self.features_path, str(task) + '_' + str(vid) + '.npy')
139 | else:
140 | video_path = os.path.join(
141 | self.features_path, str(vid) + '.npy')
142 | legal_range = self.process_single(task, vid)
143 | if not legal_range:
144 | continue
145 |
146 | temp_len = len(legal_range)
147 | temp = []
148 | while temp_len < self.max_traj_len:
149 | temp.append(legal_range[0])
150 | temp_len += 1
151 |
152 | temp.extend(legal_range)
153 | legal_range = temp
154 |
155 | for i in range(len(legal_range) - self.max_traj_len + 1):
156 | legal_range_current = legal_range[i:i + self.max_traj_len]
157 | json_data.append({'id': {'vid': vid, 'task': task, 'feature': video_path,
158 | 'legal_range': legal_range_current, },'instruction_len': self.n_steps[task]})
159 |
160 | self.json_data = json_data
161 | with open(cross_task_data_name, 'w') as f:
162 | json.dump(json_data, f)
163 |
164 | elif args.dataset == 'coin':
165 | print(root)
166 | coin_path = os.path.join(root, 'coin', 'full_npy/')
167 | val_csv_path = os.path.join(
168 | root, 'coin', 'coin_test_30.json')
169 | video_csv_path = os.path.join(
170 | root, 'coin', 'coin_train_70.json')
171 |
172 | # coin_data_name = "/data1/wanghanlin/diffusion_planning/jsons_coin/sliding_window_cross_task_data_{}_{}_new_task_id_73.json".format(
173 | # is_val, self.max_traj_len)
174 | if is_val:
175 | coin_data_name = args.json_path_val
176 | else:
177 | coin_data_name = args.json_path_train
178 |
179 | if os.path.exists(coin_data_name):
180 | with open(coin_data_name, 'r') as f:
181 | self.json_data = json.load(f)
182 | print('Loaded {}'.format(coin_data_name))
183 | else:
184 | json_data = []
185 | num = 0
186 | if is_val:
187 | with open(val_csv_path, 'r') as f:
188 | coin_data = json.load(f)
189 | else:
190 | with open(video_csv_path, 'r') as f:
191 | coin_data = json.load(f)
192 | for i in coin_data:
193 | for (k, v) in i.items():
194 | file_name = v['class'] + '_' + str(v['recipe_type']) + '_' + k + '.npy'
195 | file_path = coin_path + file_name
196 | images_ = np.load(file_path, allow_pickle=True)
197 | images = images_['frames_features']
198 | legal_range = []
199 |
200 | last_action = v['annotation'][-1]['segment'][1]
201 | last_action = math.ceil(last_action)
202 | if last_action > len(images):
203 | print(k, last_action, len(images))
204 | num += 1
205 | continue
206 |
207 | for annotation in v['annotation']:
208 | action_label = int(annotation['id']) - 1
209 | start_idx, end_idx = annotation['segment']
210 | start_idx = math.floor(start_idx)
211 | end_idx = math.ceil(end_idx)
212 |
213 | if end_idx < images.shape[0]:
214 | legal_range.append((start_idx, end_idx, action_label))
215 | else:
216 | legal_range.append((start_idx, images.shape[0] - 1, action_label))
217 |
218 | temp_len = len(legal_range)
219 | temp = []
220 | while temp_len < self.max_traj_len:
221 | temp.append(legal_range[0])
222 | temp_len += 1
223 |
224 | temp.extend(legal_range)
225 | legal_range = temp
226 |
227 | for i in range(len(legal_range) - self.max_traj_len + 1):
228 | legal_range_current = legal_range[i:i + self.max_traj_len]
229 | json_data.append({'id': {'vid': k, 'feature': file_path,
230 | 'legal_range': legal_range_current, 'task_id': v['recipe_type']},
231 | 'instruction_len': 0})
232 | print(num)
233 | self.json_data = json_data
234 | with open(coin_data_name, 'w') as f:
235 | json.dump(json_data, f)
236 |
237 | elif args.dataset == 'NIV':
238 | val_csv_path = os.path.join(
239 | root, '../NIV', 'test30.json')
240 | video_csv_path = os.path.join(
241 | root, '../NIV', 'train70.json')
242 |
243 | if is_val:
244 | niv_data_name = args.json_path_val
245 | else:
246 | niv_data_name = args.json_path_train
247 |
248 |
249 | if os.path.exists(niv_data_name):
250 | with open(niv_data_name, 'r') as f:
251 | self.json_data = json.load(f)
252 | print('Loaded {}'.format(niv_data_name))
253 | else:
254 | json_data = []
255 | if is_val:
256 | with open(val_csv_path, 'r') as f:
257 | niv_data = json.load(f)
258 | else:
259 | with open(video_csv_path, 'r') as f:
260 | niv_data = json.load(f)
261 | for d in niv_data:
262 | legal_range = []
263 | path = d['feature']
264 | info = np.load(path, allow_pickle=True)
265 | num_steps = int(info['num_steps'])
266 | assert num_steps == len(info['steps_ids'])
267 | assert info['num_steps'] == len(info['steps_starts'])
268 | assert info['num_steps'] == len(info['steps_ends'])
269 | starts = info['steps_starts']
270 | ends = info['steps_ends']
271 | action_labels = info['steps_ids']
272 | images = info['frames_features']
273 |
274 | for i in range(num_steps):
275 | action_label = int(action_labels[i])
276 | start_idx = math.floor(float(starts[i]))
277 | end_idx = math.ceil(float(ends[i]))
278 |
279 | if end_idx < images.shape[0]:
280 | legal_range.append((start_idx, end_idx, action_label))
281 | else:
282 | legal_range.append((start_idx, images.shape[0] - 1, action_label))
283 |
284 | temp_len = len(legal_range)
285 | temp = []
286 | while temp_len < self.max_traj_len:
287 | temp.append(legal_range[0])
288 | temp_len += 1
289 |
290 | temp.extend(legal_range)
291 | legal_range = temp
292 |
293 | for i in range(len(legal_range) - self.max_traj_len + 1):
294 | legal_range_current = legal_range[i:i + self.max_traj_len]
295 | json_data.append({'id': {'feature': path, 'legal_range': legal_range_current, 'task_id': d['task_id']}, 'instruction_len': 0})
296 | self.json_data = json_data
297 | with open(niv_data_name, 'w') as f:
298 | json.dump(json_data, f)
299 | print(len(json_data))
300 | else:
301 | raise NotImplementedError(
302 | 'Dataset {} is not implemented'.format(args.dataset))
303 |
304 | self.model = model
305 | self.prepare_data()
306 | self.M = 3
307 |
308 | def process_single(self, task, vid):
309 | if self.crosstask_use_feature_how:
310 | if not os.path.exists(os.path.join(self.features_path, str(task) + '_' + str(vid) + '.npy')):
311 | return False
312 | images_ = np.load(os.path.join(self.features_path, str(task) + '_' + str(vid) + '.npy'), allow_pickle=True)
313 | images = images_['frames_features']
314 | else:
315 | if not os.path.exists(os.path.join(self.features_path, vid + '.npy')):
316 | return False
317 | images = np.load(os.path.join(self.features_path, vid + '.npy'))
318 |
319 | cnst_path = os.path.join(
320 | self.constraints_path, task + '_' + vid + '.csv')
321 | legal_range = self.read_assignment(task, cnst_path)
322 | legal_range_ret = []
323 | for (start_idx, end_idx, action_label) in legal_range:
324 | if not start_idx < images.shape[0]:
325 | print(task, vid, end_idx, images.shape[0])
326 | return False
327 | if end_idx < images.shape[0]:
328 | legal_range_ret.append((start_idx, end_idx, action_label))
329 | else:
330 | legal_range_ret.append((start_idx, images.shape[0] - 1, action_label))
331 |
332 | return legal_range_ret
333 |
334 | def read_assignment(self, task_id, path):
335 | legal_range = []
336 | with open(path, 'r') as f:
337 | for line in f:
338 | step, start, end = line.strip().split(',')
339 | start = int(math.floor(float(start)))
340 | end = int(math.ceil(float(end)))
341 | action_label_ind = self.action_one_hot[task_id + '_' + step]
342 | legal_range.append((start, end, action_label_ind))
343 |
344 | return legal_range
345 |
346 | def prepare_data(self):
347 | vid_names = []
348 | frame_cnts = []
349 | for listdata in self.json_data:
350 | vid_names.append(listdata['id'])
351 | frame_cnts.append(listdata['instruction_len'])
352 | self.vid_names = vid_names
353 | self.frame_cnts = frame_cnts
354 |
355 | def __len__(self):
356 | return len(self.json_data)
357 |
--------------------------------------------------------------------------------
/step/main_distributed.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import random
4 | import time
5 | from collections import OrderedDict
6 |
7 | import torch.nn.parallel
8 | import torch.backends.cudnn as cudnn
9 | import torch.distributed as dist
10 | import torch.optim
11 | import torch.multiprocessing as mp
12 | import torch.utils.data
13 | import torch.utils.data.distributed
14 | from torch.distributed import ReduceOp
15 |
16 | import utils
17 | from dataloader.data_load import PlanningDataset
18 | from model import diffusion, temporal
19 | from model.helpers import get_lr_schedule_with_warmup
20 |
21 | from utils import *
22 | from logging import log
23 | from utils.args import get_args
24 | import numpy as np
25 | from model.helpers import Logger
26 |
27 |
28 | def reduce_tensor(tensor):
29 | rt = tensor.clone()
30 | torch.distributed.all_reduce(rt, op=ReduceOp.SUM)
31 | rt /= dist.get_world_size()
32 | return rt
33 |
34 |
35 | def main():
36 | # os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
37 | args = get_args()
38 | os.environ['PYTHONHASHSEED'] = str(args.seed)
39 |
40 | if args.verbose:
41 | print(args)
42 | if args.seed is not None:
43 | random.seed(args.seed)
44 | np.random.seed(args.seed)
45 | torch.manual_seed(args.seed)
46 | torch.cuda.manual_seed_all(args.seed)
47 |
48 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed
49 | ngpus_per_node = torch.cuda.device_count()
50 |
51 | if args.multiprocessing_distributed:
52 | args.world_size = ngpus_per_node * args.world_size
53 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
54 | else:
55 | main_worker(args.gpu, ngpus_per_node, args)
56 |
57 |
58 | def main_worker(gpu, ngpus_per_node, args):
59 | args.gpu = gpu
60 | # print('gpuid:', args.gpu)
61 |
62 | if args.distributed:
63 | if args.multiprocessing_distributed:
64 | args.rank = args.rank * ngpus_per_node + gpu
65 | dist.init_process_group(
66 | backend=args.dist_backend,
67 | init_method=args.dist_url,
68 | world_size=args.world_size,
69 | rank=args.rank,
70 | )
71 | if args.gpu is not None:
72 | torch.cuda.set_device(args.gpu)
73 | args.batch_size = int(args.batch_size / ngpus_per_node)
74 | args.batch_size_val = int(args.batch_size_val / ngpus_per_node)
75 | args.num_thread_reader = int(args.num_thread_reader / ngpus_per_node)
76 | elif args.gpu is not None:
77 | torch.cuda.set_device(args.gpu)
78 |
79 | # Data loading code
80 | train_dataset = PlanningDataset(
81 | args.root,
82 | args=args,
83 | is_val=False,
84 | model=None,
85 | )
86 | # Test data loading code
87 | test_dataset = PlanningDataset(
88 | args.root,
89 | args=args,
90 | is_val=True,
91 | model=None,
92 | )
93 | if args.distributed:
94 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
95 | test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset)
96 | else:
97 | train_sampler = None
98 | test_sampler = None
99 |
100 | train_loader = torch.utils.data.DataLoader(
101 | train_dataset,
102 | batch_size=args.batch_size,
103 | shuffle=(train_sampler is None),
104 | drop_last=True,
105 | num_workers=args.num_thread_reader,
106 | pin_memory=args.pin_memory,
107 | sampler=train_sampler,
108 | )
109 | test_loader = torch.utils.data.DataLoader(
110 | test_dataset,
111 | batch_size=args.batch_size_val,
112 | shuffle=False,
113 | drop_last=False,
114 | num_workers=args.num_thread_reader,
115 | sampler=test_sampler,
116 | )
117 |
118 | # create model
119 | temporal_model = temporal.TemporalUnet( # temporal
120 | args.action_dim + args.observation_dim,
121 | dim=256,
122 | dim_mults=(1, 2, 4), )
123 |
124 | # for param in temporal_model.named_parameters():
125 | # if 'time_mlp' not in param[0]:
126 | # param[1].requires_grad = False
127 |
128 | diffusion_model = diffusion.GaussianDiffusion(
129 | temporal_model, args.horizon, args.observation_dim, args.action_dim, args.n_diffusion_steps,
130 | loss_type='Weighted_MSE', clip_denoised=True, )
131 |
132 | model = utils.Trainer(diffusion_model, train_loader, args.ema_decay, args.lr, args.gradient_accumulate_every,
133 | args.step_start_ema, args.update_ema_every, args.log_freq)
134 |
135 | if args.pretrain_cnn_path:
136 | net_data = torch.load(args.pretrain_cnn_path)
137 | model.model.load_state_dict(net_data)
138 | model.ema_model.load_state_dict(net_data)
139 | if args.distributed:
140 | if args.gpu is not None:
141 | model.model.cuda(args.gpu)
142 | model.ema_model.cuda(args.gpu)
143 | model.model = torch.nn.parallel.DistributedDataParallel(
144 | model.model, device_ids=[args.gpu], find_unused_parameters=True)
145 | model.ema_model = torch.nn.parallel.DistributedDataParallel(
146 | model.ema_model, device_ids=[args.gpu], find_unused_parameters=True)
147 | else:
148 | model.model.cuda()
149 | model.ema_model.cuda()
150 | model.model = torch.nn.parallel.DistributedDataParallel(model.model, find_unused_parameters=True)
151 | model.ema_model = torch.nn.parallel.DistributedDataParallel(model.ema_model,
152 | find_unused_parameters=True)
153 |
154 | elif args.gpu is not None:
155 | model.model = model.model.cuda(args.gpu)
156 | model.ema_model = model.ema_model.cuda(args.gpu)
157 | else:
158 | model.model = torch.nn.DataParallel(model.model).cuda()
159 | model.ema_model = torch.nn.DataParallel(model.ema_model).cuda()
160 |
161 | scheduler = get_lr_schedule_with_warmup(model.optimizer, int(args.n_train_steps * args.epochs))
162 |
163 | checkpoint_dir = os.path.join(os.path.dirname(__file__), 'checkpoint', args.checkpoint_dir)
164 | if args.checkpoint_dir != '' and not (os.path.isdir(checkpoint_dir)) and args.rank == 0:
165 | os.mkdir(checkpoint_dir)
166 |
167 | if args.resume:
168 | checkpoint_path = get_last_checkpoint(checkpoint_dir)
169 | if checkpoint_path:
170 | log("=> loading checkpoint '{}'".format(checkpoint_path), args)
171 | checkpoint = torch.load(checkpoint_path, map_location='cuda:{}'.format(args.rank))
172 | args.start_epoch = checkpoint["epoch"]
173 | model.model.load_state_dict(checkpoint["model"])
174 | model.ema_model.load_state_dict(checkpoint["ema"])
175 | model.optimizer.load_state_dict(checkpoint["optimizer"])
176 | model.step = checkpoint["step"]
177 | # for p in model.optimizer.param_groups:
178 | # p['lr'] = 1e-5
179 | scheduler.load_state_dict(checkpoint["scheduler"])
180 | tb_logdir = checkpoint["tb_logdir"]
181 | if args.rank == 0:
182 | # creat logger
183 | tb_logger = Logger(tb_logdir)
184 | log("=> loaded checkpoint '{}' (epoch {}){}".format(checkpoint_path, checkpoint["epoch"], args.gpu), args)
185 | else:
186 | # log("=> loading checkpoint '{}' to initialize".format(checkpoint_path), args)
187 | # checkpoint = torch.load(checkpoint_path, map_location='cuda:{}'.format(args.rank))
188 | # model.model.load_state_dict(checkpoint["model"], strict=False)
189 | # model.ema_model.load_state_dict(checkpoint["ema"], strict=False)
190 | time_pre = time.strftime("%Y%m%d%H%M%S", time.localtime())
191 | logname = args.log_root + '_' + time_pre + '_' + args.dataset
192 | tb_logdir = os.path.join(args.log_root, logname)
193 | if args.rank == 0:
194 | # creat logger
195 | if not (os.path.exists(tb_logdir)):
196 | os.makedirs(tb_logdir)
197 | tb_logger = Logger(tb_logdir)
198 | tb_logger.log_info(args)
199 | log("=> no checkpoint found at '{}'".format(args.resume), args)
200 |
201 | if args.cudnn_benchmark:
202 | cudnn.benchmark = True
203 | total_batch_size = args.world_size * args.batch_size
204 | log(
205 | "Starting training loop for rank: {}, total batch size: {}".format(
206 | args.rank, total_batch_size
207 | ), args
208 | )
209 |
210 | max_eva = 0
211 | max_acc = 0
212 | old_max_epoch = 0
213 | save_max = os.path.join(os.path.dirname(__file__), 'save_max')
214 |
215 | for epoch in range(args.start_epoch, args.epochs):
216 | print('epoch : ', epoch)
217 | if args.distributed:
218 | train_sampler.set_epoch(epoch)
219 |
220 | # train for one epoch
221 | if (epoch + 1) % 10 == 0: # calculate on training set
222 | losses, acc_top1, acc_top5, trajectory_success_rate_meter, MIoU1_meter, MIoU2_meter, \
223 | acc_a0, acc_aT = model.train(args.n_train_steps, True, args, scheduler)
224 | losses_reduced = reduce_tensor(losses.cuda()).item()
225 | acc_top1_reduced = reduce_tensor(acc_top1.cuda()).item()
226 | acc_top5_reduced = reduce_tensor(acc_top5.cuda()).item()
227 | trajectory_success_rate_meter_reduced = reduce_tensor(trajectory_success_rate_meter.cuda()).item()
228 | MIoU1_meter_reduced = reduce_tensor(MIoU1_meter.cuda()).item()
229 | MIoU2_meter_reduced = reduce_tensor(MIoU2_meter.cuda()).item()
230 | acc_a0_reduced = reduce_tensor(acc_a0.cuda()).item()
231 | acc_aT_reduced = reduce_tensor(acc_aT.cuda()).item()
232 |
233 | if args.rank == 0:
234 | logs = OrderedDict()
235 | logs['Train/EpochLoss'] = losses_reduced
236 | logs['Train/EpochAcc@1'] = acc_top1_reduced
237 | logs['Train/EpochAcc@5'] = acc_top5_reduced
238 | logs['Train/Traj_Success_Rate'] = trajectory_success_rate_meter_reduced
239 | logs['Train/MIoU1'] = MIoU1_meter_reduced
240 | logs['Train/MIoU2'] = MIoU2_meter_reduced
241 | logs['Train/acc_a0'] = acc_a0_reduced
242 | logs['Train/acc_aT'] = acc_aT_reduced
243 | for key, value in logs.items():
244 | tb_logger.log_scalar(value, key, epoch + 1)
245 |
246 | tb_logger.flush()
247 | else:
248 | losses = model.train(args.n_train_steps, False, args, scheduler).cuda()
249 | losses_reduced = reduce_tensor(losses).item()
250 | if args.rank == 0:
251 | print('lrs:')
252 | for p in model.optimizer.param_groups:
253 | print(p['lr'])
254 | print('---------------------------------')
255 |
256 | logs = OrderedDict()
257 | logs['Train/EpochLoss'] = losses_reduced
258 | for key, value in logs.items():
259 | tb_logger.log_scalar(value, key, epoch + 1)
260 |
261 | tb_logger.flush()
262 |
263 | if ((epoch + 1) % 5 == 0) and args.evaluate: # or epoch > 18
264 | losses, acc_top1, acc_top5, \
265 | trajectory_success_rate_meter, MIoU1_meter, MIoU2_meter, \
266 | acc_a0, acc_aT = validate(test_loader, model.ema_model, args)
267 |
268 | losses_reduced = reduce_tensor(losses.cuda()).item()
269 | acc_top1_reduced = reduce_tensor(acc_top1.cuda()).item()
270 | acc_top5_reduced = reduce_tensor(acc_top5.cuda()).item()
271 | trajectory_success_rate_meter_reduced = reduce_tensor(trajectory_success_rate_meter.cuda()).item()
272 | MIoU1_meter_reduced = reduce_tensor(MIoU1_meter.cuda()).item()
273 | MIoU2_meter_reduced = reduce_tensor(MIoU2_meter.cuda()).item()
274 | acc_a0_reduced = reduce_tensor(acc_a0.cuda()).item()
275 | acc_aT_reduced = reduce_tensor(acc_aT.cuda()).item()
276 |
277 | if args.rank == 0:
278 | logs = OrderedDict()
279 | logs['Val/EpochLoss'] = losses_reduced
280 | logs['Val/EpochAcc@1'] = acc_top1_reduced
281 | logs['Val/EpochAcc@5'] = acc_top5_reduced
282 | logs['Val/Traj_Success_Rate'] = trajectory_success_rate_meter_reduced
283 | logs['Val/MIoU1'] = MIoU1_meter_reduced
284 | logs['Val/MIoU2'] = MIoU2_meter_reduced
285 | logs['Val/acc_a0'] = acc_a0_reduced
286 | logs['Val/acc_aT'] = acc_aT_reduced
287 | for key, value in logs.items():
288 | tb_logger.log_scalar(value, key, epoch + 1)
289 |
290 | tb_logger.flush()
291 | print(trajectory_success_rate_meter_reduced, max_eva)
292 | if trajectory_success_rate_meter_reduced >= max_eva:
293 | if not (trajectory_success_rate_meter_reduced == max_eva and acc_top1_reduced < max_acc):
294 | save_checkpoint2(
295 | {
296 | "epoch": epoch + 1,
297 | "model": model.model.state_dict(),
298 | "ema": model.ema_model.state_dict(),
299 | "optimizer": model.optimizer.state_dict(),
300 | "step": model.step,
301 | "tb_logdir": tb_logdir,
302 | "scheduler": scheduler.state_dict(),
303 | }, save_max, old_max_epoch, epoch + 1, args.rank
304 | )
305 | max_eva = trajectory_success_rate_meter_reduced
306 | max_acc = acc_top1_reduced
307 | old_max_epoch = epoch + 1
308 |
309 | if (epoch + 1) % args.save_freq == 0:
310 | if args.rank == 0:
311 | save_checkpoint(
312 | {
313 | "epoch": epoch + 1,
314 | "model": model.model.state_dict(),
315 | "ema": model.ema_model.state_dict(),
316 | "optimizer": model.optimizer.state_dict(),
317 | "step": model.step,
318 | "tb_logdir": tb_logdir,
319 | "scheduler": scheduler.state_dict(),
320 | }, checkpoint_dir, epoch + 1
321 | )
322 |
323 |
324 | def log(output, args):
325 | with open(os.path.join(os.path.dirname(__file__), 'log', args.checkpoint_dir + '.txt'), "a") as f:
326 | f.write(output + '\n')
327 |
328 |
329 | def save_checkpoint(state, checkpoint_dir, epoch, n_ckpt=3):
330 | torch.save(state, os.path.join(checkpoint_dir, "epoch{:0>4d}.pth.tar".format(epoch)))
331 | if epoch - n_ckpt >= 0:
332 | oldest_ckpt = os.path.join(checkpoint_dir, "epoch{:0>4d}.pth.tar".format(epoch - n_ckpt))
333 | if os.path.isfile(oldest_ckpt):
334 | os.remove(oldest_ckpt)
335 |
336 |
337 | def save_checkpoint2(state, checkpoint_dir, old_epoch, epoch, rank):
338 | torch.save(state, os.path.join(checkpoint_dir, "epoch{:0>4d}_{}.pth.tar".format(epoch, rank)))
339 | if old_epoch > 0:
340 | oldest_ckpt = os.path.join(checkpoint_dir, "epoch{:0>4d}_{}.pth.tar".format(old_epoch, rank))
341 | if os.path.isfile(oldest_ckpt):
342 | os.remove(oldest_ckpt)
343 |
344 |
345 | def get_last_checkpoint(checkpoint_dir):
346 | all_ckpt = glob.glob(os.path.join(checkpoint_dir, 'epoch*.pth.tar'))
347 | if all_ckpt:
348 | all_ckpt = sorted(all_ckpt)
349 | return all_ckpt[-1]
350 | else:
351 | return ''
352 |
353 |
354 | if __name__ == "__main__":
355 | main()
356 |
--------------------------------------------------------------------------------
/step/inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import time
4 | import json
5 | import numpy as np
6 | import torch.nn.parallel
7 | import torch.backends.cudnn as cudnn
8 | import torch.distributed as dist
9 | import torch.optim
10 | import torch.multiprocessing as mp
11 | import torch.utils.data
12 | import torch.utils.data.distributed
13 | import utils
14 | from torch.distributed import ReduceOp
15 | from dataloader.data_load import PlanningDataset
16 | from model import diffusion, temporal
17 | from utils import *
18 | from utils.args import get_args
19 | from action_dictionary import action_dictionary
20 |
21 | def map_numbers_to_values(input_list, mapping_dict):
22 | result = []
23 | for sublist in input_list:
24 | result.append([mapping_dict[num+1] for num in sublist])
25 | return result
26 |
27 | def accuracy2(output, target, topk=(1,), max_traj_len=0):
28 | with torch.no_grad():
29 | maxk = max(topk)
30 | batch_size = target.size(0)
31 | _, pred = output.topk(maxk, 1, True, True)
32 | pred = pred.t()
33 | correct = pred.eq(target.view(1, -1).expand_as(pred))
34 |
35 | correct_a = correct[:1].view(-1, max_traj_len)
36 | correct_a0 = correct_a[:, 0].reshape(-1).float().mean().mul_(100.0)
37 | correct_aT = correct_a[:, -1].reshape(-1).float().mean().mul_(100.0)
38 |
39 | res = []
40 | for k in topk:
41 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
42 | res.append(correct_k.mul_(100.0 / batch_size))
43 |
44 | correct_1 = correct[:1]
45 |
46 | # Success Rate
47 | trajectory_success = torch.all(correct_1.view(correct_1.shape[1] // max_traj_len, -1), dim=1)
48 | trajectory_success_rate = trajectory_success.sum() * 100.0 / trajectory_success.shape[0]
49 |
50 | # MIoU
51 | _, pred_token = output.topk(1, 1, True, True)
52 | pred_inst = pred_token.view(correct_1.shape[1], -1)
53 | pred_inst_set = set()
54 | target_inst = target.view(correct_1.shape[1], -1)
55 | target_inst_set = set()
56 | for i in range(pred_inst.shape[0]):
57 | pred_inst_set.add(tuple(pred_inst[i].tolist()))
58 | target_inst_set.add(tuple(target_inst[i].tolist()))
59 | MIoU1 = 100.0 * len(pred_inst_set.intersection(target_inst_set)) / len(pred_inst_set.union(target_inst_set))
60 |
61 | batch_size = batch_size // max_traj_len
62 | pred_inst = pred_token.view(batch_size, -1) # [bs, T]
63 | pred_inst_set = set()
64 | target_inst = target.view(batch_size, -1) # [bs, T]
65 | target_inst_set = set()
66 | MIoU_sum = 0
67 | for i in range(pred_inst.shape[0]):
68 | pred_inst_set.update(pred_inst[i].tolist())
69 | target_inst_set.update(target_inst[i].tolist())
70 | MIoU_current = 100.0 * len(pred_inst_set.intersection(target_inst_set)) / len(
71 | pred_inst_set.union(target_inst_set))
72 | MIoU_sum += MIoU_current
73 | pred_inst_set.clear()
74 | target_inst_set.clear()
75 |
76 | MIoU2 = MIoU_sum / batch_size
77 | return res[0], trajectory_success_rate, MIoU1, MIoU2, correct_a0, correct_aT
78 |
79 |
80 | def test(vid_names, val_loader, model, args):
81 | model.eval()
82 | acc_top1 = AverageMeter()
83 | trajectory_success_rate_meter = AverageMeter()
84 | MIoU1_meter = AverageMeter()
85 | MIoU2_meter = AverageMeter()
86 | A0_acc = AverageMeter()
87 | AT_acc = AverageMeter()
88 | index_list = []
89 | action_list = []
90 | differing_sublists = {}
91 | all_sublists = {}
92 | final_pred_list = []
93 | file_final_list_test = args.steps_path
94 | i=1
95 |
96 | for i_batch, sample_batch in enumerate(val_loader):
97 |
98 | global_img_tensors = sample_batch[0].cuda().contiguous()
99 |
100 | video_label = sample_batch[1].cuda()
101 | batch_size_current, T = video_label.size()
102 | print('batch size', batch_size_current)
103 | print('T =', T)
104 |
105 | cond = {}
106 |
107 | with torch.no_grad():
108 | cond[0] = global_img_tensors[:, 0, :].float()
109 | cond[T - 1] = global_img_tensors[:, -1, :].float()
110 |
111 | video_label_reshaped = video_label.view(-1)
112 | print('video label reshaped:' ,video_label_reshaped.shape)
113 | output = model(cond, if_jump=True)
114 |
115 | actions_pred = output.contiguous()
116 | actions_pred = actions_pred[:, :, :args.action_dim].contiguous()
117 | print('shape actions pred:', actions_pred.shape) #dim = [256, 3, 105]
118 | argmax_index = torch.argmax(actions_pred, dim = -1)
119 | index_list.append(argmax_index)
120 | print('index list length:', len(index_list))
121 | print('argmax actions pred length:', argmax_index.shape)
122 |
123 |
124 | tensor_list_action_sequence = argmax_index.tolist()
125 |
126 | output_list_action_sequence = map_numbers_to_values(tensor_list_action_sequence, action_dictionary)
127 |
128 | print('length action sequence', len(output_list_action_sequence))
129 |
130 | index_differing = i_batch
131 |
132 | for i in range(len(video_label)):
133 | bs = 256
134 | index_vid = (index_differing*bs)+i
135 | print('index:', index_vid+1, 'GT sequence: ', vid_names[index_vid]['legal_range'],'predicted sequence : ', argmax_index[i].tolist())
136 | final_pred_list.append(argmax_index[i].tolist())
137 | if not (torch.equal(video_label[i][:1], argmax_index[i][:1]) and torch.equal(video_label[i][-1:], argmax_index[i][-1:])):
138 |
139 | differing_sublists[index_vid+1] = {
140 | 'i_batch': i_batch,
141 | 'video_label_list': video_label[i].tolist(),
142 | 'predicted_list': argmax_index[i].tolist(),
143 | 'predicted_sequence': output_list_action_sequence[i],
144 | 'vid': vid_names[index_vid]['vid'],
145 | 'legal_range':vid_names[index_vid]['legal_range']
146 | }
147 |
148 | for i in range(len(video_label)):
149 | bs = 256
150 | index_vid_sim = (index_differing*bs)+i
151 | all_sublists[index_vid_sim] = {
152 | 'i_batch': i_batch,
153 | 'video_label_list': video_label[i].tolist(),
154 | 'predicted_list': argmax_index[i].tolist(),
155 | 'predicted_sequence': output_list_action_sequence[i],
156 | 'vid': vid_names[index_vid_sim]['vid'],
157 | 'legal_range':vid_names[index_vid_sim]['legal_range'],
158 |
159 | }
160 |
161 | actions_pred = actions_pred.view(-1, args.action_dim)
162 | acc1, trajectory_success_rate, MIoU1, MIoU2, a0_acc, aT_acc = \
163 | accuracy2(actions_pred.cpu(), video_label_reshaped.cpu(), topk=(1,), max_traj_len=args.horizon)
164 |
165 | acc_top1.update(acc1.item(), batch_size_current)
166 | trajectory_success_rate_meter.update(trajectory_success_rate.item(), batch_size_current)
167 | MIoU1_meter.update(MIoU1, batch_size_current)
168 | MIoU2_meter.update(MIoU2, batch_size_current)
169 | A0_acc.update(a0_acc, batch_size_current)
170 | AT_acc.update(aT_acc, batch_size_current)
171 |
172 |
173 | print('Differing.............................................................................................')
174 | for index, sublist_data in differing_sublists.items():
175 | print(f"Index: {index}", '||', f"i_batch: {sublist_data['i_batch']}", '||', f"Video Label List: {sublist_data['video_label_list']}", '||',
176 | f"Predicted List: {sublist_data['predicted_list']}", '||', f"Predicted Sequence: {sublist_data['predicted_sequence']}", '||', f"Vid: {sublist_data['vid']}",
177 | '||', f"Legal range: {sublist_data['legal_range']}")
178 | print()
179 | print('Length of the failure cases : ' , len(differing_sublists) )
180 |
181 | print('All lists.............................................................................................')
182 | for index, sublist_data in all_sublists.items():
183 | print(f"Index: {index}", '||', f"i_batch: {sublist_data['i_batch']}", '||', f"Video Label List: {sublist_data['video_label_list']}", '||',
184 | f"Predicted List: {sublist_data['predicted_list']}", '||', f"Predicted Sequence: {sublist_data['predicted_sequence']}", '||', f"Vid: {sublist_data['vid']}"
185 | , '||', f"Legal range: {sublist_data['legal_range']}")
186 | print()
187 |
188 | with open (file_final_list_test, 'w') as ou:
189 | json.dump(final_pred_list, ou)
190 | return torch.tensor(acc_top1.avg), \
191 | torch.tensor(trajectory_success_rate_meter.avg), \
192 | torch.tensor(MIoU1_meter.avg), torch.tensor(MIoU2_meter.avg), \
193 | torch.tensor(A0_acc.avg), torch.tensor(AT_acc.avg)
194 |
195 |
196 | def reduce_tensor(tensor):
197 | rt = tensor.clone()
198 | torch.distributed.all_reduce(rt, op=ReduceOp.SUM)
199 | rt /= dist.get_world_size()
200 | return rt
201 |
202 |
203 | def main():
204 | # os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
205 | args = get_args()
206 | os.environ['PYTHONHASHSEED'] = str(args.seed)
207 |
208 | if args.verbose:
209 | print(args)
210 | if args.seed is not None:
211 | random.seed(args.seed)
212 | np.random.seed(args.seed)
213 | torch.manual_seed(args.seed)
214 | torch.cuda.manual_seed_all(args.seed)
215 |
216 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed
217 | ngpus_per_node = torch.cuda.device_count()
218 | # print('ngpus_per_node:', ngpus_per_node)
219 |
220 | if args.multiprocessing_distributed:
221 | args.world_size = ngpus_per_node * args.world_size
222 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
223 | else:
224 | main_worker(args.gpu, ngpus_per_node, args)
225 |
226 | #####Generate final step predictions
227 | with open(args.json_path_val, 'r') as original_data_file:
228 | original_data = json.load(original_data_file)
229 |
230 | with open(args.steps_path, 'r') as large_list_file:
231 | large_list = json.load(large_list_file)
232 |
233 |
234 | for i, item in enumerate(original_data):
235 | item["id"]["pred_list"] = large_list[i]
236 |
237 | with open(args.step_model_output, 'w') as modified_data_file:
238 | json.dump(original_data, modified_data_file)
239 |
240 |
241 |
242 | def main_worker(gpu, ngpus_per_node, args):
243 | args.gpu = gpu
244 | # print('gpuid:', args.gpu)
245 |
246 | if args.distributed:
247 | if args.multiprocessing_distributed:
248 | args.rank = args.rank * ngpus_per_node + gpu
249 | dist.init_process_group(
250 | backend=args.dist_backend,
251 | init_method=args.dist_url,
252 | world_size=args.world_size,
253 | rank=args.rank,
254 | )
255 | if args.gpu is not None:
256 | torch.cuda.set_device(args.gpu)
257 | args.batch_size = int(args.batch_size / ngpus_per_node)
258 | args.batch_size_val = int(args.batch_size_val / ngpus_per_node)
259 | args.num_thread_reader = int(args.num_thread_reader / ngpus_per_node)
260 | elif args.gpu is not None:
261 | torch.cuda.set_device(args.gpu)
262 |
263 | # Test data loading code
264 | test_dataset = PlanningDataset(
265 | args.root,
266 | args=args,
267 | is_val=True,
268 | model=None,
269 | )
270 |
271 | vid_names = test_dataset.vid_names
272 |
273 | if args.distributed:
274 | test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset)
275 | test_sampler.shuffle = False
276 | else:
277 | test_sampler = None
278 | print('none test sampler')
279 |
280 | test_loader = torch.utils.data.DataLoader(
281 | test_dataset,
282 | batch_size=args.batch_size_val,
283 | shuffle=False,
284 | drop_last=False,
285 | num_workers=args.num_thread_reader,
286 | sampler=test_sampler,
287 | )
288 |
289 | # create model
290 | temporal_model = temporal.TemporalUnet(
291 | args.action_dim + args.observation_dim,
292 | dim=256,
293 | dim_mults=(1, 2, 4), )
294 |
295 | diffusion_model = diffusion.GaussianDiffusion(
296 | temporal_model, args.horizon, args.observation_dim, args.action_dim, args.n_diffusion_steps,
297 | loss_type='Weighted_MSE', clip_denoised=True,)
298 |
299 | model = utils.Trainer(diffusion_model, None, args.ema_decay, args.lr, args.gradient_accumulate_every,
300 | args.step_start_ema, args.update_ema_every, args.log_freq)
301 |
302 | if args.pretrain_cnn_path:
303 | net_data = torch.load(args.pretrain_cnn_path)
304 | model.model.load_state_dict(net_data)
305 | model.ema_model.load_state_dict(net_data)
306 | if args.distributed:
307 | if args.gpu is not None:
308 | model.model.cuda(args.gpu)
309 | model.ema_model.cuda(args.gpu)
310 | model.model = torch.nn.parallel.DistributedDataParallel(
311 | model.model, device_ids=[args.gpu], find_unused_parameters=True)
312 | model.ema_model = torch.nn.parallel.DistributedDataParallel(
313 | model.ema_model, device_ids=[args.gpu], find_unused_parameters=True)
314 | else:
315 | model.model.cuda()
316 | model.ema_model.cuda()
317 | model.model = torch.nn.parallel.DistributedDataParallel(model.model, find_unused_parameters=True)
318 | model.ema_model = torch.nn.parallel.DistributedDataParallel(model.ema_model,
319 | find_unused_parameters=True)
320 |
321 | elif args.gpu is not None:
322 | model.model = model.model.cuda(args.gpu)
323 | model.ema_model = model.ema_model.cuda(args.gpu)
324 | else:
325 | model.model = torch.nn.DataParallel(model.model).cuda()
326 | model.ema_model = torch.nn.DataParallel(model.ema_model).cuda()
327 |
328 | if args.resume:
329 | checkpoint_path = ""
330 | if checkpoint_path:
331 | print("=> loading checkpoint '{}'".format(checkpoint_path), args)
332 | checkpoint = torch.load(checkpoint_path, map_location='cuda:{}'.format(args.rank))
333 | args.start_epoch = checkpoint["epoch"]
334 | model.model.load_state_dict(checkpoint["model"], strict=True)
335 | model.ema_model.load_state_dict(checkpoint["ema"], strict=True)
336 | model.step = checkpoint["step"]
337 | else:
338 | assert 0
339 |
340 | if args.cudnn_benchmark:
341 | cudnn.benchmark = True
342 |
343 | time_start = time.time()
344 | acc_top1_reduced_sum = []
345 | trajectory_success_rate_meter_reduced_sum = []
346 | MIoU1_meter_reduced_sum = []
347 | MIoU2_meter_reduced_sum = []
348 | acc_a0_reduced_sum = []
349 | acc_aT_reduced_sum = []
350 | test_times = 1
351 |
352 | for epoch in range(0, test_times):
353 | tmp = epoch
354 | random.seed(tmp)
355 | np.random.seed(tmp)
356 | torch.manual_seed(tmp)
357 | torch.cuda.manual_seed_all(tmp)
358 |
359 | acc_top1, trajectory_success_rate_meter, MIoU1_meter, MIoU2_meter, acc_a0, acc_aT = test(vid_names, test_loader, model.ema_model, args)
360 |
361 | acc_top1_reduced = reduce_tensor(acc_top1.cuda()).item()
362 | trajectory_success_rate_meter_reduced = reduce_tensor(trajectory_success_rate_meter.cuda()).item()
363 | MIoU1_meter_reduced = reduce_tensor(MIoU1_meter.cuda()).item()
364 | MIoU2_meter_reduced = reduce_tensor(MIoU2_meter.cuda()).item()
365 | acc_a0_reduced = reduce_tensor(acc_a0.cuda()).item()
366 | acc_aT_reduced = reduce_tensor(acc_aT.cuda()).item()
367 |
368 | acc_top1_reduced_sum.append(acc_top1_reduced)
369 | trajectory_success_rate_meter_reduced_sum.append(trajectory_success_rate_meter_reduced)
370 | MIoU1_meter_reduced_sum.append(MIoU1_meter_reduced)
371 | MIoU2_meter_reduced_sum.append(MIoU2_meter_reduced)
372 | acc_a0_reduced_sum.append(acc_a0_reduced)
373 | acc_aT_reduced_sum.append(acc_aT_reduced)
374 |
375 | if args.rank == 0:
376 | time_end = time.time()
377 | print('time: ', time_end - time_start)
378 | print('-----------------Mean&Var-----------------------')
379 | print('Val/EpochAcc@1', sum(acc_top1_reduced_sum) / test_times, np.var(acc_top1_reduced_sum))
380 | print('Val/Traj_Success_Rate', sum(trajectory_success_rate_meter_reduced_sum) / test_times, np.var(trajectory_success_rate_meter_reduced_sum))
381 | print('Val/MIoU1', sum(MIoU1_meter_reduced_sum) / test_times, np.var(MIoU1_meter_reduced_sum))
382 | print('Val/MIoU2', sum(MIoU2_meter_reduced_sum) / test_times, np.var(MIoU2_meter_reduced_sum))
383 | print('Val/acc_a0', sum(acc_a0_reduced_sum) / test_times, np.var(acc_a0_reduced_sum))
384 | print('Val/acc_aT', sum(acc_aT_reduced_sum) / test_times, np.var(acc_aT_reduced_sum))
385 |
386 |
387 | if __name__ == "__main__":
388 | main()
389 |
--------------------------------------------------------------------------------