├── plan ├── log │ ├── readme.txt │ └── whl.txt ├── checkpoint │ └── whl │ │ └── readme.txt ├── utils │ ├── __init__.py │ ├── one_hot.py │ ├── accuracy.py │ ├── eval.py │ ├── training.py │ └── args.py ├── save_max │ └── readme.txt ├── action_dictionary.py ├── model │ ├── temporal.py │ ├── helpers.py │ └── diffusion.py ├── dataloader │ └── data_load.py └── main_distributed.py ├── step ├── log │ └── readme.txt ├── checkpoint │ └── whl │ │ └── readme.txt ├── utils │ ├── __init__.py │ ├── accuracy.py │ ├── eval.py │ ├── training.py │ └── args.py ├── loading_data.py ├── action_dictionary.py ├── model │ ├── temporal.py │ ├── helpers.py │ └── diffusion.py ├── dataloader │ ├── data_load.py │ └── train_test_data_load.py ├── main_distributed.py └── inference.py ├── requirements.txt ├── dataset ├── NIV │ ├── download.sh │ ├── test30.json │ └── train70.json ├── coin │ └── download.sh └── crosstask │ ├── actions_one_hot.npy │ └── download.sh ├── PKG ├── graphs │ ├── NIV │ │ └── NIV_trained_graph.pkl │ ├── COIN │ │ └── coin_trained_graph.pkl │ ├── CrossTaskHow │ │ ├── weighted_graph │ │ │ └── trained_graph.pkl │ │ ├── min_max_N_graph │ │ │ └── trained_graph.pkl │ │ └── out_edge_N_graph │ │ │ └── trained_graph.pkl │ └── CrossTaskBase │ │ └── trained_graph_CrossTask_base.pkl ├── graph_visualize.py └── action_dictionary.py └── README.md /plan/log/readme.txt: -------------------------------------------------------------------------------- 1 | save KEPP logs here -------------------------------------------------------------------------------- /step/log/readme.txt: -------------------------------------------------------------------------------- 1 | save KEPP logs here -------------------------------------------------------------------------------- /plan/checkpoint/whl/readme.txt: -------------------------------------------------------------------------------- 1 | save checkpoints here -------------------------------------------------------------------------------- /step/checkpoint/whl/readme.txt: -------------------------------------------------------------------------------- 1 | save MLP checkpoints here -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops 2 | tensorboard 3 | tensorboardX 4 | torch 5 | -------------------------------------------------------------------------------- /plan/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .training import * 2 | from .eval import * 3 | from .args import * 4 | -------------------------------------------------------------------------------- /step/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .training import * 2 | from .eval import * 3 | from .args import * 4 | -------------------------------------------------------------------------------- /dataset/NIV/download.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | wget https://huggingface.co/nyanko7/LLaMA-7B/resolve/main/* -------------------------------------------------------------------------------- /dataset/coin/download.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | wget https://vision.eecs.yorku.ca/WebShare/COIN_s3d.zip 4 | 5 | unzip COIN_s3d.zip -------------------------------------------------------------------------------- /PKG/graphs/NIV/NIV_trained_graph.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ravindu-Yasas-Nagasinghe/KEPP/HEAD/PKG/graphs/NIV/NIV_trained_graph.pkl -------------------------------------------------------------------------------- /PKG/graphs/COIN/coin_trained_graph.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ravindu-Yasas-Nagasinghe/KEPP/HEAD/PKG/graphs/COIN/coin_trained_graph.pkl -------------------------------------------------------------------------------- /dataset/crosstask/actions_one_hot.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ravindu-Yasas-Nagasinghe/KEPP/HEAD/dataset/crosstask/actions_one_hot.npy -------------------------------------------------------------------------------- /plan/save_max/readme.txt: -------------------------------------------------------------------------------- 1 | The checkpoint with maximum vaidationaccuract will be saved here. Use the path of the checkpoint as the input to the inference file. -------------------------------------------------------------------------------- /PKG/graphs/CrossTaskHow/weighted_graph/trained_graph.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ravindu-Yasas-Nagasinghe/KEPP/HEAD/PKG/graphs/CrossTaskHow/weighted_graph/trained_graph.pkl -------------------------------------------------------------------------------- /PKG/graphs/CrossTaskBase/trained_graph_CrossTask_base.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ravindu-Yasas-Nagasinghe/KEPP/HEAD/PKG/graphs/CrossTaskBase/trained_graph_CrossTask_base.pkl -------------------------------------------------------------------------------- /PKG/graphs/CrossTaskHow/min_max_N_graph/trained_graph.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ravindu-Yasas-Nagasinghe/KEPP/HEAD/PKG/graphs/CrossTaskHow/min_max_N_graph/trained_graph.pkl -------------------------------------------------------------------------------- /PKG/graphs/CrossTaskHow/out_edge_N_graph/trained_graph.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ravindu-Yasas-Nagasinghe/KEPP/HEAD/PKG/graphs/CrossTaskHow/out_edge_N_graph/trained_graph.pkl -------------------------------------------------------------------------------- /dataset/crosstask/download.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | wget https://www.di.ens.fr/~dzhukov/crosstask/crosstask_release.zip 4 | 5 | wget https://www.di.ens.fr/~dzhukov/crosstask/crosstask_features.zip 6 | 7 | wget https://vision.eecs.yorku.ca/WebShare/CrossTask_s3d.zip 8 | 9 | unzip '*.zip' -------------------------------------------------------------------------------- /step/loading_data.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import random 4 | import time 5 | from collections import OrderedDict 6 | 7 | import torch.nn.parallel 8 | import torch.backends.cudnn as cudnn 9 | import torch.distributed as dist 10 | import torch.optim 11 | import torch.multiprocessing as mp 12 | import torch.utils.data 13 | import torch.utils.data.distributed 14 | from torch.distributed import ReduceOp 15 | import torch.nn.functional as F 16 | from dataloader.train_test_data_load import PlanningDataset 17 | from model.helpers import get_lr_schedule_with_warmup, Logger 18 | import torch.nn as nn 19 | from utils import * 20 | from logging import log 21 | from utils.args import get_args 22 | import numpy as np 23 | 24 | def main(): 25 | args = get_args() 26 | os.environ['PYTHONHASHSEED'] = str(args.seed) 27 | if os.path.exists(args.json_path_val): 28 | pass 29 | else: 30 | train_dataset = PlanningDataset( 31 | args.root, 32 | args=args, 33 | is_val=False, 34 | model=None, 35 | ) 36 | print('Train loaded') 37 | test_dataset = PlanningDataset( 38 | args.root, 39 | args=args, 40 | is_val=True, 41 | model=None, 42 | ) 43 | print('test loaded') 44 | 45 | if __name__ == "__main__": 46 | main() 47 | -------------------------------------------------------------------------------- /plan/utils/one_hot.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.args import get_args 3 | args = get_args() 4 | class LLMLabelOnehot(torch.nn.Module): 5 | def __init__(self, batch_size, T, num_lists, probabilities): 6 | super(LLMLabelOnehot, self).__init__() 7 | 8 | self.batch_size = batch_size 9 | self.T = T 10 | self.num_lists = num_lists 11 | self.probabilities = probabilities 12 | 13 | self.onehot = torch.zeros((batch_size * T, args.class_dim_llama)).cuda() 14 | 15 | def forward(self, LLM_label): 16 | for batch_idx in range(self.batch_size): 17 | for t in range(self.T): 18 | list_prob = torch.zeros(args.class_dim_llama).cuda() 19 | for list_idx in range(self.num_lists): 20 | if self.probabilities==[1]: 21 | list_prob[LLM_label[batch_idx][t]] = self.probabilities[list_idx] 22 | else: 23 | list_prob[LLM_label[batch_idx][list_idx][t]] = self.probabilities[list_idx] 24 | 25 | self.onehot[batch_idx * self.T + t, :] = list_prob 26 | 27 | self.onehot = self.onehot.reshape(self.batch_size, self.T, -1).cuda() 28 | return self.onehot 29 | 30 | 31 | class PKGLabelOnehot(torch.nn.Module): 32 | def __init__(self, batch_size, T, num_lists, probabilities): 33 | super(PKGLabelOnehot, self).__init__() 34 | 35 | self.batch_size = batch_size 36 | self.T = T 37 | self.num_lists = num_lists 38 | self.probabilities = probabilities 39 | 40 | self.onehot = torch.zeros((batch_size * T, args.class_dim_graph)).cuda() 41 | 42 | def forward(self, PKG_label): 43 | for batch_idx in range(self.batch_size): 44 | for t in range(self.T): 45 | list_prob = torch.zeros(args.class_dim_graph).cuda() 46 | for list_idx in range(self.num_lists): 47 | if self.probabilities==[1]: 48 | list_prob[PKG_label[batch_idx][t]] = self.probabilities[list_idx] 49 | else: 50 | list_prob[PKG_label[batch_idx][list_idx][t]] = self.probabilities[list_idx] 51 | 52 | self.onehot[batch_idx * self.T + t, :] = list_prob 53 | 54 | self.onehot = self.onehot.reshape(self.batch_size, self.T, -1).cuda() 55 | return self.onehot -------------------------------------------------------------------------------- /plan/utils/accuracy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def accuracy(output, target, topk=(1,), max_traj_len=0): 5 | with torch.no_grad(): 6 | maxk = max(topk) 7 | batch_size = target.size(0) 8 | 9 | _, pred = output.topk(maxk, 1, True, True) 10 | pred = pred.t() # [k, bs*T] 11 | correct = pred.eq(target.view(1, -1).expand_as(pred)) # [k, bs*T] 12 | 13 | correct_a = correct[:1].view(-1, max_traj_len) # [bs, T] 14 | correct_a0 = correct_a[:, 0].reshape(-1).float().mean().mul_(100.0) 15 | correct_aT = correct_a[:, -1].reshape(-1).float().mean().mul_(100.0) 16 | 17 | res = [] 18 | for k in topk: 19 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 20 | res.append(correct_k.mul_(100.0 / batch_size)) 21 | 22 | correct_1 = correct[:1] # (1, bs*T) 23 | 24 | # Success Rate 25 | trajectory_success = torch.all(correct_1.view(correct_1.shape[1] // max_traj_len, -1), dim=1) 26 | trajectory_success_rate = trajectory_success.sum() * 100.0 / trajectory_success.shape[0] 27 | 28 | # MIoU 29 | _, pred_token = output.topk(1, 1, True, True) # [bs*T, 1] 30 | pred_inst = pred_token.view(correct_1.shape[1], -1) # [bs*T, 1] 31 | pred_inst_set = set() 32 | target_inst = target.view(correct_1.shape[1], -1) # [bs*T, 1] 33 | target_inst_set = set() 34 | for i in range(pred_inst.shape[0]): 35 | # print(pred_inst[i], target_inst[i]) 36 | pred_inst_set.add(tuple(pred_inst[i].tolist())) 37 | target_inst_set.add(tuple(target_inst[i].tolist())) 38 | MIoU1 = 100.0 * len(pred_inst_set.intersection(target_inst_set)) / len(pred_inst_set.union(target_inst_set)) 39 | 40 | batch_size = batch_size // max_traj_len 41 | pred_inst = pred_token.view(batch_size, -1) # [bs, T] 42 | pred_inst_set = set() 43 | target_inst = target.view(batch_size, -1) # [bs, T] 44 | target_inst_set = set() 45 | MIoU_sum = 0 46 | for i in range(pred_inst.shape[0]): 47 | # print(pred_inst[i], target_inst[i]) 48 | pred_inst_set.update(pred_inst[i].tolist()) 49 | target_inst_set.update(target_inst[i].tolist()) 50 | MIoU_current = 100.0 * len(pred_inst_set.intersection(target_inst_set)) / len( 51 | pred_inst_set.union(target_inst_set)) 52 | MIoU_sum += MIoU_current 53 | pred_inst_set.clear() 54 | target_inst_set.clear() 55 | 56 | MIoU2 = MIoU_sum / batch_size 57 | return res, trajectory_success_rate, MIoU1, MIoU2, correct_a0, correct_aT 58 | 59 | -------------------------------------------------------------------------------- /step/utils/accuracy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def accuracy(output, target, topk=(1,), max_traj_len=0): 5 | with torch.no_grad(): 6 | maxk = max(topk) 7 | batch_size = target.size(0) 8 | 9 | _, pred = output.topk(maxk, 1, True, True) 10 | pred = pred.t() # [k, bs*T] 11 | correct = pred.eq(target.view(1, -1).expand_as(pred)) # [k, bs*T] 12 | 13 | correct_a = correct[:1].view(-1, max_traj_len) # [bs, T] 14 | correct_a0 = correct_a[:, 0].reshape(-1).float().mean().mul_(100.0) 15 | correct_aT = correct_a[:, -1].reshape(-1).float().mean().mul_(100.0) 16 | 17 | res = [] 18 | for k in topk: 19 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 20 | res.append(correct_k.mul_(100.0 / batch_size)) 21 | 22 | correct_1 = correct[:1] # (1, bs*T) 23 | 24 | # Success Rate 25 | trajectory_success = torch.all(correct_1.view(correct_1.shape[1] // max_traj_len, -1), dim=1) 26 | trajectory_success_rate = trajectory_success.sum() * 100.0 / trajectory_success.shape[0] 27 | 28 | # MIoU 29 | _, pred_token = output.topk(1, 1, True, True) # [bs*T, 1] 30 | pred_inst = pred_token.view(correct_1.shape[1], -1) # [bs*T, 1] 31 | pred_inst_set = set() 32 | target_inst = target.view(correct_1.shape[1], -1) # [bs*T, 1] 33 | target_inst_set = set() 34 | for i in range(pred_inst.shape[0]): 35 | # print(pred_inst[i], target_inst[i]) 36 | pred_inst_set.add(tuple(pred_inst[i].tolist())) 37 | target_inst_set.add(tuple(target_inst[i].tolist())) 38 | MIoU1 = 100.0 * len(pred_inst_set.intersection(target_inst_set)) / len(pred_inst_set.union(target_inst_set)) 39 | 40 | batch_size = batch_size // max_traj_len 41 | pred_inst = pred_token.view(batch_size, -1) # [bs, T] 42 | pred_inst_set = set() 43 | target_inst = target.view(batch_size, -1) # [bs, T] 44 | target_inst_set = set() 45 | MIoU_sum = 0 46 | for i in range(pred_inst.shape[0]): 47 | # print(pred_inst[i], target_inst[i]) 48 | pred_inst_set.update(pred_inst[i].tolist()) 49 | target_inst_set.update(target_inst[i].tolist()) 50 | MIoU_current = 100.0 * len(pred_inst_set.intersection(target_inst_set)) / len( 51 | pred_inst_set.union(target_inst_set)) 52 | MIoU_sum += MIoU_current 53 | pred_inst_set.clear() 54 | target_inst_set.clear() 55 | 56 | MIoU2 = MIoU_sum / batch_size 57 | return res, trajectory_success_rate, MIoU1, MIoU2, correct_a0, correct_aT 58 | 59 | -------------------------------------------------------------------------------- /PKG/graph_visualize.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import cv2 3 | import numpy as np 4 | 5 | import os 6 | import pickle 7 | import networkx as nx 8 | import logging 9 | import json 10 | import torch 11 | from tabulate import tabulate 12 | from action_dictionary import action_dictionary 13 | 14 | # Function to find all paths from an input node within a fixed length 15 | def find_paths_within_length(graph, start_node, max_length): 16 | def dfs_paths(node, path, length): 17 | if length <= max_length: 18 | path.append(node) 19 | if node == start_node: 20 | yield list(path) 21 | else: 22 | for neighbor in graph.neighbors(node): 23 | if neighbor not in path: 24 | yield from dfs_paths(neighbor, path, length + 1) 25 | path.pop() 26 | 27 | paths = list(dfs_paths(start_node, [], 0)) 28 | return paths 29 | 30 | 31 | graph_save_path = '/home/ravindu.nagasinghe/GithubCodes/RaviPP/trained_graph.pkl' 32 | 33 | with open(graph_save_path, 'rb') as graph_file: 34 | graph = pickle.load(graph_file) 35 | 36 | # Input parameters 37 | input_node = 1 # Replace with the desired input node 38 | max_length = 4 # Replace with the desired fixed length 39 | 40 | # Find all paths within the fixed length from the input node 41 | all_paths = find_paths_within_length(graph, input_node, max_length) 42 | 43 | # Create an image for visualization 44 | node_positions = nx.circular_layout(graph) 45 | image_width = len(all_paths[0]) * 100 # Width of the image based on the number of steps 46 | image_height = len(graph.nodes()) * 100 # Height of the image based on the number of nodes 47 | 48 | image = np.zeros((image_height, image_width, 3), dtype=np.uint8) # Initialize an empty image 49 | image.fill(255) # Set the background to white 50 | 51 | # Draw nodes 52 | for node, pos in node_positions.items(): 53 | x, y = int(pos[0] * image_width), int(pos[1] * image_height) 54 | cv2.circle(image, (x, y), 10, (0, 0, 0), -1) # Draw black nodes 55 | 56 | # Draw paths with thickness based on weight 57 | for path in all_paths: 58 | for i in range(len(path) - 1): 59 | u, v = path[i], path[i + 1] 60 | weight = graph[u][v]['weight'] 61 | thickness = int((weight / 5) * 3) # Adjust the scaling factor as needed 62 | u_pos, v_pos = node_positions[u], node_positions[v] 63 | u_x, u_y = int(u_pos[0] * image_width), int(u_pos[1] * image_height) 64 | v_x, v_y = int(v_pos[0] * image_width), int(v_pos[1] * image_height) 65 | cv2.line(image, (u_x, u_y), (v_x, v_y), (0, 0, 255), thickness) 66 | 67 | # Display the image (you can save it using cv2.imwrite) 68 | file_path = "graph_visualization.png" 69 | cv2.imwrite(file_path, image) 70 | 71 | -------------------------------------------------------------------------------- /PKG/action_dictionary.py: -------------------------------------------------------------------------------- 1 | action_dictionary = {1: 'pour water', 2: 'pour juice', 3: 'pour jello powder', 4: 'pour alcohol', 2 | 5: 'stir mixture', 6: 'pour mixture into cup', 7: 'cut shelve', 8: 'assemble shelve', 3 | 9: 'sand shelve', 10: 'paint shelve', 11: 'attach shelve', 12: 'add onion', 13: 'add taco', 4 | 14: 'add lettuce', 15: 'add meat', 16: 'add tomato', 17: 'add cheese', 18: 'stir', 5 | 19: 'add tortilla', 20: 'season steak', 21: 'put steak on grill', 22: 'close lid', 6 | 23: 'open lid', 24: 'move steak on grill', 25: 'flip steak', 26: 'check temperature', 7 | 27: 'take steak from grill', 28: 'top steak', 29: 'cut steak', 30: 'taste steak', 31: 'add rice', 8 | 32: 'add ham', 33: 'add kimchi', 34: 'pour sesame oil', 35: 'pour egg', 36: 'add sugar', 9 | 37: 'whisk mixture', 38: 'put mixture into bag', 39: 'spread mixture', 40: 'put meringue into oven', 10 | 41: 'add coffee', 42: 'press coffee', 43: 'pour espresso', 44: 'steam milk', 45: 'pour milk', 11 | 46: 'cut cucumber', 47: 'cut onion', 48: 'add salt', 49: 'pour vinegar', 50: 'add spices', 12 | 51: 'put vegetables in water', 52: 'pack cucumbers in jar', 53: 'seal jar', 13 | 54: 'put jar in water', 55: 'cut lemon', 56: 'squeeze lemon', 57: 'pour lemon juice', 14 | 58: 'add ice', 59: 'pour lemonade into glass', 60: 'dip bread in mixture', 61: 'melt butter', 15 | 62: 'put bread in pan', 63: 'add vanilla extract', 64: 'flip bread', 65: 'remove bread from pan', 16 | 66: 'top toast', 67: 'brake on', 68: 'raise jack', 69: 'lower jack', 70: 'add chili powder', 17 | 71: 'add mustard seeds', 72: 'add curry leaves', 73: 'add fish', 74: 'peel banana', 18 | 75: 'cut banana', 76: 'put bananas into blender', 77: 'mix ingredients', 78: 'remove cap', 19 | 79: 'put funnel', 80: 'pour oil', 81: 'remove funnel', 82: 'close cap', 83: 'pull out dipstick', 20 | 84: 'wipe off dipstick', 85: 'insert dipstick', 86: 'get things out', 87: 'start loose', 21 | 88: 'jack up', 89: 'unscrew wheel', 90: 'withdraw wheel', 91: 'put wheel', 92: 'screw wheel', 22 | 93: 'jack down', 94: 'tight wheel', 95: 'put things back', 96: 'add whipped cream', 23 | 97: 'add flour', 98: 'add butter', 99: 'put dough into form', 100: 'spread creme upon cake', 24 | 101: 'cut strawberries', 102: 'add strawberries to cake', 103: 'pour mixture into pan', 25 | 104: 'flip pancake', 105: 'take pancake from pan'} -------------------------------------------------------------------------------- /plan/action_dictionary.py: -------------------------------------------------------------------------------- 1 | action_dictionary = {1: 'pour water', 2: 'pour juice', 3: 'pour jello powder', 4: 'pour alcohol', 2 | 5: 'stir mixture', 6: 'pour mixture into cup', 7: 'cut shelve', 8: 'assemble shelve', 3 | 9: 'sand shelve', 10: 'paint shelve', 11: 'attach shelve', 12: 'add onion', 13: 'add taco', 4 | 14: 'add lettuce', 15: 'add meat', 16: 'add tomato', 17: 'add cheese', 18: 'stir', 5 | 19: 'add tortilla', 20: 'season steak', 21: 'put steak on grill', 22: 'close lid', 6 | 23: 'open lid', 24: 'move steak on grill', 25: 'flip steak', 26: 'check temperature', 7 | 27: 'take steak from grill', 28: 'top steak', 29: 'cut steak', 30: 'taste steak', 31: 'add rice', 8 | 32: 'add ham', 33: 'add kimchi', 34: 'pour sesame oil', 35: 'pour egg', 36: 'add sugar', 9 | 37: 'whisk mixture', 38: 'put mixture into bag', 39: 'spread mixture', 40: 'put meringue into oven', 10 | 41: 'add coffee', 42: 'press coffee', 43: 'pour espresso', 44: 'steam milk', 45: 'pour milk', 11 | 46: 'cut cucumber', 47: 'cut onion', 48: 'add salt', 49: 'pour vinegar', 50: 'add spices', 12 | 51: 'put vegetables in water', 52: 'pack cucumbers in jar', 53: 'seal jar', 13 | 54: 'put jar in water', 55: 'cut lemon', 56: 'squeeze lemon', 57: 'pour lemon juice', 14 | 58: 'add ice', 59: 'pour lemonade into glass', 60: 'dip bread in mixture', 61: 'melt butter', 15 | 62: 'put bread in pan', 63: 'add vanilla extract', 64: 'flip bread', 65: 'remove bread from pan', 16 | 66: 'top toast', 67: 'brake on', 68: 'raise jack', 69: 'lower jack', 70: 'add chili powder', 17 | 71: 'add mustard seeds', 72: 'add curry leaves', 73: 'add fish', 74: 'peel banana', 18 | 75: 'cut banana', 76: 'put bananas into blender', 77: 'mix ingredients', 78: 'remove cap', 19 | 79: 'put funnel', 80: 'pour oil', 81: 'remove funnel', 82: 'close cap', 83: 'pull out dipstick', 20 | 84: 'wipe off dipstick', 85: 'insert dipstick', 86: 'get things out', 87: 'start loose', 21 | 88: 'jack up', 89: 'unscrew wheel', 90: 'withdraw wheel', 91: 'put wheel', 92: 'screw wheel', 22 | 93: 'jack down', 94: 'tight wheel', 95: 'put things back', 96: 'add whipped cream', 23 | 97: 'add flour', 98: 'add butter', 99: 'put dough into form', 100: 'spread creme upon cake', 24 | 101: 'cut strawberries', 102: 'add strawberries to cake', 103: 'pour mixture into pan', 25 | 104: 'flip pancake', 105: 'take pancake from pan'} -------------------------------------------------------------------------------- /step/action_dictionary.py: -------------------------------------------------------------------------------- 1 | action_dictionary = {1: 'pour water', 2: 'pour juice', 3: 'pour jello powder', 4: 'pour alcohol', 2 | 5: 'stir mixture', 6: 'pour mixture into cup', 7: 'cut shelve', 8: 'assemble shelve', 3 | 9: 'sand shelve', 10: 'paint shelve', 11: 'attach shelve', 12: 'add onion', 13: 'add taco', 4 | 14: 'add lettuce', 15: 'add meat', 16: 'add tomato', 17: 'add cheese', 18: 'stir', 5 | 19: 'add tortilla', 20: 'season steak', 21: 'put steak on grill', 22: 'close lid', 6 | 23: 'open lid', 24: 'move steak on grill', 25: 'flip steak', 26: 'check temperature', 7 | 27: 'take steak from grill', 28: 'top steak', 29: 'cut steak', 30: 'taste steak', 31: 'add rice', 8 | 32: 'add ham', 33: 'add kimchi', 34: 'pour sesame oil', 35: 'pour egg', 36: 'add sugar', 9 | 37: 'whisk mixture', 38: 'put mixture into bag', 39: 'spread mixture', 40: 'put meringue into oven', 10 | 41: 'add coffee', 42: 'press coffee', 43: 'pour espresso', 44: 'steam milk', 45: 'pour milk', 11 | 46: 'cut cucumber', 47: 'cut onion', 48: 'add salt', 49: 'pour vinegar', 50: 'add spices', 12 | 51: 'put vegetables in water', 52: 'pack cucumbers in jar', 53: 'seal jar', 13 | 54: 'put jar in water', 55: 'cut lemon', 56: 'squeeze lemon', 57: 'pour lemon juice', 14 | 58: 'add ice', 59: 'pour lemonade into glass', 60: 'dip bread in mixture', 61: 'melt butter', 15 | 62: 'put bread in pan', 63: 'add vanilla extract', 64: 'flip bread', 65: 'remove bread from pan', 16 | 66: 'top toast', 67: 'brake on', 68: 'raise jack', 69: 'lower jack', 70: 'add chili powder', 17 | 71: 'add mustard seeds', 72: 'add curry leaves', 73: 'add fish', 74: 'peel banana', 18 | 75: 'cut banana', 76: 'put bananas into blender', 77: 'mix ingredients', 78: 'remove cap', 19 | 79: 'put funnel', 80: 'pour oil', 81: 'remove funnel', 82: 'close cap', 83: 'pull out dipstick', 20 | 84: 'wipe off dipstick', 85: 'insert dipstick', 86: 'get things out', 87: 'start loose', 21 | 88: 'jack up', 89: 'unscrew wheel', 90: 'withdraw wheel', 91: 'put wheel', 92: 'screw wheel', 22 | 93: 'jack down', 94: 'tight wheel', 95: 'put things back', 96: 'add whipped cream', 23 | 97: 'add flour', 98: 'add butter', 99: 'put dough into form', 100: 'spread creme upon cake', 24 | 101: 'cut strawberries', 102: 'add strawberries to cake', 103: 'pour mixture into pan', 25 | 104: 'flip pancake', 105: 'take pancake from pan'} -------------------------------------------------------------------------------- /step/utils/eval.py: -------------------------------------------------------------------------------- 1 | from .accuracy import * 2 | from model.helpers import AverageMeter 3 | 4 | 5 | def validate(val_loader, model, args): 6 | model.eval() 7 | losses = AverageMeter() 8 | acc_top1 = AverageMeter() 9 | acc_top5 = AverageMeter() 10 | trajectory_success_rate_meter = AverageMeter() 11 | MIoU1_meter = AverageMeter() 12 | MIoU2_meter = AverageMeter() 13 | 14 | A0_acc = AverageMeter() 15 | AT_acc = AverageMeter() 16 | 17 | for i_batch, sample_batch in enumerate(val_loader): 18 | # compute output 19 | global_img_tensors = sample_batch[0].cuda().contiguous().float() 20 | video_label = sample_batch[1].cuda() 21 | batch_size_current, T = video_label.size() 22 | #task_class = sample_batch[2].view(-1).cuda() 23 | cond = {} 24 | 25 | with torch.no_grad(): 26 | cond[0] = global_img_tensors[:, 0, :] 27 | cond[T - 1] = global_img_tensors[:, -1, :] 28 | ''' 29 | task_onehot = torch.zeros((task_class.size(0), args.class_dim)) # [bs*T, ac_dim] 30 | ind = torch.arange(0, len(task_class)) 31 | task_onehot[ind, task_class] = 1. 32 | task_onehot = task_onehot.cuda() 33 | temp = task_onehot.unsqueeze(1) 34 | task_class_ = temp.repeat(1, T, 1) # [bs, T, args.class_dim] 35 | cond['task'] = task_class_ 36 | ''' 37 | video_label_reshaped = video_label.view(-1) 38 | 39 | action_label_onehot = torch.zeros((video_label_reshaped.size(0), args.action_dim)) 40 | ind = torch.arange(0, len(video_label_reshaped)) 41 | action_label_onehot[ind, video_label_reshaped] = 1. 42 | action_label_onehot = action_label_onehot.reshape(batch_size_current, T, -1).cuda() 43 | 44 | x_start = torch.zeros((batch_size_current, T, args.action_dim + args.observation_dim)) 45 | x_start[:, 0, args.action_dim:] = global_img_tensors[:, 0, :] 46 | x_start[:, -1, args.action_dim:] = global_img_tensors[:, -1, :] 47 | action_label_onehot[:,1:-1,:] = 0. 48 | x_start[:, :, :args.action_dim] = action_label_onehot 49 | #x_start[:, :, :args.class_dim] = task_class_ 50 | output = model(cond) 51 | actions_pred = output.contiguous() 52 | loss = model.module.loss_fn(actions_pred, x_start.cuda()) 53 | 54 | actions_pred = actions_pred[:, :, :args.action_dim].contiguous() 55 | actions_pred = actions_pred.view(-1, args.action_dim) # [bs*T, action_dim] 56 | 57 | (acc1, acc5), trajectory_success_rate, MIoU1, MIoU2, a0_acc, aT_acc = \ 58 | accuracy(actions_pred.cpu(), video_label_reshaped.cpu(), topk=(1, 5), max_traj_len=args.horizon) 59 | 60 | losses.update(loss.item(), batch_size_current) 61 | acc_top1.update(acc1.item(), batch_size_current) 62 | acc_top5.update(acc5.item(), batch_size_current) 63 | trajectory_success_rate_meter.update(trajectory_success_rate.item(), batch_size_current) 64 | MIoU1_meter.update(MIoU1, batch_size_current) 65 | MIoU2_meter.update(MIoU2, batch_size_current) 66 | A0_acc.update(a0_acc, batch_size_current) 67 | AT_acc.update(aT_acc, batch_size_current) 68 | 69 | return torch.tensor(losses.avg), torch.tensor(acc_top1.avg), torch.tensor(acc_top5.avg), \ 70 | torch.tensor(trajectory_success_rate_meter.avg), \ 71 | torch.tensor(MIoU1_meter.avg), torch.tensor(MIoU2_meter.avg), \ 72 | torch.tensor(A0_acc.avg), torch.tensor(AT_acc.avg) 73 | -------------------------------------------------------------------------------- /plan/utils/eval.py: -------------------------------------------------------------------------------- 1 | from .accuracy import * 2 | from model.helpers import AverageMeter 3 | from .one_hot import PKGLabelOnehot 4 | from .one_hot import LLMLabelOnehot 5 | 6 | 7 | def validate(val_loader, model, args): 8 | model.eval() 9 | losses = AverageMeter() 10 | acc_top1 = AverageMeter() 11 | acc_top5 = AverageMeter() 12 | trajectory_success_rate_meter = AverageMeter() 13 | MIoU1_meter = AverageMeter() 14 | MIoU2_meter = AverageMeter() 15 | 16 | A0_acc = AverageMeter() 17 | AT_acc = AverageMeter() 18 | for i_batch, sample_batch in enumerate(val_loader): 19 | # compute output 20 | global_img_tensors = sample_batch[0].cuda().contiguous().float() 21 | video_label = sample_batch[1].cuda() 22 | batch_size_current, T = video_label.size() 23 | LLM_label = sample_batch[2].cuda() 24 | 25 | PKG_label = sample_batch[3].cuda() 26 | cond = {} 27 | llm_label_onehot = LLMLabelOnehot(batch_size_current, T, args.num_seq_LLM,[2/3, 1/3]) 28 | pkg_label_onehot = PKGLabelOnehot(batch_size_current, T, args.num_seq_PKG,[2/3, 1/3]) 29 | 30 | 31 | with torch.no_grad(): 32 | cond[0] = global_img_tensors[:, 0, :] 33 | cond[T - 1] = global_img_tensors[:, -1, :] 34 | 35 | LLM_label_onehot = llm_label_onehot(LLM_label) 36 | 37 | PKG_label_onehot = pkg_label_onehot(PKG_label) 38 | 39 | cond['LLM_action_path'] = LLM_label_onehot 40 | cond['PKG_action_path'] = PKG_label_onehot 41 | 42 | video_label_reshaped = video_label.view(-1) 43 | 44 | action_label_onehot = torch.zeros((video_label_reshaped.size(0), args.action_dim)) 45 | ind = torch.arange(0, len(video_label_reshaped)) 46 | action_label_onehot[ind, video_label_reshaped] = 1. 47 | action_label_onehot = action_label_onehot.reshape(batch_size_current, T, -1).cuda() 48 | 49 | x_start = torch.zeros((batch_size_current, T, args.class_dim_graph + args.class_dim_llama + args.action_dim + args.observation_dim)) 50 | x_start[:, 0, args.class_dim_graph + args.class_dim_llama + args.action_dim:] = global_img_tensors[:, 0, :] 51 | x_start[:, -1, args.class_dim_graph + args.class_dim_llama + args.action_dim:] = global_img_tensors[:, -1, :] 52 | 53 | x_start[:, :,args.class_dim_graph + args.class_dim_llama:args.class_dim_graph + args.class_dim_llama + args.action_dim] = action_label_onehot 54 | x_start[:, :, :args.class_dim_graph] = PKG_label_onehot 55 | x_start[:,:,args.class_dim_graph:args.class_dim_graph + args.class_dim_llama] = LLM_label_onehot 56 | output = model(cond) 57 | actions_pred = output.contiguous() 58 | loss = model.module.loss_fn(actions_pred, x_start.cuda()) 59 | 60 | actions_pred = actions_pred[:, :, args.class_dim_graph + args.class_dim_llama:args.class_dim_graph + args.class_dim_llama + args.action_dim].contiguous() 61 | actions_pred = actions_pred.view(-1, args.action_dim) # [bs*T, action_dim] 62 | 63 | (acc1, acc5), trajectory_success_rate, MIoU1, MIoU2, a0_acc, aT_acc = \ 64 | accuracy(actions_pred.cpu(), video_label_reshaped.cpu(), topk=(1, 5), max_traj_len=args.horizon) 65 | 66 | losses.update(loss.item(), batch_size_current) 67 | acc_top1.update(acc1.item(), batch_size_current) 68 | acc_top5.update(acc5.item(), batch_size_current) 69 | trajectory_success_rate_meter.update(trajectory_success_rate.item(), batch_size_current) 70 | MIoU1_meter.update(MIoU1, batch_size_current) 71 | MIoU2_meter.update(MIoU2, batch_size_current) 72 | A0_acc.update(a0_acc, batch_size_current) 73 | AT_acc.update(aT_acc, batch_size_current) 74 | 75 | return torch.tensor(losses.avg), torch.tensor(acc_top1.avg), torch.tensor(acc_top5.avg), \ 76 | torch.tensor(trajectory_success_rate_meter.avg), \ 77 | torch.tensor(MIoU1_meter.avg), torch.tensor(MIoU2_meter.avg), \ 78 | torch.tensor(A0_acc.avg), torch.tensor(AT_acc.avg) 79 | -------------------------------------------------------------------------------- /step/model/temporal.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import einops 4 | from einops.layers.torch import Rearrange 5 | 6 | from .helpers import ( 7 | SinusoidalPosEmb, 8 | Downsample1d, 9 | Upsample1d, 10 | Conv1dBlock, 11 | ) 12 | 13 | 14 | class ResidualTemporalBlock(nn.Module): 15 | 16 | def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=3): 17 | super().__init__() 18 | 19 | self.blocks = nn.ModuleList([ 20 | Conv1dBlock(inp_channels, out_channels, kernel_size), 21 | Conv1dBlock(out_channels, out_channels, kernel_size, if_zero=True) 22 | ]) 23 | self.time_mlp = nn.Sequential( # should be removed for Noise and Deterministic Baselines 24 | nn.Mish(), 25 | nn.Linear(embed_dim, out_channels), 26 | Rearrange('batch t -> batch t 1'), 27 | ) 28 | self.residual_conv = nn.Conv1d(inp_channels, out_channels, 1) \ 29 | if inp_channels != out_channels else nn.Identity() 30 | 31 | def forward(self, x, t): 32 | out = self.blocks[0](x) + self.time_mlp(t) # for diffusion 33 | # out = self.blocks[0](x) # for Noise and Deterministic Baselines 34 | out = self.blocks[1](out) 35 | return out + self.residual_conv(x) 36 | 37 | 38 | 39 | class TemporalUnet(nn.Module): 40 | def __init__( 41 | self, 42 | transition_dim, 43 | dim=32, 44 | dim_mults=(1, 2, 4, 8), 45 | ): 46 | super().__init__() 47 | 48 | dims = [transition_dim, *map(lambda m: dim * m, dim_mults)] 49 | in_out = list(zip(dims[:-1], dims[1:])) 50 | 51 | time_dim = dim 52 | self.time_mlp = nn.Sequential( # should be removed for Noise and Deterministic Baselines 53 | SinusoidalPosEmb(dim), 54 | nn.Linear(dim, dim * 4), 55 | nn.Mish(), 56 | nn.Linear(dim * 4, dim), 57 | ) 58 | 59 | self.downs = nn.ModuleList([]) 60 | self.ups = nn.ModuleList([]) 61 | num_resolutions = len(in_out) 62 | 63 | # print(in_out) 64 | for ind, (dim_in, dim_out) in enumerate(in_out): 65 | is_last = ind >= (num_resolutions - 1) 66 | 67 | self.downs.append(nn.ModuleList([ 68 | ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim), 69 | ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim), 70 | Downsample1d(dim_out) if not is_last else nn.Identity() 71 | ])) 72 | 73 | mid_dim = dims[-1] 74 | self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim) 75 | self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim) 76 | 77 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): 78 | is_last = ind >= (num_resolutions - 1) 79 | 80 | self.ups.append(nn.ModuleList([ 81 | ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim), 82 | ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim), 83 | Upsample1d(dim_in) if not is_last else nn.Identity() 84 | ])) 85 | 86 | self.final_conv = nn.Sequential( 87 | Conv1dBlock(dim, dim, kernel_size=3, if_zero=True), 88 | nn.Conv1d(dim, transition_dim, 1), 89 | ) 90 | 91 | def forward(self, x, time): 92 | x = einops.rearrange(x, 'b h t -> b t h') 93 | 94 | # t = None # for Noise and Deterministic Baselines 95 | t = self.time_mlp(time) # for diffusion 96 | h = [] 97 | 98 | for resnet, resnet2, downsample in self.downs: 99 | x = resnet(x, t) 100 | x = resnet2(x, t) 101 | h.append(x) 102 | x = downsample(x) 103 | 104 | 105 | x = self.mid_block1(x, t) 106 | x = self.mid_block2(x, t) 107 | 108 | for resnet, resnet2, upsample in self.ups: 109 | x = torch.cat((x, h.pop()), dim=1) 110 | x = resnet(x, t) 111 | x = resnet2(x, t) 112 | x = upsample(x) 113 | 114 | x = self.final_conv(x) 115 | x = einops.rearrange(x, 'b t h -> b h t') 116 | return x 117 | -------------------------------------------------------------------------------- /plan/model/temporal.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import einops 4 | from einops.layers.torch import Rearrange 5 | 6 | from .helpers import ( 7 | SinusoidalPosEmb, 8 | Downsample1d, 9 | Upsample1d, 10 | Conv1dBlock, 11 | ) 12 | 13 | 14 | class ResidualTemporalBlock(nn.Module): 15 | 16 | def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=3): 17 | super().__init__() 18 | 19 | self.blocks = nn.ModuleList([ 20 | Conv1dBlock(inp_channels, out_channels, kernel_size), 21 | Conv1dBlock(out_channels, out_channels, kernel_size, if_zero=True) 22 | ]) 23 | self.time_mlp = nn.Sequential( # should be removed for Noise and Deterministic Baselines 24 | nn.Mish(), 25 | nn.Linear(embed_dim, out_channels), 26 | Rearrange('batch t -> batch t 1'), 27 | ) 28 | self.residual_conv = nn.Conv1d(inp_channels, out_channels, 1) \ 29 | if inp_channels != out_channels else nn.Identity() 30 | 31 | def forward(self, x, t): 32 | out = self.blocks[0](x) + self.time_mlp(t) # for diffusion 33 | # out = self.blocks[0](x) # for Noise and Deterministic Baselines 34 | out = self.blocks[1](out) 35 | return out + self.residual_conv(x) 36 | 37 | 38 | class TemporalUnet(nn.Module): 39 | def __init__( 40 | self, 41 | transition_dim, 42 | dim=32, 43 | dim_mults=(1, 2, 4, 8), 44 | ): 45 | super().__init__() 46 | 47 | dims = [transition_dim, *map(lambda m: dim * m, dim_mults)] 48 | in_out = list(zip(dims[:-1], dims[1:])) 49 | 50 | time_dim = dim 51 | self.time_mlp = nn.Sequential( # should be removed for Noise and Deterministic Baselines 52 | SinusoidalPosEmb(dim), 53 | nn.Linear(dim, dim * 4), 54 | nn.Mish(), 55 | nn.Linear(dim * 4, dim), 56 | ) 57 | 58 | self.downs = nn.ModuleList([]) 59 | self.ups = nn.ModuleList([]) 60 | num_resolutions = len(in_out) 61 | 62 | # print(in_out) 63 | for ind, (dim_in, dim_out) in enumerate(in_out): 64 | is_last = ind >= (num_resolutions - 1) 65 | 66 | self.downs.append(nn.ModuleList([ 67 | ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim), 68 | ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim), 69 | Downsample1d(dim_out) if not is_last else nn.Identity() 70 | ])) 71 | 72 | mid_dim = dims[-1] 73 | self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim) 74 | self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim) 75 | 76 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): 77 | is_last = ind >= (num_resolutions - 1) 78 | 79 | self.ups.append(nn.ModuleList([ 80 | ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim), 81 | ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim), 82 | Upsample1d(dim_in) if not is_last else nn.Identity() 83 | ])) 84 | 85 | self.final_conv = nn.Sequential( 86 | Conv1dBlock(dim, dim, kernel_size=3, if_zero=True), 87 | nn.Conv1d(dim, transition_dim, 1), 88 | ) 89 | 90 | def forward(self, x, time): 91 | x = einops.rearrange(x, 'b h t -> b t h') 92 | 93 | # t = None # for Noise and Deterministic Baselines 94 | t = self.time_mlp(time) # for diffusion 95 | h = [] 96 | 97 | for resnet, resnet2, downsample in self.downs: 98 | x = resnet(x, t) 99 | 100 | x = resnet2(x, t) 101 | 102 | h.append(x) 103 | 104 | x = downsample(x) 105 | 106 | 107 | x = self.mid_block1(x, t) 108 | x = self.mid_block2(x, t) 109 | 110 | for resnet, resnet2, upsample in self.ups: 111 | x = torch.cat((x, h.pop()), dim=1) 112 | x = resnet(x, t) 113 | x = resnet2(x, t) 114 | x = upsample(x) 115 | 116 | x = self.final_conv(x) 117 | x = einops.rearrange(x, 'b t h -> b h t') 118 | return x 119 | -------------------------------------------------------------------------------- /dataset/NIV/test30.json: -------------------------------------------------------------------------------- 1 | [{"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0001.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0004.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0005.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0007.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0012.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0020.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0021.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0022.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0025.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0030.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0001.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0007.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0014.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0018.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0019.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0023.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0025.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0027.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0002.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0006.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0012.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0013.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0016.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0017.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0019.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0021.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0023.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0026.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0001.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0009.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0011.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0017.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0018.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0020.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0025.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0028.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0029.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/repot_0004.npy", "task_id": 4}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/repot_0015.npy", "task_id": 4}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/repot_0018.npy", "task_id": 4}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/repot_0025.npy", "task_id": 4}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/repot_0029.npy", "task_id": 4}] -------------------------------------------------------------------------------- /step/utils/training.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from model.helpers import AverageMeter 3 | from .accuracy import * 4 | 5 | 6 | def cycle(dl): 7 | while True: 8 | for data in dl: 9 | yield data 10 | 11 | 12 | class EMA(): 13 | """ 14 | empirical moving average 15 | """ 16 | 17 | def __init__(self, beta): 18 | super().__init__() 19 | self.beta = beta 20 | 21 | def update_model_average(self, ma_model, current_model): 22 | for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): 23 | old_weight, up_weight = ma_params.data, current_params.data 24 | ma_params.data = self.update_average(old_weight, up_weight) 25 | 26 | def update_average(self, old, new): 27 | if old is None: 28 | return new 29 | return old * self.beta + (1 - self.beta) * new 30 | 31 | 32 | class Trainer(object): 33 | def __init__( 34 | self, 35 | diffusion_model, 36 | datasetloader, 37 | ema_decay=0.995, 38 | train_lr=1e-5, 39 | gradient_accumulate_every=1, 40 | step_start_ema=400, 41 | update_ema_every=10, 42 | log_freq=100, 43 | ): 44 | super().__init__() 45 | self.model = diffusion_model 46 | self.ema = EMA(ema_decay) 47 | self.ema_model = copy.deepcopy(self.model) 48 | self.update_ema_every = update_ema_every 49 | 50 | self.step_start_ema = step_start_ema 51 | self.log_freq = log_freq 52 | self.gradient_accumulate_every = gradient_accumulate_every 53 | 54 | self.dataloader = cycle(datasetloader) 55 | self.optimizer = torch.optim.AdamW(diffusion_model.parameters(), lr=train_lr, weight_decay=0.0) 56 | # self.optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, diffusion_model.parameters()), lr=train_lr, weight_decay=0.0) 57 | 58 | self.reset_parameters() 59 | self.step = 0 60 | 61 | def reset_parameters(self): 62 | self.ema_model.load_state_dict(self.model.state_dict()) 63 | 64 | def step_ema(self): 65 | if self.step < self.step_start_ema: 66 | self.reset_parameters() 67 | return 68 | self.ema.update_model_average(self.ema_model, self.model) 69 | 70 | # -----------------------------------------------------------------------------# 71 | # ------------------------------------ api ------------------------------------# 72 | # -----------------------------------------------------------------------------# 73 | 74 | def train(self, n_train_steps, if_calculate_acc, args, scheduler): 75 | self.model.train() 76 | self.ema_model.train() 77 | losses = AverageMeter() 78 | self.optimizer.zero_grad() 79 | 80 | for step in range(n_train_steps): 81 | for i in range(self.gradient_accumulate_every): 82 | batch = next(self.dataloader) 83 | bs, T = batch[1].shape # [bs, (T+1), ob_dim] 84 | #print('shape: ', batch[1].shape , batch[0].shape) 85 | global_img_tensors = batch[0].cuda().contiguous().float() 86 | img_tensors = torch.zeros((bs, T, args.action_dim + args.observation_dim)) 87 | img_tensors[:, 0, args.action_dim:] = global_img_tensors[:, 0, :] 88 | img_tensors[:, -1, args.action_dim:] = global_img_tensors[:, -1, :] 89 | img_tensors = img_tensors.cuda() 90 | 91 | video_label = batch[1].view(-1).cuda() # [bs*T] 92 | #task_class = batch[2].view(-1).cuda() # [bs] 93 | 94 | action_label_onehot = torch.zeros((video_label.size(0), self.model.module.action_dim)) 95 | # [bs*T, ac_dim] 96 | ind = torch.arange(0, len(video_label)) 97 | action_label_onehot[ind, video_label] = 1. 98 | action_label_onehot = action_label_onehot.reshape(bs, T, -1).cuda() 99 | action_label_onehot[:,1:-1,:] = 0. 100 | img_tensors[:, :, :args.action_dim] = action_label_onehot 101 | 102 | ''' 103 | task_onehot = torch.zeros((task_class.size(0), args.class_dim)) 104 | # [bs*T, ac_dim] 105 | ind = torch.arange(0, len(task_class)) 106 | task_onehot[ind, task_class] = 1. 107 | task_onehot = task_onehot.cuda() 108 | temp = task_onehot.unsqueeze(1) 109 | task_class_ = temp.repeat(1, T, 1) # [bs, T, args.class_dim] 110 | img_tensors[:, :, :args.class_dim] = task_class_ 111 | ''' 112 | cond = {0: global_img_tensors[:, 0, :].float(), T - 1: global_img_tensors[:, -1, :].float()} 113 | x = img_tensors.float() 114 | loss = self.model.module.loss(x, cond) 115 | loss = loss / self.gradient_accumulate_every 116 | loss.backward() 117 | losses.update(loss.item(), bs) 118 | 119 | self.optimizer.step() 120 | self.optimizer.zero_grad() 121 | scheduler.step() 122 | 123 | if self.step % self.update_ema_every == 0: 124 | self.step_ema() 125 | self.step += 1 126 | 127 | if if_calculate_acc: 128 | with torch.no_grad(): 129 | output = self.ema_model(cond) 130 | actions_pred = output[:, :, :self.model.module.action_dim]\ 131 | .contiguous().view(-1, self.model.module.action_dim) # [bs*T, action_dim] 132 | 133 | (acc1, acc5), trajectory_success_rate, MIoU1, MIoU2, a0_acc, aT_acc = \ 134 | accuracy(actions_pred.cpu(), video_label.cpu(), topk=(1, 5), 135 | max_traj_len=self.model.module.horizon) 136 | 137 | return torch.tensor(losses.avg), acc1, acc5, torch.tensor(trajectory_success_rate), \ 138 | torch.tensor(MIoU1), torch.tensor(MIoU2), a0_acc, aT_acc 139 | 140 | else: 141 | return torch.tensor(losses.avg) 142 | -------------------------------------------------------------------------------- /plan/utils/training.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from model.helpers import AverageMeter 3 | from .accuracy import * 4 | from .one_hot import PKGLabelOnehot 5 | from .one_hot import LLMLabelOnehot 6 | 7 | 8 | def cycle(dl): 9 | while True: 10 | for data in dl: 11 | yield data 12 | 13 | 14 | class EMA(): 15 | """ 16 | empirical moving average 17 | """ 18 | 19 | def __init__(self, beta): 20 | super().__init__() 21 | self.beta = beta 22 | 23 | def update_model_average(self, ma_model, current_model): 24 | for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): 25 | old_weight, up_weight = ma_params.data, current_params.data 26 | ma_params.data = self.update_average(old_weight, up_weight) 27 | 28 | def update_average(self, old, new): 29 | if old is None: 30 | return new 31 | return old * self.beta + (1 - self.beta) * new 32 | 33 | 34 | class Trainer(object): 35 | def __init__( 36 | self, 37 | diffusion_model, 38 | datasetloader, 39 | ema_decay=0.995, 40 | train_lr=1e-5, 41 | gradient_accumulate_every=1, 42 | step_start_ema=400, 43 | update_ema_every=10, 44 | log_freq=100, 45 | ): 46 | super().__init__() 47 | self.model = diffusion_model 48 | self.ema = EMA(ema_decay) 49 | self.ema_model = copy.deepcopy(self.model) 50 | self.update_ema_every = update_ema_every 51 | 52 | self.step_start_ema = step_start_ema 53 | self.log_freq = log_freq 54 | self.gradient_accumulate_every = gradient_accumulate_every 55 | 56 | self.dataloader = cycle(datasetloader) 57 | self.optimizer = torch.optim.AdamW(diffusion_model.parameters(), lr=train_lr, weight_decay=0.0) 58 | # self.optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, diffusion_model.parameters()), lr=train_lr, weight_decay=0.0) 59 | 60 | self.reset_parameters() 61 | self.step = 0 62 | 63 | def reset_parameters(self): 64 | self.ema_model.load_state_dict(self.model.state_dict()) 65 | 66 | def step_ema(self): 67 | if self.step < self.step_start_ema: 68 | self.reset_parameters() 69 | return 70 | self.ema.update_model_average(self.ema_model, self.model) 71 | 72 | # -----------------------------------------------------------------------------# 73 | # ------------------------------------ api ------------------------------------# 74 | # -----------------------------------------------------------------------------# 75 | 76 | def train(self, n_train_steps, if_calculate_acc, args, scheduler): 77 | self.model.train() 78 | self.ema_model.train() 79 | losses = AverageMeter() 80 | self.optimizer.zero_grad() 81 | 82 | for step in range(n_train_steps): 83 | for i in range(self.gradient_accumulate_every): 84 | batch = next(self.dataloader) 85 | bs, T = batch[1].shape # [bs, (T+1), ob_dim] 86 | global_img_tensors = batch[0].cuda().contiguous().float() 87 | img_tensors = torch.zeros((bs, T, args.class_dim_graph + args.class_dim_llama + args.action_dim + args.observation_dim)) 88 | img_tensors[:, 0, args.class_dim_graph + args.class_dim_llama+args.action_dim:] = global_img_tensors[:, 0, :] 89 | img_tensors[:, -1, args.class_dim_graph + args.class_dim_llama+args.action_dim:] = global_img_tensors[:, -1, :] 90 | img_tensors = img_tensors.cuda() 91 | video_label = batch[1].view(-1).cuda() # [bs*T] 92 | LLM_label = batch[2].cuda() # [bs] 93 | PKG_label = batch[3].cuda() 94 | llm_label_onehot = LLMLabelOnehot(bs, T, args.num_seq_LLM,[2/3, 1/3]) 95 | pkg_label_onehot = PKGLabelOnehot(bs, T, args.num_seq_PKG,[2/3, 1/3]) 96 | 97 | 98 | action_label_onehot = torch.zeros((video_label.size(0), self.model.module.action_dim)) 99 | # [bs*T, ac_dim] 100 | ind = torch.arange(0, len(video_label)) 101 | action_label_onehot[ind, video_label] = 1. 102 | action_label_onehot = action_label_onehot.reshape(bs, T, -1).cuda() 103 | img_tensors[:, :, args.class_dim_graph + args.class_dim_llama:args.class_dim_graph + args.class_dim_llama+args.action_dim] = action_label_onehot 104 | 105 | 106 | PKG_label_onehot = pkg_label_onehot(PKG_label) 107 | LLM_label_onehot = llm_label_onehot(LLM_label) 108 | 109 | 110 | img_tensors[:, :, args.class_dim_graph : args.class_dim_graph + args.class_dim_llama] = LLM_label_onehot 111 | img_tensors[:, :, : args.class_dim_graph] = PKG_label_onehot 112 | 113 | cond = {0: global_img_tensors[:, 0, :].float(), T - 1: global_img_tensors[:, -1, :].float(),'LLM_action_path': LLM_label_onehot , 'PKG_action_path': PKG_label_onehot} 114 | x = img_tensors.float() 115 | 116 | loss = self.model.module.loss(x, cond) 117 | loss = loss / self.gradient_accumulate_every 118 | loss.backward() 119 | losses.update(loss.item(), bs) 120 | 121 | self.optimizer.step() 122 | self.optimizer.zero_grad() 123 | scheduler.step() 124 | 125 | if self.step % self.update_ema_every == 0: 126 | self.step_ema() 127 | self.step += 1 128 | 129 | if if_calculate_acc: 130 | with torch.no_grad(): 131 | output = self.ema_model(cond) 132 | actions_pred = output[:, :, args.class_dim_graph + args.class_dim_llama:args.class_dim_graph + args.class_dim_llama+self.model.module.action_dim]\ 133 | .contiguous().view(-1, self.model.module.action_dim) # [bs*T, action_dim] 134 | 135 | (acc1, acc5), trajectory_success_rate, MIoU1, MIoU2, a0_acc, aT_acc = \ 136 | accuracy(actions_pred.cpu(), video_label.cpu(), topk=(1, 5), 137 | max_traj_len=self.model.module.horizon) 138 | 139 | return torch.tensor(losses.avg), acc1, acc5, torch.tensor(trajectory_success_rate), \ 140 | torch.tensor(MIoU1), torch.tensor(MIoU2), a0_acc, aT_acc 141 | 142 | else: 143 | return torch.tensor(losses.avg) 144 | -------------------------------------------------------------------------------- /step/dataloader/data_load.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import Dataset 5 | import json 6 | from collections import namedtuple 7 | 8 | 9 | Batch = namedtuple('Batch', 'Observations Actions') 10 | 11 | 12 | class PlanningDataset(Dataset): 13 | """ 14 | load video and action features from dataset 15 | """ 16 | 17 | def __init__(self, 18 | root, 19 | args=None, 20 | is_val=False, 21 | model=None, 22 | ): 23 | self.is_val = is_val 24 | self.data_root = root 25 | self.args = args 26 | self.max_traj_len = args.horizon 27 | self.vid_names = None 28 | self.frame_cnts = None 29 | self.images = None 30 | self.last_vid = '' 31 | 32 | if args.dataset == 'crosstask': 33 | if is_val: 34 | cross_task_data_name = args.json_path_val 35 | print('cross task data name val', cross_task_data_name) 36 | 37 | else: 38 | cross_task_data_name = args.json_path_train 39 | print('cross task data name train', cross_task_data_name) 40 | 41 | 42 | if os.path.exists(cross_task_data_name): 43 | with open(cross_task_data_name, 'r') as f: 44 | self.json_data = json.load(f) 45 | print('Loaded {}'.format(cross_task_data_name)) 46 | else: 47 | assert 0 48 | elif args.dataset == 'coin': 49 | if is_val: 50 | coin_data_name = args.json_path_val 51 | 52 | else: 53 | coin_data_name = args.json_path_train 54 | 55 | if os.path.exists(coin_data_name): 56 | with open(coin_data_name, 'r') as f: 57 | self.json_data = json.load(f) 58 | print('Loaded {}'.format(coin_data_name)) 59 | else: 60 | assert 0 61 | elif args.dataset == 'NIV': 62 | if is_val: 63 | niv_data_name = args.json_path_val 64 | 65 | else: 66 | niv_data_name = args.json_path_train 67 | 68 | if os.path.exists(niv_data_name): 69 | with open(niv_data_name, 'r') as f: 70 | self.json_data = json.load(f) 71 | print('Loaded {}'.format(niv_data_name)) 72 | else: 73 | assert 0 74 | else: 75 | raise NotImplementedError( 76 | 'Dataset {} is not implemented'.format(args.dataset)) 77 | 78 | self.model = model 79 | self.prepare_data() 80 | self.M = 3 81 | 82 | def prepare_data(self): 83 | vid_names = [] 84 | frame_cnts = [] 85 | for listdata in self.json_data: 86 | vid_names.append(listdata['id']) 87 | frame_cnts.append(listdata['instruction_len']) 88 | self.vid_names = vid_names 89 | self.frame_cnts = frame_cnts 90 | print('vid name list length', len(vid_names)) 91 | 92 | 93 | def curate_dataset(self, images, legal_range, M=2): 94 | images_list = [] 95 | labels_onehot_list = [] 96 | idx_list = [] 97 | for start_idx, end_idx, action_label in legal_range: 98 | idx = start_idx 99 | idx_list.append(idx) 100 | image_start_idx = max(0, idx) 101 | 102 | if image_start_idx + M <= len(images): 103 | #image_start = images[image_start_idx: image_start_idx + M] 104 | 105 | if image_start_idx == 0: 106 | image_start = images[image_start_idx: image_start_idx + M] 107 | else: 108 | image_start = images[image_start_idx - 1: image_start_idx + M - 1] ############################### Modified to load data similar to other papers 109 | 110 | else: 111 | image_start = images[len(images) - M: len(images)] 112 | image_start_cat = image_start[0] 113 | for w in range(len(image_start) - 1): 114 | image_start_cat = np.concatenate((image_start_cat, image_start[w + 1]), axis=0) 115 | images_list.append(image_start_cat) 116 | labels_onehot_list.append(action_label) 117 | 118 | end_idx = max(2, end_idx) 119 | #image_end = images[end_idx - 2:end_idx + M - 2] 120 | 121 | if end_idx >= len(images)-1: 122 | 123 | image_end = images[end_idx - 2:end_idx + M - 2] 124 | else: 125 | 126 | image_end = images[end_idx - 1:end_idx + M - 1] #########################Modified to load data similar to other papers ##################################### 127 | 128 | image_end_cat = image_end[0] 129 | for w in range(len(image_end) - 1): 130 | image_end_cat = np.concatenate((image_end_cat, image_end[w + 1]), axis=0) 131 | images_list.append(image_end_cat) 132 | 133 | return images_list, labels_onehot_list, idx_list 134 | 135 | def sample_single(self, index): 136 | folder_id = self.vid_names[index] 137 | if self.args.dataset == 'crosstask': 138 | if folder_id['vid'] != self.last_vid: 139 | images_ = np.load(folder_id['feature'], allow_pickle=True) 140 | self.images = images_['frames_features'] 141 | self.last_vid = folder_id['vid'] 142 | else: 143 | images_ = np.load(folder_id['feature'], allow_pickle=True) 144 | self.images = images_['frames_features'] 145 | images, labels_matrix, idx_list = self.curate_dataset( 146 | self.images, folder_id['legal_range'], M=self.M) 147 | 148 | shapes = [arr.shape for arr in images] 149 | 150 | frames = torch.tensor(np.array(images)) 151 | labels_tensor = torch.tensor(labels_matrix, dtype=torch.long) 152 | 153 | return frames, labels_tensor 154 | 155 | def __getitem__(self, index): 156 | if self.is_val: 157 | frames, labels = self.sample_single(index) 158 | 159 | else: 160 | frames, labels = self.sample_single(index) 161 | if self.is_val: 162 | batch = Batch(frames, labels) 163 | else: 164 | batch = Batch(frames, labels) 165 | return batch 166 | 167 | def __len__(self): 168 | return len(self.json_data) -------------------------------------------------------------------------------- /plan/dataloader/data_load.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import Dataset 5 | import json 6 | from collections import namedtuple 7 | 8 | 9 | Batch = namedtuple('Batch', 'Observations Actions LLM Knowledge') 10 | 11 | 12 | class PlanningDataset(Dataset): 13 | """ 14 | load video and action features from dataset 15 | """ 16 | 17 | def __init__(self, 18 | root, 19 | args=None, 20 | is_val=False, 21 | model=None, 22 | ): 23 | self.is_val = is_val 24 | self.data_root = root 25 | self.args = args 26 | self.max_traj_len = args.horizon 27 | self.vid_names = None 28 | self.frame_cnts = None 29 | self.images = None 30 | self.last_vid = '' 31 | 32 | if args.dataset == 'crosstask': 33 | if is_val: 34 | cross_task_data_name = args.json_path_val 35 | print('cross task data name val', cross_task_data_name) 36 | 37 | else: 38 | cross_task_data_name = args.json_path_train 39 | print('cross task data name train', cross_task_data_name) 40 | 41 | if os.path.exists(cross_task_data_name): 42 | with open(cross_task_data_name, 'r') as f: 43 | self.json_data = json.load(f) 44 | print('Loaded {}'.format(cross_task_data_name)) 45 | else: 46 | assert 0 47 | elif args.dataset == 'coin': 48 | if is_val: 49 | coin_data_name = args.json_path_val 50 | 51 | else: 52 | coin_data_name = args.json_path_train 53 | 54 | if os.path.exists(coin_data_name): 55 | with open(coin_data_name, 'r') as f: 56 | self.json_data = json.load(f) 57 | print('Loaded {}'.format(coin_data_name)) 58 | else: 59 | assert 0 60 | elif args.dataset == 'NIV': 61 | if is_val: 62 | niv_data_name = args.json_path_val 63 | 64 | else: 65 | niv_data_name = args.json_path_train 66 | 67 | if os.path.exists(niv_data_name): 68 | with open(niv_data_name, 'r') as f: 69 | self.json_data = json.load(f) 70 | print('Loaded {}'.format(niv_data_name)) 71 | else: 72 | assert 0 73 | else: 74 | raise NotImplementedError( 75 | 'Dataset {} is not implemented'.format(args.dataset)) 76 | 77 | self.model = model 78 | self.prepare_data() 79 | self.M = 3 80 | 81 | def prepare_data(self): 82 | vid_names = [] 83 | frame_cnts = [] 84 | for listdata in self.json_data: 85 | vid_names.append(listdata['id']) 86 | frame_cnts.append(listdata['instruction_len']) 87 | self.vid_names = vid_names 88 | self.frame_cnts = frame_cnts 89 | print('vid name list length', len(vid_names)) 90 | 91 | 92 | def curate_dataset(self, images, legal_range, M=2): 93 | images_list = [] 94 | labels_onehot_list = [] 95 | idx_list = [] 96 | for start_idx, end_idx, action_label in legal_range: 97 | idx = start_idx 98 | idx_list.append(idx) 99 | image_start_idx = max(0, idx) 100 | if image_start_idx + M <= len(images): 101 | #image_start = images[image_start_idx: image_start_idx + M] 102 | 103 | if image_start_idx == 0: 104 | image_start = images[image_start_idx: image_start_idx + M] 105 | else: 106 | image_start = images[image_start_idx - 1: image_start_idx + M - 1] ############################### Modified to load data similar to other papers 107 | 108 | else: 109 | image_start = images[len(images) - M: len(images)] 110 | image_start_cat = image_start[0] 111 | for w in range(len(image_start) - 1): 112 | image_start_cat = np.concatenate((image_start_cat, image_start[w + 1]), axis=0) 113 | images_list.append(image_start_cat) 114 | labels_onehot_list.append(action_label) 115 | 116 | end_idx = max(2, end_idx) 117 | #image_end = images[end_idx - 2:end_idx + M - 2] 118 | 119 | if end_idx >= len(images)-1: 120 | 121 | image_end = images[end_idx - 2:end_idx + M - 2] 122 | else: 123 | 124 | image_end = images[end_idx - 1:end_idx + M - 1] #########################Modified to load data similar to other papers ##################################### 125 | 126 | image_end_cat = image_end[0] 127 | for w in range(len(image_end) - 1): 128 | image_end_cat = np.concatenate((image_end_cat, image_end[w + 1]), axis=0) 129 | images_list.append(image_end_cat) 130 | return images_list, labels_onehot_list, idx_list 131 | 132 | def sample_single(self, index): 133 | folder_id = self.vid_names[index] 134 | 135 | graph_path = folder_id['graph_action_path'] 136 | LLM_path = folder_id['graph_action_path'] ############################################################## Modify this into LLM_action_path if LLM_actions_are_available############################# 137 | 138 | 139 | if self.args.dataset == 'crosstask': 140 | if folder_id['vid'] != self.last_vid: 141 | images_ = np.load(folder_id['feature'], allow_pickle=True) 142 | self.images = images_['frames_features'] 143 | self.last_vid = folder_id['vid'] 144 | else: 145 | images_ = np.load(folder_id['feature'], allow_pickle=True) 146 | self.images = images_['frames_features'] 147 | images, labels_matrix, idx_list = self.curate_dataset( 148 | self.images, folder_id['legal_range'], M=self.M) 149 | 150 | 151 | # Get the shape of each array 152 | shapes = [arr.shape for arr in images] 153 | 154 | frames = torch.tensor(np.array(images)) 155 | labels_tensor = torch.tensor(labels_matrix, dtype=torch.long) 156 | 157 | graph_path = torch.tensor(graph_path, dtype=torch.long) 158 | 159 | LLM_path = torch.tensor(LLM_path, dtype=torch.long) 160 | return frames, labels_tensor, LLM_path, graph_path 161 | 162 | def __getitem__(self, index): 163 | frames, labels, LLM_path, graph_path = self.sample_single(index) 164 | batch = Batch(frames, labels, LLM_path, graph_path) 165 | return batch 166 | 167 | def __len__(self): 168 | return len(self.json_data) -------------------------------------------------------------------------------- /plan/utils/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_args(description='whl'): 5 | parser = argparse.ArgumentParser(description=description) 6 | parser.add_argument('--checkpoint_root', 7 | type=str, 8 | default='checkpoint', 9 | help='checkpoint dir root') 10 | parser.add_argument('--log_root', 11 | type=str, 12 | default='log', 13 | help='log dir root') 14 | parser.add_argument('--checkpoint_dir', 15 | type=str, 16 | default='', 17 | help='checkpoint model folder') 18 | parser.add_argument('--optimizer', 19 | type=str, 20 | default='adam', 21 | help='opt algorithm') 22 | parser.add_argument('--num_thread_reader', 23 | type=int, 24 | default=40, 25 | help='') 26 | parser.add_argument('--batch_size', 27 | type=int, 28 | default=256, 29 | help='batch size') 30 | parser.add_argument('--batch_size_val', 31 | type=int, 32 | default=1024, 33 | help='batch size eval') 34 | parser.add_argument('--pretrain_cnn_path', 35 | type=str, 36 | default='', 37 | help='') 38 | parser.add_argument('--momemtum', 39 | type=float, 40 | default=0.9, 41 | help='SGD momemtum') 42 | parser.add_argument('--log_freq', 43 | type=int, 44 | default=500, 45 | help='how many steps do we log once') 46 | parser.add_argument('--save_freq', 47 | type=int, 48 | default=1, 49 | help='how many epochs do we save once') 50 | parser.add_argument('--gradient_accumulate_every', 51 | type=int, 52 | default=1, 53 | help='accumulation_steps') 54 | parser.add_argument('--ema_decay', 55 | type=float, 56 | default=0.995, 57 | help='') 58 | parser.add_argument('--step_start_ema', 59 | type=int, 60 | default=400, 61 | help='') 62 | parser.add_argument('--update_ema_every', 63 | type=int, 64 | default=10, 65 | help='') 66 | parser.add_argument('--crop_only', 67 | type=int, 68 | default=1, 69 | help='random seed') 70 | parser.add_argument('--centercrop', 71 | type=int, 72 | default=0, 73 | help='random seed') 74 | parser.add_argument('--random_flip', 75 | type=int, 76 | default=1, 77 | help='random seed') 78 | parser.add_argument('--verbose', 79 | type=int, 80 | default=1, 81 | help='') 82 | parser.add_argument('--fps', 83 | type=int, 84 | default=1, 85 | help='') 86 | parser.add_argument('--cudnn_benchmark', 87 | type=int, 88 | default=0, 89 | help='') 90 | parser.add_argument('--horizon', 91 | type=int, 92 | default=4, 93 | help='') 94 | parser.add_argument('--dataset', 95 | type=str, 96 | default='crosstask', 97 | #default='coin', 98 | help='dataset') 99 | parser.add_argument('--action_dim', 100 | type=int, 101 | default=105, 102 | help='') 103 | parser.add_argument('--observation_dim', 104 | type=int, 105 | default=1536, 106 | help='') 107 | parser.add_argument('--class_dim_graph', 108 | type=int, 109 | default=105, ####### modify according to the dataset. 110 | help='') 111 | parser.add_argument('--class_dim_llama', 112 | type=int, 113 | default=105, ####### modify according to the dataset. 114 | help='') 115 | parser.add_argument('--num_seq_PKG', 116 | type=int, 117 | default=2, ####### modify according to the dataset. 118 | help='') 119 | parser.add_argument('--num_seq_LLM', 120 | type=int, 121 | default=2, ####### modify according to the dataset. 122 | help='') 123 | parser.add_argument('--n_diffusion_steps', 124 | type=int, 125 | default=200, 126 | help='') 127 | parser.add_argument('--n_train_steps', 128 | type=int, 129 | default=200, 130 | help='training_steps_per_epoch') 131 | parser.add_argument('--root', 132 | type=str, 133 | default='/l/users/ravindu.nagasinghe/KEPP/datasets/PDPP/CrossTask_assets', 134 | help='root path of dataset crosstask') 135 | parser.add_argument('--json_path_train', 136 | type=str, 137 | default='/l/users/ravindu.nagasinghe/KEPP/plan/data_lists/train_list.json', 138 | help='path of the generated json file for train') 139 | parser.add_argument('--json_path_val', 140 | type=str, 141 | default='/l/users/ravindu.nagasinghe/KEPP/plan/data_lists/test_list.json', 142 | help='path of the generated json file for val') 143 | 144 | parser.add_argument('--epochs', default=120, type=int, metavar='N', 145 | help='number of total epochs to run') 146 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 147 | help='manual epoch number (useful on restarts)') 148 | parser.add_argument('--lr', '--learning-rate', default=5e-4, type=float, 149 | metavar='LR', help='initial learning rate', dest='lr') 150 | parser.add_argument('--resume', dest='resume', action='store_true', 151 | help='resume training from last checkpoint') 152 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 153 | help='evaluate model on validation set') 154 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 155 | help='use pre-trained model') 156 | parser.add_argument('--pin_memory', dest='pin_memory', action='store_true', 157 | help='use pin_memory') 158 | parser.add_argument('--world-size', default=1, type=int, 159 | help='number of nodes for distributed training') 160 | parser.add_argument('--rank', default=0, type=int, 161 | help='node rank for distributed training') 162 | parser.add_argument('--dist-file', default='dist-file', type=str, 163 | help='url used to set up distributed training') 164 | parser.add_argument('--dist-url', default='tcp://localhost:21712', type=str, 165 | help='url used to set up distributed training') 166 | parser.add_argument('--dist-backend', default='nccl', type=str, 167 | help='distributed backend') 168 | parser.add_argument('--seed', default=217, type=int, 169 | help='seed for initializing training. ') 170 | parser.add_argument('--gpu', default=None, type=int, 171 | help='GPU id to use.') 172 | parser.add_argument('--multiprocessing-distributed', action='store_true', 173 | help='Use multi-processing distributed training to launch ' 174 | 'N processes per node, which has N GPUs. This is the ' 175 | 'fastest way to use PyTorch for either single node or ' 176 | 'multi node data parallel training') 177 | args = parser.parse_args() 178 | return args -------------------------------------------------------------------------------- /step/utils/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_args(description='whl'): 5 | parser = argparse.ArgumentParser(description=description) 6 | parser.add_argument('--checkpoint_root', 7 | type=str, 8 | default='checkpoint', 9 | help='checkpoint dir root') 10 | parser.add_argument('--log_root', 11 | type=str, 12 | default='log', 13 | help='log dir root') 14 | parser.add_argument('--checkpoint_dir', 15 | type=str, 16 | default='', 17 | help='checkpoint model folder') 18 | parser.add_argument('--optimizer', 19 | type=str, 20 | default='adam', 21 | help='opt algorithm') 22 | parser.add_argument('--num_thread_reader', 23 | type=int, 24 | default=40, 25 | help='') 26 | parser.add_argument('--batch_size', 27 | type=int, 28 | default=256, 29 | help='batch size') 30 | parser.add_argument('--batch_size_val', 31 | type=int, 32 | default=1024, 33 | help='batch size eval') 34 | parser.add_argument('--pretrain_cnn_path', 35 | type=str, 36 | default='', 37 | help='') 38 | parser.add_argument('--momemtum', 39 | type=float, 40 | default=0.9, 41 | help='SGD momemtum') 42 | parser.add_argument('--log_freq', 43 | type=int, 44 | default=500, 45 | help='how many steps do we log once') 46 | parser.add_argument('--save_freq', 47 | type=int, 48 | default=1, 49 | help='how many epochs do we save once') 50 | parser.add_argument('--gradient_accumulate_every', 51 | type=int, 52 | default=1, 53 | help='accumulation_steps') 54 | parser.add_argument('--ema_decay', 55 | type=float, 56 | default=0.995, 57 | help='') 58 | parser.add_argument('--step_start_ema', 59 | type=int, 60 | default=400, 61 | help='') 62 | parser.add_argument('--update_ema_every', 63 | type=int, 64 | default=10, 65 | help='') 66 | parser.add_argument('--crop_only', 67 | type=int, 68 | default=1, 69 | help='random seed') 70 | parser.add_argument('--centercrop', 71 | type=int, 72 | default=0, 73 | help='random seed') 74 | parser.add_argument('--random_flip', 75 | type=int, 76 | default=1, 77 | help='random seed') 78 | parser.add_argument('--verbose', 79 | type=int, 80 | default=1, 81 | help='') 82 | parser.add_argument('--fps', 83 | type=int, 84 | default=1, 85 | help='') 86 | parser.add_argument('--cudnn_benchmark', 87 | type=int, 88 | default=0, 89 | help='') 90 | parser.add_argument('--horizon', 91 | type=int, 92 | default=4, 93 | help='') 94 | parser.add_argument('--dataset', 95 | type=str, 96 | default='crosstask', 97 | #default='coin', 98 | help='dataset') 99 | parser.add_argument('--action_dim', 100 | type=int, 101 | default=105, 102 | #default=778, 103 | help='') 104 | parser.add_argument('--observation_dim', 105 | type=int, 106 | default=1536, 107 | help='') 108 | parser.add_argument('--n_diffusion_steps', 109 | type=int, 110 | default=200, 111 | help='') 112 | parser.add_argument('--n_train_steps', 113 | type=int, 114 | default=200, 115 | help='training_steps_per_epoch') 116 | parser.add_argument('--root', 117 | type=str, 118 | default='/l/users/ravindu.nagasinghe/KEPP/datasets/CrossTask_assets', 119 | help='root path of dataset crosstask') 120 | parser.add_argument('--json_path_train', 121 | type=str, 122 | default='/l/users/ravindu.nagasinghe/KEPP/step/data_lists/train_data_list.json', 123 | help='path of the generated json file for train') 124 | parser.add_argument('--json_path_val', 125 | type=str, 126 | default='/l/users/ravindu.nagasinghe/KEPP/step/data_lists/test_data_list.json', #modify for train set /l/users/ravindu.nagasinghe/KEPP/step/data_lists/train_data_list.json 127 | help='path of the generated json file for val') 128 | parser.add_argument('--steps_path', 129 | type=str, 130 | default='/l/users/ravindu.nagasinghe/KEPP/step/test_list_steps.json', # modify for train set '/l/users/ravindu.nagasinghe/KEPP/step/train_list_steps.json' 131 | help='the path for predicted steps only') 132 | parser.add_argument('--step_model_output', 133 | type=str, 134 | default='/l/users/ravindu.nagasinghe/KEPP/step/test_data_step_model.json', # modify for train set '/l/users/ravindu.nagasinghe/KEPP/step/train_data_step_model.json' 135 | help='the path for predicted steps final output') 136 | 137 | parser.add_argument('--epochs', default=120, type=int, metavar='N', 138 | help='number of total epochs to run') 139 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 140 | help='manual epoch number (useful on restarts)') 141 | parser.add_argument('--lr', '--learning-rate', default=5e-4, type=float, 142 | metavar='LR', help='initial learning rate', dest='lr') 143 | parser.add_argument('--resume', dest='resume', action='store_true', 144 | help='resume training from last checkpoint') 145 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 146 | help='evaluate model on validation set') 147 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 148 | help='use pre-trained model') 149 | parser.add_argument('--pin_memory', dest='pin_memory', action='store_true', 150 | help='use pin_memory') 151 | parser.add_argument('--world-size', default=1, type=int, 152 | help='number of nodes for distributed training') 153 | parser.add_argument('--rank', default=0, type=int, 154 | help='node rank for distributed training') 155 | parser.add_argument('--dist-file', default='dist-file', type=str, 156 | help='url used to set up distributed training') 157 | parser.add_argument('--dist-url', default='tcp://localhost:21712', type=str, 158 | help='url used to set up distributed training') 159 | parser.add_argument('--dist-backend', default='nccl', type=str, 160 | help='distributed backend') 161 | parser.add_argument('--seed', default=217, type=int, 162 | help='seed for initializing training. ') 163 | parser.add_argument('--gpu', default=None, type=int, 164 | help='GPU id to use.') 165 | parser.add_argument('--multiprocessing-distributed', action='store_true', 166 | help='Use multi-processing distributed training to launch ' 167 | 'N processes per node, which has N GPUs. This is the ' 168 | 'fastest way to use PyTorch for either single node or ' 169 | 'multi node data parallel training') 170 | args = parser.parse_args() 171 | return args -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # KEPP: Why Not Use Your Textbook? Knowledge-Enhanced Procedure Planning of Instructional Videos-CVPR 2024 2 | 3 | [![paper](https://img.shields.io/badge/arXiv-Paper-42FF33)](https://arxiv.org/abs/2403.02782) 4 | [![Project Page](https://img.shields.io/badge/Project-Page-blue)](https://ravindu-yasas-nagasinghe.github.io/KEPP-Project_Page/) 5 | 6 | This repository gives the official implementation of KEPP:Why Not Use Your Textbook? Knowledge-Enhanced Procedure Planning of Instructional Videos (CVPR 2024) 7 | 8 | In our project, we explore the capability of an agent to construct a logical sequence of action steps, thereby assembling a strategic procedural plan. This plan is crucial for navigating from an initial visual observation to a target visual outcome, as depicted in real-life instructional videos. Existing works have attained partial success by extensively leveraging various sources of information available in the datasets, such as heavy intermediate visual observations, procedural names, or natural language step-by-step instructions, for features or supervision signals. However, the task remains formidable due to the implicit causal constraints in the sequencing of steps and the variability inherent in multiple feasible plans. To tackle these intricacies that previous efforts have overlooked, we propose to enhance the agent's capabilities by infusing it with procedural knowledge. This knowledge, sourced from training procedure plans and structured as a directed weighted graph, equips the agent to better navigate the complexities of step sequencing and its potential variations. We coin our approach KEPP, a novel Knowledge-Enhanced Procedure Planning system, which harnesses a probabilistic procedural knowledge graph extracted from training data, effectively acting as a comprehensive textbook for the training domain. Experimental evaluations across three widely-used datasets under settings of varying complexity reveal that KEPP attains superior, state-of-the-art results while requiring only minimal supervision. The main architecture of our model is as follows. 9 | 10 | 21 | 22 | ![kepp (2)_page-0001](https://github.com/Ravindu-Yasas-Nagasinghe/KEPP/assets/56619402/ef7a12f5-bf7d-461d-a03b-4630ccd23751) 23 | 24 | ### Contents 25 | 1) [Setup](#Setup) 26 | 2) [Data Preparation](#Data-Preparation) 27 | 3) [Train Step model](#Train-Step-model) 28 | 4) [Generate paths from procedure knowlege graph](#Generate-paths-from-procedure-knowlege-graph) 29 | 5) [Inference](#Inference) 30 | ## Setup 31 | In a conda env with cuda available, run: 32 | ```shell 33 | pip install -r requirements.txt 34 | ``` 35 | ## Data Preparation 36 | ### CrossTask 37 | 1. Download datasets&features 38 | ```shell 39 | cd {root}/dataset/crosstask 40 | bash download.sh 41 | ``` 42 | 2. move your datasplit files and action one-hot coding file to `{root}/dataset/crosstask/crosstask_release/` 43 | ```shell 44 | mv *.json crosstask_release 45 | mv actions_one_hot.npy crosstask_release 46 | ``` 47 | ### COIN 48 | 1. Download datasets&features 49 | ```shell 50 | cd {root}/dataset/coin 51 | bash download.sh 52 | ``` 53 | ### NIV 54 | 1. Download datasets&features 55 | ```shell 56 | cd {root}/dataset/NIV 57 | bash download.sh 58 | ``` 59 | ## Train Step model 60 | 1. First generate the training and testing dataset json files. You can modify the dataset, train steps, horizon(prediction length), json files savepath etc. in `args.py`. Set the `--json_path_train`, and `--json_path_val` in `args.py` as the dataset json file paths. 61 | ```shell 62 | cd {root}/step 63 | python loading_data.py 64 | ``` 65 | Dimensions for different datasets are listed below: 66 | 67 | | Dataset | observation_dim | action_dim | class_dim | 68 | |----| ----| ----| ----| 69 | | CrossTask | 1536(how) 9600(base) | 105 | 18 | 70 | | COIN | 1536 | 778 | 180 | 71 | | NIV | 1536 | 48 | 5 | 72 | 73 | 2. Train the step model 74 | ```shell 75 | python main_distributed.py --multiprocessing-distributed --num_thread_reader=8 --cudnn_benchmark=1 --pin_memory --checkpoint_dir=whl --resume --batch_size=256 --batch_size_val=256 --evaluate 76 | ``` 77 | The trained models will be saved in {root}/step/save_max. 78 | 79 | 3. Generate first and last action predictions for train and test dataset. 80 | * Modify the checkpoint path(L329) as the evaluated model(in save_max) in inference.py. 81 | * Modify the `--json_path_val` , `--steps_path` , and `--step_model_output` arguments in `args.py` to generate step predicted dataset json file paths for train and test datasets seperately. Run following command for train and test datasets seperately by modifying as afore mentioned. 82 | 83 | ```shell 84 | python inference.py --multiprocessing-distributed --num_thread_reader=8 --cudnn_benchmark=1 --pin_memory --checkpoint_dir=whl --resume --batch_size=256 --batch_size_val=256 --evaluate > output.txt 85 | ``` 86 | ## Generate paths from procedure knowlege graph 87 | 1. Train the graph for the relavent dataset (Not compulsory) 88 | ```shell 89 | cd {root}/PKG 90 | python graph_creation.py 91 | Select mode "train_out_n" 92 | ``` 93 | Trained graphs for CrossTask, COIN, NIV datasets are available on `cd {root}/PKG/graphs`. Change (L13) `graph_save_path` of `graph_creation.py` to load procedure knowledge graphs trained on different datasets. 94 | 95 | 2. Obtain PKG conditions for train and test datasets. 96 | * Modify line 540 of `graph_creation.py` as the output of step model (`--step_model_output`). 97 | * Modify line 568 of `graph_creation.py` to set the output path for the generated procedure knowlwdge graph conditioned train and test dataset json files. 98 | * run the following for both train and test dataset files generated from the step model by modifying `graph_creation.py` file as afore mentioned. 99 | ```shell 100 | python graph_creation.py 101 | Select mode "validate" 102 | ``` 103 | ## Train plan model 104 | 1. Modify the `json_path_train` and `json_path_val` arguments of `args.py` in plan model as the outputs generated from procedure knowlwdge graph for train and test data respectively. 105 | 106 | Modify the parameter `--num_seq_PKG` in `args.py` to match the generated amount of PKG conditions. (Modify `--num_seq_LLM` to the same number as well if LLM conditions are not used seperately.) 107 | ```shell 108 | cd {root}/plan 109 | python main_distributed.py --multiprocessing-distributed --num_thread_reader=8 --cudnn_benchmark=1 --pin_memory --checkpoint_dir=whl --resume --batch_size=256 --batch_size_val=256 --evaluate 110 | ``` 111 | ## Inference 112 | 113 | For Metrics 114 | ​Modify the max checkpoint path(L339) as the evaluated model in inference.py and run: 115 | ```shell 116 | python inference.py --multiprocessing-distributed --num_thread_reader=8 --cudnn_benchmark=1 --pin_memory --checkpoint_dir=whl --resume --batch_size=256 --batch_size_val=256 --evaluate > output.txt 117 | ``` 118 | Results of given checkpoints: 119 | 120 | | dataset | SR | mAcc | MIoU | 121 | | ---- | -- | -- | -- | 122 | | Crosstask_T=4 | 21.02 | 56.08 | 64.15 | 123 | | COIN_T=4 | 15.63 | 39.53 | 53.27 | 124 | | NIV_T=4 | 22.71 | 41.59 | 91.49 | 125 | 126 | Here we present the qualitative examples of our proposed method. Intermediate steps are padded in the step model because it only predicts the start and end actions. 127 | 128 |

