├── .gitignore ├── HNTT_data_analysis.ipynb ├── LICENSE ├── LICENSE-MIT └── LICENSE-MSRLA.pdf ├── README.md ├── SECURITY.md ├── barcodes ├── __init__.py ├── barcode_dataset.py ├── barcodes_classifier.py └── create_barcodes.py ├── cross_validation.py ├── evaluate_ANTT_model.py ├── hyperparameters.json ├── plot_ANTT_evaluation.py ├── plot_ANTT_training.py ├── requirements.txt ├── symbolic ├── __init__.py ├── symbolic_classifier.py └── symbolic_dataset.py ├── topdown ├── __init__.py ├── create_topdown_img.py ├── topdown_classifier.py └── topdown_dataset.py ├── train.py ├── utils.py └── visuals ├── __init__.py ├── visuals_classifier.py └── visuals_dataset.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # User specific Python virtual env settings 107 | **/.vscode/settings.json 108 | 109 | # data 110 | data 111 | barcode_data 112 | td_data 113 | 114 | # logs 115 | runs 116 | logs -------------------------------------------------------------------------------- /LICENSE/LICENSE-MIT: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) Microsoft Corporation 4 | 5 | All rights reserved. 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. -------------------------------------------------------------------------------- /LICENSE/LICENSE-MSRLA.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/NTT/ab73f9d0945670054863d53163c65addb0aa2700/LICENSE/LICENSE-MSRLA.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Navigation Turing Test (NTT): Learning to Evaluate Human-Like Navigation 2 | All code and data to reproduce results in Section 5 of [Navigation Turing Test (NTT): Learning to Evaluate Human-Like Navigation [ICML 2021]](https://arxiv.org/abs/2105.09637) 3 | 4 | # Getting Started 5 | 6 | All code developed in Ubuntu 18.04.2 with WSL and Python 3.6.7. 7 | 8 | Optional, but recommended. Setup a virtual environment (e.g. [VirtualEnv](https://virtualenv.pypa.io/).) 9 | 10 | Install Dependencies: 11 | 12 | pip install -r requirements.txt 13 | 14 | # Human Navigation Turing Test (HNTT) 15 | To reproduce our analysis of responses to the HNTT included in Section 5.1 of the ICML paper: 16 | 17 | Download HNTT_data.csv from [this link](https://icml2021.z5.web.core.windows.net/HNTT_data.zip) (30KiB) 18 | 19 | Step through HNTT_data_analysis.ipynb notebook 20 | 21 | ## HNTT Survey Templates and Videos 22 | The HNTT survey templates, with the answer key embedded, can be downloaded from [this link](https://icml2021.z5.web.core.windows.net/icml2021-hntt-survey-templates.zip) (301KiB) 23 | 24 | The corresponding HNTT videos can be downloaded from [this link](https://icml2021.z5.web.core.windows.net/icml2021-hntt-videos.zip) (134MiB) 25 | 26 | # Automated Navigation Turing Test (ANTT) 27 | ## Training ANTT Models (Section 3.3) 28 | To train ANTT models, download the training dataset from [this link](https://icml2021.z5.web.core.windows.net/ICML2021-train-data.zip) (1.95 GiB) then run: 29 | 30 | python train.py --model-type ['visuals', 'symbolic', 'topdown', 'barcode'] --human-train --human-test --agent-train --agent-test 31 | 32 | Alternatively, to run a hyperparameter sweep with 5-fold cross validation, first update hyperparameters.json, then run: 33 | 34 | python cross_validation.py --model-type ['visuals', 'symbolic', 'topdown', 'barcode'] --human-dirs --agent-dirs 35 | 36 | To see all the parameters along with their default values, run `python cross_validation.py --help`. 37 | 38 | To monitor training runs: 39 | 40 | tensorboard --logdir ./logs/ 41 | 42 | To plot learning curves with variance (e.g. to reproduce figure 2 in the paper): 43 | 44 | python plot_ANTT_training.py 45 | 46 | [Optional] To reproduce the barcode or topdown data from the raw trajectories, run: 47 | 48 | python topdown/create_topdown_img.py --folders --outdir 49 | python barcodes/create_barcodes.py --indir --outdir 50 | 51 | ## Evaluation (Section 5.2) 52 | 53 | To reproduce ANTT analysis included in Section 5.2 of the ICML paper: 54 | 55 | Download HNTT_data.csv from [this link](https://icml2021.z5.web.core.windows.net/HNTT_data.zip) (30KiB) 56 | 57 | Download evaluation dataset from [this link](https://icml2021.z5.web.core.windows.net/ICML2021-eval-data.zip) (264MiB) 58 | 59 | Then either: 60 | + Download trained models (.pt files) and saved model output (.pkl) from [this link](https://icml2021.z5.web.core.windows.net/ICML2021-trained-models.zip) (9GiB) 61 | + Train your own ANTT models as described above 62 | 63 | To evaluate a trained model: 64 | 65 | python evaluate_ANTT_model.py --path-to-models PATH --model-type ['BARCODE', 'CNN', 'SYMBOLIC', 'TOPDOWN'] 66 | 67 | If model is a recurrent CNN or SYMBOLIC model, also pass --subsequence-length N 68 | 69 | If model has been previously evaluated, its output for each question in the behavioural study will be saved in a .pkl file. For faster re-evaluation (without classifying replays again) add --load-model-output 70 | 71 | To reproduce Figures 9 and 10, plot the evaluation of all ANTT models by: 72 | 73 | python plot_ANTT_evaluation.py 74 | 75 | # Passing the Navigation Turing Test 76 | 77 | In later work on ["How Humans Perceive Human-like Behavior in Video Game Navigation"](https://www.microsoft.com/en-us/research/publication/how-humans-perceive-human-like-behavior-in-video-game-navigation/) published at CHI 2022 we presented the first agent to pass the Navigation Turing Test. [Videos of the CHI 2022 agent are available here.](https://icml2021.z5.web.core.windows.net/videos-new-agent.zip) 78 | 79 | # License 80 | Code is licensed under MIT, data and all other content is licensed under Microsoft Research License Agreement (MSR-LA). See LICENSE folder. 81 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /barcodes/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/NTT/ab73f9d0945670054863d53163c65addb0aa2700/barcodes/__init__.py -------------------------------------------------------------------------------- /barcodes/barcode_dataset.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------------------------- 2 | # Copyright (c) 2021 Microsoft Corporation 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and 5 | # associated documentation files (the "Software"), to deal in the Software without restriction, 6 | # including without limitation the rights to use, copy, modify, merge, publish, distribute, 7 | # sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all copies or 11 | # substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT 14 | # NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 15 | # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, 16 | # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 17 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 18 | # -------------------------------------------------------------------------------------------------- 19 | 20 | import os 21 | import numpy as np 22 | from torch.utils.data.dataset import Dataset 23 | import torch.tensor 24 | from PIL import Image 25 | import random 26 | 27 | 28 | def read_trajectories(dirs, label): 29 | files_to_use = [] 30 | if not isinstance(dirs, list): 31 | dirs = [dirs] # if there is only one directory 32 | for dir in dirs: 33 | files = os.listdir(dir) 34 | if 'sets.json' in files: 35 | files.remove('sets.json') 36 | files = [os.path.join(dir, file) for file in files] 37 | files_to_use += files 38 | 39 | barcodes = [] 40 | # For each episode 41 | for filename in files_to_use: 42 | barcode = Image.open(filename) 43 | barcode_data = np.array(barcode) / 255 44 | barcode_data = np.transpose(barcode_data, (2, 0, 1)) 45 | barcodes.append({'data': barcode_data, 'label': label}) 46 | return barcodes 47 | 48 | 49 | class TrajectoryDatasetBarcodes(Dataset): 50 | def __init__(self, human_dirs, agent_dirs, seq_length=None): 51 | # seq_length is only used here for compatibility with the other 52 | # datasets 53 | assert seq_length is None or seq_length == 1, "Barcode data do not have a sequence length, you may need to remove it from the hyperparameter file." 54 | # Label Human trajectories 1.0, Agent trajectories 0.0 55 | self.data = read_trajectories( 56 | human_dirs, 1.0) + read_trajectories(agent_dirs, 0.0) 57 | 58 | def __len__(self): 59 | return len(self.data) 60 | 61 | def __getitem__(self, idx): 62 | barcode = self.data[idx] 63 | data = barcode['data'] 64 | label = barcode['label'] 65 | start_y = random.randint(0, max(0, data.shape[1] - 200)) 66 | cut_barcode = data[:, start_y:start_y + 200, :] 67 | y_shape = cut_barcode.shape[1] 68 | if y_shape < 200: 69 | cut_barcode = np.pad( 70 | cut_barcode, ((0, 0), (0, 200 - y_shape), (0, 0)), mode='edge') 71 | return torch.tensor( 72 | cut_barcode, dtype=torch.float32), torch.tensor( 73 | label, dtype=torch.long) 74 | -------------------------------------------------------------------------------- /barcodes/barcodes_classifier.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------------------------- 2 | # Copyright (c) 2021 Microsoft Corporation 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and 5 | # associated documentation files (the "Software"), to deal in the Software without restriction, 6 | # including without limitation the rights to use, copy, modify, merge, publish, distribute, 7 | # sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all copies or 11 | # substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT 14 | # NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 15 | # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, 16 | # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 17 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 18 | # -------------------------------------------------------------------------------------------------- 19 | 20 | import torch 21 | from torch import nn 22 | from torchvision import models 23 | 24 | 25 | class BarcodesClassifier(nn.Module): 26 | def __init__(self, device, dropout=0.0, hidden=32): 27 | super(BarcodesClassifier, self).__init__() 28 | 29 | vgg16 = models.vgg16(pretrained=True) 30 | 31 | # freeze convolution weights 32 | for param in vgg16.features.parameters(): 33 | param.requires_grad = False 34 | 35 | # replace last 2 layers (dropout and fc) with a dropout layer and 1 or 36 | # 2 fc layers 37 | num_features = vgg16.classifier[6].in_features 38 | # Remove last 2 layers (dropout and fc) 39 | features = list(vgg16.classifier.children())[:-2] 40 | features.extend([nn.Dropout(dropout)]) 41 | if hidden is not None: 42 | features.extend([nn.Linear(num_features, hidden)]) 43 | num_features = hidden 44 | # Add our layer with 2 outputs 45 | features.extend([nn.Linear(num_features, 2)]) 46 | vgg16.classifier = nn.Sequential( 47 | *features) # Replace the model classifier 48 | 49 | self.model = vgg16 50 | self.model.to(device) 51 | 52 | def forward(self, x): 53 | x = self.model(x.to(torch.float)) 54 | return x 55 | 56 | def loss_function(self, x, y): 57 | return nn.CrossEntropyLoss()(x, y.long()) 58 | 59 | def correct_predictions(self, model_output, labels): 60 | _, predictions = torch.max(model_output.data, 1) 61 | return (predictions == labels).sum().item() 62 | -------------------------------------------------------------------------------- /barcodes/create_barcodes.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------------------------- 2 | # Copyright (c) 2021 Microsoft Corporation 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and 5 | # associated documentation files (the "Software"), to deal in the Software without restriction, 6 | # including without limitation the rights to use, copy, modify, merge, publish, distribute, 7 | # sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all copies or 11 | # substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT 14 | # NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 15 | # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, 16 | # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 17 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 18 | # -------------------------------------------------------------------------------------------------- 19 | 20 | import os 21 | from PIL import Image 22 | import numpy as np 23 | import itertools 24 | import base64 25 | import io 26 | import json 27 | import argparse 28 | 29 | 30 | if __name__ == "__main__": 31 | parser = argparse.ArgumentParser( 32 | description='Takes Bleeding Edge JSON replays and creates "barcodes".', 33 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 34 | parser.add_argument( 35 | '--indir', 36 | type=str, 37 | help='Path to folders with the JSON to convert to barcodes.') 38 | parser.add_argument('--outdir', type=str, default="barcode_data", 39 | help='Output directory.') 40 | 41 | args = parser.parse_args() 42 | 43 | IN_DIR = args.indir 44 | OUT_DIR = args.outdir 45 | 46 | os.makedirs(OUT_DIR, exist_ok=True) 47 | list_in_dirs = os.listdir(IN_DIR) 48 | for f, filename in enumerate(list_in_dirs): 49 | file = os.path.join(IN_DIR, filename) 50 | if filename == 'sets.json': 51 | print(f"{f+1}/{len(list_in_dirs)}: Skipping sets.json") 52 | continue 53 | video = [] 54 | with open(file) as main_file: 55 | for line in itertools.islice(main_file, 0, None, 1): 56 | step = json.loads(line) 57 | key = list(step.keys())[0] 58 | encoded_img = step[key]["Observations"]["Players"][0]["Image"]["ImageBytes"] 59 | decoded_image_data = base64.decodebytes( 60 | encoded_img.encode('utf-8')) 61 | image = Image.open(io.BytesIO(decoded_image_data)) 62 | img = np.array(image) 63 | video.append(img) 64 | 65 | # compute barcodes 66 | videodata = np.array(video) 67 | size = videodata.shape 68 | barcode = np.zeros((size[0], size[2], 3)) 69 | 70 | for t in range(0, size[0]): 71 | frame = videodata[t, :] 72 | x_sum = np.sum(frame, axis=0) 73 | barcode[t] = x_sum / size[1] 74 | 75 | print( 76 | f"{f+1}/{len(list_in_dirs)}: Creating barcode with shape", 77 | barcode.shape) 78 | img = Image.fromarray(barcode.astype(np.uint8)) 79 | png_filename = os.path.splitext(filename)[0] + '.png' 80 | img.save(os.path.join(OUT_DIR, png_filename)) 81 | -------------------------------------------------------------------------------- /cross_validation.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------------------------- 2 | # Copyright (c) 2021 Microsoft Corporation 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and 5 | # associated documentation files (the "Software"), to deal in the Software without restriction, 6 | # including without limitation the rights to use, copy, modify, merge, publish, distribute, 7 | # sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all copies or 11 | # substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT 14 | # NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 15 | # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, 16 | # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 17 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 18 | # -------------------------------------------------------------------------------------------------- 19 | 20 | # Training script for 5-fold cross-validation (human vs scripted) of 21 | # Bleeding Edge gameplay data 22 | 23 | from __future__ import print_function 24 | import json 25 | import argparse 26 | import os 27 | import time 28 | import numpy as np 29 | import torch 30 | from torch.utils.data import DataLoader, Subset 31 | from torch import optim 32 | 33 | from utils import get_model, get_dataset 34 | from train import training_loop 35 | 36 | # Record already loaded dataset to not reload again 37 | LOADED_DATASETS = dict() 38 | 39 | 40 | def validate( 41 | model_type, 42 | human_root_dirs, 43 | agent_root_dirs, 44 | train_subsets, 45 | val_subsets, 46 | log_dir, 47 | args, 48 | hp): 49 | device = torch.device("cuda" if args.cuda else "cpu") 50 | model_class = get_model(model_type) 51 | dataset_class = get_dataset(model_type) 52 | 53 | if hp["sequence_length"] in LOADED_DATASETS: 54 | dataset = LOADED_DATASETS[hp["sequence_length"]] 55 | else: 56 | print("Loading dataset...") 57 | dataset = dataset_class(human_dirs=human_root_dirs, 58 | agent_dirs=agent_root_dirs, 59 | seq_length=hp["sequence_length"]) 60 | LOADED_DATASETS[hp["sequence_length"]] = dataset 61 | 62 | validation_split = len(val_subsets) / \ 63 | (len(train_subsets) + len(val_subsets)) 64 | one_indices = [] 65 | zero_indices = [] 66 | # The built-in __next__ method does not work for our custom datasets 67 | # => enumerate() and other form of iterations do not work either (e.g. for d in dataset) 68 | for i in range(len(dataset)): 69 | if dataset[i][1] == 0: 70 | zero_indices.append(i) 71 | else: 72 | one_indices.append(i) 73 | fold_size_ones = int(validation_split * len(one_indices)) 74 | fold_size_zeros = int(validation_split * len(zero_indices)) 75 | assert len(val_subsets) == 1 76 | val_subset_num = val_subsets[0] 77 | val_indices = zero_indices[val_subset_num * fold_size_zeros: (val_subset_num + 1) * fold_size_zeros] + \ 78 | one_indices[val_subset_num * fold_size_ones: (val_subset_num + 1) * fold_size_ones] 79 | train_indices = [i for i in range(len(dataset)) if i not in val_indices] 80 | train_dataset = Subset(dataset, train_indices) 81 | val_dataset = Subset(dataset, val_indices) 82 | 83 | train_loader = DataLoader( 84 | train_dataset, 85 | batch_size=args.batch_size, 86 | shuffle=True, 87 | drop_last=False, 88 | num_workers=1) 89 | val_loader = DataLoader( 90 | val_dataset, 91 | batch_size=args.batch_size, 92 | shuffle=True, 93 | drop_last=False, 94 | num_workers=1) 95 | 96 | model = model_class(device, hp["dropout"], hp["hidden_size"]).to(device) 97 | optimizer = optim.Adam(model.parameters(), lr=hp["lr"]) 98 | run_name = f"validation-fold-{val_subsets[0]}" 99 | best_acc, test_acc = training_loop( 100 | log_dir, run_name, model, optimizer, train_loader, val_loader, args.log_interval, device, args.epochs) 101 | 102 | print(f"{run_name}: Final test accuracy {test_acc}, best accuracy {best_acc}") 103 | return best_acc 104 | 105 | 106 | def cross_validation(model_type, human_root_dirs, agent_root_dirs, args): 107 | log_dir = os.path.join( 108 | args.log_dir, 109 | f"{model_type}-crossval-{time.strftime('%Y%m%d-%H%M%S')}") 110 | total_subsets = [0, 1, 2, 3, 4] 111 | triedHP = set() 112 | with open(args.hp_info) as hp_file: 113 | hp_info = json.load(hp_file) 114 | allHP = hp_info["allHP"] 115 | defaultHP = hp_info["defaultHP"] 116 | hp_order = hp_info["hp_order"] 117 | 118 | totalHPcombinations = 0 119 | for v in allHP.values(): 120 | if len(v) > 1: 121 | totalHPcombinations += len(v) 122 | print("Total hyperparameter combinations:", totalHPcombinations) 123 | best_acc = 0 124 | best_hp = defaultHP.copy() 125 | 126 | for hp_key in hp_order: 127 | hp = best_hp.copy() 128 | for hp_value in allHP[hp_key]: 129 | hp[hp_key] = hp_value 130 | if tuple(hp.values()) in triedHP: 131 | # This combination of hyperparameters has already been tried 132 | continue 133 | print(f"Starting new cross validation with hyperparameters: {hp}") 134 | accs = [] 135 | sweep_log_dir = os.path.join( 136 | log_dir, '-'.join(map(str, list(hp.values())))) 137 | for s in total_subsets: 138 | print(f"Current cross validation {s+1}/{len(total_subsets)}") 139 | train_subsets = [i for i in total_subsets if i != s] 140 | acc = validate( 141 | model_type, 142 | human_root_dirs, 143 | agent_root_dirs, 144 | train_subsets, 145 | [s], 146 | sweep_log_dir, 147 | args, 148 | hp) 149 | accs.append(acc) 150 | mean_acc = sum(accs) / len(accs) 151 | print( 152 | "Mean best accuracy across all", 153 | len(total_subsets), 154 | "sweeps is", 155 | mean_acc) 156 | if mean_acc > best_acc: 157 | best_acc = mean_acc 158 | best_hp = hp.copy() 159 | print( 160 | "New best mean accuracy of", 161 | best_acc, 162 | "with hyperparameters", 163 | best_hp) 164 | triedHP.add(tuple(hp.values())) 165 | print( 166 | "Final best accuracy", 167 | best_acc, 168 | "with best hyperparameters", 169 | best_hp) 170 | 171 | 172 | if __name__ == "__main__": 173 | parser = argparse.ArgumentParser( 174 | description='5-fold cross-validation for human-agent discriminator on Bleeding Edge trajectories.', 175 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 176 | parser.add_argument('--batch-size', type=int, default=256, 177 | help='Input batch size for training.') 178 | parser.add_argument('--epochs', type=int, default=10, 179 | help='Number of epochs to train.') 180 | parser.add_argument('--no-cuda', action='store_true', default=False, 181 | help='Disables CUDA training.') 182 | parser.add_argument('--seed', type=int, default=1, 183 | help='Random seed.') 184 | parser.add_argument( 185 | '--log-interval', 186 | type=int, 187 | default=10000, 188 | help='Number of batches to wait for before logging training status.') 189 | parser.add_argument('--log-dir', type=str, default="logs", 190 | help='Path to save logs and models.') 191 | parser.add_argument('--model-type', type=str, default='symbolic', 192 | choices=["visuals", "symbolic", "topdown", "barcode"], 193 | help='Name of the classifier to train.') 194 | parser.add_argument( 195 | '--human-dirs', 196 | type=str, 197 | default=['data/ICML2021-train-data/human'], 198 | nargs='+', 199 | help='List of directories to human data.') 200 | parser.add_argument( 201 | '--agent-dirs', 202 | type=str, 203 | nargs='+', 204 | help='List of directories to agent data.', 205 | default=[ 206 | 'data/ICML2021-train-data/hybrid/checkpoint_12700', 207 | 'data/ICML2021-train-data/hybrid/checkpoint_11400', 208 | 'data/ICML2021-train-data/symbolic/checkpoint10900', 209 | 'data/ICML2021-train-data/symbolic/checkpoint11600']) 210 | parser.add_argument( 211 | '--hp-info', 212 | type=str, 213 | default="./hyperparameters.json", 214 | help="File path with all the hyperparameters info.") 215 | 216 | args = parser.parse_args() 217 | args.cuda = not args.no_cuda and torch.cuda.is_available() 218 | 219 | torch.manual_seed(args.seed) 220 | np.random.seed(args.seed) 221 | 222 | cross_validation(args.model_type, args.human_dirs, args.agent_dirs, args) 223 | -------------------------------------------------------------------------------- /evaluate_ANTT_model.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------------------------- 2 | # Copyright (c) 2021 Microsoft Corporation 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and 5 | # associated documentation files (the "Software"), to deal in the Software without restriction, 6 | # including without limitation the rights to use, copy, modify, merge, publish, distribute, 7 | # sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all copies or 11 | # substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT 14 | # NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 15 | # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, 16 | # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 17 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 18 | # -------------------------------------------------------------------------------------------------- 19 | 20 | # Script to evaluate trained ANTT model on video pairs used in HNTT behavioural study 21 | # Generates data for reproducing Table 2 in the appendix 22 | 23 | import argparse 24 | import os 25 | import torch 26 | import numpy as np 27 | import pandas as pd 28 | import pickle 29 | from sklearn.metrics import accuracy_score 30 | from scipy.stats import spearmanr 31 | 32 | parser = argparse.ArgumentParser( 33 | description='Script to evaluate a trained ANTT model', 34 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 35 | parser.add_argument( 36 | '--path-to-eval-data', 37 | type=str, 38 | default='./data/ICML2021-eval-data', 39 | metavar='STR', 40 | help='Path to folder of trajectories in format required by ANTT models') 41 | parser.add_argument( 42 | '--path-to-models', 43 | type=str, 44 | default='./data/ICML2021-trained-models/SYM-FF', 45 | metavar='STR', 46 | help='Path to folder of trained models (.pt) or saved model outputs (.pkl)') 47 | parser.add_argument( 48 | '--model-type', 49 | choices=[ 50 | 'BARCODE', 51 | 'CNN', 52 | 'SYMBOLIC', 53 | 'TOPDOWN'], 54 | default='SYMBOLIC', 55 | help='Type of model to be evaluated') 56 | parser.add_argument( 57 | '--subsequence-length', 58 | type=int, 59 | default=1, 60 | metavar='INT', 61 | help='length of subsequence input to recurrent CNN or SYMBOLIC models') 62 | parser.add_argument('--load-model-output', action='store_true', default=False, 63 | help='Load saved model output') 64 | args = parser.parse_args() 65 | 66 | if args.model_type == "BARCODE": 67 | from barcodes.barcodes_classifier import BarcodesClassifier as modelClass 68 | from PIL import Image 69 | elif args.model_type == "CNN": 70 | from visuals.visuals_classifier import VisualsClassifier as modelClass 71 | import base64 72 | import json 73 | import io 74 | import itertools 75 | from PIL import Image 76 | elif args.model_type == "SYMBOLIC": 77 | from symbolic.symbolic_classifier import SymbolicClassifier as modelClass 78 | from symbolic.symbolic_dataset import read_trajectories 79 | elif args.model_type == "TOPDOWN": 80 | from topdown.topdown_classifier import TopdownClassifier as modelClass 81 | import torchvision 82 | 83 | # Each sublist is a pair of trajectories shown to study participants 84 | user_study_1_human_hybrid = [["___ReplayDebug-Map_Rooftops_Seeds_Main-2021.01.15-12.17.36", 85 | "___ReplayDebug-Map_Rooftops_Seeds_Main-2020.12.16-16.08.06"], 86 | ["___ReplayDebug-Map_Rooftops_Seeds_Main-2020.12.15-18.25.22", 87 | "___ReplayDebug-Map_Rooftops_Seeds_Main-2021.01.15-12.23.26"], 88 | ["___ReplayDebug-Map_Rooftops_Seeds_Main-2021.01.15-15.30.37", 89 | "___ReplayDebug-Map_Rooftops_Seeds_Main-2020.12.17-11.40.11"], 90 | ["___ReplayDebug-Map_Rooftops_Seeds_Main-2020.12.15-18.14.12", 91 | "___ReplayDebug-Map_Rooftops_Seeds_Main-2021.01.15-15.35.22"], 92 | ["___ReplayDebug-Map_Rooftops_Seeds_Main-2020.12.16-15.57.17", 93 | "___ReplayDebug-Map_Rooftops_Seeds_Main-2021.01.15-14.37.26"], 94 | ["___ReplayDebug-Map_Rooftops_Seeds_Main-2021.01.15-15.24.55", 95 | "___ReplayDebug-Map_Rooftops_Seeds_Main-2020.12.16-16.03.38"]] 96 | 97 | # Labels 1 if 2nd video is human, 0 if human is the 1st video 98 | user_study_1_human_hybrid_labels = np.array([1.0, 0.0, 1.0, 0.0, 0.0, 1.0]) 99 | 100 | # Each sublist is a pair of trajectories shown to study participants 101 | user_study_1_symbolic_hybrid = [["___ReplayDebug-Map_Rooftops_Seeds_Main-2021.01.11-18.11.48", 102 | "___ReplayDebug-Map_Rooftops_Seeds_Main-2021.01.15-15.22.10"], 103 | ["___ReplayDebug-Map_Rooftops_Seeds_Main-2021.01.11-18.16.45", 104 | "___ReplayDebug-Map_Rooftops_Seeds_Main-2021.01.15-15.23.34"], 105 | ["___ReplayDebug-Map_Rooftops_Seeds_Main-2021.01.11-18.14.45", 106 | "___ReplayDebug-Map_Rooftops_Seeds_Main-2021.01.15-15.23.57"], 107 | ["___ReplayDebug-Map_Rooftops_Seeds_Main-2021.01.15-14.34.52", 108 | "___ReplayDebug-Map_Rooftops_Seeds_Main-2021.01.11-18.09.52"]] 109 | 110 | # Each sublist is a pair of trajectories shown to study participants 111 | user_study_2_human_symbolic = [["___ReplayDebug-Map_Rooftops_Seeds_Main-2021.01.11-18.13.07", 112 | "___ReplayDebug-Map_Rooftops_Seeds_Main-2020.12.15-18.23.57"], 113 | ["___ReplayDebug-Map_Rooftops_Seeds_Main-2020.12.17-11.33.59", 114 | "___ReplayDebug-Map_Rooftops_Seeds_Main-2021.01.11-18.26.15"], 115 | ["___ReplayDebug-Map_Rooftops_Seeds_Main-2021.01.11-18.25.18", 116 | "___ReplayDebug-Map_Rooftops_Seeds_Main-2020.12.15-18.21.30"], 117 | ["___ReplayDebug-Map_Rooftops_Seeds_Main-2020.12.17-11.41.46", 118 | "___ReplayDebug-Map_Rooftops_Seeds_Main-2021.01.11-18.16.22"], 119 | ["___ReplayDebug-Map_Rooftops_Seeds_Main-2020.12.16-16.10.12", 120 | "___ReplayDebug-Map_Rooftops_Seeds_Main-2021.01.11-16.54.44"], 121 | ["___ReplayDebug-Map_Rooftops_Seeds_Main-2021.01.11-18.25.42", 122 | "___ReplayDebug-Map_Rooftops_Seeds_Main-2020.12.17-11.38.34"]] 123 | 124 | # Labels 1 if 2nd video is human, 0 if human is the 1st video 125 | user_study_2_human_symbolic_labels = np.array([1.0, 0.0, 1.0, 0.0, 0.0, 1.0]) 126 | 127 | # Each sublist is a pair of trajectories shown to study participants 128 | user_study_2_symbolic_hybrid = [["___ReplayDebug-Map_Rooftops_Seeds_Main-2021.01.11-18.23.05", 129 | "___ReplayDebug-Map_Rooftops_Seeds_Main-2021.01.15-15.34.59"], 130 | ["___ReplayDebug-Map_Rooftops_Seeds_Main-2021.01.11-18.12.27", 131 | "___ReplayDebug-Map_Rooftops_Seeds_Main-2021.01.15-13.34.24"], 132 | ["___ReplayDebug-Map_Rooftops_Seeds_Main-2021.01.11-18.17.19", 133 | "___ReplayDebug-Map_Rooftops_Seeds_Main-2021.01.15-13.34.55"], 134 | ["___ReplayDebug-Map_Rooftops_Seeds_Main-2021.01.15-15.34.05", 135 | "___ReplayDebug-Map_Rooftops_Seeds_Main-2021.01.11-18.17.50"]] 136 | 137 | # Labels 1 if 2nd video is human for all human vs agent comparisons ([0:6] & [10:16]) 138 | # Or if 2nd video is hybrid for hybrid vs symbolic comparisons ([6:10] & [16:20]) 139 | all_study_labels = np.array([1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 140 | 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0]) 141 | 142 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 143 | 144 | 145 | def loadHumanResponses(): 146 | max_vote_user_response = np.array([]) 147 | percentage_user_response = np.array([]) 148 | 149 | df = pd.read_csv('./data/HNTT_data.csv', header=0) 150 | for study in [1, 2]: 151 | study_df = df[df.studyno == study] 152 | for question in range(1, 11): 153 | question_df = study_df[study_df.question_id == question] 154 | 155 | response_counts = {} 156 | value_counts = question_df['subj_resp'].value_counts() 157 | response_counts["A"] = value_counts.get("A", 0) 158 | response_counts["B"] = value_counts.get("B", 0) 159 | 160 | # User study participants vote in favour of left video 161 | if response_counts["A"] > response_counts["B"]: 162 | max_vote_user_response = np.append(max_vote_user_response, 0.0) 163 | elif response_counts["A"] == response_counts["B"]: 164 | # Break ties randomly 165 | max_vote_user_response = np.append( 166 | max_vote_user_response, np.random.randint(2)) 167 | else: 168 | max_vote_user_response = np.append(max_vote_user_response, 1.0) 169 | 170 | # Human or hybrid is on right so get percentage that agree with 171 | # this 172 | if all_study_labels[question - 1] == 1.0: 173 | percentage_user_response = np.append( 174 | percentage_user_response, 175 | response_counts["B"] / (response_counts["A"] + response_counts["B"])) 176 | else: # Human or hybrid is on left so get percentage that agree with this 177 | percentage_user_response = np.append( 178 | percentage_user_response, 179 | response_counts["A"] / (response_counts["A"] + response_counts["B"])) 180 | 181 | print("Max Vote User Responses: {}".format(max_vote_user_response)) 182 | print("Percentage User Responses: {}".format(percentage_user_response)) 183 | 184 | print("Max Vote Human Response Accuracy On All: {}".format( 185 | accuracy_score(max_vote_user_response, all_study_labels))) 186 | print("Max Vote Human Response Accuracy In Human-Agent Q's: {}".format( 187 | accuracy_score(np.append(max_vote_user_response[0:6], max_vote_user_response[10:16]), 188 | np.append(all_study_labels[0:6], all_study_labels[10:16])))) 189 | print("Max Vote Human Response Accuracy Picking Hybrid Agent In Hybrid-Symbolic Agent Q's: {}".format( 190 | accuracy_score(np.append(max_vote_user_response[6:10], max_vote_user_response[16:20]), 191 | np.append(all_study_labels[6:10], all_study_labels[16:20])))) 192 | print("------------------------------------------------------------") 193 | 194 | return max_vote_user_response, percentage_user_response 195 | 196 | 197 | if __name__ == "__main__": 198 | print("LOADING HUMAN USER STUDY RESPONSES TO COMPARE MODEL OUTPUT AGAINST") 199 | max_vote_user_response, percentage_user_response = loadHumanResponses() 200 | 201 | # Initialise lists to store stats for every model in directory 202 | ground_truth_accuracy_list = [] 203 | human_agent_userlabel_accuracy_list = [] 204 | hybrid_symbolic_userlabel_accuracy_list = [] 205 | spearman_rank_human_agent = [] 206 | spearman_rank_hybrid_symbolic = [] 207 | 208 | # Loop over all trained models in directory 209 | for filename in os.listdir(args.path_to_models): 210 | if not filename.endswith(".pt"): 211 | continue 212 | PATH_TO_MODEL = os.path.join(args.path_to_models, filename) 213 | PATH_TO_MODEL_OUTPUT = os.path.join(args.path_to_models, filename[:-3] + "-model_output.pkl") 214 | 215 | if args.load_model_output: 216 | print("LOADING SAVED OUTPUT FOR MODEL: {}".format(PATH_TO_MODEL)) 217 | print("FROM: {}".format(PATH_TO_MODEL_OUTPUT)) 218 | model_output_dict = pickle.load(open(PATH_TO_MODEL_OUTPUT, "rb")) 219 | else: 220 | print("LOADING TRAINED MODEL: {}".format(PATH_TO_MODEL)) 221 | model = modelClass(device).to(device) 222 | model.load_state_dict(torch.load(PATH_TO_MODEL, 223 | map_location=device)) 224 | model.eval() # Do not update params of model 225 | # Create empty dictionary to fill then save 226 | model_output_dict = {} 227 | 228 | # For every pair of trajectories shown to human participants predict most human-like trajectory 229 | # For models that classify only one trajectory, classify both trajectories separately 230 | # then pick the one given highest probability of being human 231 | model_predictions = np.array([]) 232 | percentage_model = np.array([]) 233 | for j, traj_pair in enumerate(user_study_1_human_hybrid + 234 | user_study_1_symbolic_hybrid + 235 | user_study_2_human_symbolic + 236 | user_study_2_symbolic_hybrid): 237 | percentage_humanlike = [] 238 | for traj in traj_pair: 239 | if not args.load_model_output: 240 | model_output_dict[traj] = [] 241 | if args.model_type == "BARCODE": 242 | # load the barcode corresponding to this trajectory 243 | in_barcode = os.path.join( 244 | args.path_to_eval_data, "barcodes", traj + 'Trajectories.png') 245 | 246 | img = Image.open(in_barcode) 247 | img = np.array(img) / 255 248 | img = np.transpose(img, (2, 0, 1)) 249 | print("barcode trajectory shape:", img.shape) 250 | 251 | with torch.no_grad(): 252 | human_count = 0 253 | agent_count = 0 254 | 255 | # sample four random 320x200 windows from the barcode 256 | for i in range(0, 4): 257 | if img.shape[1] - 200 < 0: 258 | start_y = 0 259 | else: 260 | start_y = np.random.randint( 261 | 0, img.shape[1] - 200) 262 | cut_barcode = img[:, start_y:start_y + 200, :] 263 | y_shape = cut_barcode.shape[1] 264 | if y_shape < 200: 265 | cut_barcode = np.pad( 266 | cut_barcode, ((0, 0), (0, 200 - y_shape), (0, 0)), mode='edge') 267 | 268 | input_bc = torch.Tensor(cut_barcode) 269 | input_bc = torch.unsqueeze(input_bc, 0).to(device) 270 | 271 | if args.load_model_output: 272 | model_output = model_output_dict[traj][0] 273 | else: 274 | model_output = model(input_bc) 275 | model_output_dict[traj].append(model_output) 276 | 277 | _, prediction = torch.max(model_output.data, 1) 278 | 279 | if prediction == 1: 280 | human_count += 1 281 | else: 282 | agent_count += 1 283 | 284 | percentage_humanlike.append( 285 | human_count / (human_count + agent_count)) 286 | 287 | elif args.model_type == "TOPDOWN": 288 | transform = torchvision.transforms.Compose( 289 | [torchvision.transforms.Resize((512, 512)), 290 | torchvision.transforms.ToTensor(), 291 | torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 292 | 293 | PATH_TO_IMAGE_FOLDER = os.path.join( 294 | args.path_to_eval_data, 'topdown_320x200', traj + 'Trajectories.json/') 295 | data = torchvision.datasets.ImageFolder( 296 | root=PATH_TO_IMAGE_FOLDER, transform=transform) 297 | if args.load_model_output: 298 | model_output = model_output_dict[traj][0] 299 | else: 300 | model_output = model(torch.unsqueeze(data[0][0], 0)) 301 | model_output_dict[traj].append(model_output) 302 | 303 | # model_output is in the range [-1,1], so normalise to 304 | # [0,1] like all other models 305 | normalised_model_output = ( 306 | model_output.data[0][1].item() + 1) / 2 307 | percentage_humanlike.append(normalised_model_output) 308 | 309 | elif args.model_type == "CNN": 310 | PATH_TO_TRAJECTORY = os.path.join( 311 | args.path_to_eval_data, 312 | 'study_videos_cut_jpg', 313 | traj + 'Trajectories.json') 314 | with open(PATH_TO_TRAJECTORY) as main_file: 315 | video = [] 316 | for line in itertools.islice(main_file, 0, None, 10): 317 | step = json.loads(line) 318 | key = list(step.keys())[0] 319 | 320 | encoded_img = step[key]["Observations"]["Players"][0]["Image"]["ImageBytes"] 321 | decoded_image_data = base64.decodebytes( 322 | encoded_img.encode('utf-8')) 323 | image = Image.open(io.BytesIO(decoded_image_data)) 324 | img = np.array(image) 325 | video.append(img) 326 | 327 | videodata = np.array(video) / 255 328 | videodata = np.transpose(videodata, (0, 3, 1, 2)) 329 | print("video trajectory shape:", videodata.shape) 330 | 331 | with torch.no_grad(): 332 | human_count = 0 333 | agent_count = 0 334 | number_sequences = len( 335 | video) // args.subsequence_length 336 | for i in range(number_sequences): 337 | sequence_start_idx = i * args.subsequence_length 338 | 339 | input_seq = torch.Tensor( 340 | videodata[sequence_start_idx:sequence_start_idx + args.subsequence_length, :]) 341 | input_seq = torch.unsqueeze( 342 | input_seq, 0).to(device) 343 | 344 | if args.load_model_output: 345 | model_output = model_output_dict[traj][i] 346 | else: 347 | model_output = model(input_seq) 348 | model_output_dict[traj].append( 349 | model_output) 350 | 351 | _, prediction = torch.max(model_output.data, 1) 352 | 353 | if prediction == 1: 354 | human_count += 1 355 | else: 356 | agent_count += 1 357 | 358 | percentage_humanlike.append( 359 | human_count / (human_count + agent_count)) 360 | 361 | elif args.model_type == "SYMBOLIC": 362 | PATH_TO_TRAJECTORY = os.path.join( 363 | args.path_to_eval_data, 364 | 'study_videos_cut_jpg', 365 | traj + 'Trajectories.json') 366 | traj_data = read_trajectories(PATH_TO_TRAJECTORY, -1)[0][0] 367 | 368 | with torch.no_grad(): 369 | human_count = 0 370 | agent_count = 0 371 | number_sequences = len( 372 | traj_data["obs"]) // args.subsequence_length 373 | for i in range(number_sequences): 374 | sequence_start_idx = i * args.subsequence_length 375 | sample_trajectory = traj_data["obs"][sequence_start_idx: 376 | sequence_start_idx + args.subsequence_length] 377 | 378 | if args.load_model_output: 379 | model_output = model_output_dict[traj][i] 380 | else: 381 | model_output = model( 382 | torch.tensor([sample_trajectory])) 383 | model_output_dict[traj].append(model_output) 384 | 385 | if round(model_output.item()) == 1: 386 | human_count += 1 387 | else: 388 | agent_count += 1 389 | 390 | percentage_humanlike.append( 391 | human_count / (human_count + agent_count)) 392 | else: 393 | raise NotImplementedError( 394 | "Model type " + args.model_type + " evaluation not implemented") 395 | 396 | # Model votes left video is more humanlike 397 | if percentage_humanlike[0] > percentage_humanlike[1]: 398 | model_predictions = np.append(model_predictions, 0.0) 399 | elif percentage_humanlike[0] == percentage_humanlike[1]: 400 | # Break ties randomly 401 | model_predictions = np.append( 402 | model_predictions, np.random.randint(2)) 403 | else: # Model votes right video is more humanlike 404 | model_predictions = np.append(model_predictions, 1.0) 405 | 406 | # Human or hybrid is on right so get percentage that agree with 407 | # this 408 | if all_study_labels[j] == 1.0: 409 | percentage_model = np.append( 410 | percentage_model, percentage_humanlike[1]) 411 | else: # Human or hybrid is on left so get percentage that agree with this 412 | percentage_model = np.append( 413 | percentage_model, percentage_humanlike[0]) 414 | 415 | # Save model output to enable faster stats re-running 416 | pickle.dump(model_output_dict, open(PATH_TO_MODEL_OUTPUT, "wb")) 417 | print("Ground Truth Labels: {}".format(all_study_labels)) 418 | print("Model Predictions: {}".format(model_predictions)) 419 | 420 | # Calculate model accuracy on held-out test dataset compared to ground 421 | # truth label (only on human vs agent examples) 422 | ground_truth_accuracy = accuracy_score(np.append(user_study_1_human_hybrid_labels, user_study_2_human_symbolic_labels), np.append( 423 | model_predictions[0:6], model_predictions[10:16])) # 1st 6 questions in both studies are human vs agent 424 | print('Per Trajectory Model Accuracy With Ground Truth Labels: {:.4f}'.format( 425 | ground_truth_accuracy)) 426 | ground_truth_accuracy_list.append(ground_truth_accuracy) 427 | 428 | model_accuracy_userlabels_human_agent = accuracy_score(np.append( 429 | max_vote_user_response[0:6], max_vote_user_response[10:16]), np.append(model_predictions[0:6], model_predictions[10:16])) 430 | print('Model Accuracy on Human-Agent Comparisons With Max Vote User Study Responses As Labels: {:.4f}'.format( 431 | model_accuracy_userlabels_human_agent)) 432 | human_agent_userlabel_accuracy_list.append( 433 | model_accuracy_userlabels_human_agent) 434 | 435 | # Spearman rank correlation of model predictions to percentage user 436 | # ranking 437 | print(percentage_user_response[0:6]) 438 | print(percentage_user_response[10:16]) 439 | coef, p = spearmanr(np.append(percentage_user_response[0:6], percentage_user_response[10:16]), 440 | np.append(percentage_model[0:6], percentage_model[10:16])) 441 | print( 442 | 'Spearmans correlation coefficient of all human vs agent comparisons: {} (p={})'.format( 443 | coef, 444 | p)) 445 | if not np.isnan(coef): 446 | spearman_rank_human_agent.append(coef) 447 | 448 | model_accuracy_userlabels_hybrid_symbolic = accuracy_score(np.append( 449 | max_vote_user_response[6:10], max_vote_user_response[16:20]), np.append(model_predictions[6:10], model_predictions[16:20])) 450 | print('Model Accuracy on Hybrid-Symbolic Agent Comparisons With Max Vote User Study Responses As Labels: {:.4f}'.format( 451 | model_accuracy_userlabels_hybrid_symbolic)) 452 | hybrid_symbolic_userlabel_accuracy_list.append( 453 | model_accuracy_userlabels_hybrid_symbolic) 454 | 455 | coef, p = spearmanr(np.append(percentage_user_response[6:10], percentage_user_response[16:20]), 456 | np.append(percentage_model[6:10], percentage_model[16:20])) 457 | print( 458 | 'Spearmans correlation coefficient of all hybrid vs symbolic agent comparisons: {} (p={})'.format( 459 | coef, 460 | p)) 461 | if not np.isnan(coef): 462 | spearman_rank_hybrid_symbolic.append(coef) 463 | print("------------------------------------------------------------") 464 | 465 | print("Results Summary From All Models in: {}".format(args.path_to_models)) 466 | print( 467 | "Model Ground Truth Accuracy: Mean {} - STD {}".format( 468 | np.array(ground_truth_accuracy_list).mean(), 469 | np.array(ground_truth_accuracy_list).std())) 470 | 471 | print("Model Accuracy on Human-Agent Comparisons With Max Vote User Study Responses As Labels: Mean {} - STD {}".format( 472 | np.array(human_agent_userlabel_accuracy_list).mean(), np.array(human_agent_userlabel_accuracy_list).std())) 473 | 474 | print("Spearman Rank Correlation Coefficient on Human vs Agent Rankings: Mean {} - STD {}".format( 475 | np.array(spearman_rank_human_agent).mean(), np.array(spearman_rank_human_agent).std())) 476 | 477 | print("Model Accuracy on Hybrid-Symbolic Agent Comparisons With Max Vote User Study Responses As Labels: Mean {} - STD {}".format( 478 | np.array(hybrid_symbolic_userlabel_accuracy_list).mean(), np.array(hybrid_symbolic_userlabel_accuracy_list).std())) 479 | 480 | print("Spearman Rank Correlation Coefficient on Hybrid vs Symbolic Agent Rankings: Mean {} - STD {}".format( 481 | np.array(spearman_rank_hybrid_symbolic).mean(), np.array(spearman_rank_hybrid_symbolic).std())) 482 | -------------------------------------------------------------------------------- /hyperparameters.json: -------------------------------------------------------------------------------- 1 | { 2 | "allHP": { 3 | "sequence_length": [5, 10, 20], 4 | "dropout": [0, 0.50, 0.85], 5 | "hidden_size": [16, 32], 6 | "lr": [0.001] 7 | }, 8 | "defaultHP": { 9 | "sequence_length": 20, 10 | "dropout": 0.85, 11 | "hidden_size": 16, 12 | "lr": 0.001 13 | }, 14 | "hp_order": ["sequence_length", "dropout", "hidden_size", "lr"] 15 | } -------------------------------------------------------------------------------- /plot_ANTT_evaluation.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------------------------- 2 | # Copyright (c) 2021 Microsoft Corporation 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and 5 | # associated documentation files (the "Software"), to deal in the Software without restriction, 6 | # including without limitation the rights to use, copy, modify, merge, publish, distribute, 7 | # sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all copies or 11 | # substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT 14 | # NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 15 | # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, 16 | # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 17 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 18 | # -------------------------------------------------------------------------------------------------- 19 | 20 | # Script to reproduce Figures 9 and 10 in Section 5.2 21 | # Provided data from table 2 in the appendix 22 | # or generated by evaluate_ANTT_model.py 23 | 24 | import pandas as pd 25 | import seaborn as sns 26 | import matplotlib.pyplot as plt 27 | import numpy as np 28 | 29 | # Data from table 2 in the appendix 30 | df = pd.DataFrame({ 31 | 'ANTT Model': ['SYM-FF', 'SYM-GRU', 'VIS-FF', 'VIS-GRU', 'TD-CNN', 'BC-CNN', 32 | 'SYM-FF', 'SYM-GRU', 'VIS-FF', 'VIS-GRU', 'TD-CNN', 'BC-CNN', 33 | 'SYM-FF', 'SYM-GRU', 'VIS-FF', 'VIS-GRU', 'TD-CNN', 'BC-CNN'], 34 | 'Comparison': ["Identity Accuracy", "Identity Accuracy", "Identity Accuracy", 35 | "Identity Accuracy", "Identity Accuracy", "Identity Accuracy", 36 | "Human-Agent", "Human-Agent", "Human-Agent", 37 | "Human-Agent", "Human-Agent", "Human-Agent", 38 | "Hybrid-Symbolic", "Hybrid-Symbolic", "Hybrid-Symbolic", 39 | "Hybrid-Symbolic", "Hybrid-Symbolic", "Hybrid-Symbolic"], 40 | 'Accuracy': [0.85, 0.85, 0.633, 0.767, 0.583, 0.717, 41 | 0.85, 0.85, 0.633, 0.767, 0.583, 0.717, 42 | 0.475, 0.400, 0.225, 0.425, 0.525, 0.475], 43 | 'std': [0.062, 0.082, 0.041, 0.097, 0.075, 0.145, 44 | 0.062, 0.082, 0.041, 0.097, 0.075, 0.145, 45 | 0.166, 0.200, 0.050, 0.127, 0.094, 0.050]}) 46 | 47 | # Bootstrap observations to get std bars 48 | dfCopy = df.copy() 49 | duplicates = 3000 # increase this number to increase precision 50 | for _, row in df.iterrows(): 51 | for times in range(duplicates): 52 | new_row = row.copy() 53 | new_row['Accuracy'] = np.random.normal(row['Accuracy'], row['std']) 54 | dfCopy = dfCopy.append(new_row, ignore_index=True) 55 | 56 | sns.catplot( 57 | x="Comparison", 58 | y="Accuracy", 59 | hue="ANTT Model", 60 | kind="bar", 61 | data=dfCopy) 62 | 63 | # Data from table 2 in the appendix 64 | rankdf = pd.DataFrame({ 65 | 'ANTT Model': ['SYM-FF', 'SYM-GRU', 'VIS-FF', 'VIS-GRU', 'TD-CNN', 'BC-CNN', 66 | 'SYM-FF', 'SYM-GRU', 'VIS-FF', 'VIS-GRU', 'TD-CNN', 'BC-CNN'], 67 | 'Comparison': ["Human-Agent", "Human-Agent", "Human-Agent", 68 | "Human-Agent", "Human-Agent", "Human-Agent", 69 | "Hybrid-Symbolic", "Hybrid-Symbolic", "Hybrid-Symbolic", 70 | "Hybrid-Symbolic", "Hybrid-Symbolic", "Hybrid-Symbolic"], 71 | 'Spearman Rank Correlation': [0.364, 0.173, -0.041, 0.220, 0.222, -0.009, 72 | -0.244, -0.249, -0.165, -0.056, -0.093, -0.095], 73 | 'std': [0.043, 0.049, 0.160, 0.267, 0.059, 0.131, 74 | 0.252, 0.210, 0.286, 0.331, 0.149, 0.412]}) 75 | 76 | # Bootstrap observations to get std bars 77 | rankdfCopy = rankdf.copy() 78 | duplicates = 3000 # increase this number to increase precision 79 | for index, row in rankdf.iterrows(): 80 | for times in range(duplicates): 81 | new_row = row.copy() 82 | new_row['Spearman Rank Correlation'] = np.random.normal( 83 | row['Spearman Rank Correlation'], row['std']) 84 | rankdfCopy = rankdfCopy.append(new_row, ignore_index=True) 85 | 86 | sns.catplot( 87 | x="Comparison", 88 | y="Spearman Rank Correlation", 89 | hue="ANTT Model", 90 | kind="bar", 91 | data=rankdfCopy) 92 | 93 | plt.show() 94 | -------------------------------------------------------------------------------- /plot_ANTT_training.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------------------------- 2 | # Copyright (c) 2021 Microsoft Corporation 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and 5 | # associated documentation files (the "Software"), to deal in the Software without restriction, 6 | # including without limitation the rights to use, copy, modify, merge, publish, distribute, 7 | # sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all copies or 11 | # substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT 14 | # NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 15 | # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, 16 | # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 17 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 18 | # -------------------------------------------------------------------------------------------------- 19 | 20 | # Script to reproduce Figure 2 in Section 3.3: learning curves of ANTT models 21 | 22 | import os 23 | import matplotlib.pyplot as plt 24 | import pandas as pd 25 | import seaborn as sns 26 | 27 | 28 | def main(): 29 | path = os.path.abspath('.') # Run from pwd or specify a path here 30 | print("Plotting data from ", path) 31 | 32 | # list all subfolders of the folder - each subfolder is considered an experiment and subfolders 33 | # within that subfolder are separate runs of that experiment 34 | list_subfolders_with_paths = [ 35 | f.path for f in os.scandir(path) if f.is_dir()] 36 | print("Found following experiments: ", list_subfolders_with_paths) 37 | 38 | experiment_names = [] 39 | colours = ['red', 'green', 'blue', 'orange', 'pink', 'yellow', 'black'] 40 | fig, axes = plt.subplots(2, 2, figsize=(20, 10), sharey=False) 41 | for experiment, color in zip(list_subfolders_with_paths, colours): 42 | print('{} = {}'.format(color, experiment)) 43 | run_cvss = [f.path for f in os.scandir(experiment)] 44 | experiment_name = os.path.basename(os.path.normpath(experiment)) 45 | experiment_names.append(experiment_name) 46 | 47 | run_dfs = [] 48 | for run in run_cvss: 49 | run_data_frame = pd.read_csv(run) 50 | run_dfs.append(run_data_frame) 51 | 52 | experiment_df = pd.concat(run_dfs) 53 | sns.lineplot( 54 | ax=axes[0][0], 55 | data=experiment_df, 56 | x='epoch', 57 | y='train_loss', 58 | ci='sd', 59 | legend='brief', 60 | label=experiment_name) 61 | sns.lineplot( 62 | ax=axes[0][1], 63 | data=experiment_df, 64 | x='epoch', 65 | y='train_acc', 66 | ci='sd', 67 | legend='brief', 68 | label=experiment_name) 69 | sns.lineplot( 70 | ax=axes[1][0], 71 | data=experiment_df, 72 | x='epoch', 73 | y='val_loss', 74 | ci='sd', 75 | legend='brief', 76 | label=experiment_name) 77 | sns.lineplot( 78 | ax=axes[1][1], 79 | data=experiment_df, 80 | x='epoch', 81 | y='val_acc', 82 | ci='sd', 83 | legend='brief', 84 | label=experiment_name) 85 | 86 | plt.show() 87 | 88 | 89 | if __name__ == '__main__': 90 | main() 91 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jupyter 2 | pandas==1.1.5 3 | pygments==2.9.0 4 | rsa==4.7 5 | scikit-learn==0.24.0 6 | seaborn==0.11.1 7 | tensorboard==2.5.0 8 | torch==1.7.1 9 | torchvision==0.8.2 10 | -------------------------------------------------------------------------------- /symbolic/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/NTT/ab73f9d0945670054863d53163c65addb0aa2700/symbolic/__init__.py -------------------------------------------------------------------------------- /symbolic/symbolic_classifier.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------------------------- 2 | # Copyright (c) 2021 Microsoft Corporation 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and 5 | # associated documentation files (the "Software"), to deal in the Software without restriction, 6 | # including without limitation the rights to use, copy, modify, merge, publish, distribute, 7 | # sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all copies or 11 | # substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT 14 | # NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 15 | # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, 16 | # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 17 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 18 | # -------------------------------------------------------------------------------------------------- 19 | 20 | import torch 21 | import torch.utils.data 22 | from torch import nn 23 | from torch.nn import functional as F 24 | 25 | 26 | class SymbolicClassifier(nn.Module): 27 | def __init__(self, device, dropout=0, hidden_size=32): 28 | super(SymbolicClassifier, self).__init__() 29 | self.hidden_size = hidden_size 30 | self.dropout = dropout 31 | 32 | self.gru_enc = nn.GRU(input_size=3, # xyz location 33 | hidden_size=self.hidden_size, 34 | num_layers=1, 35 | batch_first=True) 36 | for name, param in self.gru_enc.named_parameters(): 37 | if 'bias' in name: 38 | nn.init.constant_(param, 0) 39 | elif 'weight' in name: 40 | nn.init.orthogonal_(param) 41 | 42 | self.fc2 = nn.Linear(self.hidden_size, 1) 43 | self.device = device 44 | 45 | def forward(self, x): 46 | batch_size = x.size(0) 47 | hidden_state = torch.zeros( 48 | (1, batch_size, self.hidden_size), requires_grad=True).to( 49 | self.device) 50 | _, final_gru_output = self.gru_enc(x, hidden_state) 51 | final_gru_output = nn.Dropout(p=self.dropout)(final_gru_output) 52 | return torch.sigmoid(self.fc2(final_gru_output)) 53 | 54 | def loss_function(self, x, y): 55 | # x has shape (1, batch_size, 1) 56 | # y has shape (batch_size) 57 | return F.binary_cross_entropy(x[0, :, 0], y) 58 | 59 | def correct_predictions(self, model_output, labels): 60 | predictions = (model_output[0, :, 0] > 0.5).float() 61 | return (predictions == labels).sum().item() 62 | -------------------------------------------------------------------------------- /symbolic/symbolic_dataset.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------------------------- 2 | # Copyright (c) 2021 Microsoft Corporation 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and 5 | # associated documentation files (the "Software"), to deal in the Software without restriction, 6 | # including without limitation the rights to use, copy, modify, merge, publish, distribute, 7 | # sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all copies or 11 | # substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT 14 | # NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 15 | # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, 16 | # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 17 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 18 | # -------------------------------------------------------------------------------------------------- 19 | 20 | import os 21 | import json 22 | import numpy as np 23 | from torch.utils.data.dataset import Dataset 24 | import torch.tensor 25 | 26 | 27 | def read_trajectories(directories, label): 28 | if not isinstance(directories, list): 29 | directories = [directories] # if there is only one directory 30 | traj_data = [] 31 | 32 | x_range = [np.inf, -np.inf] 33 | y_range = [np.inf, -np.inf] 34 | z_range = [np.inf, -np.inf] 35 | 36 | files_to_use = [] 37 | for dir in directories: 38 | if dir.endswith(".json"): 39 | # Direct path to a trajectory 40 | files_to_use.append(dir) 41 | else: 42 | files = os.listdir(dir) 43 | if 'sets.json' in files: 44 | files.remove('sets.json') 45 | files = [os.path.join(dir, file) for file in files] 46 | files_to_use += files 47 | 48 | # For each episode 49 | for filename in files_to_use: 50 | if filename == "sets.json": 51 | continue 52 | traj_data.append({}) 53 | traj_data[-1]["obs"] = [] 54 | traj_data[-1]["label"] = label 55 | 56 | with open(filename) as main_file: 57 | for line in main_file: 58 | step = json.loads(line) 59 | key = list(step.keys())[0] 60 | 61 | # Normalize x,y,z location of player/agent as obs 62 | player_pos = step[key]["Observations"]["Players"][0]["Position"][0] 63 | 64 | x = player_pos["X"] 65 | y = player_pos["Y"] 66 | z = player_pos["Z"] 67 | obs_list = normalize_pos([x, y, z]) 68 | 69 | if x > x_range[1]: 70 | x_range[1] = x 71 | if x < x_range[0]: 72 | x_range[0] = x 73 | 74 | if y > y_range[1]: 75 | y_range[1] = y 76 | if y < y_range[0]: 77 | y_range[0] = y 78 | 79 | if z > z_range[1]: 80 | z_range[1] = z 81 | if z < z_range[0]: 82 | z_range[0] = z 83 | 84 | traj_data[-1]["obs"].append(obs_list) 85 | 86 | ranges = {"x_range": x_range, "y_range": y_range, "z_range": z_range} 87 | return traj_data, ranges 88 | 89 | 90 | def normalize_pos(pos): 91 | normalized_x = (pos[0] + 126.5697) / 2995.3220 92 | normalized_y = (pos[1] - 10903.4404) / 10060.0283 93 | normalized_z = (pos[2] + 313.1935) / 880.6552 94 | return [normalized_x, normalized_y, normalized_z] 95 | 96 | 97 | class TrajectoryDatasetSymbolic(Dataset): 98 | def __init__(self, human_dirs, agent_dirs, seq_length): 99 | # Label Human trajectories 1.0, Agent trajectories 0.0 100 | self.data = read_trajectories(human_dirs, 1.0)[ 101 | 0] + read_trajectories(agent_dirs, 0.0)[0] 102 | self.seq_length = seq_length 103 | 104 | def __len__(self): 105 | count = 0 106 | for episode in self.data: 107 | count += len(episode["obs"]) // self.seq_length 108 | return count 109 | 110 | def __getitem__(self, idx): 111 | count = 0 112 | for episode in self.data: 113 | sequences_in_episode = len(episode["obs"]) // self.seq_length 114 | if idx >= count + sequences_in_episode: 115 | count += sequences_in_episode 116 | else: 117 | sequence_idx_in_episode = idx - count 118 | sequence_start_idx = sequence_idx_in_episode * self.seq_length 119 | sample_trajectory = episode["obs"][sequence_start_idx: 120 | sequence_start_idx + self.seq_length] 121 | label = episode["label"] 122 | return torch.tensor(sample_trajectory), torch.tensor(label) 123 | -------------------------------------------------------------------------------- /topdown/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/NTT/ab73f9d0945670054863d53163c65addb0aa2700/topdown/__init__.py -------------------------------------------------------------------------------- /topdown/create_topdown_img.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------------------------- 2 | # Copyright (c) 2021 Microsoft Corporation 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and 5 | # associated documentation files (the "Software"), to deal in the Software without restriction, 6 | # including without limitation the rights to use, copy, modify, merge, publish, distribute, 7 | # sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all copies or 11 | # substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT 14 | # NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 15 | # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, 16 | # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 17 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 18 | # -------------------------------------------------------------------------------------------------- 19 | 20 | 21 | import os 22 | import sys 23 | import argparse 24 | import numpy as np 25 | from PIL import Image 26 | 27 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 28 | from symbolic.symbolic_dataset import read_trajectories 29 | 30 | if __name__ == "__main__": 31 | parser = argparse.ArgumentParser( 32 | description='Takes Bleeding Edge JSON replays and creates images of the top-down trajectories.', 33 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 34 | parser.add_argument( 35 | '--background', 36 | type=str, 37 | default="black", 38 | choices=[ 39 | "white", 40 | "black"], 41 | help='Background color of the image (black or white).') 42 | parser.add_argument( 43 | '--resolution', 44 | type=int, 45 | default=[ 46 | 200, 47 | 320], 48 | nargs='+', 49 | help='Image resolution.') 50 | parser.add_argument( 51 | '--folders', 52 | type=str, 53 | nargs='+', 54 | help='Path to folders with the JSON to convert to images.') 55 | parser.add_argument( 56 | '--outdir', 57 | type=str, 58 | default=f"{os.path.dirname(os.path.dirname(os.path.abspath(__file__)))}/td_data", 59 | help='Output directory.') 60 | parser.add_argument( 61 | '--border', 62 | type=int, 63 | default=0, 64 | help='Size of the image border, useful for visuals, not models.') 65 | 66 | args = parser.parse_args() 67 | 68 | im_res = args.resolution 69 | folders_to_convert = args.folders 70 | 71 | assert len(im_res) == 2 72 | assert len(folders_to_convert) > 0 73 | 74 | color = 255 if args.background == "black" else 0 75 | 76 | i = 0 77 | for folder in folders_to_convert: 78 | target_folder = os.path.join( 79 | args.outdir, os.path.basename( 80 | os.path.normpath(folder))) 81 | os.makedirs(target_folder, exist_ok=True) 82 | 83 | all_trajectories, ranges = read_trajectories([folder], "unknown") 84 | for traj_data in all_trajectories: 85 | if args.background == "white": 86 | image = np.ones(im_res, dtype=np.uint8) * 255 87 | elif args.background == "black": 88 | image = np.zeros(im_res, dtype=np.uint8) 89 | else: 90 | raise NotImplementedError( 91 | "Only supported background colors are black and white.") 92 | 93 | for pos in traj_data["obs"]: 94 | arr_x = min(int(pos[0] * im_res[0]), im_res[0] - 1) 95 | arr_y = min(int(pos[1] * im_res[1]), im_res[1] - 1) 96 | image[arr_x, arr_y] = color 97 | for t in range(args.border): 98 | image[:, t] = color 99 | image[t, :] = color 100 | image[im_res[0] - 1 - t, :] = color 101 | image[:, im_res[1] - 1 - t] = color 102 | im = Image.fromarray(image) 103 | target_file = os.path.join(target_folder, "td_" + str(i) + ".png") 104 | i += 1 105 | im.save(target_file) 106 | -------------------------------------------------------------------------------- /topdown/topdown_classifier.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------------------------- 2 | # Copyright (c) 2021 Microsoft Corporation 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and 5 | # associated documentation files (the "Software"), to deal in the Software without restriction, 6 | # including without limitation the rights to use, copy, modify, merge, publish, distribute, 7 | # sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all copies or 11 | # substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT 14 | # NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 15 | # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, 16 | # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 17 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 18 | # -------------------------------------------------------------------------------------------------- 19 | 20 | import torchvision 21 | import torch 22 | import torch.nn as nn 23 | 24 | 25 | class TopdownClassifier(nn.Module): 26 | def __init__(self, device, dropout=0.0, hidden_size=0): 27 | super(TopdownClassifier, self).__init__() 28 | self.device = device 29 | 30 | # Define topdown model 31 | self.vgg16 = torchvision.models.vgg16(pretrained=True) 32 | 33 | # freeze convolution weights 34 | for param in self.vgg16.features.parameters(): 35 | param.requires_grad = False 36 | 37 | # replace last layer with a new layer with only 2 outputs 38 | self.num_features = self.vgg16.classifier[6].in_features 39 | self.features = list( 40 | self.vgg16.classifier.children())[ 41 | :-2] # Remove last layer 42 | self.features.extend([nn.Dropout(dropout)]) 43 | if hidden_size > 0: 44 | self.features.extend([nn.Linear(self.num_features, hidden_size)]) 45 | self.num_features = hidden_size 46 | # Add our layer with 2 outputs 47 | self.features.extend([nn.Linear(self.num_features, 2)]) 48 | self.vgg16.classifier = nn.Sequential( 49 | *self.features) # Replace the model classifier 50 | 51 | self.model = self.vgg16.to(device) 52 | 53 | def forward(self, x): 54 | return self.model(x) 55 | 56 | def loss_function(self, x, y): 57 | return nn.CrossEntropyLoss()(x, y.long()) 58 | 59 | def correct_predictions(self, model_output, labels): 60 | _, predictions = torch.max(model_output.data, 1) 61 | return (predictions == labels).sum().item() 62 | 63 | def load_state_dict(self, f): 64 | self.model.load_state_dict(f) 65 | -------------------------------------------------------------------------------- /topdown/topdown_dataset.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------------------------- 2 | # Copyright (c) 2021 Microsoft Corporation 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and 5 | # associated documentation files (the "Software"), to deal in the Software without restriction, 6 | # including without limitation the rights to use, copy, modify, merge, publish, distribute, 7 | # sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all copies or 11 | # substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT 14 | # NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 15 | # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, 16 | # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 17 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 18 | # -------------------------------------------------------------------------------------------------- 19 | 20 | import os 21 | import numpy as np 22 | from torch.utils.data.dataset import Dataset 23 | import torch.tensor 24 | from PIL import Image 25 | 26 | 27 | def read_trajectories(directories, label): 28 | files_to_use = [] 29 | if not isinstance(directories, list): 30 | directories = [directories] # if there is only one directory 31 | for dir in directories: 32 | files = os.listdir(dir) 33 | if 'sets.json' in files: 34 | files.remove('sets.json') 35 | files = [os.path.join(dir, file) for file in files] 36 | files_to_use += files 37 | images = [] 38 | # For each episode 39 | for filename in files_to_use: 40 | image = Image.open(filename) 41 | image_data = np.array(image) / 255 42 | image_data = np.resize( 43 | image_data, (*image_data.shape, 3)) # add 3 channels 44 | image_data = np.transpose(image_data, (2, 0, 1)) 45 | images.append({'data': image_data, 'label': label}) 46 | return images 47 | 48 | 49 | class TrajectoryDatasetTopdown(Dataset): 50 | def __init__(self, human_dirs, agent_dirs, seq_length=None): 51 | # seq_length is only used here for compatibility with the other 52 | # datasets 53 | assert seq_length is None or seq_length == 1, "Topdown data do not have a sequence length, you may need to remove it from the hyperparameter file." 54 | # Label Human trajectories 1.0, Agent trajectories 0.0 55 | self.data = read_trajectories( 56 | human_dirs, 1.0) + read_trajectories(agent_dirs, 0.0) 57 | 58 | def __len__(self): 59 | return len(self.data) 60 | 61 | def __getitem__(self, idx): 62 | return torch.tensor( 63 | self.data[idx]['data'], dtype=torch.float32), torch.tensor( 64 | self.data[idx]['label'], dtype=torch.float32) 65 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------------------------- 2 | # Copyright (c) 2021 Microsoft Corporation 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and 5 | # associated documentation files (the "Software"), to deal in the Software without restriction, 6 | # including without limitation the rights to use, copy, modify, merge, publish, distribute, 7 | # sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all copies or 11 | # substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT 14 | # NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 15 | # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, 16 | # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 17 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 18 | # -------------------------------------------------------------------------------------------------- 19 | 20 | # Training script for classification (human vs scripted) of Bleeding Edge 21 | # gameplay data 22 | 23 | from __future__ import print_function 24 | import argparse 25 | import csv 26 | import os 27 | import time 28 | import torch 29 | import torch.utils.data 30 | from torch import optim 31 | from torch.utils.tensorboard import SummaryWriter 32 | 33 | from utils import get_model, get_dataset 34 | 35 | 36 | def train( 37 | model, 38 | optimizer, 39 | epoch, 40 | train_loader, 41 | log_interval, 42 | device, 43 | writer, 44 | update_model): 45 | if update_model: 46 | model.train() 47 | else: 48 | model.eval() 49 | train_loss = 0 50 | train_running_correct = 0 51 | for batch_idx, data in enumerate(train_loader): 52 | optimizer.zero_grad() 53 | 54 | obs = data[0].to(device) 55 | model_output = model(obs) 56 | 57 | label = data[1].to(device) 58 | loss = model.loss_function(model_output, label) 59 | train_loss += loss.item() 60 | 61 | train_running_correct += model.correct_predictions(model_output, label) 62 | 63 | if update_model: 64 | loss.backward() 65 | optimizer.step() 66 | 67 | if batch_idx + 1 % log_interval == 0: 68 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tloss: {:.6f}'.format( 69 | epoch, batch_idx * len(data[0]), len(train_loader.dataset), 70 | 100. * batch_idx / len(train_loader), 71 | loss.item() / len(data[0]))) 72 | 73 | avg_train_loss = train_loss / len(train_loader.dataset) 74 | train_accuracy = 100. * train_running_correct / len(train_loader.dataset) 75 | print('====> Epoch: {} Average Train loss: {:.4f} - Accuracy: {:.2f}'.format(epoch, 76 | avg_train_loss, train_accuracy)) 77 | if writer is not None: 78 | writer.add_scalar('Train/LOSS', avg_train_loss, epoch) 79 | writer.add_scalar('Train/ACCURACY', train_accuracy, epoch) 80 | return avg_train_loss, train_accuracy 81 | 82 | 83 | def test(model, optimizer, epoch, test_loader, device, writer): 84 | model.eval() 85 | test_loss = 0 86 | test_running_correct = 0 87 | with torch.no_grad(): 88 | for data in test_loader: 89 | optimizer.zero_grad() 90 | 91 | obs = data[0].to(device) 92 | model_output = model(obs) 93 | 94 | label = data[1].to(device) 95 | loss = model.loss_function(model_output, label) 96 | test_loss += loss.item() 97 | 98 | test_running_correct += model.correct_predictions( 99 | model_output, label) 100 | 101 | test_loss = test_loss / len(test_loader.dataset) 102 | test_accuracy = 100. * test_running_correct / len(test_loader.dataset) 103 | print('====> Epoch: {} Test set loss: {:.4f} - Accuracy: {:.2f}'.format(epoch, 104 | test_loss, test_accuracy)) 105 | if writer is not None: 106 | writer.add_scalar('Test/LOSS', test_loss, epoch) 107 | writer.add_scalar('Test/ACCURACY', test_accuracy, epoch) 108 | return test_loss, test_accuracy 109 | 110 | 111 | def training_loop( 112 | log_dir, 113 | run_name, 114 | model, 115 | optimizer, 116 | train_loader, 117 | test_loader, 118 | log_interval, 119 | device, 120 | total_epochs): 121 | timestamp_str = time.strftime("%Y%m%d-%H%M%S") 122 | # Start Tensorboard Logging 123 | os.makedirs(os.path.join(log_dir, run_name, "tensorboard"), exist_ok=True) 124 | os.makedirs(os.path.join(log_dir, run_name, "models"), exist_ok=True) 125 | writer = SummaryWriter( 126 | os.path.join( 127 | log_dir, 128 | run_name, 129 | "tensorboard", 130 | timestamp_str)) 131 | # A CSV file for storing results to plot 132 | os.makedirs(os.path.join(log_dir, run_name, "csv"), exist_ok=True) 133 | csv_filename = os.path.join( 134 | log_dir, run_name, "csv", f"{timestamp_str}.csv") 135 | csvfile = open(csv_filename, 'w') 136 | fields = ['epoch', 'train_loss', 'train_acc', 'val_loss', 'val_acc'] 137 | csvwriter = csv.writer(csvfile) 138 | csvwriter.writerow(fields) 139 | 140 | # Main training loop 141 | best_acc = 0 142 | test_acc = 0 143 | for epoch in range(0, total_epochs + 1): 144 | train_loss, train_acc = train( 145 | model, optimizer, epoch, train_loader, log_interval, device, writer, epoch != 0) 146 | test_loss, test_acc = test( 147 | model, optimizer, epoch, test_loader, device, writer) 148 | if test_acc > best_acc: 149 | print("New best model with validation accuracy", test_acc) 150 | best_acc = test_acc 151 | torch.save( 152 | model.state_dict(), 153 | os.path.join( 154 | log_dir, 155 | run_name, 156 | "models", 157 | "best.pt")) 158 | torch.save( 159 | model.state_dict(), 160 | os.path.join( 161 | log_dir, 162 | run_name, 163 | "models", 164 | "last.pt")) 165 | csvwriter.writerow([epoch, train_loss, train_acc, test_loss, test_acc]) 166 | 167 | csvfile.close() # Close results csv file 168 | writer.close() # Close Tensorboard logging 169 | return best_acc, test_acc 170 | 171 | 172 | if __name__ == "__main__": 173 | parser = argparse.ArgumentParser( 174 | description='Predict if trajectory is human or scripted.', 175 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 176 | parser.add_argument('--batch-size', type=int, default=256, 177 | help='Input batch size for training.') 178 | parser.add_argument('--epochs', type=int, default=10, 179 | help='Number of epochs to train.') 180 | parser.add_argument('--no-cuda', action='store_true', default=False, 181 | help='Disables CUDA training.') 182 | parser.add_argument('--seed', type=int, default=1, 183 | help='Random seed.') 184 | parser.add_argument( 185 | '--log-interval', 186 | type=int, 187 | default=10, 188 | help='Number of batches to wait for before logging training status.') 189 | parser.add_argument('--log-dir', type=str, default='logs', 190 | help='Path where logs will be saved.') 191 | parser.add_argument('--sequence-length', type=int, default=5, 192 | help='Number of observations input in sequence.') 193 | parser.add_argument('--dropout', type=float, default=0, 194 | help='Dropout likelihood in the classifier model.') 195 | parser.add_argument('--lr', type=float, default=0.001, 196 | help='Learning rate.') 197 | parser.add_argument('--hidden-dim', type=int, default=32, 198 | help='Hidden dimensions in the classifier model.') 199 | parser.add_argument('--model-type', type=str, default='symbolic', 200 | choices=["visuals", "symbolic", "topdown", "barcode"]) 201 | parser.add_argument( 202 | '--human-train', 203 | type=str, 204 | help='Path to human train data.') 205 | parser.add_argument( 206 | '--human-test', 207 | type=str, 208 | help='Path to human test data.') 209 | parser.add_argument( 210 | '--agent-train', 211 | type=str, 212 | help='Path to agent train data.') 213 | parser.add_argument( 214 | '--agent-test', 215 | type=str, 216 | help='Path to agent test data.') 217 | args = parser.parse_args() 218 | args.cuda = not args.no_cuda and torch.cuda.is_available() 219 | 220 | assert args.human_train is not None, "Human train dataset must be specified with --human-train." 221 | assert args.human_test is not None, "Human test dataset must be specified with --human-test." 222 | assert args.agent_train is not None, "Agent train dataset must be specified with --agent-train." 223 | assert args.agent_test is not None, "Agent test dataset must be specified with --agent-test." 224 | 225 | torch.manual_seed(args.seed) 226 | 227 | device = torch.device("cuda" if args.cuda else "cpu") 228 | 229 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} 230 | 231 | dataset = get_dataset(args.model_type) 232 | print("Loading training dataset...") 233 | train_dataset = dataset(human_dirs=args.human_train, 234 | agent_dirs=args.agent_train, 235 | seq_length=args.sequence_length) 236 | train_loader = torch.utils.data.DataLoader( 237 | train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs) 238 | 239 | print("Loading testing dataset...") 240 | test_dataset = dataset(human_dirs=args.human_test, 241 | agent_dirs=args.agent_test, 242 | seq_length=args.sequence_length) 243 | test_loader = torch.utils.data.DataLoader( 244 | test_dataset, batch_size=args.batch_size, shuffle=False, **kwargs) 245 | 246 | print("Initializing model...") 247 | model = get_model( 248 | args.model_type)( 249 | device, 250 | args.dropout, 251 | args.hidden_dim).to(device) 252 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 253 | 254 | run_name = f"{args.model_type}-{time.strftime('%Y%m%d-%H%M%S')}" 255 | training_loop( 256 | args.log_dir, 257 | run_name, 258 | model, 259 | optimizer, 260 | train_loader, 261 | test_loader, 262 | args.log_interval, 263 | device, 264 | args.epochs) 265 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------------------------- 2 | # Copyright (c) 2021 Microsoft Corporation 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and 5 | # associated documentation files (the "Software"), to deal in the Software without restriction, 6 | # including without limitation the rights to use, copy, modify, merge, publish, distribute, 7 | # sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all copies or 11 | # substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT 14 | # NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 15 | # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, 16 | # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 17 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 18 | # -------------------------------------------------------------------------------------------------- 19 | 20 | from symbolic.symbolic_classifier import SymbolicClassifier 21 | from visuals.visuals_classifier import VisualsClassifier 22 | from barcodes.barcodes_classifier import BarcodesClassifier 23 | from topdown.topdown_classifier import TopdownClassifier 24 | from symbolic.symbolic_dataset import TrajectoryDatasetSymbolic 25 | from visuals.visuals_dataset import TrajectoryDatasetVisuals 26 | from barcodes.barcode_dataset import TrajectoryDatasetBarcodes 27 | from topdown.topdown_dataset import TrajectoryDatasetTopdown 28 | 29 | 30 | def get_model(model_type): 31 | if model_type == "visuals": 32 | return VisualsClassifier 33 | if model_type == "symbolic": 34 | return SymbolicClassifier 35 | if model_type == "barcode": 36 | return BarcodesClassifier 37 | if model_type == "topdown": 38 | return TopdownClassifier 39 | raise NotImplementedError 40 | 41 | 42 | def get_dataset(model_type): 43 | if model_type == "visuals": 44 | return TrajectoryDatasetVisuals 45 | if model_type == "symbolic": 46 | return TrajectoryDatasetSymbolic 47 | if model_type == "barcode": 48 | return TrajectoryDatasetBarcodes 49 | if model_type == "topdown": 50 | return TrajectoryDatasetTopdown 51 | raise NotImplementedError 52 | -------------------------------------------------------------------------------- /visuals/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/NTT/ab73f9d0945670054863d53163c65addb0aa2700/visuals/__init__.py -------------------------------------------------------------------------------- /visuals/visuals_classifier.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------------------------- 2 | # Copyright (c) 2021 Microsoft Corporation 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and 5 | # associated documentation files (the "Software"), to deal in the Software without restriction, 6 | # including without limitation the rights to use, copy, modify, merge, publish, distribute, 7 | # sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all copies or 11 | # substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT 14 | # NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 15 | # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, 16 | # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 17 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 18 | # -------------------------------------------------------------------------------------------------- 19 | 20 | import torch 21 | import torch.utils.data 22 | from torch import nn 23 | from torchvision import models 24 | 25 | 26 | class VisualsClassifier(nn.Module): 27 | def __init__(self, device, dropout=0.5, hidden_dim=32): 28 | super(VisualsClassifier, self).__init__() 29 | 30 | self.hidden_dim = hidden_dim 31 | vgg16 = models.vgg16(pretrained=True) 32 | # freeze convolution weights 33 | for param in vgg16.features.parameters(): 34 | param.requires_grad = False 35 | 36 | # replace last layer with a new layer with only 2 outputs 37 | self.num_features = vgg16.classifier[6].in_features 38 | features = list(vgg16.classifier.children())[:-1] # Remove last layer 39 | vgg16.classifier = nn.Sequential( 40 | *features) # Replace the model classifier 41 | 42 | self.cnn_encoder = vgg16 43 | 44 | self.gru_enc = nn.GRU(input_size=self.num_features, 45 | hidden_size=hidden_dim, 46 | num_layers=1, 47 | batch_first=True) 48 | 49 | for name, param in self.gru_enc.named_parameters(): 50 | if 'bias' in name: 51 | nn.init.constant_(param, 0) 52 | elif 'weight' in name: 53 | nn.init.orthogonal_(param) 54 | 55 | self.fc2 = nn.Linear(hidden_dim, 2) 56 | self.dropout = nn.Dropout(p=dropout) 57 | self.device = device 58 | 59 | def forward(self, x): 60 | batch_size = x.size(0) 61 | hidden_state = torch.zeros( 62 | (1, batch_size, self.hidden_dim), requires_grad=True).to( 63 | self.device) 64 | # just treat sequence data as batch for the vgg model 65 | original_shape = x.size() 66 | x = torch.reshape( 67 | x, 68 | (original_shape[0] * original_shape[1], 69 | original_shape[2], 70 | original_shape[3], 71 | original_shape[4])).to( 72 | torch.float) 73 | x = self.cnn_encoder(x) 74 | x = self.dropout(x) 75 | # convert back to batch,sequence,obs 76 | x = torch.reshape( 77 | x, 78 | (original_shape[0], 79 | original_shape[1], 80 | self.num_features)) 81 | _, final_gru_output = self.gru_enc(x, hidden_state) 82 | final_gru_output = torch.squeeze(final_gru_output, 0) 83 | final_gru_output = self.dropout(final_gru_output) 84 | return self.fc2(final_gru_output) 85 | 86 | def loss_function(self, x, y): 87 | return nn.CrossEntropyLoss()(x, y.long()) 88 | 89 | def correct_predictions(self, model_output, labels): 90 | _, predictions = torch.max(model_output.data, 1) 91 | return (predictions == labels).sum().item() 92 | -------------------------------------------------------------------------------- /visuals/visuals_dataset.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------------------------- 2 | # Copyright (c) 2021 Microsoft Corporation 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and 5 | # associated documentation files (the "Software"), to deal in the Software without restriction, 6 | # including without limitation the rights to use, copy, modify, merge, publish, distribute, 7 | # sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all copies or 11 | # substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT 14 | # NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 15 | # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, 16 | # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 17 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 18 | # -------------------------------------------------------------------------------------------------- 19 | 20 | import os 21 | import json 22 | import numpy as np 23 | from torch.utils.data.dataset import Dataset 24 | import torch.tensor 25 | import base64 26 | from PIL import Image 27 | import io 28 | import itertools 29 | 30 | 31 | def read_trajectories(directories, label): 32 | print("Loading trajectories in", directories) 33 | if not isinstance(directories, list): 34 | directories = [directories] # if there is only one directory 35 | traj_data = [] 36 | 37 | files_to_use = [] 38 | for dir in directories: 39 | files = os.listdir(dir) 40 | if 'sets.json' in files: 41 | files.remove('sets.json') 42 | files = [os.path.join(dir, file) for file in files] 43 | files_to_use += files 44 | 45 | # For each episode 46 | for filename in files_to_use: 47 | traj_data.append({}) 48 | traj_data[-1]["obs"] = [] 49 | traj_data[-1]["label"] = label 50 | 51 | with open(filename) as main_file: 52 | video = [] 53 | for line in itertools.islice(main_file, 0, None, 10): 54 | step = json.loads(line) 55 | key = list(step.keys())[0] 56 | 57 | encoded_img = step[key]["Observations"]["Players"][0]["Image"]["ImageBytes"] 58 | decoded_image_data = base64.decodebytes( 59 | encoded_img.encode('utf-8')) 60 | image = Image.open(io.BytesIO(decoded_image_data)) 61 | img = np.array(image) 62 | video.append(img) 63 | 64 | videodata = np.array(video) / 255 65 | videodata = np.transpose(videodata, (0, 3, 1, 2)) 66 | traj_data[-1]["obs"] = videodata 67 | 68 | print("Files loaded: ", len(traj_data)) 69 | return traj_data 70 | 71 | 72 | class TrajectoryDatasetVisuals(Dataset): 73 | def __init__(self, human_dirs, agent_dirs, seq_length): 74 | # Label Human trajectories 1.0, Agent trajectories 0.0 75 | self.data = read_trajectories( 76 | human_dirs, 1.0) + read_trajectories(agent_dirs, 0.0) 77 | self.seq_length = seq_length 78 | 79 | def __len__(self): 80 | count = 0 81 | for episode in self.data: 82 | count += len(episode["obs"]) // self.seq_length 83 | return count 84 | 85 | def __getitem__(self, idx): 86 | count = 0 87 | for episode in self.data: 88 | sequences_in_episode = len(episode["obs"]) // self.seq_length 89 | if idx >= count + sequences_in_episode: 90 | count += sequences_in_episode 91 | else: 92 | sequence_idx_in_episode = idx - count 93 | sequence_start_idx = sequence_idx_in_episode * self.seq_length 94 | sample_trajectory = episode["obs"][sequence_start_idx: 95 | sequence_start_idx + self.seq_length] 96 | label = episode["label"] 97 | return torch.tensor(sample_trajectory), torch.tensor(label) 98 | --------------------------------------------------------------------------------