129 | 130 | 131 |

132 | 133 | 134 | Checkpoint links will be uploaded soon 135 | 136 | 137 | ### Citation 138 | ```shell 139 | @InProceedings{Nagasinghe_2024_CVPR, 140 | author = {Nagasinghe, Kumaranage Ravindu Yasas and Zhou, Honglu and Gunawardhana, Malitha and Min, Martin Renqiang and Harari, Daniel and Khan, Muhammad Haris}, 141 | title = {Why Not Use Your Textbook? Knowledge-Enhanced Procedure Planning of Instructional Videos}, 142 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 143 | month = {June}, 144 | year = {2024}, 145 | pages = {18816-18826} 146 | } 147 | ``` 148 | ### Contact 149 | In case of any query, create issue or contact ravindunagasinghe1998@gmail.com 150 | 151 | ### Acknowledgement 152 | * This work was supported by joint MBZUAI-WIS grant P007. The authors are grateful for their generous support, which made this research possible. 153 | * This codebase is built on PDPP 154 | 155 | -------------------------------------------------------------------------------- /step/model/helpers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from einops.layers.torch import Rearrange 6 | from torch.optim.lr_scheduler import LambdaLR 7 | import os 8 | import numpy as np 9 | import logging 10 | from tensorboardX import SummaryWriter 11 | 12 | 13 | # -----------------------------------------------------------------------------# 14 | # ---------------------------------- modules ----------------------------------# 15 | # -----------------------------------------------------------------------------# 16 | 17 | def zero_module(module): 18 | """ 19 | Zero out the parameters of a module and return it. 20 | """ 21 | for p in module.parameters(): 22 | p.detach().zero_() 23 | return module 24 | 25 | 26 | class SinusoidalPosEmb(nn.Module): 27 | def __init__(self, dim): 28 | super().__init__() 29 | self.dim = dim 30 | 31 | def forward(self, x): 32 | device = x.device 33 | half_dim = self.dim // 2 34 | emb = math.log(10000) / (half_dim - 1) 35 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 36 | emb = x[:, None] * emb[None, :] 37 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 38 | return emb 39 | 40 | 41 | class Downsample1d(nn.Module): 42 | def __init__(self, dim): 43 | super().__init__() 44 | print(dim, 'down!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!') 45 | self.conv = nn.Conv1d(dim, dim, 2, 1, 0) #####Edited kernal size previously had (dim, dim, 2, 1, 0) . make 2 to 1 when horizon is 2. 46 | 47 | def forward(self, x): 48 | return self.conv(x) 49 | 50 | 51 | class Upsample1d(nn.Module): 52 | def __init__(self, dim): 53 | super().__init__() 54 | self.conv = nn.ConvTranspose1d(dim, dim, 2, 1, 0) #####Edited kernal size previously had (dim, dim, 2, 1, 0) . make 2 to 1 when horizon is 2. 55 | 56 | def forward(self, x): 57 | return self.conv(x) 58 | 59 | 60 | class Conv1dBlock(nn.Module): 61 | """ 62 | Conv1d --> GroupNorm --> Mish 63 | """ 64 | 65 | def __init__(self, inp_channels, out_channels, kernel_size, n_groups=32, drop_out=0.0, if_zero=False): 66 | super().__init__() 67 | if drop_out > 0.0: 68 | self.block = nn.Sequential( 69 | zero_module( 70 | nn.Conv1d(inp_channels, out_channels, kernel_size, padding=1), 71 | ), 72 | Rearrange('batch channels horizon -> batch channels 1 horizon'), 73 | nn.GroupNorm(n_groups, out_channels), 74 | Rearrange('batch channels 1 horizon -> batch channels horizon'), 75 | nn.Mish(), 76 | nn.Dropout(p=drop_out), 77 | ) 78 | elif if_zero: 79 | self.block = nn.Sequential( 80 | zero_module( 81 | nn.Conv1d(inp_channels, out_channels, kernel_size, padding=1), 82 | ), 83 | Rearrange('batch channels horizon -> batch channels 1 horizon'), 84 | nn.GroupNorm(n_groups, out_channels), 85 | Rearrange('batch channels 1 horizon -> batch channels horizon'), 86 | nn.Mish(), 87 | 88 | ) 89 | else: 90 | self.block = nn.Sequential( 91 | nn.Conv1d(inp_channels, out_channels, kernel_size, padding=1), 92 | Rearrange('batch channels horizon -> batch channels 1 horizon'), 93 | nn.GroupNorm(n_groups, out_channels), 94 | Rearrange('batch channels 1 horizon -> batch channels horizon'), 95 | nn.Mish(), 96 | ) 97 | 98 | def forward(self, x): 99 | return self.block(x) 100 | 101 | 102 | # -----------------------------------------------------------------------------# 103 | # ---------------------------------- sampling ---------------------------------# 104 | # -----------------------------------------------------------------------------# 105 | 106 | def extract(a, t, x_shape): 107 | b, *_ = t.shape 108 | out = a.gather(-1, t) 109 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 110 | 111 | 112 | def cosine_beta_schedule(timesteps, s=0.008, dtype=torch.float32): 113 | """ 114 | cosine schedule 115 | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ 116 | """ 117 | steps = timesteps + 1 118 | x = np.linspace(0, steps, steps) 119 | alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 120 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 121 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 122 | betas_clipped = np.clip(betas, a_min=0, a_max=0.999) 123 | return torch.tensor(betas_clipped, dtype=dtype) 124 | 125 | 126 | def condition_projection(x, conditions, action_dim): 127 | for t, val in conditions.items(): 128 | if t != 'task': 129 | x[:, t, action_dim:] = val.clone() 130 | 131 | x[:, 1:-1, :] = 0. ########################condition padding 132 | 133 | return x 134 | 135 | 136 | # -----------------------------------------------------------------------------# 137 | # ---------------------------------- Loss -------------------------------------# 138 | # -----------------------------------------------------------------------------# 139 | 140 | class Weighted_MSE(nn.Module): 141 | 142 | def __init__(self, weights, action_dim): 143 | super().__init__() 144 | # self.register_buffer('weights', weights) 145 | self.action_dim = action_dim 146 | 147 | def forward(self, pred, targ): 148 | """ 149 | :param pred: [B, T, task_dim+action_dim+observation_dim] 150 | :param targ: [B, T, task_dim+action_dim+observation_dim] 151 | :return: 152 | """ 153 | 154 | loss_action = F.mse_loss(pred, targ, reduction='none') 155 | loss_action[:, 0, :self.action_dim] *= 10. 156 | loss_action[:, -1, :self.action_dim] *= 10. 157 | loss_action[:, 1:-1, :self.action_dim] *= 0. 158 | loss_action = loss_action.sum() 159 | return loss_action 160 | 161 | 162 | Losses = { 163 | 'Weighted_MSE': Weighted_MSE, 164 | } 165 | 166 | # -----------------------------------------------------------------------------# 167 | # -------------------------------- lr_schedule --------------------------------# 168 | # -----------------------------------------------------------------------------# 169 | 170 | def get_lr_schedule_with_warmup(optimizer, num_training_steps, last_epoch=-1): 171 | num_warmup_steps = num_training_steps * 20 / 120 172 | decay_steps = num_training_steps * 30 / 120 173 | 174 | def lr_lambda(current_step): 175 | if current_step <= num_warmup_steps: 176 | return max(0., float(current_step) / float(max(1, num_warmup_steps))) 177 | else: 178 | return max(0.5 ** ((current_step - num_warmup_steps) // decay_steps), 0.) 179 | 180 | return LambdaLR(optimizer, lr_lambda, last_epoch) 181 | 182 | # -----------------------------------------------------------------------------# 183 | # ---------------------------------- logging ----------------------------------# 184 | # -----------------------------------------------------------------------------# 185 | 186 | # Taken from PyTorch's examples.imagenet.main 187 | class AverageMeter(object): 188 | """Computes and stores the average and current value""" 189 | 190 | def __init__(self): 191 | self.reset() 192 | 193 | def reset(self): 194 | self.val = 0 195 | self.avg = 0 196 | self.sum = 0 197 | self.count = 0 198 | 199 | def update(self, val, n=1): 200 | self.val = val 201 | self.sum += val * n 202 | self.count += n 203 | self.avg = self.sum / self.count 204 | 205 | 206 | class Logger: 207 | def __init__(self, log_dir, n_logged_samples=10, summary_writer=SummaryWriter, if_exist=False): 208 | self._log_dir = log_dir 209 | print('logging outputs to ', log_dir) 210 | self._n_logged_samples = n_logged_samples 211 | self._summ_writer = summary_writer(log_dir, flush_secs=120, max_queue=10) 212 | if not if_exist: 213 | log = logging.getLogger(log_dir) 214 | if not log.handlers: 215 | log.setLevel(logging.DEBUG) 216 | if not os.path.exists(log_dir): 217 | os.mkdir(log_dir) 218 | fh = logging.FileHandler(os.path.join(log_dir, 'log.txt')) 219 | fh.setLevel(logging.INFO) 220 | formatter = logging.Formatter(fmt='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S') 221 | fh.setFormatter(formatter) 222 | log.addHandler(fh) 223 | self.log = log 224 | 225 | def log_scalar(self, scalar, name, step_): 226 | self._summ_writer.add_scalar('{}'.format(name), scalar, step_) 227 | 228 | def log_scalars(self, scalar_dict, group_name, step, phase): 229 | """Will log all scalars in the same plot.""" 230 | self._summ_writer.add_scalars('{}_{}'.format(group_name, phase), scalar_dict, step) 231 | 232 | def flush(self): 233 | self._summ_writer.flush() 234 | 235 | def log_info(self, info): 236 | self.log.info("{}".format(info)) 237 | -------------------------------------------------------------------------------- /plan/model/helpers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from einops.layers.torch import Rearrange 6 | from torch.optim.lr_scheduler import LambdaLR 7 | import os 8 | import numpy as np 9 | import logging 10 | from tensorboardX import SummaryWriter 11 | 12 | 13 | # -----------------------------------------------------------------------------# 14 | # ---------------------------------- modules ----------------------------------# 15 | # -----------------------------------------------------------------------------# 16 | 17 | def zero_module(module): 18 | """ 19 | Zero out the parameters of a module and return it. 20 | """ 21 | for p in module.parameters(): 22 | p.detach().zero_() 23 | return module 24 | 25 | 26 | class SinusoidalPosEmb(nn.Module): 27 | def __init__(self, dim): 28 | super().__init__() 29 | self.dim = dim 30 | 31 | def forward(self, x): 32 | device = x.device 33 | half_dim = self.dim // 2 34 | emb = math.log(10000) / (half_dim - 1) 35 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 36 | emb = x[:, None] * emb[None, :] 37 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 38 | return emb 39 | 40 | 41 | class Downsample1d(nn.Module): 42 | def __init__(self, dim): 43 | super().__init__() 44 | self.conv = nn.Conv1d(dim, dim, 2, 1, 0) #####Edited kernal size previously had (dim, dim, 2, 1, 0) 45 | 46 | def forward(self, x): 47 | return self.conv(x) 48 | 49 | 50 | class Upsample1d(nn.Module): 51 | def __init__(self, dim): 52 | super().__init__() 53 | print(dim, 'down!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!') 54 | self.conv = nn.ConvTranspose1d(dim, dim, 2, 1, 0) #####Edited kernal size previously had (dim, dim, 2, 1, 0) 55 | 56 | def forward(self, x): 57 | return self.conv(x) 58 | 59 | 60 | class Conv1dBlock(nn.Module): 61 | """ 62 | Conv1d --> GroupNorm --> Mish 63 | """ 64 | 65 | def __init__(self, inp_channels, out_channels, kernel_size, n_groups=32, drop_out=0.0, if_zero=False): 66 | super().__init__() 67 | if drop_out > 0.0: 68 | self.block = nn.Sequential( 69 | zero_module( 70 | nn.Conv1d(inp_channels, out_channels, kernel_size, padding=1), 71 | ), 72 | Rearrange('batch channels horizon -> batch channels 1 horizon'), 73 | nn.GroupNorm(n_groups, out_channels), 74 | Rearrange('batch channels 1 horizon -> batch channels horizon'), 75 | nn.Mish(), 76 | nn.Dropout(p=drop_out), 77 | ) 78 | elif if_zero: 79 | self.block = nn.Sequential( 80 | zero_module( 81 | nn.Conv1d(inp_channels, out_channels, kernel_size, padding=1), 82 | ), 83 | Rearrange('batch channels horizon -> batch channels 1 horizon'), 84 | nn.GroupNorm(n_groups, out_channels), 85 | Rearrange('batch channels 1 horizon -> batch channels horizon'), 86 | nn.Mish(), 87 | 88 | ) 89 | else: 90 | self.block = nn.Sequential( 91 | nn.Conv1d(inp_channels, out_channels, kernel_size, padding=1), 92 | Rearrange('batch channels horizon -> batch channels 1 horizon'), 93 | nn.GroupNorm(n_groups, out_channels), 94 | Rearrange('batch channels 1 horizon -> batch channels horizon'), 95 | nn.Mish(), 96 | ) 97 | 98 | def forward(self, x): 99 | return self.block(x) 100 | 101 | 102 | # -----------------------------------------------------------------------------# 103 | # ---------------------------------- sampling ---------------------------------# 104 | # -----------------------------------------------------------------------------# 105 | 106 | def extract(a, t, x_shape): 107 | b, *_ = t.shape 108 | out = a.gather(-1, t) 109 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 110 | 111 | 112 | def cosine_beta_schedule(timesteps, s=0.008, dtype=torch.float32): 113 | """ 114 | cosine schedule 115 | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ 116 | """ 117 | steps = timesteps + 1 118 | x = np.linspace(0, steps, steps) 119 | alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 120 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 121 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 122 | betas_clipped = np.clip(betas, a_min=0, a_max=0.999) 123 | return torch.tensor(betas_clipped, dtype=dtype) 124 | 125 | 126 | def condition_projection(x, conditions, action_dim, LLM_dim , PKG_dim): 127 | for t, val in conditions.items(): 128 | #print('t---------------',t) 129 | if t != 'PKG_action_path' and t != 'LLM_action_path': 130 | x[:, t, LLM_dim + PKG_dim + action_dim:] = val.clone() 131 | 132 | x[:, 1:-1, LLM_dim + PKG_dim + action_dim:] = 0. 133 | x[:, :, :PKG_dim] = conditions['PKG_action_path'] 134 | x[:, :, PKG_dim:PKG_dim + LLM_dim] = conditions['LLM_action_path'] 135 | 136 | return x 137 | 138 | 139 | # -----------------------------------------------------------------------------# 140 | # ---------------------------------- Loss -------------------------------------# 141 | # -----------------------------------------------------------------------------# 142 | 143 | class Weighted_MSE(nn.Module): 144 | 145 | def __init__(self, weights, action_dim, LLM_dim, PKG_dim): 146 | super().__init__() 147 | # self.register_buffer('weights', weights) 148 | self.action_dim = action_dim 149 | self.LLM_dim = LLM_dim 150 | self.PKG_dim = PKG_dim 151 | 152 | def forward(self, pred, targ): 153 | """ 154 | :param pred: [B, T, task_dim+action_dim+observation_dim] 155 | :param targ: [B, T, task_dim+action_dim+observation_dim] 156 | :return: 157 | """ 158 | 159 | loss_action = F.mse_loss(pred, targ, reduction='none') 160 | loss_action[:, 0, self.LLM_dim + self.PKG_dim: self.LLM_dim + self.PKG_dim + self.action_dim] *= 5. 161 | loss_action[:, -1, self.LLM_dim + self.PKG_dim: self.LLM_dim + self.PKG_dim + self.action_dim] *= 5. 162 | 163 | #loss_action[:, 1:-1, self.class_dim:self.class_dim + self.action_dim] *= 0. 164 | loss_action = loss_action.sum() 165 | return loss_action 166 | 167 | 168 | Losses = { 169 | 'Weighted_MSE': Weighted_MSE, 170 | } 171 | 172 | # -----------------------------------------------------------------------------# 173 | # -------------------------------- lr_schedule --------------------------------# 174 | # -----------------------------------------------------------------------------# 175 | 176 | def get_lr_schedule_with_warmup(optimizer, num_training_steps, last_epoch=-1): 177 | num_warmup_steps = num_training_steps * 20 / 120 178 | decay_steps = num_training_steps * 30 / 120 179 | 180 | def lr_lambda(current_step): 181 | if current_step <= num_warmup_steps: 182 | return max(0., float(current_step) / float(max(1, num_warmup_steps))) 183 | else: 184 | return max(0.5 ** ((current_step - num_warmup_steps) // decay_steps), 0.) 185 | 186 | return LambdaLR(optimizer, lr_lambda, last_epoch) 187 | 188 | # -----------------------------------------------------------------------------# 189 | # ---------------------------------- logging ----------------------------------# 190 | # -----------------------------------------------------------------------------# 191 | 192 | # Taken from PyTorch's examples.imagenet.main 193 | class AverageMeter(object): 194 | """Computes and stores the average and current value""" 195 | 196 | def __init__(self): 197 | self.reset() 198 | 199 | def reset(self): 200 | self.val = 0 201 | self.avg = 0 202 | self.sum = 0 203 | self.count = 0 204 | 205 | def update(self, val, n=1): 206 | self.val = val 207 | self.sum += val * n 208 | self.count += n 209 | self.avg = self.sum / self.count 210 | 211 | 212 | class Logger: 213 | def __init__(self, log_dir, n_logged_samples=10, summary_writer=SummaryWriter, if_exist=False): 214 | self._log_dir = log_dir 215 | print('logging outputs to ', log_dir) 216 | self._n_logged_samples = n_logged_samples 217 | self._summ_writer = summary_writer(log_dir, flush_secs=120, max_queue=10) 218 | if not if_exist: 219 | log = logging.getLogger(log_dir) 220 | if not log.handlers: 221 | log.setLevel(logging.DEBUG) 222 | if not os.path.exists(log_dir): 223 | os.mkdir(log_dir) 224 | fh = logging.FileHandler(os.path.join(log_dir, 'log.txt')) 225 | fh.setLevel(logging.INFO) 226 | formatter = logging.Formatter(fmt='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S') 227 | fh.setFormatter(formatter) 228 | log.addHandler(fh) 229 | self.log = log 230 | 231 | def log_scalar(self, scalar, name, step_): 232 | self._summ_writer.add_scalar('{}'.format(name), scalar, step_) 233 | 234 | def log_scalars(self, scalar_dict, group_name, step, phase): 235 | """Will log all scalars in the same plot.""" 236 | self._summ_writer.add_scalars('{}_{}'.format(group_name, phase), scalar_dict, step) 237 | 238 | def flush(self): 239 | self._summ_writer.flush() 240 | 241 | def log_info(self, info): 242 | self.log.info("{}".format(info)) 243 | -------------------------------------------------------------------------------- /step/model/diffusion.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | 6 | from .helpers import ( 7 | cosine_beta_schedule, 8 | extract, 9 | condition_projection, 10 | Losses, 11 | ) 12 | 13 | 14 | class GaussianDiffusion(nn.Module): 15 | def __init__(self, model, horizon, observation_dim, action_dim, n_timesteps=200, 16 | loss_type='Weighted_MSE', clip_denoised=False, ddim_discr_method='uniform', 17 | ): 18 | super().__init__() 19 | self.horizon = horizon 20 | self.observation_dim = observation_dim 21 | self.action_dim = action_dim 22 | self.model = model 23 | 24 | betas = cosine_beta_schedule(n_timesteps) 25 | alphas = 1. - betas 26 | alphas_cumprod = torch.cumprod(alphas, dim=0) 27 | alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]]) 28 | 29 | self.n_timesteps = n_timesteps 30 | self.clip_denoised = clip_denoised 31 | self.eta = 0.0 32 | self.random_ratio = 1.0 33 | 34 | # ---------------------------ddim-------------------------------- 35 | ddim_timesteps = 10 36 | 37 | if ddim_discr_method == 'uniform': 38 | c = n_timesteps // ddim_timesteps 39 | ddim_timestep_seq = np.asarray(list(range(0, n_timesteps, c))) 40 | elif ddim_discr_method == 'quad': 41 | ddim_timestep_seq = ( 42 | (np.linspace(0, np.sqrt(n_timesteps), ddim_timesteps)) ** 2 43 | ).astype(int) 44 | else: 45 | assert RuntimeError() 46 | 47 | self.ddim_timesteps = ddim_timesteps 48 | self.ddim_timestep_seq = ddim_timestep_seq 49 | # ---------------------------------------------------------------- 50 | 51 | self.register_buffer('betas', betas) 52 | self.register_buffer('alphas_cumprod', alphas_cumprod) 53 | self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) 54 | 55 | # calculations for diffusion q(x_t | x_{t-1}) and others 56 | self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) 57 | self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod)) 58 | self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod)) 59 | self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod)) 60 | self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1)) 61 | 62 | # calculations for posterior q(x_{t-1} | x_t, x_0) 63 | posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) 64 | self.register_buffer('posterior_variance', posterior_variance) 65 | 66 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain 67 | self.register_buffer('posterior_log_variance_clipped', 68 | torch.log(torch.clamp(posterior_variance, min=1e-20))) 69 | self.register_buffer('posterior_mean_coef1', 70 | betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) 71 | self.register_buffer('posterior_mean_coef2', 72 | (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)) 73 | self.loss_type = loss_type 74 | self.loss_fn = Losses[loss_type](None, self.action_dim) 75 | 76 | # ------------------------------------------ sampling ------------------------------------------# 77 | 78 | def q_posterior(self, x_start, x_t, t): 79 | posterior_mean = ( 80 | extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + 81 | extract(self.posterior_mean_coef2, t, x_t.shape) * x_t 82 | ) 83 | posterior_variance = extract(self.posterior_variance, t, x_t.shape) 84 | posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) 85 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 86 | 87 | def p_mean_variance(self, x, cond, t): 88 | x_recon = self.model(x, t) 89 | 90 | if self.clip_denoised: 91 | x_recon.clamp(-1., 1.) 92 | else: 93 | assert RuntimeError() 94 | 95 | model_mean, posterior_variance, posterior_log_variance = self.q_posterior( 96 | x_start=x_recon, x_t=x, t=t) 97 | return model_mean, posterior_variance, posterior_log_variance 98 | 99 | @torch.no_grad() 100 | def _predict_eps_from_xstart(self, x_t, t, pred_xstart): 101 | return \ 102 | (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) \ 103 | / extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) 104 | 105 | @torch.no_grad() 106 | def p_sample_ddim(self, x, cond, t, t_prev, if_prev=False): 107 | b, *_, device = *x.shape, x.device 108 | x_recon = self.model(x, t) 109 | 110 | if self.clip_denoised: 111 | x_recon.clamp(-1., 1.) 112 | else: 113 | assert RuntimeError() 114 | 115 | eps = self._predict_eps_from_xstart(x, t, x_recon) 116 | alpha_bar = extract(self.alphas_cumprod, t, x.shape) 117 | if if_prev: 118 | alpha_bar_prev = extract(self.alphas_cumprod_prev, t_prev, x.shape) 119 | else: 120 | alpha_bar_prev = extract(self.alphas_cumprod, t_prev, x.shape) 121 | sigma = ( 122 | self.eta 123 | * torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) 124 | * torch.sqrt(1 - alpha_bar / alpha_bar_prev) 125 | ) 126 | 127 | noise = torch.randn_like(x) * self.random_ratio 128 | mean_pred = ( 129 | x_recon * torch.sqrt(alpha_bar_prev) 130 | + torch.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps 131 | ) 132 | 133 | nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) 134 | return mean_pred + nonzero_mask * sigma * noise 135 | 136 | @torch.no_grad() 137 | def p_sample(self, x, cond, t): 138 | b, *_, device = *x.shape, x.device 139 | model_mean, _, model_log_variance = self.p_mean_variance(x=x, cond=cond, t=t) 140 | noise = torch.randn_like(x) * self.random_ratio 141 | nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) 142 | return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise 143 | 144 | @torch.no_grad() 145 | def p_sample_loop(self, cond, if_jump): 146 | device = self.betas.device 147 | batch_size = len(cond[0]) 148 | horizon = self.horizon 149 | shape = (batch_size, horizon, self.action_dim + self.observation_dim) 150 | 151 | x = torch.randn(shape, device=device) * self.random_ratio # xt for Noise and diffusion 152 | # x = torch.zeros(shape, device=device) # for Deterministic 153 | x = condition_projection(x, cond, self.action_dim) 154 | 155 | ''' 156 | The if-else below is for diffusion, should be removed for Noise and Deterministic 157 | ''' 158 | if not if_jump: 159 | for i in reversed(range(0, self.n_timesteps)): 160 | timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long) 161 | x = self.p_sample(x, cond, timesteps) 162 | x = condition_projection(x, cond, self.action_dim) 163 | 164 | else: 165 | for i in reversed(range(0, self.ddim_timesteps)): 166 | timesteps = torch.full((batch_size,), self.ddim_timestep_seq[i], device=device, dtype=torch.long) 167 | if i == 0: 168 | timesteps_prev = torch.full((batch_size,), 0, device=device, dtype=torch.long) 169 | x = self.p_sample_ddim(x, cond, timesteps, timesteps_prev, True) 170 | else: 171 | timesteps_prev = torch.full((batch_size,), self.ddim_timestep_seq[i-1], device=device, dtype=torch.long) 172 | x = self.p_sample_ddim(x, cond, timesteps, timesteps_prev) 173 | x = condition_projection(x, cond, self.action_dim) 174 | 175 | ''' 176 | The two lines below is for Noise and Deterministic 177 | ''' 178 | # x = self.model(x, None) 179 | # x = condition_projection(x, cond, self.action_dim, self.class_dim) 180 | 181 | return x 182 | 183 | # ------------------------------------------ training ------------------------------------------# 184 | 185 | def q_sample(self, x_start, t, noise=None): 186 | if noise is None: 187 | noise = torch.randn_like(x_start) * self.random_ratio 188 | 189 | sample = ( 190 | extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 191 | extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise 192 | ) 193 | 194 | return sample 195 | 196 | def p_losses(self, x_start, cond, t): 197 | noise = torch.randn_like(x_start) * self.random_ratio # for Noise and diffusion 198 | # noise = torch.zeros_like(x_start) # for Deterministic 199 | # x_noisy = noise # for Noise and Deterministic 200 | 201 | x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) # for diffusion, should be removed for Noise and Deterministic 202 | x_noisy = condition_projection(x_noisy, cond, self.action_dim) 203 | x_recon = self.model(x_noisy, t) 204 | x_recon = condition_projection(x_recon, cond, self.action_dim) 205 | 206 | loss = self.loss_fn(x_recon, x_start) 207 | return loss 208 | 209 | def loss(self, x, cond): 210 | batch_size = len(x) # for diffusion 211 | t = torch.randint(0, self.n_timesteps, (batch_size,), device=x.device).long() # for diffusion 212 | # t = None # for Noise and Deterministic 213 | return self.p_losses(x, cond, t) 214 | 215 | def forward(self, cond, if_jump=False): 216 | return self.p_sample_loop(cond, if_jump) 217 | -------------------------------------------------------------------------------- /plan/model/diffusion.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | 6 | from .helpers import ( 7 | cosine_beta_schedule, 8 | extract, 9 | condition_projection, 10 | Losses, 11 | ) 12 | 13 | 14 | class GaussianDiffusion(nn.Module): 15 | def __init__(self, model, horizon, observation_dim, action_dim, LLM_dim, PKG_dim, n_timesteps=200, 16 | loss_type='Weighted_MSE', clip_denoised=False, ddim_discr_method='uniform', 17 | ): 18 | super().__init__() 19 | self.horizon = horizon 20 | self.observation_dim = observation_dim 21 | self.action_dim = action_dim 22 | self.LLM_dim = LLM_dim 23 | self.PKG_dim = PKG_dim 24 | self.model = model 25 | 26 | betas = cosine_beta_schedule(n_timesteps) 27 | alphas = 1. - betas 28 | alphas_cumprod = torch.cumprod(alphas, dim=0) 29 | alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]]) 30 | 31 | self.n_timesteps = n_timesteps 32 | self.clip_denoised = clip_denoised 33 | self.eta = 0.0 34 | self.random_ratio = 1.0 35 | 36 | # ---------------------------ddim-------------------------------- 37 | ddim_timesteps = 10 38 | 39 | if ddim_discr_method == 'uniform': 40 | c = n_timesteps // ddim_timesteps 41 | ddim_timestep_seq = np.asarray(list(range(0, n_timesteps, c))) 42 | elif ddim_discr_method == 'quad': 43 | ddim_timestep_seq = ( 44 | (np.linspace(0, np.sqrt(n_timesteps), ddim_timesteps)) ** 2 45 | ).astype(int) 46 | else: 47 | assert RuntimeError() 48 | 49 | self.ddim_timesteps = ddim_timesteps 50 | self.ddim_timestep_seq = ddim_timestep_seq 51 | # ---------------------------------------------------------------- 52 | 53 | self.register_buffer('betas', betas) 54 | self.register_buffer('alphas_cumprod', alphas_cumprod) 55 | self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) 56 | 57 | # calculations for diffusion q(x_t | x_{t-1}) and others 58 | self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) 59 | self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod)) 60 | self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod)) 61 | self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod)) 62 | self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1)) 63 | 64 | # calculations for posterior q(x_{t-1} | x_t, x_0) 65 | posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) 66 | self.register_buffer('posterior_variance', posterior_variance) 67 | 68 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain 69 | self.register_buffer('posterior_log_variance_clipped', 70 | torch.log(torch.clamp(posterior_variance, min=1e-20))) 71 | self.register_buffer('posterior_mean_coef1', 72 | betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) 73 | self.register_buffer('posterior_mean_coef2', 74 | (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)) 75 | self.loss_type = loss_type 76 | self.loss_fn = Losses[loss_type](None, self.action_dim, self.LLM_dim, self.PKG_dim) 77 | 78 | # ------------------------------------------ sampling ------------------------------------------# 79 | 80 | def q_posterior(self, x_start, x_t, t): 81 | posterior_mean = ( 82 | extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + 83 | extract(self.posterior_mean_coef2, t, x_t.shape) * x_t 84 | ) 85 | posterior_variance = extract(self.posterior_variance, t, x_t.shape) 86 | posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) 87 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 88 | 89 | def p_mean_variance(self, x, cond, t): 90 | x_recon = self.model(x, t) 91 | 92 | if self.clip_denoised: 93 | x_recon.clamp(-1., 1.) 94 | else: 95 | assert RuntimeError() 96 | 97 | model_mean, posterior_variance, posterior_log_variance = self.q_posterior( 98 | x_start=x_recon, x_t=x, t=t) 99 | return model_mean, posterior_variance, posterior_log_variance 100 | 101 | @torch.no_grad() 102 | def _predict_eps_from_xstart(self, x_t, t, pred_xstart): 103 | return \ 104 | (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) \ 105 | / extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) 106 | 107 | @torch.no_grad() 108 | def p_sample_ddim(self, x, cond, t, t_prev, if_prev=False): 109 | b, *_, device = *x.shape, x.device 110 | x_recon = self.model(x, t) 111 | 112 | if self.clip_denoised: 113 | x_recon.clamp(-1., 1.) 114 | else: 115 | assert RuntimeError() 116 | 117 | eps = self._predict_eps_from_xstart(x, t, x_recon) 118 | alpha_bar = extract(self.alphas_cumprod, t, x.shape) 119 | if if_prev: 120 | alpha_bar_prev = extract(self.alphas_cumprod_prev, t_prev, x.shape) 121 | else: 122 | alpha_bar_prev = extract(self.alphas_cumprod, t_prev, x.shape) 123 | sigma = ( 124 | self.eta 125 | * torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) 126 | * torch.sqrt(1 - alpha_bar / alpha_bar_prev) 127 | ) 128 | 129 | noise = torch.randn_like(x) * self.random_ratio 130 | mean_pred = ( 131 | x_recon * torch.sqrt(alpha_bar_prev) 132 | + torch.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps 133 | ) 134 | 135 | nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) 136 | return mean_pred + nonzero_mask * sigma * noise 137 | 138 | @torch.no_grad() 139 | def p_sample(self, x, cond, t): 140 | b, *_, device = *x.shape, x.device 141 | model_mean, _, model_log_variance = self.p_mean_variance(x=x, cond=cond, t=t) 142 | noise = torch.randn_like(x) * self.random_ratio 143 | nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) 144 | return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise 145 | 146 | @torch.no_grad() 147 | def p_sample_loop(self, cond, if_jump): 148 | device = self.betas.device 149 | batch_size = len(cond[0]) 150 | horizon = self.horizon 151 | shape = (batch_size, horizon, self.PKG_dim + self.LLM_dim + self.action_dim + self.observation_dim) 152 | 153 | x = torch.randn(shape, device=device) * self.random_ratio # xt for Noise and diffusion 154 | # x = torch.zeros(shape, device=device) # for Deterministic 155 | x = condition_projection(x, cond, self.action_dim, self.LLM_dim, self.PKG_dim) 156 | 157 | ''' 158 | The if-else below is for diffusion, should be removed for Noise and Deterministic 159 | ''' 160 | if not if_jump: 161 | for i in reversed(range(0, self.n_timesteps)): 162 | timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long) 163 | x = self.p_sample(x, cond, timesteps) 164 | x = condition_projection(x, cond, self.action_dim, self.LLM_dim, self.PKG_dim) 165 | 166 | else: 167 | for i in reversed(range(0, self.ddim_timesteps)): 168 | timesteps = torch.full((batch_size,), self.ddim_timestep_seq[i], device=device, dtype=torch.long) 169 | if i == 0: 170 | timesteps_prev = torch.full((batch_size,), 0, device=device, dtype=torch.long) 171 | x = self.p_sample_ddim(x, cond, timesteps, timesteps_prev, True) 172 | else: 173 | timesteps_prev = torch.full((batch_size,), self.ddim_timestep_seq[i-1], device=device, dtype=torch.long) 174 | x = self.p_sample_ddim(x, cond, timesteps, timesteps_prev) 175 | x = condition_projection(x, cond, self.action_dim, self.LLM_dim, self.PKG_dim) 176 | 177 | ''' 178 | The two lines below is for Noise and Deterministic 179 | ''' 180 | # x = self.model(x, None) 181 | # x = condition_projection(x, cond, self.action_dim, self.class_dim) 182 | 183 | return x 184 | 185 | # ------------------------------------------ training ------------------------------------------# 186 | 187 | def q_sample(self, x_start, t, noise=None): 188 | if noise is None: 189 | noise = torch.randn_like(x_start) * self.random_ratio 190 | 191 | sample = ( 192 | extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 193 | extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise 194 | ) 195 | 196 | return sample 197 | 198 | def p_losses(self, x_start, cond, t): 199 | noise = torch.randn_like(x_start) * self.random_ratio # for Noise and diffusion 200 | # noise = torch.zeros_like(x_start) # for Deterministic 201 | # x_noisy = noise # for Noise and Deterministic 202 | 203 | x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) # for diffusion, should be removed for Noise and Deterministic 204 | x_noisy = condition_projection(x_noisy, cond, self.action_dim, self.LLM_dim, self.PKG_dim) 205 | 206 | x_recon = self.model(x_noisy, t) 207 | x_recon = condition_projection(x_recon, cond, self.action_dim, self.LLM_dim, self.PKG_dim) 208 | 209 | loss = self.loss_fn(x_recon, x_start) 210 | return loss 211 | 212 | def loss(self, x, cond): 213 | batch_size = len(x) # for diffusion 214 | t = torch.randint(0, self.n_timesteps, (batch_size,), device=x.device).long() # for diffusion 215 | # t = None # for Noise and Deterministic 216 | return self.p_losses(x, cond, t) 217 | 218 | def forward(self, cond, if_jump=False): 219 | return self.p_sample_loop(cond, if_jump) 220 | -------------------------------------------------------------------------------- /dataset/NIV/train70.json: -------------------------------------------------------------------------------- 1 | [{"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0002.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0003.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0006.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0008.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0009.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0010.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0011.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0013.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0014.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0015.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0016.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0017.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0018.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0019.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0023.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0024.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0026.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0027.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0028.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/changing_tire_0029.npy", "task_id": 0}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0002.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0003.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0004.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0005.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0006.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0008.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0009.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0010.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0011.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0012.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0013.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0015.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0016.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0017.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0020.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0021.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0022.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0024.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0026.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0028.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0029.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/coffee_0030.npy", "task_id": 1}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0001.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0003.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0004.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0005.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0007.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0008.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0009.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0010.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0011.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0014.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0015.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0018.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0020.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0022.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0024.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0025.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0027.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0028.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0029.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/cpr_0030.npy", "task_id": 2}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0002.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0003.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0004.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0005.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0006.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0007.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0008.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0010.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0012.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0013.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0014.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0015.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0016.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0019.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0021.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0022.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0023.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0024.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0026.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0027.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/jump_car_0030.npy", "task_id": 3}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/repot_0002.npy", "task_id": 4}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/repot_0003.npy", "task_id": 4}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/repot_0007.npy", "task_id": 4}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/repot_0008.npy", "task_id": 4}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/repot_0010.npy", "task_id": 4}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/repot_0011.npy", "task_id": 4}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/repot_0013.npy", "task_id": 4}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/repot_0014.npy", "task_id": 4}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/repot_0016.npy", "task_id": 4}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/repot_0017.npy", "task_id": 4}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/repot_0019.npy", "task_id": 4}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/repot_0021.npy", "task_id": 4}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/repot_0023.npy", "task_id": 4}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/repot_0024.npy", "task_id": 4}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/repot_0027.npy", "task_id": 4}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/repot_0028.npy", "task_id": 4}, {"feature": "/data0/wanghanlin/planning_diffusion/dataset/NIV/processed_data/repot_0030.npy", "task_id": 4}] -------------------------------------------------------------------------------- /plan/log/whl.txt: -------------------------------------------------------------------------------- 1 | => no checkpoint found at 'True' 2 | Starting training loop for rank: 0, total batch size: 256 3 | => no checkpoint found at 'True' 4 | Starting training loop for rank: 0, total batch size: 256 5 | => no checkpoint found at 'True' 6 | Starting training loop for rank: 0, total batch size: 256 7 | => no checkpoint found at 'True' 8 | Starting training loop for rank: 0, total batch size: 256 9 | => no checkpoint found at 'True' 10 | Starting training loop for rank: 0, total batch size: 256 11 | => no checkpoint found at 'True' 12 | Starting training loop for rank: 0, total batch size: 256 13 | => no checkpoint found at 'True' 14 | Starting training loop for rank: 0, total batch size: 256 15 | => no checkpoint found at 'True' 16 | Starting training loop for rank: 0, total batch size: 256 17 | => no checkpoint found at 'True' 18 | Starting training loop for rank: 0, total batch size: 256 19 | => no checkpoint found at 'True' 20 | Starting training loop for rank: 0, total batch size: 256 21 | => no checkpoint found at 'True' 22 | Starting training loop for rank: 0, total batch size: 256 23 | => no checkpoint found at 'True' 24 | Starting training loop for rank: 1, total batch size: 256 25 | => no checkpoint found at 'True' 26 | Starting training loop for rank: 0, total batch size: 256 27 | => no checkpoint found at 'True' 28 | => no checkpoint found at 'True' 29 | Starting training loop for rank: 1, total batch size: 256 30 | Starting training loop for rank: 0, total batch size: 256 31 | => loading checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/final/PDPP/checkpoint/whl/epoch0004.pth.tar' 32 | => loading checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/final/PDPP/checkpoint/whl/epoch0004.pth.tar' 33 | Starting training loop for rank: 1, total batch size: 256 34 | => loaded checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/final/PDPP/checkpoint/whl/epoch0004.pth.tar' (epoch 4)0 35 | Starting training loop for rank: 0, total batch size: 256 36 | => loading checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/final/PDPP/checkpoint/whl/epoch0004.pth.tar' 37 | => loading checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/final/PDPP/checkpoint/whl/epoch0004.pth.tar' 38 | Starting training loop for rank: 1, total batch size: 256 39 | => loaded checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/final/PDPP/checkpoint/whl/epoch0004.pth.tar' (epoch 4)0 40 | Starting training loop for rank: 0, total batch size: 256 41 | => loading checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/final/PDPP/checkpoint/whl/epoch0004.pth.tar' 42 | => loading checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/final/PDPP/checkpoint/whl/epoch0004.pth.tar' 43 | Starting training loop for rank: 1, total batch size: 256 44 | => loaded checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/final/PDPP/checkpoint/whl/epoch0004.pth.tar' (epoch 4)0 45 | Starting training loop for rank: 0, total batch size: 256 46 | => loading checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/final/PDPP/checkpoint/whl/epoch0005.pth.tar' 47 | => loading checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/final/PDPP/checkpoint/whl/epoch0005.pth.tar' 48 | Starting training loop for rank: 1, total batch size: 256 49 | => loaded checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/final/PDPP/checkpoint/whl/epoch0005.pth.tar' (epoch 5)0 50 | Starting training loop for rank: 0, total batch size: 256 51 | => no checkpoint found at 'True' 52 | Starting training loop for rank: 0, total batch size: 256 53 | => no checkpoint found at 'True' 54 | Starting training loop for rank: 0, total batch size: 256 55 | => no checkpoint found at 'True' 56 | Starting training loop for rank: 1, total batch size: 256 57 | => no checkpoint found at 'True' 58 | Starting training loop for rank: 0, total batch size: 256 59 | => no checkpoint found at 'True' 60 | => no checkpoint found at 'True' 61 | => no checkpoint found at 'True' 62 | Starting training loop for rank: 2, total batch size: 256 63 | Starting training loop for rank: 1, total batch size: 256 64 | Starting training loop for rank: 3, total batch size: 256 65 | => no checkpoint found at 'True' 66 | Starting training loop for rank: 0, total batch size: 256 67 | => no checkpoint found at 'True' 68 | Starting training loop for rank: 1, total batch size: 256 69 | => no checkpoint found at 'True' 70 | Starting training loop for rank: 0, total batch size: 256 71 | => no checkpoint found at 'True' 72 | Starting training loop for rank: 1, total batch size: 256 73 | => no checkpoint found at 'True' 74 | Starting training loop for rank: 0, total batch size: 256 75 | => no checkpoint found at 'True' 76 | Starting training loop for rank: 1, total batch size: 256 77 | => no checkpoint found at 'True' 78 | Starting training loop for rank: 0, total batch size: 256 79 | => no checkpoint found at 'True' 80 | Starting training loop for rank: 1, total batch size: 256 81 | => no checkpoint found at 'True' 82 | Starting training loop for rank: 0, total batch size: 256 83 | => no checkpoint found at 'True' 84 | Starting training loop for rank: 1, total batch size: 256 85 | => no checkpoint found at 'True' 86 | Starting training loop for rank: 0, total batch size: 256 87 | => no checkpoint found at 'True' 88 | Starting training loop for rank: 1, total batch size: 256 89 | => no checkpoint found at 'True' 90 | Starting training loop for rank: 0, total batch size: 256 91 | => no checkpoint found at 'True' 92 | Starting training loop for rank: 1, total batch size: 256 93 | => no checkpoint found at 'True' 94 | Starting training loop for rank: 0, total batch size: 256 95 | => no checkpoint found at 'True' 96 | Starting training loop for rank: 1, total batch size: 256 97 | => no checkpoint found at 'True' 98 | Starting training loop for rank: 0, total batch size: 256 99 | => no checkpoint found at 'True' 100 | Starting training loop for rank: 1, total batch size: 256 101 | => no checkpoint found at 'True' 102 | Starting training loop for rank: 0, total batch size: 256 103 | => no checkpoint found at 'True' 104 | Starting training loop for rank: 1, total batch size: 256 105 | => no checkpoint found at 'True' 106 | Starting training loop for rank: 0, total batch size: 256 107 | => no checkpoint found at 'True' 108 | Starting training loop for rank: 1, total batch size: 256 109 | => no checkpoint found at 'True' 110 | Starting training loop for rank: 0, total batch size: 256 111 | => no checkpoint found at 'True' 112 | Starting training loop for rank: 1, total batch size: 256 113 | => no checkpoint found at 'True' 114 | Starting training loop for rank: 0, total batch size: 256 115 | => no checkpoint found at 'True' 116 | Starting training loop for rank: 1, total batch size: 256 117 | => no checkpoint found at 'True' 118 | Starting training loop for rank: 0, total batch size: 256 119 | => no checkpoint found at 'True' 120 | Starting training loop for rank: 1, total batch size: 256 121 | => no checkpoint found at 'True' 122 | Starting training loop for rank: 0, total batch size: 256 123 | => no checkpoint found at 'True' 124 | Starting training loop for rank: 1, total batch size: 256 125 | => no checkpoint found at 'True' 126 | Starting training loop for rank: 0, total batch size: 256 127 | => no checkpoint found at 'True' 128 | Starting training loop for rank: 1, total batch size: 256 129 | => no checkpoint found at 'True' 130 | Starting training loop for rank: 0, total batch size: 256 131 | => no checkpoint found at 'True' 132 | Starting training loop for rank: 1, total batch size: 256 133 | => no checkpoint found at 'True' 134 | Starting training loop for rank: 0, total batch size: 256 135 | => no checkpoint found at 'True' 136 | Starting training loop for rank: 1, total batch size: 256 137 | => no checkpoint found at 'True' 138 | Starting training loop for rank: 0, total batch size: 256 139 | => no checkpoint found at 'True' 140 | Starting training loop for rank: 1, total batch size: 256 141 | => no checkpoint found at 'True' 142 | Starting training loop for rank: 0, total batch size: 256 143 | => loading checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar' 144 | => loading checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar' 145 | Starting training loop for rank: 1, total batch size: 256 146 | => loaded checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar' (epoch 4)0 147 | Starting training loop for rank: 0, total batch size: 256 148 | => no checkpoint found at 'True' 149 | Starting training loop for rank: 1, total batch size: 256 150 | => no checkpoint found at 'True' 151 | Starting training loop for rank: 0, total batch size: 256 152 | => loading checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar' 153 | => loading checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar' 154 | => loaded checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar' (epoch 4)0 155 | Starting training loop for rank: 0, total batch size: 256 156 | Starting training loop for rank: 1, total batch size: 256 157 | => loading checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar' 158 | => loading checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar' 159 | => loaded checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar' (epoch 4)0 160 | Starting training loop for rank: 0, total batch size: 256 161 | Starting training loop for rank: 1, total batch size: 256 162 | => loading checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar' 163 | => loading checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar' 164 | => loaded checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar' (epoch 4)0 165 | Starting training loop for rank: 0, total batch size: 256 166 | Starting training loop for rank: 1, total batch size: 256 167 | => loading checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar' 168 | => loading checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar' 169 | Starting training loop for rank: 1, total batch size: 256 170 | => loaded checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar' (epoch 4)0 171 | Starting training loop for rank: 0, total batch size: 256 172 | => loading checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar' 173 | => loading checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar' 174 | => loaded checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar' (epoch 4)0 175 | Starting training loop for rank: 0, total batch size: 256 176 | Starting training loop for rank: 1, total batch size: 256 177 | => loading checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar' 178 | => loading checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar' 179 | Starting training loop for rank: 1, total batch size: 256 180 | => loaded checkpoint '/l/users/ravindu.nagasinghe/MAIN_codes/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar' (epoch 4)0 181 | Starting training loop for rank: 0, total batch size: 256 182 | => no checkpoint found at 'True' 183 | Starting training loop for rank: 1, total batch size: 256 184 | => no checkpoint found at 'True' 185 | Starting training loop for rank: 0, total batch size: 256 186 | => no checkpoint found at 'True' 187 | Starting training loop for rank: 1, total batch size: 256 188 | => no checkpoint found at 'True' 189 | Starting training loop for rank: 0, total batch size: 256 190 | => no checkpoint found at 'True' 191 | Starting training loop for rank: 1, total batch size: 256 192 | => no checkpoint found at 'True' 193 | Starting training loop for rank: 0, total batch size: 256 194 | => loading checkpoint '/l/users/ravindu.nagasinghe/T4How/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar' 195 | => loading checkpoint '/l/users/ravindu.nagasinghe/T4How/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar' 196 | Starting training loop for rank: 1, total batch size: 256 197 | => loaded checkpoint '/l/users/ravindu.nagasinghe/T4How/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar' (epoch 4)0 198 | Starting training loop for rank: 0, total batch size: 256 199 | => loading checkpoint '/l/users/ravindu.nagasinghe/T4How/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar' 200 | => loading checkpoint '/l/users/ravindu.nagasinghe/T4How/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar' 201 | Starting training loop for rank: 1, total batch size: 256 202 | => loaded checkpoint '/l/users/ravindu.nagasinghe/T4How/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar' (epoch 4)0 203 | Starting training loop for rank: 0, total batch size: 256 204 | => loading checkpoint '/l/users/ravindu.nagasinghe/T4How/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar' 205 | => loading checkpoint '/l/users/ravindu.nagasinghe/T4How/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar' 206 | => loaded checkpoint '/l/users/ravindu.nagasinghe/T4How/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar' (epoch 4)0 207 | Starting training loop for rank: 0, total batch size: 256 208 | Starting training loop for rank: 1, total batch size: 256 209 | => loading checkpoint '/l/users/ravindu.nagasinghe/T4How/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar' 210 | => loading checkpoint '/l/users/ravindu.nagasinghe/T4How/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar' 211 | => loaded checkpoint '/l/users/ravindu.nagasinghe/T4How/multiple_plans/PDPP/checkpoint/whl/epoch0004.pth.tar' (epoch 4)0 212 | Starting training loop for rank: 0, total batch size: 256 213 | Starting training loop for rank: 1, total batch size: 256 214 | => no checkpoint found at 'True' 215 | Starting training loop for rank: 1, total batch size: 256 216 | => no checkpoint found at 'True' 217 | Starting training loop for rank: 0, total batch size: 256 218 | -------------------------------------------------------------------------------- /plan/main_distributed.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import random 4 | import time 5 | from collections import OrderedDict 6 | 7 | import torch.nn.parallel 8 | import torch.backends.cudnn as cudnn 9 | import torch.distributed as dist 10 | import torch.optim 11 | import torch.multiprocessing as mp 12 | import torch.utils.data 13 | import torch.utils.data.distributed 14 | from torch.distributed import ReduceOp 15 | 16 | import utils 17 | from dataloader.data_load import PlanningDataset 18 | from model import diffusion, temporal 19 | from model.helpers import get_lr_schedule_with_warmup 20 | 21 | from utils import * 22 | from logging import log 23 | from utils.args import get_args 24 | import numpy as np 25 | from model.helpers import Logger 26 | 27 | 28 | def reduce_tensor(tensor): 29 | rt = tensor.clone() 30 | torch.distributed.all_reduce(rt, op=ReduceOp.SUM) 31 | rt /= dist.get_world_size() 32 | return rt 33 | 34 | 35 | def main(): 36 | # os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 37 | args = get_args() 38 | os.environ['PYTHONHASHSEED'] = str(args.seed) 39 | 40 | if args.verbose: 41 | print(args) 42 | if args.seed is not None: 43 | random.seed(args.seed) 44 | np.random.seed(args.seed) 45 | torch.manual_seed(args.seed) 46 | torch.cuda.manual_seed_all(args.seed) 47 | 48 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 49 | ngpus_per_node = torch.cuda.device_count() 50 | 51 | if args.multiprocessing_distributed: 52 | args.world_size = ngpus_per_node * args.world_size 53 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 54 | else: 55 | main_worker(args.gpu, ngpus_per_node, args) 56 | 57 | 58 | def main_worker(gpu, ngpus_per_node, args): 59 | args.gpu = gpu 60 | # print('gpuid:', args.gpu) 61 | 62 | if args.distributed: 63 | if args.multiprocessing_distributed: 64 | args.rank = args.rank * ngpus_per_node + gpu 65 | dist.init_process_group( 66 | backend=args.dist_backend, 67 | init_method=args.dist_url, 68 | world_size=args.world_size, 69 | rank=args.rank, 70 | ) 71 | if args.gpu is not None: 72 | torch.cuda.set_device(args.gpu) 73 | args.batch_size = int(args.batch_size / ngpus_per_node) 74 | args.batch_size_val = int(args.batch_size_val / ngpus_per_node) 75 | args.num_thread_reader = int(args.num_thread_reader / ngpus_per_node) 76 | elif args.gpu is not None: 77 | torch.cuda.set_device(args.gpu) 78 | 79 | # Data loading code 80 | train_dataset = PlanningDataset( 81 | args.root, 82 | args=args, 83 | is_val=False, 84 | model=None, 85 | ) 86 | # Test data loading code 87 | test_dataset = PlanningDataset( 88 | args.root, 89 | args=args, 90 | is_val=True, 91 | model=None, 92 | ) 93 | print(train_dataset , test_dataset) 94 | if args.distributed: 95 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 96 | test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset) 97 | else: 98 | train_sampler = None 99 | test_sampler = None 100 | 101 | train_loader = torch.utils.data.DataLoader( 102 | train_dataset, 103 | batch_size=args.batch_size, 104 | shuffle=(train_sampler is None), 105 | drop_last=True, 106 | num_workers=args.num_thread_reader, 107 | pin_memory=args.pin_memory, 108 | sampler=train_sampler, 109 | ) 110 | test_loader = torch.utils.data.DataLoader( 111 | test_dataset, 112 | batch_size=args.batch_size_val, 113 | shuffle=False, 114 | drop_last=False, 115 | num_workers=args.num_thread_reader, 116 | sampler=test_sampler, 117 | ) 118 | 119 | # create model 120 | temporal_model = temporal.TemporalUnet( # temporal 121 | args.action_dim + args.observation_dim + args.class_dim_llama + args.class_dim_graph, 122 | dim=256, 123 | dim_mults=(1, 2, 4), ) 124 | 125 | diffusion_model = diffusion.GaussianDiffusion( 126 | temporal_model, args.horizon, args.observation_dim, args.action_dim, args.class_dim_llama, args.class_dim_graph, args.n_diffusion_steps, 127 | loss_type='Weighted_MSE', clip_denoised=True, ) 128 | 129 | 130 | model = utils.Trainer(diffusion_model, train_loader, args.ema_decay, args.lr, args.gradient_accumulate_every, 131 | args.step_start_ema, args.update_ema_every, args.log_freq) 132 | 133 | if args.pretrain_cnn_path: 134 | net_data = torch.load(args.pretrain_cnn_path) 135 | model.model.load_state_dict(net_data) 136 | model.ema_model.load_state_dict(net_data) 137 | 138 | if args.distributed: 139 | if args.gpu is not None: 140 | model.model.cuda(args.gpu) 141 | model.ema_model.cuda(args.gpu) 142 | model.model = torch.nn.parallel.DistributedDataParallel( 143 | model.model, device_ids=[args.gpu], find_unused_parameters=True) 144 | model.ema_model = torch.nn.parallel.DistributedDataParallel( 145 | model.ema_model, device_ids=[args.gpu], find_unused_parameters=True) 146 | else: 147 | model.model.cuda() 148 | model.ema_model.cuda() 149 | model.model = torch.nn.parallel.DistributedDataParallel(model.model, find_unused_parameters=True) 150 | model.ema_model = torch.nn.parallel.DistributedDataParallel(model.ema_model, 151 | find_unused_parameters=True) 152 | 153 | elif args.gpu is not None: 154 | model.model = model.model.cuda(args.gpu) 155 | model.ema_model = model.ema_model.cuda(args.gpu) 156 | else: 157 | model.model = torch.nn.DataParallel(model.model).cuda() 158 | model.ema_model = torch.nn.DataParallel(model.ema_model).cuda() 159 | 160 | scheduler = get_lr_schedule_with_warmup(model.optimizer, int(args.n_train_steps * args.epochs)) 161 | 162 | checkpoint_dir = os.path.join(os.path.dirname(__file__), 'checkpoint', args.checkpoint_dir) 163 | if args.checkpoint_dir != '' and not (os.path.isdir(checkpoint_dir)) and args.rank == 0: 164 | os.mkdir(checkpoint_dir) 165 | 166 | if args.resume: 167 | checkpoint_path = get_last_checkpoint(checkpoint_dir) 168 | if checkpoint_path: 169 | log("=> loading checkpoint '{}'".format(checkpoint_path), args) 170 | checkpoint = torch.load(checkpoint_path, map_location='cuda:{}'.format(args.rank)) 171 | args.start_epoch = checkpoint["epoch"] 172 | model.model.load_state_dict(checkpoint["model"]) 173 | model.ema_model.load_state_dict(checkpoint["ema"]) 174 | model.optimizer.load_state_dict(checkpoint["optimizer"]) 175 | model.step = checkpoint["step"] 176 | # for p in model.optimizer.param_groups: 177 | # p['lr'] = 1e-5 178 | scheduler.load_state_dict(checkpoint["scheduler"]) 179 | tb_logdir = checkpoint["tb_logdir"] 180 | if args.rank == 0: 181 | # creat logger 182 | tb_logger = Logger(tb_logdir) 183 | log("=> loaded checkpoint '{}' (epoch {}){}".format(checkpoint_path, checkpoint["epoch"], args.gpu), args) 184 | 185 | 186 | else: 187 | 188 | time_pre = time.strftime("%Y%m%d%H%M%S", time.localtime()) 189 | logname = args.log_root + '_' + time_pre + '_' + args.dataset 190 | tb_logdir = os.path.join(args.log_root, logname) 191 | if args.rank == 0: 192 | # creat logger 193 | if not (os.path.exists(tb_logdir)): 194 | os.makedirs(tb_logdir) 195 | tb_logger = Logger(tb_logdir) 196 | tb_logger.log_info(args) 197 | log("=> no checkpoint found at '{}'".format(args.resume), args) 198 | 199 | if args.cudnn_benchmark: 200 | cudnn.benchmark = True 201 | total_batch_size = args.world_size * args.batch_size 202 | 203 | log( 204 | "Starting training loop for rank: {}, total batch size: {}".format( 205 | args.rank, total_batch_size 206 | ), args 207 | ) 208 | 209 | max_eva = 0 210 | max_acc = 0 211 | old_max_epoch = 0 212 | save_max = os.path.join(os.path.dirname(__file__), 'save_max') 213 | 214 | for epoch in range(args.start_epoch, args.epochs): 215 | print('epoch : ', epoch) 216 | if args.distributed: 217 | train_sampler.set_epoch(epoch) 218 | 219 | # train for one epoch 220 | if (epoch + 1) % 10 == 0: # calculate on training set 221 | losses, acc_top1, acc_top5, trajectory_success_rate_meter, MIoU1_meter, MIoU2_meter, \ 222 | acc_a0, acc_aT = model.train(args.n_train_steps, True, args, scheduler) 223 | losses_reduced = reduce_tensor(losses.cuda()).item() 224 | acc_top1_reduced = reduce_tensor(acc_top1.cuda()).item() 225 | acc_top5_reduced = reduce_tensor(acc_top5.cuda()).item() 226 | trajectory_success_rate_meter_reduced = reduce_tensor(trajectory_success_rate_meter.cuda()).item() 227 | MIoU1_meter_reduced = reduce_tensor(MIoU1_meter.cuda()).item() 228 | MIoU2_meter_reduced = reduce_tensor(MIoU2_meter.cuda()).item() 229 | acc_a0_reduced = reduce_tensor(acc_a0.cuda()).item() 230 | acc_aT_reduced = reduce_tensor(acc_aT.cuda()).item() 231 | 232 | if args.rank == 0: 233 | logs = OrderedDict() 234 | logs['Train/EpochLoss'] = losses_reduced 235 | logs['Train/EpochAcc@1'] = acc_top1_reduced 236 | logs['Train/EpochAcc@5'] = acc_top5_reduced 237 | logs['Train/Traj_Success_Rate'] = trajectory_success_rate_meter_reduced 238 | logs['Train/MIoU1'] = MIoU1_meter_reduced 239 | logs['Train/MIoU2'] = MIoU2_meter_reduced 240 | logs['Train/acc_a0'] = acc_a0_reduced 241 | logs['Train/acc_aT'] = acc_aT_reduced 242 | for key, value in logs.items(): 243 | tb_logger.log_scalar(value, key, epoch + 1) 244 | 245 | tb_logger.flush() 246 | else: 247 | losses = model.train(args.n_train_steps, False, args, scheduler).cuda() 248 | losses_reduced = reduce_tensor(losses).item() 249 | if args.rank == 0: 250 | print('lrs:') 251 | for p in model.optimizer.param_groups: 252 | print(p['lr']) 253 | print('---------------------------------') 254 | 255 | logs = OrderedDict() 256 | logs['Train/EpochLoss'] = losses_reduced 257 | for key, value in logs.items(): 258 | tb_logger.log_scalar(value, key, epoch + 1) 259 | 260 | tb_logger.flush() 261 | 262 | if ((epoch + 1) % 5 == 0) and args.evaluate: # or epoch > 18 263 | losses, acc_top1, acc_top5, \ 264 | trajectory_success_rate_meter, MIoU1_meter, MIoU2_meter, \ 265 | acc_a0, acc_aT = validate(test_loader, model.ema_model, args) 266 | 267 | losses_reduced = reduce_tensor(losses.cuda()).item() 268 | acc_top1_reduced = reduce_tensor(acc_top1.cuda()).item() 269 | acc_top5_reduced = reduce_tensor(acc_top5.cuda()).item() 270 | trajectory_success_rate_meter_reduced = reduce_tensor(trajectory_success_rate_meter.cuda()).item() 271 | MIoU1_meter_reduced = reduce_tensor(MIoU1_meter.cuda()).item() 272 | MIoU2_meter_reduced = reduce_tensor(MIoU2_meter.cuda()).item() 273 | acc_a0_reduced = reduce_tensor(acc_a0.cuda()).item() 274 | acc_aT_reduced = reduce_tensor(acc_aT.cuda()).item() 275 | 276 | if args.rank == 0: 277 | logs = OrderedDict() 278 | logs['Val/EpochLoss'] = losses_reduced 279 | logs['Val/EpochAcc@1'] = acc_top1_reduced 280 | logs['Val/EpochAcc@5'] = acc_top5_reduced 281 | logs['Val/Traj_Success_Rate'] = trajectory_success_rate_meter_reduced 282 | logs['Val/MIoU1'] = MIoU1_meter_reduced 283 | logs['Val/MIoU2'] = MIoU2_meter_reduced 284 | logs['Val/acc_a0'] = acc_a0_reduced 285 | logs['Val/acc_aT'] = acc_aT_reduced 286 | for key, value in logs.items(): 287 | tb_logger.log_scalar(value, key, epoch + 1) 288 | 289 | tb_logger.flush() 290 | print(trajectory_success_rate_meter_reduced, max_eva) 291 | if trajectory_success_rate_meter_reduced >= max_eva: 292 | if not (trajectory_success_rate_meter_reduced == max_eva and acc_top1_reduced < max_acc): 293 | save_checkpoint2( 294 | { 295 | "epoch": epoch + 1, 296 | "model": model.model.state_dict(), 297 | "ema": model.ema_model.state_dict(), 298 | "optimizer": model.optimizer.state_dict(), 299 | "step": model.step, 300 | "tb_logdir": tb_logdir, 301 | "scheduler": scheduler.state_dict(), 302 | }, save_max, old_max_epoch, epoch + 1, args.rank 303 | ) 304 | max_eva = trajectory_success_rate_meter_reduced 305 | max_acc = acc_top1_reduced 306 | old_max_epoch = epoch + 1 307 | 308 | if (epoch + 1) % args.save_freq == 0: 309 | if args.rank == 0: 310 | save_checkpoint( 311 | { 312 | "epoch": epoch + 1, 313 | "model": model.model.state_dict(), 314 | "ema": model.ema_model.state_dict(), 315 | "optimizer": model.optimizer.state_dict(), 316 | "step": model.step, 317 | "tb_logdir": tb_logdir, 318 | "scheduler": scheduler.state_dict(), 319 | }, checkpoint_dir, epoch + 1 320 | ) 321 | 322 | 323 | def log(output, args): 324 | with open(os.path.join(os.path.dirname(__file__), 'log', args.checkpoint_dir + '.txt'), "a") as f: 325 | f.write(output + '\n') 326 | 327 | 328 | def save_checkpoint(state, checkpoint_dir, epoch, n_ckpt=3): 329 | torch.save(state, os.path.join(checkpoint_dir, "epoch{:0>4d}.pth.tar".format(epoch))) 330 | if epoch - n_ckpt >= 0: 331 | oldest_ckpt = os.path.join(checkpoint_dir, "epoch{:0>4d}.pth.tar".format(epoch - n_ckpt)) 332 | if os.path.isfile(oldest_ckpt): 333 | os.remove(oldest_ckpt) 334 | 335 | 336 | def save_checkpoint2(state, checkpoint_dir, old_epoch, epoch, rank): 337 | torch.save(state, os.path.join(checkpoint_dir, "epoch{:0>4d}_{}.pth.tar".format(epoch, rank))) 338 | if old_epoch > 0: 339 | oldest_ckpt = os.path.join(checkpoint_dir, "epoch{:0>4d}_{}.pth.tar".format(old_epoch, rank)) 340 | if os.path.isfile(oldest_ckpt): 341 | os.remove(oldest_ckpt) 342 | 343 | 344 | def get_last_checkpoint(checkpoint_dir): 345 | all_ckpt = glob.glob(os.path.join(checkpoint_dir, 'epoch*.pth.tar')) 346 | if all_ckpt: 347 | all_ckpt = sorted(all_ckpt) 348 | return all_ckpt[-1] 349 | else: 350 | return '' 351 | 352 | 353 | if __name__ == "__main__": 354 | main() 355 | -------------------------------------------------------------------------------- /step/dataloader/train_test_data_load.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import Dataset 5 | import json 6 | import math 7 | from collections import namedtuple 8 | 9 | def get_vids_from_json(path): 10 | task_vids = {} 11 | with open(path, 'r') as f: 12 | json_data = json.load(f) 13 | 14 | for i in json_data: 15 | task = i['task'] 16 | vid = i['vid'] 17 | if task not in task_vids: 18 | task_vids[task] = [] 19 | task_vids[task].append(vid) 20 | return task_vids 21 | 22 | 23 | def get_vids(path): 24 | task_vids = {} 25 | with open(path, 'r') as f: 26 | for line in f: 27 | task, vid, url = line.strip().split(',') 28 | if task not in task_vids: 29 | task_vids[task] = [] 30 | task_vids[task].append(vid) 31 | return task_vids 32 | 33 | 34 | def read_task_info(path): 35 | titles = {} 36 | urls = {} 37 | n_steps = {} 38 | steps = {} 39 | with open(path, 'r') as f: 40 | idx = f.readline() 41 | while idx != '': 42 | idx = idx.strip() 43 | titles[idx] = f.readline().strip() 44 | urls[idx] = f.readline().strip() 45 | n_steps[idx] = int(f.readline().strip()) 46 | steps[idx] = f.readline().strip().split(',') 47 | next(f) 48 | idx = f.readline() 49 | return {'title': titles, 'url': urls, 'n_steps': n_steps, 'steps': steps} 50 | 51 | 52 | class PlanningDataset(Dataset): 53 | def __init__(self, 54 | root, 55 | args=None, 56 | is_val=False, 57 | model=None, 58 | crosstask_use_feature_how=True, 59 | ): 60 | self.is_val = is_val 61 | self.data_root = root 62 | self.args = args 63 | self.max_traj_len = args.horizon 64 | self.vid_names = None 65 | self.frame_cnts = None 66 | self.images = None 67 | self.last_vid = '' 68 | self.crosstask_use_feature_how = crosstask_use_feature_how 69 | if args.dataset == 'crosstask': 70 | """ 71 | . 72 | └── crosstask 73 | ├── crosstask_features 74 | └── crosstask_release 75 | ├── tasks_primary.txt 76 | ├── videos.csv or json 77 | └── videos_val.csv or json 78 | """ 79 | val_csv_path = os.path.join( 80 | root, 'crosstask_release', 'test_list.json') 81 | video_csv_path = os.path.join( 82 | root, 'crosstask_release', 'train_list.json') 83 | 84 | if crosstask_use_feature_how: 85 | self.features_path = os.path.join(root, 'processed_data') 86 | else: 87 | self.features_path = os.path.join(root, 'crosstask_features') 88 | 89 | self.constraints_path = os.path.join( 90 | root, 'crosstask_release', 'annotations') 91 | 92 | self.action_one_hot = np.load( 93 | os.path.join(root, 'crosstask_release', 'actions_one_hot.npy'), 94 | allow_pickle=True).item() 95 | 96 | if is_val: 97 | cross_task_data_name = args.json_path_val 98 | else: 99 | cross_task_data_name = args.json_path_train 100 | 101 | if os.path.exists(cross_task_data_name): 102 | with open(cross_task_data_name, 'r') as f: 103 | self.json_data = json.load(f) 104 | print('Loaded {}'.format(cross_task_data_name)) 105 | else: 106 | file_type = val_csv_path.split('.')[-1] 107 | if file_type == 'json': 108 | all_task_vids = get_vids_from_json(video_csv_path) 109 | val_vids = get_vids_from_json(val_csv_path) 110 | else: 111 | all_task_vids = get_vids(video_csv_path) 112 | val_vids = get_vids(val_csv_path) 113 | 114 | if is_val: 115 | task_vids = val_vids 116 | else: 117 | task_vids = {task: [vid for vid in vids if task not in val_vids or vid not in val_vids[task]] for 118 | task, vids in 119 | all_task_vids.items()} 120 | 121 | primary_info = read_task_info(os.path.join( 122 | root, 'crosstask_release', 'tasks_primary.txt')) 123 | 124 | self.n_steps = primary_info['n_steps'] 125 | all_tasks = set(self.n_steps.keys()) 126 | 127 | task_vids = {task: vids for task, 128 | vids in task_vids.items() if task in all_tasks} 129 | 130 | all_vids = [] 131 | for task, vids in task_vids.items(): 132 | all_vids.extend([(task, vid) for vid in vids]) 133 | json_data = [] 134 | for idx in range(len(all_vids)): 135 | task, vid = all_vids[idx] 136 | if self.crosstask_use_feature_how: 137 | video_path = os.path.join( 138 | self.features_path, str(task) + '_' + str(vid) + '.npy') 139 | else: 140 | video_path = os.path.join( 141 | self.features_path, str(vid) + '.npy') 142 | legal_range = self.process_single(task, vid) 143 | if not legal_range: 144 | continue 145 | 146 | temp_len = len(legal_range) 147 | temp = [] 148 | while temp_len < self.max_traj_len: 149 | temp.append(legal_range[0]) 150 | temp_len += 1 151 | 152 | temp.extend(legal_range) 153 | legal_range = temp 154 | 155 | for i in range(len(legal_range) - self.max_traj_len + 1): 156 | legal_range_current = legal_range[i:i + self.max_traj_len] 157 | json_data.append({'id': {'vid': vid, 'task': task, 'feature': video_path, 158 | 'legal_range': legal_range_current, },'instruction_len': self.n_steps[task]}) 159 | 160 | self.json_data = json_data 161 | with open(cross_task_data_name, 'w') as f: 162 | json.dump(json_data, f) 163 | 164 | elif args.dataset == 'coin': 165 | print(root) 166 | coin_path = os.path.join(root, 'coin', 'full_npy/') 167 | val_csv_path = os.path.join( 168 | root, 'coin', 'coin_test_30.json') 169 | video_csv_path = os.path.join( 170 | root, 'coin', 'coin_train_70.json') 171 | 172 | # coin_data_name = "/data1/wanghanlin/diffusion_planning/jsons_coin/sliding_window_cross_task_data_{}_{}_new_task_id_73.json".format( 173 | # is_val, self.max_traj_len) 174 | if is_val: 175 | coin_data_name = args.json_path_val 176 | else: 177 | coin_data_name = args.json_path_train 178 | 179 | if os.path.exists(coin_data_name): 180 | with open(coin_data_name, 'r') as f: 181 | self.json_data = json.load(f) 182 | print('Loaded {}'.format(coin_data_name)) 183 | else: 184 | json_data = [] 185 | num = 0 186 | if is_val: 187 | with open(val_csv_path, 'r') as f: 188 | coin_data = json.load(f) 189 | else: 190 | with open(video_csv_path, 'r') as f: 191 | coin_data = json.load(f) 192 | for i in coin_data: 193 | for (k, v) in i.items(): 194 | file_name = v['class'] + '_' + str(v['recipe_type']) + '_' + k + '.npy' 195 | file_path = coin_path + file_name 196 | images_ = np.load(file_path, allow_pickle=True) 197 | images = images_['frames_features'] 198 | legal_range = [] 199 | 200 | last_action = v['annotation'][-1]['segment'][1] 201 | last_action = math.ceil(last_action) 202 | if last_action > len(images): 203 | print(k, last_action, len(images)) 204 | num += 1 205 | continue 206 | 207 | for annotation in v['annotation']: 208 | action_label = int(annotation['id']) - 1 209 | start_idx, end_idx = annotation['segment'] 210 | start_idx = math.floor(start_idx) 211 | end_idx = math.ceil(end_idx) 212 | 213 | if end_idx < images.shape[0]: 214 | legal_range.append((start_idx, end_idx, action_label)) 215 | else: 216 | legal_range.append((start_idx, images.shape[0] - 1, action_label)) 217 | 218 | temp_len = len(legal_range) 219 | temp = [] 220 | while temp_len < self.max_traj_len: 221 | temp.append(legal_range[0]) 222 | temp_len += 1 223 | 224 | temp.extend(legal_range) 225 | legal_range = temp 226 | 227 | for i in range(len(legal_range) - self.max_traj_len + 1): 228 | legal_range_current = legal_range[i:i + self.max_traj_len] 229 | json_data.append({'id': {'vid': k, 'feature': file_path, 230 | 'legal_range': legal_range_current, 'task_id': v['recipe_type']}, 231 | 'instruction_len': 0}) 232 | print(num) 233 | self.json_data = json_data 234 | with open(coin_data_name, 'w') as f: 235 | json.dump(json_data, f) 236 | 237 | elif args.dataset == 'NIV': 238 | val_csv_path = os.path.join( 239 | root, '../NIV', 'test30.json') 240 | video_csv_path = os.path.join( 241 | root, '../NIV', 'train70.json') 242 | 243 | if is_val: 244 | niv_data_name = args.json_path_val 245 | else: 246 | niv_data_name = args.json_path_train 247 | 248 | 249 | if os.path.exists(niv_data_name): 250 | with open(niv_data_name, 'r') as f: 251 | self.json_data = json.load(f) 252 | print('Loaded {}'.format(niv_data_name)) 253 | else: 254 | json_data = [] 255 | if is_val: 256 | with open(val_csv_path, 'r') as f: 257 | niv_data = json.load(f) 258 | else: 259 | with open(video_csv_path, 'r') as f: 260 | niv_data = json.load(f) 261 | for d in niv_data: 262 | legal_range = [] 263 | path = d['feature'] 264 | info = np.load(path, allow_pickle=True) 265 | num_steps = int(info['num_steps']) 266 | assert num_steps == len(info['steps_ids']) 267 | assert info['num_steps'] == len(info['steps_starts']) 268 | assert info['num_steps'] == len(info['steps_ends']) 269 | starts = info['steps_starts'] 270 | ends = info['steps_ends'] 271 | action_labels = info['steps_ids'] 272 | images = info['frames_features'] 273 | 274 | for i in range(num_steps): 275 | action_label = int(action_labels[i]) 276 | start_idx = math.floor(float(starts[i])) 277 | end_idx = math.ceil(float(ends[i])) 278 | 279 | if end_idx < images.shape[0]: 280 | legal_range.append((start_idx, end_idx, action_label)) 281 | else: 282 | legal_range.append((start_idx, images.shape[0] - 1, action_label)) 283 | 284 | temp_len = len(legal_range) 285 | temp = [] 286 | while temp_len < self.max_traj_len: 287 | temp.append(legal_range[0]) 288 | temp_len += 1 289 | 290 | temp.extend(legal_range) 291 | legal_range = temp 292 | 293 | for i in range(len(legal_range) - self.max_traj_len + 1): 294 | legal_range_current = legal_range[i:i + self.max_traj_len] 295 | json_data.append({'id': {'feature': path, 'legal_range': legal_range_current, 'task_id': d['task_id']}, 'instruction_len': 0}) 296 | self.json_data = json_data 297 | with open(niv_data_name, 'w') as f: 298 | json.dump(json_data, f) 299 | print(len(json_data)) 300 | else: 301 | raise NotImplementedError( 302 | 'Dataset {} is not implemented'.format(args.dataset)) 303 | 304 | self.model = model 305 | self.prepare_data() 306 | self.M = 3 307 | 308 | def process_single(self, task, vid): 309 | if self.crosstask_use_feature_how: 310 | if not os.path.exists(os.path.join(self.features_path, str(task) + '_' + str(vid) + '.npy')): 311 | return False 312 | images_ = np.load(os.path.join(self.features_path, str(task) + '_' + str(vid) + '.npy'), allow_pickle=True) 313 | images = images_['frames_features'] 314 | else: 315 | if not os.path.exists(os.path.join(self.features_path, vid + '.npy')): 316 | return False 317 | images = np.load(os.path.join(self.features_path, vid + '.npy')) 318 | 319 | cnst_path = os.path.join( 320 | self.constraints_path, task + '_' + vid + '.csv') 321 | legal_range = self.read_assignment(task, cnst_path) 322 | legal_range_ret = [] 323 | for (start_idx, end_idx, action_label) in legal_range: 324 | if not start_idx < images.shape[0]: 325 | print(task, vid, end_idx, images.shape[0]) 326 | return False 327 | if end_idx < images.shape[0]: 328 | legal_range_ret.append((start_idx, end_idx, action_label)) 329 | else: 330 | legal_range_ret.append((start_idx, images.shape[0] - 1, action_label)) 331 | 332 | return legal_range_ret 333 | 334 | def read_assignment(self, task_id, path): 335 | legal_range = [] 336 | with open(path, 'r') as f: 337 | for line in f: 338 | step, start, end = line.strip().split(',') 339 | start = int(math.floor(float(start))) 340 | end = int(math.ceil(float(end))) 341 | action_label_ind = self.action_one_hot[task_id + '_' + step] 342 | legal_range.append((start, end, action_label_ind)) 343 | 344 | return legal_range 345 | 346 | def prepare_data(self): 347 | vid_names = [] 348 | frame_cnts = [] 349 | for listdata in self.json_data: 350 | vid_names.append(listdata['id']) 351 | frame_cnts.append(listdata['instruction_len']) 352 | self.vid_names = vid_names 353 | self.frame_cnts = frame_cnts 354 | 355 | def __len__(self): 356 | return len(self.json_data) 357 | -------------------------------------------------------------------------------- /step/main_distributed.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import random 4 | import time 5 | from collections import OrderedDict 6 | 7 | import torch.nn.parallel 8 | import torch.backends.cudnn as cudnn 9 | import torch.distributed as dist 10 | import torch.optim 11 | import torch.multiprocessing as mp 12 | import torch.utils.data 13 | import torch.utils.data.distributed 14 | from torch.distributed import ReduceOp 15 | 16 | import utils 17 | from dataloader.data_load import PlanningDataset 18 | from model import diffusion, temporal 19 | from model.helpers import get_lr_schedule_with_warmup 20 | 21 | from utils import * 22 | from logging import log 23 | from utils.args import get_args 24 | import numpy as np 25 | from model.helpers import Logger 26 | 27 | 28 | def reduce_tensor(tensor): 29 | rt = tensor.clone() 30 | torch.distributed.all_reduce(rt, op=ReduceOp.SUM) 31 | rt /= dist.get_world_size() 32 | return rt 33 | 34 | 35 | def main(): 36 | # os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 37 | args = get_args() 38 | os.environ['PYTHONHASHSEED'] = str(args.seed) 39 | 40 | if args.verbose: 41 | print(args) 42 | if args.seed is not None: 43 | random.seed(args.seed) 44 | np.random.seed(args.seed) 45 | torch.manual_seed(args.seed) 46 | torch.cuda.manual_seed_all(args.seed) 47 | 48 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 49 | ngpus_per_node = torch.cuda.device_count() 50 | 51 | if args.multiprocessing_distributed: 52 | args.world_size = ngpus_per_node * args.world_size 53 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 54 | else: 55 | main_worker(args.gpu, ngpus_per_node, args) 56 | 57 | 58 | def main_worker(gpu, ngpus_per_node, args): 59 | args.gpu = gpu 60 | # print('gpuid:', args.gpu) 61 | 62 | if args.distributed: 63 | if args.multiprocessing_distributed: 64 | args.rank = args.rank * ngpus_per_node + gpu 65 | dist.init_process_group( 66 | backend=args.dist_backend, 67 | init_method=args.dist_url, 68 | world_size=args.world_size, 69 | rank=args.rank, 70 | ) 71 | if args.gpu is not None: 72 | torch.cuda.set_device(args.gpu) 73 | args.batch_size = int(args.batch_size / ngpus_per_node) 74 | args.batch_size_val = int(args.batch_size_val / ngpus_per_node) 75 | args.num_thread_reader = int(args.num_thread_reader / ngpus_per_node) 76 | elif args.gpu is not None: 77 | torch.cuda.set_device(args.gpu) 78 | 79 | # Data loading code 80 | train_dataset = PlanningDataset( 81 | args.root, 82 | args=args, 83 | is_val=False, 84 | model=None, 85 | ) 86 | # Test data loading code 87 | test_dataset = PlanningDataset( 88 | args.root, 89 | args=args, 90 | is_val=True, 91 | model=None, 92 | ) 93 | if args.distributed: 94 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 95 | test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset) 96 | else: 97 | train_sampler = None 98 | test_sampler = None 99 | 100 | train_loader = torch.utils.data.DataLoader( 101 | train_dataset, 102 | batch_size=args.batch_size, 103 | shuffle=(train_sampler is None), 104 | drop_last=True, 105 | num_workers=args.num_thread_reader, 106 | pin_memory=args.pin_memory, 107 | sampler=train_sampler, 108 | ) 109 | test_loader = torch.utils.data.DataLoader( 110 | test_dataset, 111 | batch_size=args.batch_size_val, 112 | shuffle=False, 113 | drop_last=False, 114 | num_workers=args.num_thread_reader, 115 | sampler=test_sampler, 116 | ) 117 | 118 | # create model 119 | temporal_model = temporal.TemporalUnet( # temporal 120 | args.action_dim + args.observation_dim, 121 | dim=256, 122 | dim_mults=(1, 2, 4), ) 123 | 124 | # for param in temporal_model.named_parameters(): 125 | # if 'time_mlp' not in param[0]: 126 | # param[1].requires_grad = False 127 | 128 | diffusion_model = diffusion.GaussianDiffusion( 129 | temporal_model, args.horizon, args.observation_dim, args.action_dim, args.n_diffusion_steps, 130 | loss_type='Weighted_MSE', clip_denoised=True, ) 131 | 132 | model = utils.Trainer(diffusion_model, train_loader, args.ema_decay, args.lr, args.gradient_accumulate_every, 133 | args.step_start_ema, args.update_ema_every, args.log_freq) 134 | 135 | if args.pretrain_cnn_path: 136 | net_data = torch.load(args.pretrain_cnn_path) 137 | model.model.load_state_dict(net_data) 138 | model.ema_model.load_state_dict(net_data) 139 | if args.distributed: 140 | if args.gpu is not None: 141 | model.model.cuda(args.gpu) 142 | model.ema_model.cuda(args.gpu) 143 | model.model = torch.nn.parallel.DistributedDataParallel( 144 | model.model, device_ids=[args.gpu], find_unused_parameters=True) 145 | model.ema_model = torch.nn.parallel.DistributedDataParallel( 146 | model.ema_model, device_ids=[args.gpu], find_unused_parameters=True) 147 | else: 148 | model.model.cuda() 149 | model.ema_model.cuda() 150 | model.model = torch.nn.parallel.DistributedDataParallel(model.model, find_unused_parameters=True) 151 | model.ema_model = torch.nn.parallel.DistributedDataParallel(model.ema_model, 152 | find_unused_parameters=True) 153 | 154 | elif args.gpu is not None: 155 | model.model = model.model.cuda(args.gpu) 156 | model.ema_model = model.ema_model.cuda(args.gpu) 157 | else: 158 | model.model = torch.nn.DataParallel(model.model).cuda() 159 | model.ema_model = torch.nn.DataParallel(model.ema_model).cuda() 160 | 161 | scheduler = get_lr_schedule_with_warmup(model.optimizer, int(args.n_train_steps * args.epochs)) 162 | 163 | checkpoint_dir = os.path.join(os.path.dirname(__file__), 'checkpoint', args.checkpoint_dir) 164 | if args.checkpoint_dir != '' and not (os.path.isdir(checkpoint_dir)) and args.rank == 0: 165 | os.mkdir(checkpoint_dir) 166 | 167 | if args.resume: 168 | checkpoint_path = get_last_checkpoint(checkpoint_dir) 169 | if checkpoint_path: 170 | log("=> loading checkpoint '{}'".format(checkpoint_path), args) 171 | checkpoint = torch.load(checkpoint_path, map_location='cuda:{}'.format(args.rank)) 172 | args.start_epoch = checkpoint["epoch"] 173 | model.model.load_state_dict(checkpoint["model"]) 174 | model.ema_model.load_state_dict(checkpoint["ema"]) 175 | model.optimizer.load_state_dict(checkpoint["optimizer"]) 176 | model.step = checkpoint["step"] 177 | # for p in model.optimizer.param_groups: 178 | # p['lr'] = 1e-5 179 | scheduler.load_state_dict(checkpoint["scheduler"]) 180 | tb_logdir = checkpoint["tb_logdir"] 181 | if args.rank == 0: 182 | # creat logger 183 | tb_logger = Logger(tb_logdir) 184 | log("=> loaded checkpoint '{}' (epoch {}){}".format(checkpoint_path, checkpoint["epoch"], args.gpu), args) 185 | else: 186 | # log("=> loading checkpoint '{}' to initialize".format(checkpoint_path), args) 187 | # checkpoint = torch.load(checkpoint_path, map_location='cuda:{}'.format(args.rank)) 188 | # model.model.load_state_dict(checkpoint["model"], strict=False) 189 | # model.ema_model.load_state_dict(checkpoint["ema"], strict=False) 190 | time_pre = time.strftime("%Y%m%d%H%M%S", time.localtime()) 191 | logname = args.log_root + '_' + time_pre + '_' + args.dataset 192 | tb_logdir = os.path.join(args.log_root, logname) 193 | if args.rank == 0: 194 | # creat logger 195 | if not (os.path.exists(tb_logdir)): 196 | os.makedirs(tb_logdir) 197 | tb_logger = Logger(tb_logdir) 198 | tb_logger.log_info(args) 199 | log("=> no checkpoint found at '{}'".format(args.resume), args) 200 | 201 | if args.cudnn_benchmark: 202 | cudnn.benchmark = True 203 | total_batch_size = args.world_size * args.batch_size 204 | log( 205 | "Starting training loop for rank: {}, total batch size: {}".format( 206 | args.rank, total_batch_size 207 | ), args 208 | ) 209 | 210 | max_eva = 0 211 | max_acc = 0 212 | old_max_epoch = 0 213 | save_max = os.path.join(os.path.dirname(__file__), 'save_max') 214 | 215 | for epoch in range(args.start_epoch, args.epochs): 216 | print('epoch : ', epoch) 217 | if args.distributed: 218 | train_sampler.set_epoch(epoch) 219 | 220 | # train for one epoch 221 | if (epoch + 1) % 10 == 0: # calculate on training set 222 | losses, acc_top1, acc_top5, trajectory_success_rate_meter, MIoU1_meter, MIoU2_meter, \ 223 | acc_a0, acc_aT = model.train(args.n_train_steps, True, args, scheduler) 224 | losses_reduced = reduce_tensor(losses.cuda()).item() 225 | acc_top1_reduced = reduce_tensor(acc_top1.cuda()).item() 226 | acc_top5_reduced = reduce_tensor(acc_top5.cuda()).item() 227 | trajectory_success_rate_meter_reduced = reduce_tensor(trajectory_success_rate_meter.cuda()).item() 228 | MIoU1_meter_reduced = reduce_tensor(MIoU1_meter.cuda()).item() 229 | MIoU2_meter_reduced = reduce_tensor(MIoU2_meter.cuda()).item() 230 | acc_a0_reduced = reduce_tensor(acc_a0.cuda()).item() 231 | acc_aT_reduced = reduce_tensor(acc_aT.cuda()).item() 232 | 233 | if args.rank == 0: 234 | logs = OrderedDict() 235 | logs['Train/EpochLoss'] = losses_reduced 236 | logs['Train/EpochAcc@1'] = acc_top1_reduced 237 | logs['Train/EpochAcc@5'] = acc_top5_reduced 238 | logs['Train/Traj_Success_Rate'] = trajectory_success_rate_meter_reduced 239 | logs['Train/MIoU1'] = MIoU1_meter_reduced 240 | logs['Train/MIoU2'] = MIoU2_meter_reduced 241 | logs['Train/acc_a0'] = acc_a0_reduced 242 | logs['Train/acc_aT'] = acc_aT_reduced 243 | for key, value in logs.items(): 244 | tb_logger.log_scalar(value, key, epoch + 1) 245 | 246 | tb_logger.flush() 247 | else: 248 | losses = model.train(args.n_train_steps, False, args, scheduler).cuda() 249 | losses_reduced = reduce_tensor(losses).item() 250 | if args.rank == 0: 251 | print('lrs:') 252 | for p in model.optimizer.param_groups: 253 | print(p['lr']) 254 | print('---------------------------------') 255 | 256 | logs = OrderedDict() 257 | logs['Train/EpochLoss'] = losses_reduced 258 | for key, value in logs.items(): 259 | tb_logger.log_scalar(value, key, epoch + 1) 260 | 261 | tb_logger.flush() 262 | 263 | if ((epoch + 1) % 5 == 0) and args.evaluate: # or epoch > 18 264 | losses, acc_top1, acc_top5, \ 265 | trajectory_success_rate_meter, MIoU1_meter, MIoU2_meter, \ 266 | acc_a0, acc_aT = validate(test_loader, model.ema_model, args) 267 | 268 | losses_reduced = reduce_tensor(losses.cuda()).item() 269 | acc_top1_reduced = reduce_tensor(acc_top1.cuda()).item() 270 | acc_top5_reduced = reduce_tensor(acc_top5.cuda()).item() 271 | trajectory_success_rate_meter_reduced = reduce_tensor(trajectory_success_rate_meter.cuda()).item() 272 | MIoU1_meter_reduced = reduce_tensor(MIoU1_meter.cuda()).item() 273 | MIoU2_meter_reduced = reduce_tensor(MIoU2_meter.cuda()).item() 274 | acc_a0_reduced = reduce_tensor(acc_a0.cuda()).item() 275 | acc_aT_reduced = reduce_tensor(acc_aT.cuda()).item() 276 | 277 | if args.rank == 0: 278 | logs = OrderedDict() 279 | logs['Val/EpochLoss'] = losses_reduced 280 | logs['Val/EpochAcc@1'] = acc_top1_reduced 281 | logs['Val/EpochAcc@5'] = acc_top5_reduced 282 | logs['Val/Traj_Success_Rate'] = trajectory_success_rate_meter_reduced 283 | logs['Val/MIoU1'] = MIoU1_meter_reduced 284 | logs['Val/MIoU2'] = MIoU2_meter_reduced 285 | logs['Val/acc_a0'] = acc_a0_reduced 286 | logs['Val/acc_aT'] = acc_aT_reduced 287 | for key, value in logs.items(): 288 | tb_logger.log_scalar(value, key, epoch + 1) 289 | 290 | tb_logger.flush() 291 | print(trajectory_success_rate_meter_reduced, max_eva) 292 | if trajectory_success_rate_meter_reduced >= max_eva: 293 | if not (trajectory_success_rate_meter_reduced == max_eva and acc_top1_reduced < max_acc): 294 | save_checkpoint2( 295 | { 296 | "epoch": epoch + 1, 297 | "model": model.model.state_dict(), 298 | "ema": model.ema_model.state_dict(), 299 | "optimizer": model.optimizer.state_dict(), 300 | "step": model.step, 301 | "tb_logdir": tb_logdir, 302 | "scheduler": scheduler.state_dict(), 303 | }, save_max, old_max_epoch, epoch + 1, args.rank 304 | ) 305 | max_eva = trajectory_success_rate_meter_reduced 306 | max_acc = acc_top1_reduced 307 | old_max_epoch = epoch + 1 308 | 309 | if (epoch + 1) % args.save_freq == 0: 310 | if args.rank == 0: 311 | save_checkpoint( 312 | { 313 | "epoch": epoch + 1, 314 | "model": model.model.state_dict(), 315 | "ema": model.ema_model.state_dict(), 316 | "optimizer": model.optimizer.state_dict(), 317 | "step": model.step, 318 | "tb_logdir": tb_logdir, 319 | "scheduler": scheduler.state_dict(), 320 | }, checkpoint_dir, epoch + 1 321 | ) 322 | 323 | 324 | def log(output, args): 325 | with open(os.path.join(os.path.dirname(__file__), 'log', args.checkpoint_dir + '.txt'), "a") as f: 326 | f.write(output + '\n') 327 | 328 | 329 | def save_checkpoint(state, checkpoint_dir, epoch, n_ckpt=3): 330 | torch.save(state, os.path.join(checkpoint_dir, "epoch{:0>4d}.pth.tar".format(epoch))) 331 | if epoch - n_ckpt >= 0: 332 | oldest_ckpt = os.path.join(checkpoint_dir, "epoch{:0>4d}.pth.tar".format(epoch - n_ckpt)) 333 | if os.path.isfile(oldest_ckpt): 334 | os.remove(oldest_ckpt) 335 | 336 | 337 | def save_checkpoint2(state, checkpoint_dir, old_epoch, epoch, rank): 338 | torch.save(state, os.path.join(checkpoint_dir, "epoch{:0>4d}_{}.pth.tar".format(epoch, rank))) 339 | if old_epoch > 0: 340 | oldest_ckpt = os.path.join(checkpoint_dir, "epoch{:0>4d}_{}.pth.tar".format(old_epoch, rank)) 341 | if os.path.isfile(oldest_ckpt): 342 | os.remove(oldest_ckpt) 343 | 344 | 345 | def get_last_checkpoint(checkpoint_dir): 346 | all_ckpt = glob.glob(os.path.join(checkpoint_dir, 'epoch*.pth.tar')) 347 | if all_ckpt: 348 | all_ckpt = sorted(all_ckpt) 349 | return all_ckpt[-1] 350 | else: 351 | return '' 352 | 353 | 354 | if __name__ == "__main__": 355 | main() 356 | -------------------------------------------------------------------------------- /step/inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import time 4 | import json 5 | import numpy as np 6 | import torch.nn.parallel 7 | import torch.backends.cudnn as cudnn 8 | import torch.distributed as dist 9 | import torch.optim 10 | import torch.multiprocessing as mp 11 | import torch.utils.data 12 | import torch.utils.data.distributed 13 | import utils 14 | from torch.distributed import ReduceOp 15 | from dataloader.data_load import PlanningDataset 16 | from model import diffusion, temporal 17 | from utils import * 18 | from utils.args import get_args 19 | from action_dictionary import action_dictionary 20 | 21 | def map_numbers_to_values(input_list, mapping_dict): 22 | result = [] 23 | for sublist in input_list: 24 | result.append([mapping_dict[num+1] for num in sublist]) 25 | return result 26 | 27 | def accuracy2(output, target, topk=(1,), max_traj_len=0): 28 | with torch.no_grad(): 29 | maxk = max(topk) 30 | batch_size = target.size(0) 31 | _, pred = output.topk(maxk, 1, True, True) 32 | pred = pred.t() 33 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 34 | 35 | correct_a = correct[:1].view(-1, max_traj_len) 36 | correct_a0 = correct_a[:, 0].reshape(-1).float().mean().mul_(100.0) 37 | correct_aT = correct_a[:, -1].reshape(-1).float().mean().mul_(100.0) 38 | 39 | res = [] 40 | for k in topk: 41 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 42 | res.append(correct_k.mul_(100.0 / batch_size)) 43 | 44 | correct_1 = correct[:1] 45 | 46 | # Success Rate 47 | trajectory_success = torch.all(correct_1.view(correct_1.shape[1] // max_traj_len, -1), dim=1) 48 | trajectory_success_rate = trajectory_success.sum() * 100.0 / trajectory_success.shape[0] 49 | 50 | # MIoU 51 | _, pred_token = output.topk(1, 1, True, True) 52 | pred_inst = pred_token.view(correct_1.shape[1], -1) 53 | pred_inst_set = set() 54 | target_inst = target.view(correct_1.shape[1], -1) 55 | target_inst_set = set() 56 | for i in range(pred_inst.shape[0]): 57 | pred_inst_set.add(tuple(pred_inst[i].tolist())) 58 | target_inst_set.add(tuple(target_inst[i].tolist())) 59 | MIoU1 = 100.0 * len(pred_inst_set.intersection(target_inst_set)) / len(pred_inst_set.union(target_inst_set)) 60 | 61 | batch_size = batch_size // max_traj_len 62 | pred_inst = pred_token.view(batch_size, -1) # [bs, T] 63 | pred_inst_set = set() 64 | target_inst = target.view(batch_size, -1) # [bs, T] 65 | target_inst_set = set() 66 | MIoU_sum = 0 67 | for i in range(pred_inst.shape[0]): 68 | pred_inst_set.update(pred_inst[i].tolist()) 69 | target_inst_set.update(target_inst[i].tolist()) 70 | MIoU_current = 100.0 * len(pred_inst_set.intersection(target_inst_set)) / len( 71 | pred_inst_set.union(target_inst_set)) 72 | MIoU_sum += MIoU_current 73 | pred_inst_set.clear() 74 | target_inst_set.clear() 75 | 76 | MIoU2 = MIoU_sum / batch_size 77 | return res[0], trajectory_success_rate, MIoU1, MIoU2, correct_a0, correct_aT 78 | 79 | 80 | def test(vid_names, val_loader, model, args): 81 | model.eval() 82 | acc_top1 = AverageMeter() 83 | trajectory_success_rate_meter = AverageMeter() 84 | MIoU1_meter = AverageMeter() 85 | MIoU2_meter = AverageMeter() 86 | A0_acc = AverageMeter() 87 | AT_acc = AverageMeter() 88 | index_list = [] 89 | action_list = [] 90 | differing_sublists = {} 91 | all_sublists = {} 92 | final_pred_list = [] 93 | file_final_list_test = args.steps_path 94 | i=1 95 | 96 | for i_batch, sample_batch in enumerate(val_loader): 97 | 98 | global_img_tensors = sample_batch[0].cuda().contiguous() 99 | 100 | video_label = sample_batch[1].cuda() 101 | batch_size_current, T = video_label.size() 102 | print('batch size', batch_size_current) 103 | print('T =', T) 104 | 105 | cond = {} 106 | 107 | with torch.no_grad(): 108 | cond[0] = global_img_tensors[:, 0, :].float() 109 | cond[T - 1] = global_img_tensors[:, -1, :].float() 110 | 111 | video_label_reshaped = video_label.view(-1) 112 | print('video label reshaped:' ,video_label_reshaped.shape) 113 | output = model(cond, if_jump=True) 114 | 115 | actions_pred = output.contiguous() 116 | actions_pred = actions_pred[:, :, :args.action_dim].contiguous() 117 | print('shape actions pred:', actions_pred.shape) #dim = [256, 3, 105] 118 | argmax_index = torch.argmax(actions_pred, dim = -1) 119 | index_list.append(argmax_index) 120 | print('index list length:', len(index_list)) 121 | print('argmax actions pred length:', argmax_index.shape) 122 | 123 | 124 | tensor_list_action_sequence = argmax_index.tolist() 125 | 126 | output_list_action_sequence = map_numbers_to_values(tensor_list_action_sequence, action_dictionary) 127 | 128 | print('length action sequence', len(output_list_action_sequence)) 129 | 130 | index_differing = i_batch 131 | 132 | for i in range(len(video_label)): 133 | bs = 256 134 | index_vid = (index_differing*bs)+i 135 | print('index:', index_vid+1, 'GT sequence: ', vid_names[index_vid]['legal_range'],'predicted sequence : ', argmax_index[i].tolist()) 136 | final_pred_list.append(argmax_index[i].tolist()) 137 | if not (torch.equal(video_label[i][:1], argmax_index[i][:1]) and torch.equal(video_label[i][-1:], argmax_index[i][-1:])): 138 | 139 | differing_sublists[index_vid+1] = { 140 | 'i_batch': i_batch, 141 | 'video_label_list': video_label[i].tolist(), 142 | 'predicted_list': argmax_index[i].tolist(), 143 | 'predicted_sequence': output_list_action_sequence[i], 144 | 'vid': vid_names[index_vid]['vid'], 145 | 'legal_range':vid_names[index_vid]['legal_range'] 146 | } 147 | 148 | for i in range(len(video_label)): 149 | bs = 256 150 | index_vid_sim = (index_differing*bs)+i 151 | all_sublists[index_vid_sim] = { 152 | 'i_batch': i_batch, 153 | 'video_label_list': video_label[i].tolist(), 154 | 'predicted_list': argmax_index[i].tolist(), 155 | 'predicted_sequence': output_list_action_sequence[i], 156 | 'vid': vid_names[index_vid_sim]['vid'], 157 | 'legal_range':vid_names[index_vid_sim]['legal_range'], 158 | 159 | } 160 | 161 | actions_pred = actions_pred.view(-1, args.action_dim) 162 | acc1, trajectory_success_rate, MIoU1, MIoU2, a0_acc, aT_acc = \ 163 | accuracy2(actions_pred.cpu(), video_label_reshaped.cpu(), topk=(1,), max_traj_len=args.horizon) 164 | 165 | acc_top1.update(acc1.item(), batch_size_current) 166 | trajectory_success_rate_meter.update(trajectory_success_rate.item(), batch_size_current) 167 | MIoU1_meter.update(MIoU1, batch_size_current) 168 | MIoU2_meter.update(MIoU2, batch_size_current) 169 | A0_acc.update(a0_acc, batch_size_current) 170 | AT_acc.update(aT_acc, batch_size_current) 171 | 172 | 173 | print('Differing.............................................................................................') 174 | for index, sublist_data in differing_sublists.items(): 175 | print(f"Index: {index}", '||', f"i_batch: {sublist_data['i_batch']}", '||', f"Video Label List: {sublist_data['video_label_list']}", '||', 176 | f"Predicted List: {sublist_data['predicted_list']}", '||', f"Predicted Sequence: {sublist_data['predicted_sequence']}", '||', f"Vid: {sublist_data['vid']}", 177 | '||', f"Legal range: {sublist_data['legal_range']}") 178 | print() 179 | print('Length of the failure cases : ' , len(differing_sublists) ) 180 | 181 | print('All lists.............................................................................................') 182 | for index, sublist_data in all_sublists.items(): 183 | print(f"Index: {index}", '||', f"i_batch: {sublist_data['i_batch']}", '||', f"Video Label List: {sublist_data['video_label_list']}", '||', 184 | f"Predicted List: {sublist_data['predicted_list']}", '||', f"Predicted Sequence: {sublist_data['predicted_sequence']}", '||', f"Vid: {sublist_data['vid']}" 185 | , '||', f"Legal range: {sublist_data['legal_range']}") 186 | print() 187 | 188 | with open (file_final_list_test, 'w') as ou: 189 | json.dump(final_pred_list, ou) 190 | return torch.tensor(acc_top1.avg), \ 191 | torch.tensor(trajectory_success_rate_meter.avg), \ 192 | torch.tensor(MIoU1_meter.avg), torch.tensor(MIoU2_meter.avg), \ 193 | torch.tensor(A0_acc.avg), torch.tensor(AT_acc.avg) 194 | 195 | 196 | def reduce_tensor(tensor): 197 | rt = tensor.clone() 198 | torch.distributed.all_reduce(rt, op=ReduceOp.SUM) 199 | rt /= dist.get_world_size() 200 | return rt 201 | 202 | 203 | def main(): 204 | # os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 205 | args = get_args() 206 | os.environ['PYTHONHASHSEED'] = str(args.seed) 207 | 208 | if args.verbose: 209 | print(args) 210 | if args.seed is not None: 211 | random.seed(args.seed) 212 | np.random.seed(args.seed) 213 | torch.manual_seed(args.seed) 214 | torch.cuda.manual_seed_all(args.seed) 215 | 216 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 217 | ngpus_per_node = torch.cuda.device_count() 218 | # print('ngpus_per_node:', ngpus_per_node) 219 | 220 | if args.multiprocessing_distributed: 221 | args.world_size = ngpus_per_node * args.world_size 222 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 223 | else: 224 | main_worker(args.gpu, ngpus_per_node, args) 225 | 226 | #####Generate final step predictions 227 | with open(args.json_path_val, 'r') as original_data_file: 228 | original_data = json.load(original_data_file) 229 | 230 | with open(args.steps_path, 'r') as large_list_file: 231 | large_list = json.load(large_list_file) 232 | 233 | 234 | for i, item in enumerate(original_data): 235 | item["id"]["pred_list"] = large_list[i] 236 | 237 | with open(args.step_model_output, 'w') as modified_data_file: 238 | json.dump(original_data, modified_data_file) 239 | 240 | 241 | 242 | def main_worker(gpu, ngpus_per_node, args): 243 | args.gpu = gpu 244 | # print('gpuid:', args.gpu) 245 | 246 | if args.distributed: 247 | if args.multiprocessing_distributed: 248 | args.rank = args.rank * ngpus_per_node + gpu 249 | dist.init_process_group( 250 | backend=args.dist_backend, 251 | init_method=args.dist_url, 252 | world_size=args.world_size, 253 | rank=args.rank, 254 | ) 255 | if args.gpu is not None: 256 | torch.cuda.set_device(args.gpu) 257 | args.batch_size = int(args.batch_size / ngpus_per_node) 258 | args.batch_size_val = int(args.batch_size_val / ngpus_per_node) 259 | args.num_thread_reader = int(args.num_thread_reader / ngpus_per_node) 260 | elif args.gpu is not None: 261 | torch.cuda.set_device(args.gpu) 262 | 263 | # Test data loading code 264 | test_dataset = PlanningDataset( 265 | args.root, 266 | args=args, 267 | is_val=True, 268 | model=None, 269 | ) 270 | 271 | vid_names = test_dataset.vid_names 272 | 273 | if args.distributed: 274 | test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset) 275 | test_sampler.shuffle = False 276 | else: 277 | test_sampler = None 278 | print('none test sampler') 279 | 280 | test_loader = torch.utils.data.DataLoader( 281 | test_dataset, 282 | batch_size=args.batch_size_val, 283 | shuffle=False, 284 | drop_last=False, 285 | num_workers=args.num_thread_reader, 286 | sampler=test_sampler, 287 | ) 288 | 289 | # create model 290 | temporal_model = temporal.TemporalUnet( 291 | args.action_dim + args.observation_dim, 292 | dim=256, 293 | dim_mults=(1, 2, 4), ) 294 | 295 | diffusion_model = diffusion.GaussianDiffusion( 296 | temporal_model, args.horizon, args.observation_dim, args.action_dim, args.n_diffusion_steps, 297 | loss_type='Weighted_MSE', clip_denoised=True,) 298 | 299 | model = utils.Trainer(diffusion_model, None, args.ema_decay, args.lr, args.gradient_accumulate_every, 300 | args.step_start_ema, args.update_ema_every, args.log_freq) 301 | 302 | if args.pretrain_cnn_path: 303 | net_data = torch.load(args.pretrain_cnn_path) 304 | model.model.load_state_dict(net_data) 305 | model.ema_model.load_state_dict(net_data) 306 | if args.distributed: 307 | if args.gpu is not None: 308 | model.model.cuda(args.gpu) 309 | model.ema_model.cuda(args.gpu) 310 | model.model = torch.nn.parallel.DistributedDataParallel( 311 | model.model, device_ids=[args.gpu], find_unused_parameters=True) 312 | model.ema_model = torch.nn.parallel.DistributedDataParallel( 313 | model.ema_model, device_ids=[args.gpu], find_unused_parameters=True) 314 | else: 315 | model.model.cuda() 316 | model.ema_model.cuda() 317 | model.model = torch.nn.parallel.DistributedDataParallel(model.model, find_unused_parameters=True) 318 | model.ema_model = torch.nn.parallel.DistributedDataParallel(model.ema_model, 319 | find_unused_parameters=True) 320 | 321 | elif args.gpu is not None: 322 | model.model = model.model.cuda(args.gpu) 323 | model.ema_model = model.ema_model.cuda(args.gpu) 324 | else: 325 | model.model = torch.nn.DataParallel(model.model).cuda() 326 | model.ema_model = torch.nn.DataParallel(model.ema_model).cuda() 327 | 328 | if args.resume: 329 | checkpoint_path = "" 330 | if checkpoint_path: 331 | print("=> loading checkpoint '{}'".format(checkpoint_path), args) 332 | checkpoint = torch.load(checkpoint_path, map_location='cuda:{}'.format(args.rank)) 333 | args.start_epoch = checkpoint["epoch"] 334 | model.model.load_state_dict(checkpoint["model"], strict=True) 335 | model.ema_model.load_state_dict(checkpoint["ema"], strict=True) 336 | model.step = checkpoint["step"] 337 | else: 338 | assert 0 339 | 340 | if args.cudnn_benchmark: 341 | cudnn.benchmark = True 342 | 343 | time_start = time.time() 344 | acc_top1_reduced_sum = [] 345 | trajectory_success_rate_meter_reduced_sum = [] 346 | MIoU1_meter_reduced_sum = [] 347 | MIoU2_meter_reduced_sum = [] 348 | acc_a0_reduced_sum = [] 349 | acc_aT_reduced_sum = [] 350 | test_times = 1 351 | 352 | for epoch in range(0, test_times): 353 | tmp = epoch 354 | random.seed(tmp) 355 | np.random.seed(tmp) 356 | torch.manual_seed(tmp) 357 | torch.cuda.manual_seed_all(tmp) 358 | 359 | acc_top1, trajectory_success_rate_meter, MIoU1_meter, MIoU2_meter, acc_a0, acc_aT = test(vid_names, test_loader, model.ema_model, args) 360 | 361 | acc_top1_reduced = reduce_tensor(acc_top1.cuda()).item() 362 | trajectory_success_rate_meter_reduced = reduce_tensor(trajectory_success_rate_meter.cuda()).item() 363 | MIoU1_meter_reduced = reduce_tensor(MIoU1_meter.cuda()).item() 364 | MIoU2_meter_reduced = reduce_tensor(MIoU2_meter.cuda()).item() 365 | acc_a0_reduced = reduce_tensor(acc_a0.cuda()).item() 366 | acc_aT_reduced = reduce_tensor(acc_aT.cuda()).item() 367 | 368 | acc_top1_reduced_sum.append(acc_top1_reduced) 369 | trajectory_success_rate_meter_reduced_sum.append(trajectory_success_rate_meter_reduced) 370 | MIoU1_meter_reduced_sum.append(MIoU1_meter_reduced) 371 | MIoU2_meter_reduced_sum.append(MIoU2_meter_reduced) 372 | acc_a0_reduced_sum.append(acc_a0_reduced) 373 | acc_aT_reduced_sum.append(acc_aT_reduced) 374 | 375 | if args.rank == 0: 376 | time_end = time.time() 377 | print('time: ', time_end - time_start) 378 | print('-----------------Mean&Var-----------------------') 379 | print('Val/EpochAcc@1', sum(acc_top1_reduced_sum) / test_times, np.var(acc_top1_reduced_sum)) 380 | print('Val/Traj_Success_Rate', sum(trajectory_success_rate_meter_reduced_sum) / test_times, np.var(trajectory_success_rate_meter_reduced_sum)) 381 | print('Val/MIoU1', sum(MIoU1_meter_reduced_sum) / test_times, np.var(MIoU1_meter_reduced_sum)) 382 | print('Val/MIoU2', sum(MIoU2_meter_reduced_sum) / test_times, np.var(MIoU2_meter_reduced_sum)) 383 | print('Val/acc_a0', sum(acc_a0_reduced_sum) / test_times, np.var(acc_a0_reduced_sum)) 384 | print('Val/acc_aT', sum(acc_aT_reduced_sum) / test_times, np.var(acc_aT_reduced_sum)) 385 | 386 | 387 | if __name__ == "__main__": 388 | main() 389 | --------------------------------------------------------------------------------