├── gpt ├── __init__.py ├── config-2M-DDG.py ├── configurator.py ├── aggregated_data_loader.py ├── fast_data_loader.py ├── inference.py └── observation_generator.h ├── utils ├── __init__.py ├── download_dataset.py ├── data_utils.py ├── multi_animation_runner.py ├── wrappers.py ├── data_collection.py └── svg_utils.py ├── finetuning ├── __init__.py ├── filter_data.py ├── data_aggregation_generator.py ├── scenario_generators.py └── delta_data_generator.py ├── tokenizer ├── __init__.py ├── parameters.py ├── cost2go.cpp ├── encoder.cpp └── tokenizer.py ├── docker ├── build.sh ├── requirements.txt └── dockerfile ├── .gitattributes ├── lacam ├── lacam3 │ ├── src │ │ ├── heuristic.cpp │ │ ├── lacam.cpp │ │ ├── lnode.cpp │ │ ├── translator.cpp │ │ ├── dist_table.cpp │ │ ├── utils.cpp │ │ ├── hnode.cpp │ │ ├── refiner.cpp │ │ ├── collision_table.cpp │ │ ├── metrics.cpp │ │ ├── instance.cpp │ │ ├── graph.cpp │ │ ├── scatter.cpp │ │ ├── post_processing.cpp │ │ ├── sipp.cpp │ │ └── pibt.cpp │ ├── include │ │ ├── heuristic.hpp │ │ ├── lacam.hpp │ │ ├── lnode.hpp │ │ ├── translator.hpp │ │ ├── dist_table.hpp │ │ ├── collision_table.hpp │ │ ├── post_processing.hpp │ │ ├── metrics.hpp │ │ ├── hnode.hpp │ │ ├── refiner.hpp │ │ ├── instance.hpp │ │ ├── scatter.hpp │ │ ├── sipp.hpp │ │ ├── graph.hpp │ │ ├── pibt.hpp │ │ ├── utils.hpp │ │ └── planner.hpp │ └── CMakeLists.txt ├── CMakeLists.txt ├── main.cpp └── inference.py ├── run.yaml ├── LICENSE ├── eval_configs ├── 05-puzzles │ ├── maps.yaml │ └── 05-puzzles.yaml ├── 03-warehouse │ ├── maps.yaml │ └── 03-warehouse.yaml ├── 04-movingai │ └── 04-movingai.yaml ├── 02-mazes │ └── 02-mazes.yaml └── 01-random │ └── 01-random.yaml ├── benchmark.py ├── ckpt_configs ├── 02-mazes │ └── 02-mazes.yaml └── 01-random │ └── 01-random.yaml ├── dagger.py ├── example.py ├── million_agents_run.py ├── macro_env.py ├── .gitignore ├── create_env.py └── README.md /gpt/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /finetuning/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | # __init__.py -------------------------------------------------------------------------------- /docker/build.sh: -------------------------------------------------------------------------------- 1 | docker build -t mapf-gpt . 2 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /lacam/lacam3/src/heuristic.cpp: -------------------------------------------------------------------------------- 1 | #include "../include/heuristic.hpp" 2 | 3 | Heuristic::Heuristic(const Instance *_ins, DistTable *_D) : ins(_ins), D(_D) {} 4 | 5 | int Heuristic::get(const Config &Q) 6 | { 7 | auto cost = 0; 8 | for (size_t i = 0; i < ins->N; ++i) cost += D->get(i, Q[i]); 9 | return cost; 10 | } 11 | -------------------------------------------------------------------------------- /lacam/lacam3/src/lacam.cpp: -------------------------------------------------------------------------------- 1 | #include "../include/lacam.hpp" 2 | 3 | Solution solve(const Instance &ins, int verbose, const Deadline *deadline, 4 | int seed) 5 | { 6 | info(1, verbose, deadline, "pre-processing"); 7 | auto planner = Planner(&ins, verbose, deadline, seed); 8 | return planner.solve(); 9 | } 10 | -------------------------------------------------------------------------------- /lacam/lacam3/include/heuristic.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * heuristic definition 3 | */ 4 | 5 | #pragma once 6 | #include "dist_table.hpp" 7 | #include "graph.hpp" 8 | #include "instance.hpp" 9 | 10 | struct Heuristic { 11 | const Instance *ins; 12 | DistTable *D; 13 | 14 | Heuristic(const Instance *_ins, DistTable *_D); 15 | int get(const Config &C); 16 | }; 17 | -------------------------------------------------------------------------------- /lacam/lacam3/include/lacam.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "dist_table.hpp" 4 | #include "graph.hpp" 5 | #include "instance.hpp" 6 | #include "planner.hpp" 7 | #include "post_processing.hpp" 8 | #include "sipp.hpp" 9 | #include "utils.hpp" 10 | 11 | Solution solve(const Instance &ins, const int verbose = 0, 12 | const Deadline *deadline = nullptr, int seed = 0); 13 | -------------------------------------------------------------------------------- /lacam/lacam3/include/lnode.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * low-level node of LaCAM 3 | */ 4 | 5 | #pragma once 6 | #include "graph.hpp" 7 | 8 | // low-level search node 9 | struct LNode { 10 | static int COUNT; 11 | 12 | std::vector who; 13 | Vertices where; 14 | const int depth; 15 | LNode(); 16 | LNode(LNode *parent, int i, Vertex *v); // who and where 17 | ~LNode(); 18 | }; 19 | -------------------------------------------------------------------------------- /lacam/lacam3/include/translator.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * translate between representations by paths and configurations 3 | */ 4 | 5 | #pragma once 6 | #include "graph.hpp" 7 | #include "metrics.hpp" 8 | #include "utils.hpp" 9 | 10 | std::vector translateConfigsToPaths(const std::vector &configs); 11 | std::vector translatePathsToConfigs(const std::vector &paths); 12 | -------------------------------------------------------------------------------- /lacam/lacam3/src/lnode.cpp: -------------------------------------------------------------------------------- 1 | #include "../include/lnode.hpp" 2 | 3 | int LNode::COUNT = 0; 4 | 5 | LNode::LNode() : who(), where(), depth(0) { ++COUNT; } 6 | 7 | LNode::LNode(LNode *parent, int i, Vertex *v) 8 | : who(parent->who), where(parent->where), depth(parent->depth + 1) 9 | { 10 | ++COUNT; 11 | who.push_back(i); 12 | where.push_back(v); 13 | } 14 | 15 | LNode::~LNode(){}; 16 | -------------------------------------------------------------------------------- /lacam/lacam3/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.16) 2 | file(GLOB SRCS "./src/*.cpp") 3 | project(lacam3) 4 | add_library(${PROJECT_NAME} STATIC ${SRCS}) 5 | target_compile_options(${PROJECT_NAME} PUBLIC -O3 -Wall) 6 | target_compile_features(${PROJECT_NAME} PUBLIC cxx_std_17) 7 | target_include_directories(${PROJECT_NAME} INTERFACE ./include) 8 | 9 | set_property(TARGET lacam3 PROPERTY POSITION_INDEPENDENT_CODE ON) -------------------------------------------------------------------------------- /tokenizer/parameters.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | 4 | class InputParameters(BaseModel): 5 | num_agents: int = 13 6 | num_previous_actions: int = 5 7 | agents_radius: int = 5 8 | cost2go_value_limit: int = 20 9 | cost2go_radius: int = 5 10 | context_size: int = 256 11 | mask_greed_action: bool = False 12 | mask_actions_history: bool = False 13 | mask_goal: bool = False 14 | mask_cost2go: bool = False 15 | -------------------------------------------------------------------------------- /lacam/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.16) 2 | project(lacam-project CXX) 3 | 4 | # Enforce shared library suffix as .so on macOS 5 | if(APPLE) 6 | set(CMAKE_SHARED_LIBRARY_SUFFIX ".so") 7 | endif() 8 | 9 | add_subdirectory(./lacam3) 10 | 11 | add_library(lacam SHARED main.cpp) 12 | target_compile_features(lacam PUBLIC cxx_std_17) 13 | target_include_directories(lacam PRIVATE ./lacam3/include) 14 | target_link_libraries(lacam lacam3 libpthread.so) -------------------------------------------------------------------------------- /run.yaml: -------------------------------------------------------------------------------- 1 | container: 2 | image: "mapf-gpt:latest" 3 | command: /opt/conda/envs/pogema/bin/python3 example.py 4 | tty: True 5 | environment: 6 | - "OMP_NUM_THREADS=32" 7 | - "MKL_NUM_THREADS=1" 8 | - "OPENBLAS_NUM_THREADS=1" 9 | - "NVIDIA_VISIBLE_DEVICES=0" 10 | code: 11 | volumes: [] 12 | folder: "." 13 | forward_environment_keys: [ "WANDB_API_KEY" ] 14 | ignore: [ ".git" ] 15 | host_config: 16 | runtime: nvidia 17 | shm_size: '8g' 18 | mem_limit: '128g' 19 | -------------------------------------------------------------------------------- /docker/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.21.5,<=1.26.4 2 | torch>=2.0 3 | pydantic>=1.8.2,<=1.9.1 4 | pytest>=6.2.5,<=7.1.2 5 | pandas<=1.4 6 | tensorboard>=1.15.0 7 | tensorboardx>=2.0 8 | psutil>=5.7.0 9 | threadpoolctl>=2.0.0 10 | colorlog 11 | wandb>=0.12.9,<=0.13.4 12 | pybind11==2.13.1 13 | cppimport>=22.8.2 14 | matplotlib 15 | seaborn 16 | tabulate>=0.8.7,<=0.8.10 17 | protobuf==3.20.3 18 | typing-extensions==4.5.0 19 | importlib-metadata==4.13.0 20 | dask[distributed] 21 | loguru 22 | pogema-toolbox==0.1.1 23 | pogema==1.3.2a4 24 | pyarrow 25 | huggingface_hub 26 | tqdm 27 | -------------------------------------------------------------------------------- /finetuning/filter_data.py: -------------------------------------------------------------------------------- 1 | def filter_data(inputs, gt_actions): 2 | if len(inputs) == 0 or len(gt_actions) == 0: 3 | return None 4 | 5 | known_hashes = set() 6 | filtered_inputs = [] 7 | filtered_gt_actions = [] 8 | 9 | for input, gt_action in zip(inputs, gt_actions): 10 | input_tuple = tuple(input) 11 | input_hash = hash(input_tuple) 12 | 13 | if input_hash not in known_hashes: 14 | known_hashes.add(input_hash) 15 | filtered_inputs.append(input) 16 | filtered_gt_actions.append(gt_action) 17 | 18 | return {"inputs": filtered_inputs, "gt_actions": filtered_gt_actions} 19 | -------------------------------------------------------------------------------- /utils/download_dataset.py: -------------------------------------------------------------------------------- 1 | from huggingface_hub import hf_hub_download 2 | 3 | repo_id = "aandreychuk/MAPF-GPT" 4 | local_dir = "../dataset" 5 | 6 | # Download the 'validation' part 7 | hf_hub_download(repo_id=repo_id, repo_type='dataset', subfolder="validation", filename="chunk_0_part_0.arrow", local_dir=local_dir) 8 | 9 | # Download the 'train' part 10 | # reduce the number of chunks or parts if you don't need the whole dataset 11 | # each file contains 2**21 input tensors and requires 512 MB of disk space 12 | for chunk in range(50): 13 | for part in range(10): 14 | hf_hub_download(repo_id=repo_id, repo_type='dataset', subfolder="train", filename=f"chunk_{chunk}_part_{part}.arrow", local_dir=local_dir) 15 | -------------------------------------------------------------------------------- /lacam/lacam3/include/dist_table.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * distance table with lazy evaluation, using BFS 3 | */ 4 | #pragma once 5 | 6 | #include "graph.hpp" 7 | #include "instance.hpp" 8 | #include "utils.hpp" 9 | 10 | struct DistTable { 11 | const int K; // number of vertices 12 | std::vector> 13 | table; // distance table, index: agent-id & vertex-id 14 | std::vector> OPEN; // search queue 15 | 16 | int get(const int i, const int v_id); // agent, vertex-id 17 | int get(const int i, const Vertex *v); // agent, vertex 18 | 19 | DistTable(const Instance &ins); 20 | DistTable(const Instance *ins); 21 | 22 | void setup(const Instance *ins); // initialization 23 | }; 24 | -------------------------------------------------------------------------------- /lacam/lacam3/include/collision_table.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * fast collision checking, used in SUO and refinner 3 | */ 4 | #pragma once 5 | 6 | #include "graph.hpp" 7 | #include "instance.hpp" 8 | #include "utils.hpp" 9 | 10 | struct CollisionTable { 11 | // vertex, time, agents 12 | std::vector>> body; 13 | std::vector> body_last; 14 | int collision_cnt; 15 | int N; 16 | 17 | CollisionTable(const Instance *ins); 18 | ~CollisionTable(); 19 | 20 | int getCollisionCost(const Vertex *v_from, const Vertex *v_to, 21 | const int t_from); 22 | void enrollPath(const int i, Path &path); 23 | void clearPath(const int i, Path &path); 24 | void shrink(); 25 | }; 26 | -------------------------------------------------------------------------------- /gpt/config-2M-DDG.py: -------------------------------------------------------------------------------- 1 | compile = True 2 | max_iters = 30000 3 | lr_decay_iters = 30000 4 | 5 | batch_size = 4096 6 | n_layer = 5 7 | n_head = 5 8 | n_embd = 160 9 | 10 | block_size = 256 11 | gradient_accumulation_steps = 16 12 | 13 | # init_from = 'resume' 14 | 15 | #DDG settings 16 | dagger_type = "ddg" 17 | device_id = 0 # i.e. MAPF-GPT during data collection is run on device cuda:0 18 | num_workers = 8 # number of workers during DDG 19 | file_size = 50 * 2 ** 11 # number of observation-action pairs collected by each worker during single DDG iteration 20 | max_ratio = 0.25 # maximum ratio of DDG data in training data 21 | train_data_files = ["dataset/train", f"dataset/{dagger_type}"] # to avoid overfitting on DDG data, we use both 1B MAPF-GPT dataset and DDG data for training 22 | valid_data_file = "dataset/validation" -------------------------------------------------------------------------------- /lacam/lacam3/include/post_processing.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * post processing, e.g., calculating solution quality 3 | */ 4 | #pragma once 5 | #include "dist_table.hpp" 6 | #include "instance.hpp" 7 | #include "metrics.hpp" 8 | #include "utils.hpp" 9 | 10 | bool is_feasible_solution(const Instance &ins, const Solution &solution, 11 | const int verbose = 0); 12 | void print_stats(const int verbose, const Deadline *deadline, 13 | const Instance &ins, const Solution &solution, 14 | const double comp_time_ms); 15 | void make_log(const Instance &ins, const Solution &solution, 16 | const std::string &output_name, const double comp_time_ms, 17 | const std::string &map_name, const int seed, 18 | const bool log_short = false // true -> paths not appear 19 | ); 20 | -------------------------------------------------------------------------------- /lacam/lacam3/include/metrics.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * solution evaluation metrics 3 | */ 4 | 5 | #pragma once 6 | 7 | #include "dist_table.hpp" 8 | #include "instance.hpp" 9 | #include "utils.hpp" 10 | 11 | int get_makespan(const Solution &solution); 12 | int get_makespan_paths(const std::vector &solution); 13 | 14 | int get_path_cost(const Solution &solution, int i); // single-agent path cost 15 | int get_path_cost(const Path &path); 16 | int get_sum_of_costs(const Solution &solution); 17 | int get_sum_of_costs_paths(const std::vector &solution); 18 | 19 | int get_path_loss(const Path &path); 20 | int get_sum_of_loss(const Solution &solution); 21 | int get_sum_of_loss(const Solution &solution, std::vector &agents_subset); 22 | int get_sum_of_loss_paths(const std::vector &solution); 23 | 24 | int get_makespan_lower_bound(const Instance &ins, DistTable &D); 25 | int get_sum_of_costs_lower_bound(const Instance &ins, DistTable &D); 26 | -------------------------------------------------------------------------------- /lacam/lacam3/include/hnode.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * high-level node of LaCAM 3 | */ 4 | 5 | #pragma once 6 | 7 | #include "dist_table.hpp" 8 | #include "lnode.hpp" 9 | 10 | // high-level search node 11 | struct HNode; 12 | struct CompareHNodePointers { // for determinism 13 | bool operator()(const HNode *lhs, const HNode *rhs) const; 14 | }; 15 | 16 | struct HNode { 17 | static int COUNT; 18 | 19 | const Config C; 20 | HNode *parent; 21 | std::set neighbor; 22 | 23 | // value 24 | int g; 25 | int h; 26 | int f; 27 | 28 | // for low-level search 29 | std::vector priorities; 30 | std::vector order; 31 | std::queue search_tree; 32 | 33 | HNode(Config _C, DistTable *D, HNode *_parent = nullptr, int _g = 0, 34 | int _h = 0); 35 | ~HNode(); 36 | 37 | LNode *get_next_lowlevel_node(std::mt19937 &MT); 38 | }; 39 | using HNodes = std::vector; 40 | 41 | std::ostream &operator<<(std::ostream &os, const HNode *H); 42 | -------------------------------------------------------------------------------- /lacam/lacam3/src/translator.cpp: -------------------------------------------------------------------------------- 1 | #include "../include/translator.hpp" 2 | 3 | std::vector translateConfigsToPaths(const std::vector &configs) 4 | { 5 | const auto N = configs.front().size(); 6 | auto paths = std::vector(N); 7 | for (auto i = 0; i < N; ++i) { 8 | auto T_i = get_path_cost(configs, i); 9 | for (auto t = 0; t <= T_i; ++t) { 10 | paths[i].push_back(configs[t][i]); 11 | } 12 | } 13 | return paths; 14 | } 15 | 16 | std::vector translatePathsToConfigs(const std::vector &paths) 17 | { 18 | const auto N = paths.size(); 19 | auto T = 0; 20 | for (int i = 0; i < N; ++i) { 21 | T = std::max(T, (int)paths[i].size() - 1); 22 | } 23 | 24 | std::vector configs(T + 1, Config(N, nullptr)); 25 | for (auto i = 0; i < N; ++i) { 26 | const auto T_i = (int)paths[i].size() - 1; 27 | for (auto t = 0; t <= T; ++t) { 28 | configs[t][i] = paths[i][std::min(t, T_i)]; 29 | } 30 | } 31 | return configs; 32 | } 33 | -------------------------------------------------------------------------------- /lacam/lacam3/include/refiner.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Implementation of refiners 3 | * 4 | * references: 5 | * Iterative Refinement for Real-Time Multi-Robot Path Planning. 6 | * Keisuke Okumura, Yasumasa Tamura, and Xavier Défago. 7 | * In Proceedings of IEEE/RSJ International Conference on Intelligent Robots and 8 | * Systems (IROS). 2021. 9 | * 10 | * Anytime multi-agent path finding via large neighborhood search. 11 | * Jiaoyang Li, Zhe Chen, Daniel Harabor, P Stuckey, and Sven Koenig. 12 | * In Proceedings of International Joint Conference on Artificial Intelligence 13 | * (IJCAI). 2021. 14 | */ 15 | 16 | #pragma once 17 | 18 | #include "collision_table.hpp" 19 | #include "dist_table.hpp" 20 | #include "graph.hpp" 21 | #include "instance.hpp" 22 | #include "metrics.hpp" 23 | #include "sipp.hpp" 24 | #include "translator.hpp" 25 | #include "utils.hpp" 26 | 27 | Solution refine(const Instance *ins, const Deadline *deadline, 28 | const Solution &solution, DistTable *D, const int seed = 0, 29 | const int verbose = 0); 30 | -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import pyarrow as pa 2 | import pyarrow.ipc as ipc 3 | import os 4 | 5 | def save_to_arrow(inputs, gt_actions, filepath): 6 | schema = pa.schema([ 7 | ('input_tensors', pa.list_(pa.int8())), 8 | ('gt_actions', pa.int8()) 9 | ]) 10 | 11 | input_tensors_col = pa.array(inputs, type=pa.list_(pa.int8())) 12 | gt_actions_col = pa.array(gt_actions) 13 | table = pa.Table.from_arrays([input_tensors_col, gt_actions_col], schema=schema) 14 | 15 | if not os.path.exists(os.path.dirname(filepath)): 16 | os.makedirs(os.path.dirname(filepath), exist_ok=True) 17 | with open(filepath, "wb") as f: 18 | with ipc.new_file(f, schema) as writer: 19 | writer.write(table) 20 | 21 | 22 | def compute_metrics_diff(left, right): 23 | result = {} 24 | for metric in ['ISR', 'CSR', 'makespan']: 25 | if not (metric in left and metric in right): 26 | continue 27 | result[metric] = right[metric] - left[metric] 28 | # if metric == 'makespan': 29 | # result[metric] *= -1 30 | return result 31 | -------------------------------------------------------------------------------- /lacam/lacam3/include/instance.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * instance definition 3 | */ 4 | #pragma once 5 | #include 6 | 7 | #include "graph.hpp" 8 | #include "utils.hpp" 9 | 10 | struct Instance { 11 | Graph *G; // graph 12 | Config starts; // initial configuration 13 | Config goals; // goal configuration 14 | const uint N; // number of agents 15 | bool delete_graph_after_used; 16 | 17 | Instance(Graph *_G, const Config &_starts, const Config &_goals, uint _N); 18 | Instance(const std::string &map_content, 19 | const std::vector &start_indexes, 20 | const std::vector &goal_indexes); 21 | // for MAPF benchmark 22 | Instance(const std::string &scen_content, const std::string &map_content, 23 | const int _N = 1); 24 | // random instance generation 25 | Instance(const std::string &map_content, const int _N = 1, 26 | const int seed = 0); 27 | ~Instance(); 28 | 29 | // simple feasibility check of instance 30 | bool is_valid(const int verbose = 0) const; 31 | }; 32 | 33 | // solution: a sequence of configurations 34 | using Solution = std::vector; 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Alexey Skrynnik 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /lacam/lacam3/include/scatter.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Implementation of SUO 3 | * 4 | * references: 5 | * Optimizingspaceutilizationformoreeffective multi-robot path planning. 6 | * Shuai D Han and Jingjin Yu. 7 | * In Proceedings of IEEE International Conference on Robotics and Automation 8 | * (ICRA). 2022. 9 | */ 10 | #pragma once 11 | 12 | #include "collision_table.hpp" 13 | #include "dist_table.hpp" 14 | #include "graph.hpp" 15 | #include "utils.hpp" 16 | 17 | struct Scatter { 18 | const Instance *ins; 19 | const Deadline *deadline; 20 | std::mt19937 MT; 21 | const int verbose; 22 | const int N; 23 | const int V_size; 24 | const int T; // makespan lower bound 25 | DistTable *D; 26 | const int cost_margin; 27 | int sum_of_path_length; 28 | 29 | // outcome 30 | std::vector paths; 31 | // agent, vertex-id, next vertex 32 | std::vector> scatter_data; 33 | 34 | // collision data 35 | CollisionTable CT; 36 | 37 | void construct(); 38 | 39 | Scatter(const Instance *_ins, DistTable *_D, const Deadline *_deadline, 40 | const int seed = 0, int _verbose = 0, int _cost_margin = 2); 41 | }; 42 | -------------------------------------------------------------------------------- /gpt/configurator.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from ast import literal_eval 3 | 4 | for arg in sys.argv[1:]: 5 | if '=' not in arg: 6 | # assume it's the name of a config file 7 | assert not arg.startswith('--') 8 | config_file = arg 9 | print(f"Overriding config with {config_file}:") 10 | with open(config_file) as f: 11 | print(f.read()) 12 | exec(open(config_file).read()) 13 | else: 14 | # assume it's a --key=value argument 15 | assert arg.startswith('--') 16 | key, val = arg.split('=') 17 | key = key[2:] 18 | if key in globals(): 19 | try: 20 | # attempt to eval it (e.g. if bool, number, or etc) 21 | attempt = literal_eval(val) 22 | except (SyntaxError, ValueError): 23 | # if that goes wrong, just use the string 24 | attempt = val 25 | # ensure the types match ok 26 | assert type(attempt) == type(globals()[key]) 27 | # cross fingers 28 | print(f"Overriding: {key} = {attempt}") 29 | globals()[key] = attempt 30 | else: 31 | raise ValueError(f"Unknown config key: {key}") 32 | -------------------------------------------------------------------------------- /lacam/lacam3/src/dist_table.cpp: -------------------------------------------------------------------------------- 1 | #include "../include/dist_table.hpp" 2 | 3 | DistTable::DistTable(const Instance &ins) 4 | : K(ins.G->V.size()), table(ins.N, std::vector(K, K)) 5 | { 6 | setup(&ins); 7 | } 8 | 9 | DistTable::DistTable(const Instance *ins) 10 | : K(ins->G->V.size()), table(ins->N, std::vector(K, K)) 11 | { 12 | setup(ins); 13 | } 14 | 15 | void DistTable::setup(const Instance *ins) 16 | { 17 | auto bfs = [&](const int i) { 18 | auto g_i = ins->goals[i]; 19 | auto Q = std::queue({g_i}); 20 | table[i][g_i->id] = 0; 21 | while (!Q.empty()) { 22 | auto n = Q.front(); 23 | Q.pop(); 24 | const int d_n = table[i][n->id]; 25 | for (auto &m : n->neighbor) { 26 | const int d_m = table[i][m->id]; 27 | if (d_n + 1 >= d_m) continue; 28 | table[i][m->id] = d_n + 1; 29 | Q.push(m); 30 | } 31 | } 32 | }; 33 | 34 | auto pool = std::vector>(); 35 | for (size_t i = 0; i < ins->N; ++i) { 36 | pool.emplace_back(std::async(std::launch::async, bfs, i)); 37 | } 38 | } 39 | 40 | int DistTable::get(const int i, const int v_id) { return table[i][v_id]; } 41 | 42 | int DistTable::get(const int i, const Vertex *v) { return get(i, v->id); } 43 | -------------------------------------------------------------------------------- /eval_configs/05-puzzles/maps.yaml: -------------------------------------------------------------------------------- 1 | "puzzle-00": |- 2 | ....# 3 | ..#.# 4 | #.#.. 5 | .###. 6 | ..... 7 | "puzzle-01": |- 8 | ##... 9 | ....# 10 | ...## 11 | #.... 12 | ..#.. 13 | "puzzle-02": |- 14 | .##.. 15 | ..##. 16 | ..... 17 | .##.# 18 | ..... 19 | "puzzle-03": |- 20 | ..... 21 | #..#. 22 | ..... 23 | #..#. 24 | #.... 25 | "puzzle-04": |- 26 | ..... 27 | .#.#. 28 | ##.## 29 | .#.#. 30 | ..... 31 | "puzzle-05": |- 32 | ##.## 33 | ##.## 34 | ..... 35 | .###. 36 | ##### 37 | "puzzle-06": |- 38 | ####. 39 | ###.. 40 | ...#. 41 | .#.#. 42 | .#... 43 | "puzzle-07": |- 44 | ..... 45 | .#.#. 46 | .#.#. 47 | .###. 48 | ....# 49 | "puzzle-08": |- 50 | ..... 51 | .#.#. 52 | ##.## 53 | .#.#. 54 | ..... 55 | "puzzle-09": |- 56 | #.... 57 | ###.. 58 | .###. 59 | ..... 60 | .#### 61 | "puzzle-10": |- 62 | ..#.. 63 | .##.. 64 | ..... 65 | #.### 66 | ..... 67 | "puzzle-11": |- 68 | .###. 69 | ..#.. 70 | #...# 71 | ..#.. 72 | .###. 73 | "puzzle-12": |- 74 | ##.## 75 | ##.## 76 | ..... 77 | ##.## 78 | ##.## 79 | "puzzle-13": |- 80 | .##.# 81 | ..... 82 | #.#.# 83 | ....# 84 | ###.. 85 | "puzzle-14": |- 86 | ..#.. 87 | #...# 88 | ..#.. 89 | #...# 90 | ..#.. 91 | "puzzle-15": |- 92 | ..... 93 | .#.#. 94 | .###. 95 | .###. 96 | ..... -------------------------------------------------------------------------------- /docker/dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:1.13.1-cuda11.6-cudnn8-runtime 2 | 3 | ADD requirements.txt /tmp/ 4 | 5 | RUN apt update && apt install -y \ 6 | build-essential \ 7 | ffmpeg \ 8 | libsm6 \ 9 | libxext6 \ 10 | git \ 11 | vim \ 12 | cmake \ 13 | g++ \ 14 | wget \ 15 | libboost-all-dev \ 16 | pkg-config \ 17 | && conda install -c conda-forge pybind11 18 | 19 | RUN conda create -n pogema python=3.10 -y \ 20 | && conda run -n pogema pip install --no-cache-dir -r /tmp/requirements.txt \ 21 | && conda run -n pogema pip install --no-cache-dir importlib-metadata==4.13.0 --force-reinstall 22 | 23 | RUN wget -q https://gitlab.com/libeigen/eigen/-/archive/3.3.9/eigen-3.3.9.tar.gz \ 24 | && tar -xzf eigen-3.3.9.tar.gz \ 25 | && mkdir -p eigen-3.3.9/build \ 26 | && cd eigen-3.3.9/build \ 27 | && cmake .. \ 28 | && make install \ 29 | && cd ../.. \ 30 | && rm -rf eigen-3.3.9 eigen-3.3.9.tar.gz 31 | 32 | RUN wget https://github.com/microsoft/onnxruntime/releases/download/v1.14.1/onnxruntime-linux-x64-1.14.1.tgz \ 33 | && tar -xf onnxruntime-linux-x64-1.14.1.tgz \ 34 | && cp onnxruntime-linux-x64-1.14.1/lib/* /usr/lib/ && cp onnxruntime-linux-x64-1.14.1/include/* /usr/include/ 35 | 36 | RUN apt-get clean && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* 37 | 38 | ENTRYPOINT ["/bin/bash", "-c", "source /opt/conda/etc/profile.d/conda.sh && conda activate pogema && exec \"$@\"", "--"] 39 | -------------------------------------------------------------------------------- /lacam/lacam3/include/sipp.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * implementation of SIPP 3 | * 4 | * references: 5 | * Sipp: Safe interval path planning for dynamic environments. 6 | * Mike Phillips and Maxim Likhachev. 7 | * In Proceedings of IEEE International Conference on Robotics and Automation 8 | * (ICRA). 2011. 9 | */ 10 | 11 | #pragma once 12 | 13 | #include "collision_table.hpp" 14 | #include "dist_table.hpp" 15 | #include "graph.hpp" 16 | #include "utils.hpp" 17 | 18 | // safe interval 19 | using SI = std::pair; 20 | using SIs = std::vector; 21 | 22 | struct SITable { 23 | std::unordered_map body; 24 | CollisionTable *CT; 25 | 26 | SITable(CollisionTable *_CT); 27 | ~SITable(); 28 | SIs &get(Vertex *v); 29 | }; 30 | 31 | struct SINode { 32 | const int uuid; 33 | const int time_start; 34 | const int time_end; 35 | Vertex *v; 36 | const int t; // arrival time 37 | const int g; 38 | const int f; 39 | SINode *parent; 40 | 41 | SINode(const int uuid, const SI &si, Vertex *_v, int _t, int _g, int _f, 42 | SINode *_parent); 43 | bool operator==(const SINode &other) const; 44 | }; 45 | using SINodes = std::vector; 46 | 47 | struct SINodeHasher { 48 | uint operator()(const SINode &n) const; 49 | }; 50 | 51 | Path sipp(const int i, Vertex *s_i, Vertex *g_i, DistTable *D, 52 | CollisionTable *CT, const Deadline *deadline = nullptr, 53 | const int f_upper_bound = INT_MAX); 54 | 55 | std::ostream &operator<<(std::ostream &os, const SINode *n); 56 | -------------------------------------------------------------------------------- /lacam/lacam3/include/graph.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * graph definition 3 | */ 4 | #pragma once 5 | #include "utils.hpp" 6 | 7 | struct Vertex { 8 | const int id; // index for V in Graph 9 | const int index; // index for U (width * y + x) in Graph 10 | const int x; 11 | const int y; 12 | std::vector neighbor; 13 | 14 | Vertex(int _id, int _index, int _x, int _y); 15 | }; 16 | using Vertices = std::vector; 17 | using Config = std::vector; // locations for all agents 18 | using Path = std::vector; // path 19 | using Paths = std::vector; 20 | 21 | struct Graph { 22 | Vertices V; // without nullptr 23 | Vertices U; // with nullptr, i.e., |U| = width * height 24 | int width; // grid width 25 | int height; // grid height 26 | Graph(); 27 | Graph(const std::string &map_data); // taking map content 28 | ~Graph(); 29 | 30 | int size() const; // the number of vertices, |V| 31 | }; 32 | 33 | inline int manhattanDist(Vertex *a, Vertex *b) 34 | { 35 | return std::abs(a->x - b->x) + std::abs(a->y - b->y); 36 | } 37 | 38 | bool is_same_config( 39 | const Config &C1, 40 | const Config &C2); // check equivalence of two configurations 41 | 42 | // hash function of configuration 43 | // c.f. 44 | // https://stackoverflow.com/questions/10405030/c-unordered-map-fail-when-used-with-a-vector-as-key 45 | struct ConfigHasher { 46 | uint operator()(const Config &C) const; 47 | }; 48 | 49 | std::ostream &operator<<(std::ostream &os, const Vertex *v); 50 | std::ostream &operator<<(std::ostream &os, const Config &Q); 51 | std::ostream &operator<<(std::ostream &os, const Paths &paths); 52 | -------------------------------------------------------------------------------- /eval_configs/03-warehouse/maps.yaml: -------------------------------------------------------------------------------- 1 | wfi_warehouse: |- 2 | !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! 3 | !@@!@@!$$$$$$$$$$!$$$$$$$$$$!$$$$$$$$$$!@@!@@! 4 | !@@!@@!##########!##########!##########!@@!@@! 5 | !@@!@@!$$$$$$$$$$!$$$$$$$$$$!$$$$$$$$$$!@@!@@! 6 | !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! 7 | !@@!@@!$$$$$$$$$$!$$$$$$$$$$!$$$$$$$$$$!@@!@@! 8 | !@@!@@!##########!##########!##########!@@!@@! 9 | !@@!@@!$$$$$$$$$$!$$$$$$$$$$!$$$$$$$$$$!@@!@@! 10 | !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! 11 | !@@!@@!$$$$$$$$$$!$$$$$$$$$$!$$$$$$$$$$!@@!@@! 12 | !@@!@@!##########!##########!##########!@@!@@! 13 | !@@!@@!$$$$$$$$$$!$$$$$$$$$$!$$$$$$$$$$!@@!@@! 14 | !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! 15 | !@@!@@!$$$$$$$$$$!$$$$$$$$$$!$$$$$$$$$$!@@!@@! 16 | !@@!@@!##########!##########!##########!@@!@@! 17 | !@@!@@!$$$$$$$$$$!$$$$$$$$$$!$$$$$$$$$$!@@!@@! 18 | !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! 19 | !@@!@@!$$$$$$$$$$!$$$$$$$$$$!$$$$$$$$$$!@@!@@! 20 | !@@!@@!##########!##########!##########!@@!@@! 21 | !@@!@@!$$$$$$$$$$!$$$$$$$$$$!$$$$$$$$$$!@@!@@! 22 | !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! 23 | !@@!@@!$$$$$$$$$$!$$$$$$$$$$!$$$$$$$$$$!@@!@@! 24 | !@@!@@!##########!##########!##########!@@!@@! 25 | !@@!@@!$$$$$$$$$$!$$$$$$$$$$!$$$$$$$$$$!@@!@@! 26 | !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! 27 | !@@!@@!$$$$$$$$$$!$$$$$$$$$$!$$$$$$$$$$!@@!@@! 28 | !@@!@@!##########!##########!##########!@@!@@! 29 | !@@!@@!$$$$$$$$$$!$$$$$$$$$$!$$$$$$$$$$!@@!@@! 30 | !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! 31 | !@@!@@!$$$$$$$$$$!$$$$$$$$$$!$$$$$$$$$$!@@!@@! 32 | !@@!@@!##########!##########!##########!@@!@@! 33 | !@@!@@!$$$$$$$$$$!$$$$$$$$$$!$$$$$$$$$$!@@!@@! 34 | !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -------------------------------------------------------------------------------- /utils/multi_animation_runner.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from pathlib import Path 3 | 4 | import yaml 5 | from pogema import BatchAStarAgent 6 | 7 | from pogema_toolbox.create_env import Environment 8 | 9 | from pogema_toolbox.registry import ToolboxRegistry 10 | from pogema_toolbox.run_episode import run_episode 11 | 12 | from create_env import create_eval_env 13 | from gpt.inference import MAPFGPTInference, MAPFGPTInferenceConfig 14 | from utils.svg_utils import create_multi_animation 15 | 16 | 17 | def run_episode_algos_to_svg(env, algos, filename='multi.svg'): 18 | histories = [] 19 | for algo in algos: 20 | run_episode(env, algo) 21 | histories.append(env.decompress_history(env.get_history())) 22 | 23 | obstacles = env.get_obstacles(ignore_borders=False) 24 | grid_config = deepcopy(env.grid.config) 25 | create_multi_animation(obstacles, histories=histories, grid_config=grid_config, name=filename) 26 | 27 | 28 | def main(): 29 | env_cfg = Environment( 30 | observation_type="MAPF", 31 | on_target="nothing", 32 | map_name='validation-random-seed-001', 33 | max_episode_steps=32, 34 | num_agents=32, 35 | seed=42, 36 | obs_radius=5, 37 | collision_system="soft", 38 | with_animation=True 39 | ) 40 | 41 | for maps_file in Path("eval_configs").rglob('maps.yaml'): 42 | with open(maps_file, 'r') as f: 43 | maps = yaml.safe_load(f) 44 | ToolboxRegistry.register_maps(maps) 45 | 46 | env = create_eval_env(env_cfg) 47 | algo = MAPFGPTInference(MAPFGPTInferenceConfig(path_to_weights=f'../weights/model-2M.pt', device='cuda')) 48 | run_episode_algos_to_svg(env, [algo, BatchAStarAgent()], filename='out.svg') 49 | 50 | 51 | if __name__ == "__main__": 52 | main() 53 | -------------------------------------------------------------------------------- /lacam/lacam3/include/pibt.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * implementation of PIBT 3 | * 4 | * references: 5 | * Priority Inheritance with Backtracking for Iterative Multi-agent Path 6 | * Finding. Keisuke Okumura, Manao Machida, Xavier Défago & Yasumasa Tamura. 7 | * Artificial Intelligence (AIJ). 2022. 8 | */ 9 | #pragma once 10 | #include "dist_table.hpp" 11 | #include "graph.hpp" 12 | #include "instance.hpp" 13 | #include "scatter.hpp" 14 | #include "utils.hpp" 15 | 16 | struct PIBT { 17 | const Instance *ins; 18 | std::mt19937 MT; 19 | 20 | // solver utils 21 | const int N; // number of agents 22 | const int V_size; 23 | DistTable *D; 24 | 25 | // specific to PIBT 26 | const int NO_AGENT; 27 | std::vector occupied_now; // for quick collision checking 28 | std::vector occupied_next; // for quick collision checking 29 | std::vector> C_next; // next location candidates 30 | std::vector tie_breakers; // random values, used in PIBT 31 | 32 | // swap, used in the LaCAM* paper 33 | bool flg_swap; 34 | 35 | // scatter 36 | Scatter *scatter; 37 | 38 | PIBT(const Instance *_ins, DistTable *_D, int seed = 0, bool _flg_swap = true, 39 | Scatter *_scatter = nullptr); 40 | ~PIBT(); 41 | 42 | bool set_new_config(const Config &Q_from, Config &Q_to, 43 | const std::vector &order); 44 | bool funcPIBT(const int i, const Config &Q_from, Config &Q_to); 45 | int is_swap_required_and_possible(const int ai, const Config &Q_from, 46 | Config &Q_to); 47 | bool is_swap_required(const int pusher, const int puller, 48 | Vertex *v_pusher_origin, Vertex *v_puller_origin); 49 | bool is_swap_possible(Vertex *v_pusher_origin, Vertex *v_puller_origin); 50 | }; 51 | -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import yaml 4 | from pogema_toolbox.create_env import Environment 5 | from pogema_toolbox.eval_utils import initialize_wandb, save_evaluation_results 6 | from pogema_toolbox.evaluator import evaluation 7 | from pogema_toolbox.registry import ToolboxRegistry 8 | 9 | from create_env import create_eval_env 10 | from gpt.inference import MAPFGPTInference, MAPFGPTInferenceConfig 11 | 12 | PROJECT_NAME = "Benchmark" 13 | BASE_PATH = Path("eval_configs") 14 | 15 | 16 | def ensure_weights(eval_config): 17 | for algo_name, algo_cfg in eval_config['algorithms'].items(): 18 | ToolboxRegistry.create_algorithm(algo_cfg['name'], **algo_cfg) 19 | 20 | 21 | def main(disable_wandb=False): 22 | env_cfg_name = "Environment" 23 | ToolboxRegistry.register_env(env_cfg_name, create_eval_env, Environment) 24 | ToolboxRegistry.register_algorithm( 25 | "MAPF-GPT", MAPFGPTInference, MAPFGPTInferenceConfig 26 | ) 27 | 28 | folder_names = [ 29 | "01-random", 30 | "02-mazes", 31 | "03-warehouse", 32 | "04-movingai", 33 | "05-puzzles", 34 | ] 35 | 36 | for folder in folder_names: 37 | maps_path = BASE_PATH / folder / "maps.yaml" 38 | with open(maps_path, "r") as f: 39 | maps = yaml.safe_load(f) 40 | ToolboxRegistry.register_maps(maps) 41 | 42 | config_path = BASE_PATH / folder / f"{Path(folder).name}.yaml" 43 | with open(config_path) as f: 44 | evaluation_config = yaml.safe_load(f) 45 | 46 | # ensuring model weights are downloaded 47 | ensure_weights(evaluation_config) 48 | 49 | eval_dir = BASE_PATH / folder 50 | initialize_wandb(evaluation_config, eval_dir, disable_wandb, PROJECT_NAME) 51 | evaluation(evaluation_config, eval_dir=eval_dir) 52 | save_evaluation_results(eval_dir) 53 | 54 | 55 | if __name__ == "__main__": 56 | main() 57 | -------------------------------------------------------------------------------- /eval_configs/05-puzzles/05-puzzles.yaml: -------------------------------------------------------------------------------- 1 | environment: 2 | name: Environment 3 | with_animation: False 4 | on_target: nothing 5 | max_episode_steps: 128 6 | observation_type: MAPF 7 | collision_system: soft 8 | seed: 9 | grid_search: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 10 | num_agents: 11 | grid_search: [2, 3, 4] 12 | map_name: 13 | grid_search: 14 | [ 15 | puzzle-00, 16 | puzzle-01, 17 | puzzle-02, 18 | puzzle-03, 19 | puzzle-04, 20 | puzzle-05, 21 | puzzle-06, 22 | puzzle-07, 23 | puzzle-08, 24 | puzzle-09, 25 | puzzle-10, 26 | puzzle-11, 27 | puzzle-12, 28 | puzzle-13, 29 | puzzle-14, 30 | puzzle-15, 31 | ] 32 | 33 | algorithms: 34 | MAPF-GPT-2M: 35 | name: MAPF-GPT 36 | parallel_backend: balanced_dask 37 | num_process: 4 38 | path_to_weights: weights/model-2M.pt 39 | MAPF-GPT-6M: 40 | name: MAPF-GPT 41 | parallel_backend: balanced_dask 42 | num_process: 4 43 | path_to_weights: weights/model-6M.pt 44 | 45 | results_views: 46 | TabularView1: 47 | type: tabular 48 | drop_keys: 49 | [ 50 | seed, 51 | map_name, 52 | ISR, 53 | ep_length, 54 | runtime, 55 | avg_agents_density, 56 | makespan, 57 | num_agents, 58 | ] 59 | print_results: True 60 | # 04-puzzles-SoC: 61 | # type: plot 62 | # x: num_agents 63 | # y: SoC 64 | # width: 2.5 65 | # height: 2.5 66 | # line_width: 2 67 | # use_log_scale_x: True 68 | # legend_font_size: 8 69 | # font_size: 8 70 | # name: Puzzles $5\times5$ 71 | # 72 | # 04-puzzles-CSR: 73 | # type: plot 74 | # x: num_agents 75 | # y: CSR 76 | # width: 2.5 77 | # height: 2.5 78 | # line_width: 2 79 | # use_log_scale_x: True 80 | # legend_font_size: 8 81 | # font_size: 8 82 | # name: Puzzles $5\times5$ 83 | -------------------------------------------------------------------------------- /lacam/lacam3/src/utils.cpp: -------------------------------------------------------------------------------- 1 | #include "../include/utils.hpp" 2 | 3 | void info(const int level, const int verbose) { std::cout << std::endl; } 4 | 5 | Deadline::Deadline(double _time_limit_ms) 6 | : t_s(Time::now()), time_limit_ms(_time_limit_ms) 7 | { 8 | } 9 | 10 | double Deadline::elapsed_ms() const 11 | { 12 | return std::chrono::duration_cast(Time::now() - 13 | t_s) 14 | .count(); 15 | } 16 | 17 | double Deadline::elapsed_ns() const 18 | { 19 | return std::chrono::duration_cast(Time::now() - t_s) 20 | .count(); 21 | } 22 | 23 | double elapsed_ms(const Deadline *deadline) 24 | { 25 | if (deadline == nullptr) return 0; 26 | return deadline->elapsed_ms(); 27 | } 28 | 29 | double elapsed_ns(const Deadline *deadline) 30 | { 31 | if (deadline == nullptr) return 0; 32 | return deadline->elapsed_ns(); 33 | } 34 | 35 | bool is_expired(const Deadline *deadline) 36 | { 37 | if (deadline == nullptr) return false; 38 | return deadline->elapsed_ms() > deadline->time_limit_ms; 39 | } 40 | 41 | float get_random_float(std::mt19937 &MT, float from, float to) 42 | { 43 | return std::uniform_real_distribution(from, to)(MT); 44 | } 45 | 46 | float get_random_float(std::mt19937 *MT, float from, float to) 47 | { 48 | return get_random_float(*MT, from, to); 49 | } 50 | 51 | int get_random_int(std::mt19937 &MT, int from, int to) 52 | { 53 | return std::uniform_int_distribution(from, to)(MT); 54 | } 55 | 56 | int get_random_int(std::mt19937 *MT, int from, int to) 57 | { 58 | return get_random_int(*MT, from, to); 59 | } 60 | 61 | std::ostream &operator<<(std::ostream &os, const std::vector &arr) 62 | { 63 | for (auto ele : arr) os << ele << ","; 64 | return os; 65 | } 66 | 67 | std::ostream &operator<<(std::ostream &os, const std::set &arr) 68 | { 69 | for (auto ele : arr) os << ele << ","; 70 | return os; 71 | } 72 | -------------------------------------------------------------------------------- /ckpt_configs/02-mazes/02-mazes.yaml: -------------------------------------------------------------------------------- 1 | environment: 2 | name: Environment 3 | with_animation: False 4 | on_target: nothing 5 | max_episode_steps: 128 6 | observation_type: MAPF 7 | collision_system: soft 8 | seed: 0 9 | num_agents: 10 | grid_search: [32, 48, 64] 11 | map_name: 12 | grid_search: 13 | [ 14 | validation-mazes-seed-000, 15 | validation-mazes-seed-001, 16 | validation-mazes-seed-002, 17 | validation-mazes-seed-003, 18 | validation-mazes-seed-004, 19 | validation-mazes-seed-005, 20 | validation-mazes-seed-006, 21 | validation-mazes-seed-007, 22 | validation-mazes-seed-008, 23 | validation-mazes-seed-009, 24 | validation-mazes-seed-010, 25 | validation-mazes-seed-011, 26 | validation-mazes-seed-012, 27 | validation-mazes-seed-013, 28 | validation-mazes-seed-014, 29 | validation-mazes-seed-015, 30 | validation-mazes-seed-016, 31 | validation-mazes-seed-017, 32 | validation-mazes-seed-018, 33 | validation-mazes-seed-019, 34 | validation-mazes-seed-020, 35 | validation-mazes-seed-021, 36 | validation-mazes-seed-022, 37 | validation-mazes-seed-023, 38 | validation-mazes-seed-024, 39 | validation-mazes-seed-025, 40 | validation-mazes-seed-026, 41 | validation-mazes-seed-027, 42 | validation-mazes-seed-028, 43 | validation-mazes-seed-029, 44 | validation-mazes-seed-030, 45 | validation-mazes-seed-031 46 | ] 47 | 48 | algorithms: 49 | algorithms: 50 | MAPF-GPT-DDG-2M-1000: 51 | name: MAPF-GPT 52 | path_to_weights: out/ckpt_ddg_1000.pt 53 | 54 | results_views: 55 | TabularView1: 56 | type: tabular 57 | drop_keys: [seed, map_name] 58 | print_results: True 59 | TabularView2: 60 | type: tabular 61 | drop_keys: [seed, map_name, num_agents] 62 | print_results: True 63 | -------------------------------------------------------------------------------- /ckpt_configs/01-random/01-random.yaml: -------------------------------------------------------------------------------- 1 | environment: 2 | name: Environment 3 | with_animation: False 4 | on_target: nothing 5 | max_episode_steps: 128 6 | observation_type: MAPF 7 | collision_system: soft 8 | seed: 0 9 | num_agents: 10 | grid_search: [32, 48, 64] 11 | map_name: 12 | grid_search: 13 | [ 14 | validation-random-seed-000, 15 | validation-random-seed-001, 16 | validation-random-seed-002, 17 | validation-random-seed-003, 18 | validation-random-seed-004, 19 | validation-random-seed-005, 20 | validation-random-seed-006, 21 | validation-random-seed-007, 22 | validation-random-seed-008, 23 | validation-random-seed-009, 24 | validation-random-seed-010, 25 | validation-random-seed-011, 26 | validation-random-seed-012, 27 | validation-random-seed-013, 28 | validation-random-seed-014, 29 | validation-random-seed-015, 30 | validation-random-seed-016, 31 | validation-random-seed-017, 32 | validation-random-seed-018, 33 | validation-random-seed-019, 34 | validation-random-seed-020, 35 | validation-random-seed-021, 36 | validation-random-seed-022, 37 | validation-random-seed-023, 38 | validation-random-seed-024, 39 | validation-random-seed-025, 40 | validation-random-seed-026, 41 | validation-random-seed-027, 42 | validation-random-seed-028, 43 | validation-random-seed-029, 44 | validation-random-seed-030, 45 | validation-random-seed-031 46 | ] 47 | 48 | algorithms: 49 | MAPF-GPT-DDG-2M-1000: 50 | name: MAPF-GPT 51 | path_to_weights: out/ckpt_ddg_1000.pt 52 | 53 | results_views: 54 | TabularView1: 55 | type: tabular 56 | drop_keys: [seed, map_name] 57 | print_results: True 58 | TabularView2: 59 | type: tabular 60 | drop_keys: [seed, map_name, num_agents] 61 | print_results: True 62 | 63 | -------------------------------------------------------------------------------- /utils/wrappers.py: -------------------------------------------------------------------------------- 1 | from gymnasium import Wrapper 2 | 3 | 4 | class UnrollWrapper(Wrapper): 5 | def __init__(self, env): 6 | super().__init__(env) 7 | 8 | self._unroll_steps = None 9 | self._recorded_actions = [] 10 | self._recording_episode = None 11 | 12 | def step(self, action): 13 | if self._recording_episode: 14 | self._recorded_actions.append(action) 15 | return self.env.step(action) 16 | 17 | def get_actions_at_step(self, step): 18 | if step < 0: 19 | return [-1 for _ in range(self.env.num_agents)] 20 | elif step < len(self._recorded_actions): 21 | return self._recorded_actions[step] 22 | else: 23 | raise ValueError(f'Step {step} is out of range') 24 | 25 | def set_unroll_steps(self, num_steps): 26 | self._unroll_steps = num_steps 27 | 28 | def reset(self, seed=None, **kwargs): 29 | self._recording_episode = True if self._recording_episode is None else False 30 | if seed is None: 31 | seed = self.env.grid_config.seed 32 | obs, infos = self.env.reset(seed=seed) 33 | if self.env.grid_config.on_target == "restart": 34 | targets_xy = [o['global_lifelong_targets_xy'] for o in obs] 35 | max_episode_steps = obs[0]['max_episode_steps'] 36 | if self._unroll_steps and self._recorded_actions: 37 | for idx in range(self._unroll_steps): 38 | obs, rew, terminated, truncated, infos = self.env.step(self._recorded_actions[idx]) 39 | if self.env.grid_config.on_target == "restart": 40 | obs[0]['max_episode_steps'] = max_episode_steps 41 | for i in range(len(obs)): 42 | if obs[i]['global_target_xy'] != targets_xy[i][0]: 43 | targets_xy[i] = targets_xy[i][1:] 44 | obs[i]['global_lifelong_targets_xy'] = targets_xy[i] 45 | return obs, infos 46 | -------------------------------------------------------------------------------- /lacam/lacam3/include/utils.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * utility functions 3 | */ 4 | #pragma once 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | 26 | using Time = std::chrono::steady_clock; 27 | 28 | // time manager 29 | struct Deadline { 30 | const Time::time_point t_s; 31 | const double time_limit_ms; 32 | 33 | Deadline(double _time_limit_ms = 0); 34 | double elapsed_ms() const; 35 | double elapsed_ns() const; 36 | }; 37 | 38 | double elapsed_ms(const Deadline *deadline); 39 | double elapsed_ns(const Deadline *deadline); 40 | bool is_expired(const Deadline *deadline); 41 | 42 | float get_random_float(std::mt19937 &MT, float from = 0, float to = 1); 43 | float get_random_float(std::mt19937 *MT, float from = 0, float to = 1); 44 | int get_random_int(std::mt19937 &MT, int from = 0, int to = 1); 45 | int get_random_int(std::mt19937 *MT, int from = 0, int to = 1); 46 | 47 | template 48 | void info(const int level, const int verbose, Head &&head, Tail &&...tail); 49 | 50 | void info(const int level, const int verbose); 51 | 52 | template 53 | void info(const int level, const int verbose, Head &&head, Tail &&...tail) 54 | { 55 | if (verbose < level) return; 56 | std::cout << head; 57 | info(level, verbose, std::forward(tail)...); 58 | } 59 | 60 | template 61 | void info(const int level, const int verbose, const Deadline *deadline, 62 | Body &&...body) 63 | { 64 | if (verbose < level) return; 65 | std::cout << "elapsed:" << std::setw(6) << elapsed_ms(deadline) << "ms "; 66 | info(level, verbose, (body)...); 67 | } 68 | 69 | std::ostream &operator<<(std::ostream &os, const std::vector &arr); 70 | std::ostream &operator<<(std::ostream &os, const std::list &arr); 71 | std::ostream &operator<<(std::ostream &os, const std::set &arr); 72 | -------------------------------------------------------------------------------- /lacam/lacam3/src/hnode.cpp: -------------------------------------------------------------------------------- 1 | #include "../include/hnode.hpp" 2 | 3 | #include 4 | 5 | int HNode::COUNT = 0; 6 | 7 | HNode::HNode(Config _C, DistTable *D, HNode *_parent, int _g, int _h) 8 | : C(_C), 9 | parent(_parent), 10 | neighbor(), 11 | g(_g), 12 | h(_h), 13 | f(g + h), 14 | priorities(C.size(), 0), 15 | order(C.size(), 0), 16 | search_tree(std::queue()) 17 | { 18 | ++COUNT; 19 | 20 | search_tree.push(new LNode()); 21 | const auto N = C.size(); 22 | 23 | // update neighbor 24 | if (parent != nullptr) { 25 | neighbor.insert(parent); 26 | parent->neighbor.insert(this); 27 | } 28 | 29 | // set priorities 30 | if (parent == nullptr) { 31 | // initialize 32 | for (auto i = 0; i < N; ++i) priorities[i] = (float)D->get(i, C[i]) / 10000; 33 | } else { 34 | // dynamic priorities, akin to PIBT 35 | for (auto i = 0; i < N; ++i) { 36 | if (D->get(i, C[i]) != 0) { 37 | priorities[i] = parent->priorities[i] + 1; 38 | } else { 39 | priorities[i] = parent->priorities[i] - (int)parent->priorities[i]; 40 | } 41 | } 42 | } 43 | 44 | // set order 45 | std::iota(order.begin(), order.end(), 0); 46 | std::sort(order.begin(), order.end(), 47 | [&](int i, int j) { return priorities[i] > priorities[j]; }); 48 | } 49 | 50 | HNode::~HNode() 51 | { 52 | while (!search_tree.empty()) { 53 | delete search_tree.front(); 54 | search_tree.pop(); 55 | } 56 | } 57 | 58 | LNode *HNode::get_next_lowlevel_node(std::mt19937 &MT) 59 | { 60 | if (search_tree.empty()) return nullptr; 61 | 62 | auto L = search_tree.front(); 63 | search_tree.pop(); 64 | if (L->depth < C.size()) { 65 | auto i = order[L->depth]; 66 | auto cands = C[i]->neighbor; 67 | cands.push_back(C[i]); 68 | std::shuffle(cands.begin(), cands.end(), MT); // randomize 69 | for (auto u : cands) search_tree.push(new LNode(L, i, u)); 70 | } 71 | return L; 72 | } 73 | 74 | std::ostream &operator<<(std::ostream &os, const HNode *H) 75 | { 76 | os << "f=" << std::setw(6) << H->f << "\tg=" << std::setw(6) << H->g 77 | << "\th=" << std::setw(6) << H->h << "\tQ=" << H->C; 78 | return os; 79 | } 80 | 81 | bool CompareHNodePointers::operator()(const HNode *l, const HNode *r) const 82 | { 83 | const auto N = l->C.size(); 84 | for (auto i = 0; i < N; ++i) { 85 | if (l->C[i] != r->C[i]) return l->C[i]->id < r->C[i]->id; 86 | } 87 | return false; 88 | } 89 | -------------------------------------------------------------------------------- /gpt/aggregated_data_loader.py: -------------------------------------------------------------------------------- 1 | from loguru import logger 2 | import torch 3 | from torch.utils.data import Dataset 4 | 5 | from gpt.fast_data_loader import MapfArrowDataset 6 | 7 | 8 | class AggregatedMapfArrowDataset(Dataset): 9 | def __init__(self, folder_paths, device, batch_sizes): 10 | """ 11 | Aggregates datasets from multiple folders into a single dataset. 12 | 13 | Args: 14 | folder_paths (list): List of folder paths containing datasets. 15 | device (str): Device to load the data onto (e.g., 'cuda:0'). 16 | batch_sizes (list): List of batch sizes for each dataset. 17 | """ 18 | assert len(folder_paths) == len(batch_sizes), "Each dataset must have a corresponding batch size." 19 | 20 | self.datasets = [] 21 | self.batch_sizes = batch_sizes 22 | 23 | for folder_path, batch_size in zip(folder_paths, batch_sizes): 24 | dataset = MapfArrowDataset(folder_path, device, batch_size) 25 | self.datasets.append(dataset) 26 | 27 | self.device = device 28 | 29 | def __iter__(self): 30 | """ 31 | Creates an iterator that yields data batches from the aggregated datasets with specified batch sizes. 32 | """ 33 | dataset_iters = [iter(dataset) for dataset in self.datasets] 34 | 35 | while True: 36 | batch_inputs, batch_targets = zip(*[next(dataset_iter) for dataset_iter in dataset_iters]) 37 | yield torch.cat(batch_inputs, dim=0), torch.cat(batch_targets, dim=0) 38 | 39 | def get_full_dataset_size(self): 40 | return sum(dataset.get_full_dataset_size() for dataset in self.datasets) 41 | 42 | def get_shard_size(self): 43 | return sum(dataset.get_shard_size() for dataset in self.datasets) 44 | 45 | 46 | def main(): 47 | folder_paths = ["../dataset/train", "../dataset/validation"] 48 | # folder_paths = ["../dataset/train", "../dagger"] 49 | batch_sizes = [24, 8] # Exact batch sizes for train and validation datasets 50 | aggregated_dataset = AggregatedMapfArrowDataset(folder_paths, device='cuda:0', batch_sizes=batch_sizes) 51 | data = iter(aggregated_dataset) 52 | 53 | logger.info(aggregated_dataset.get_full_dataset_size()) 54 | logger.info(aggregated_dataset.get_shard_size()) 55 | 56 | x = 0 57 | while True: 58 | x += 1 59 | qx, qy = next(data) 60 | # logger.info(str(qx.shape) + ' ' + str(qy.shape)) 61 | 62 | 63 | if __name__ == "__main__": 64 | main() 65 | -------------------------------------------------------------------------------- /utils/data_collection.py: -------------------------------------------------------------------------------- 1 | # noinspection PyUnresolvedReferences 2 | from gpt.observation_generator import ObservationGenerator, InputParameters 3 | from pogema_toolbox.registry import ToolboxRegistry 4 | 5 | 6 | def fill_actions_with_solver(env, start_step, steps_to_collect, chosen_agents, expert_algo=None): 7 | if expert_algo is not None: 8 | expert_algo.reset_states() 9 | observations, *_ = env.reset() 10 | observation_generator = ObservationGenerator(observations[0]["global_obstacles"].copy().astype(int).tolist(), 11 | InputParameters(20, 13, 5, 256, 5, 5, 64, False)) 12 | positions = [obs["global_xy"] for obs in observations] 13 | goals = [obs["global_target_xy"] for obs in observations] 14 | observation_generator.create_agents(positions, goals) 15 | for i in range(5): 16 | observation_generator.update_agents(positions, goals, env.get_actions_at_step(start_step - 5 + i)) 17 | inputs = [] 18 | gt_actions = [] 19 | for i in range(steps_to_collect): 20 | input = observation_generator.generate_observations() 21 | if expert_algo is not None: 22 | actions = expert_algo.act(observations) 23 | if not expert_algo.solved: 24 | return None, None, {'ISR': 0.0, 'CSR': 0.0, 'ep_length': 256, 'SoC': -1, 'makespan': 256, 'runtime': 10} # placeholder metrics if expert algo is failed 25 | else: 26 | actions = env.get_actions_at_step(start_step + i) # if no expert algo => use the actions from MAPF-GPT 27 | for agent_idx in chosen_agents: 28 | inputs.append(input[agent_idx]) 29 | gt_actions.append(actions[agent_idx]) 30 | observations, rew, terminated, truncated, infos = env.step(actions) 31 | if all(terminated) or all(truncated): 32 | break 33 | 34 | positions = [obs["global_xy"] for obs in observations] 35 | goals = [obs["global_target_xy"] for obs in observations] 36 | observation_generator.update_agents(positions, goals, actions) 37 | ToolboxRegistry.debug(f'Tagged {len(inputs)} steps with expert data starting from step {start_step}') 38 | if expert_algo is not None and expert_algo.cfg.name == "LaCAM" and not (all(terminated) or all(truncated)): 39 | while True: 40 | input = observation_generator.generate_observations() 41 | actions = expert_algo.act(observations) 42 | observations, rew, terminated, truncated, infos = env.step(actions) 43 | if all(terminated) or all(truncated): 44 | break 45 | return inputs, gt_actions, infos[0]['metrics'] -------------------------------------------------------------------------------- /lacam/lacam3/src/refiner.cpp: -------------------------------------------------------------------------------- 1 | #include "../include/refiner.hpp" 2 | 3 | Solution refine(const Instance *ins, const Deadline *deadline, 4 | const Solution &solution, DistTable *D, const int seed, 5 | const int verbose) 6 | { 7 | if (solution.empty()) return Solution(); 8 | info(0, verbose, deadline, "refiner-", seed, "\tactivated"); 9 | // setup 10 | const auto N = ins->N; 11 | auto MT = std::mt19937(seed); 12 | auto paths = translateConfigsToPaths(solution); 13 | auto cost_before = get_sum_of_loss_paths(paths); 14 | std::vector order(N, 0); 15 | std::iota(order.begin(), order.end(), 0); 16 | auto CT = CollisionTable(ins); 17 | for (auto i = 0; i < N; ++i) CT.enrollPath(i, paths[i]); 18 | std::shuffle(order.begin(), order.end(), MT); 19 | 20 | const auto num_refine_agents = 21 | std::max(1, std::min(get_random_int(MT, 1, 30), int(N / 4))); 22 | info(1, verbose, deadline, "refiner-", seed, 23 | "\tsize of modif set: ", num_refine_agents); 24 | for (auto k = 0; (k + 1) * num_refine_agents < N; ++k) { 25 | if (is_expired(deadline)) return Solution(); 26 | 27 | auto old_cost = 0; 28 | auto new_cost = 0; 29 | 30 | // compute old cost 31 | for (auto _i = 0; _i < num_refine_agents; ++_i) { 32 | const auto i = order[k * num_refine_agents + _i]; 33 | old_cost += get_path_loss(paths[i]); 34 | CT.clearPath(i, paths[i]); 35 | } 36 | 37 | // re-planning 38 | Paths new_paths(num_refine_agents); 39 | for (auto _i = 0; _i < num_refine_agents; ++_i) { 40 | const auto i = order[k * num_refine_agents + _i]; 41 | // note: I also tested A*, but SIPP was better 42 | new_paths[_i] = sipp(i, ins->starts[i], ins->goals[i], D, &CT, deadline, 43 | old_cost - new_cost - 1); 44 | if (new_paths[_i].empty()) break; // failure 45 | new_cost += get_path_loss(new_paths[_i]); 46 | CT.enrollPath(i, new_paths[_i]); 47 | } 48 | 49 | if (!new_paths[num_refine_agents - 1].empty() && new_cost <= old_cost) { 50 | // success 51 | for (auto _i = 0; _i < num_refine_agents; ++_i) { 52 | const auto i = order[k * num_refine_agents + _i]; 53 | paths[i] = new_paths[_i]; 54 | } 55 | } else { 56 | // failure 57 | for (auto _i = 0; _i < num_refine_agents; ++_i) { 58 | const auto i = order[k * num_refine_agents + _i]; 59 | if (!new_paths[_i].empty()) CT.clearPath(i, new_paths[_i]); 60 | CT.enrollPath(i, paths[i]); 61 | } 62 | } 63 | } 64 | 65 | info(0, verbose, deadline, "refiner-", seed, "\tsum_of_loss: ", cost_before, 66 | " -> ", get_sum_of_loss_paths(paths)); 67 | 68 | return translatePathsToConfigs(paths); 69 | } 70 | -------------------------------------------------------------------------------- /dagger.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pogema_toolbox.registry import ToolboxRegistry 3 | import os 4 | import subprocess 5 | from concurrent.futures import ThreadPoolExecutor 6 | import glob 7 | 8 | def run_dagger(dagger_type, num_workers, device_id, seed, file_size, run_infinite=False): 9 | def run_worker(worker_id): 10 | if 'warehouse' in dagger_type: 11 | subprocess.run(f"python worker.py --worker_id {worker_id} --device_id {device_id} --map_seed {0} --scenario_seed {seed + worker_id * 1000} --dataset_path {path_to_dataset} --dagger_type {dagger_type} --path_to_weights {path_to_weights} --num_agents {','.join(map(str, [32, 64, 96, 128, 160, 192]))} --file_size {file_size}", shell=True) 12 | else: 13 | subprocess.run(f"python worker.py --worker_id {worker_id} --device_id {device_id} --map_seed {seed + worker_id * 1000} --scenario_seed {0} --dataset_path {path_to_dataset} --dagger_type {dagger_type} --path_to_weights {path_to_weights} --file_size {file_size}", shell=True) 14 | 15 | path_to_weights = f"out/ckpt_{dagger_type}.pt" 16 | path_to_dataset = f"dataset/{dagger_type}" 17 | ToolboxRegistry.setup_logger('INFO') 18 | while True: 19 | checkpoint_files = glob.glob(f"out/ckpt_{dagger_type}_*0.pt") 20 | if dagger_type == 'dagger' or dagger_type == 'ddg': 21 | checkpoint_files = [f for f in checkpoint_files if 'warehouse' not in f] 22 | if checkpoint_files: 23 | path_to_weights = max(checkpoint_files, key=lambda x: int(x.split('_')[-1].split('.')[0])) 24 | with ThreadPoolExecutor(max_workers=num_workers) as executor: 25 | executor.map(run_worker, range(num_workers)) 26 | seed += 1000*num_workers 27 | if not run_infinite: 28 | break 29 | 30 | def main(): 31 | parser = argparse.ArgumentParser(description='Dagger') 32 | parser.add_argument('--num_workers', type=int, default=8, help='Number of workers (default: %(default)d)') 33 | parser.add_argument('--dagger_type', type=str, choices=['ddg', 'ddg_warehouse', 'dagger', 'dagger_warehouse'], default='ddg', help='Dagger type (default: %(default)s)') 34 | parser.add_argument('--device_id', type=int, default=0, help='Device ID (default: %(default)d)') 35 | parser.add_argument('--seed', type=int, default=0, help='Seed (default: %(default)d)') 36 | parser.add_argument('--file_size', type=int, default=50 * 2 ** 11, help='File size (default: %(default)d)') 37 | parser.add_argument('--run_infinite', action='store_true', help='Run infinite (default: %(default)d)') 38 | args = parser.parse_args() 39 | run_dagger(args.dagger_type, args.num_workers, args.device_id, args.seed, args.file_size, run_infinite=args.run_infinite) 40 | 41 | if __name__ == '__main__': 42 | main() 43 | -------------------------------------------------------------------------------- /lacam/lacam3/src/collision_table.cpp: -------------------------------------------------------------------------------- 1 | #include "../include/collision_table.hpp" 2 | 3 | CollisionTable::CollisionTable(const Instance *ins) 4 | : body(ins->G->size()), 5 | body_last(ins->G->size()), 6 | collision_cnt(0), 7 | N(ins->N) 8 | { 9 | } 10 | 11 | CollisionTable::~CollisionTable() {} 12 | 13 | int CollisionTable::getCollisionCost(const Vertex *v_from, const Vertex *v_to, 14 | const int t_from) 15 | { 16 | const int t_to = t_from + 1; 17 | auto collision = 0; 18 | // vertex collision 19 | if (t_to < body[v_to->id].size()) { 20 | collision += body[v_to->id][t_to].size(); 21 | } 22 | // edge collision 23 | if (t_to < body[v_from->id].size() && t_from < body[v_to->id].size()) { 24 | for (auto j : body[v_from->id][t_to]) { 25 | for (auto k : body[v_to->id][t_from]) { 26 | if (j == k) ++collision; 27 | } 28 | } 29 | } 30 | // goal collision 31 | for (auto last_timestep : body_last[v_to->id]) { 32 | if (t_to > last_timestep) ++collision; 33 | } 34 | return collision; 35 | } 36 | 37 | void CollisionTable::enrollPath(const int i, Path &path) 38 | { 39 | if (path.empty()) return; 40 | const auto T_i = path.size() - 1; 41 | for (auto t = 0; t <= T_i; ++t) { 42 | auto v = path[t]; 43 | 44 | // update collision count 45 | if (t > 0) collision_cnt += getCollisionCost(path[t - 1], path[t], t - 1); 46 | 47 | // register 48 | while (body[v->id].size() <= t) body[v->id].emplace_back(); 49 | body[v->id][t].push_back(i); 50 | } 51 | 52 | // goal 53 | body_last[path.back()->id].push_back(T_i); 54 | auto &&entry = body[path.back()->id]; 55 | for (auto t = T_i + 1; t < entry.size(); ++t) { 56 | collision_cnt += entry[t].size(); 57 | } 58 | } 59 | 60 | void CollisionTable::clearPath(const int i, Path &path) 61 | { 62 | if (path.empty()) return; 63 | const auto T_i = (int)path.size() - 1; 64 | for (auto t = 0; t <= T_i; ++t) { 65 | auto v = path[t]; 66 | auto &&entry = body[v->id][t]; 67 | 68 | // remove entry 69 | for (auto itr = entry.begin(); itr != entry.end();) { 70 | if (*itr == i) { 71 | entry.erase(itr); 72 | break; 73 | } else { 74 | ++itr; 75 | } 76 | } 77 | 78 | // update collision count 79 | if (t > 0) collision_cnt -= getCollisionCost(path[t - 1], path[t], t - 1); 80 | } 81 | 82 | // goal 83 | auto &&entry_body_last = body_last[path.back()->id]; 84 | for (auto itr = entry_body_last.begin(); itr != entry_body_last.end();) { 85 | if (*itr == T_i) { 86 | entry_body_last.erase(itr); 87 | break; 88 | } else { 89 | ++itr; 90 | } 91 | } 92 | auto &&entry_body = body[path.back()->id]; 93 | for (auto t = T_i + 1; t < entry_body.size(); ++t) { 94 | collision_cnt -= entry_body[t].size(); 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import torch 5 | import yaml 6 | from pogema_toolbox.create_env import Environment 7 | from pogema_toolbox.run_episode import run_episode 8 | from pogema_toolbox.registry import ToolboxRegistry 9 | 10 | from create_env import create_eval_env 11 | from gpt.inference import MAPFGPTInference, MAPFGPTInferenceConfig 12 | 13 | 14 | def main(): 15 | parser = argparse.ArgumentParser(description='MAPF-GPT Inference Script') 16 | parser.add_argument('--animation', action='store_false', help='Enable animation (default: %(default)s)') 17 | parser.add_argument('--num_agents', type=int, default=32, help='Number of agents (default: %(default)d)') 18 | parser.add_argument('--seed', type=int, default=0, help='Random seed (default: %(default)d)') 19 | parser.add_argument('--map_name', type=str, default='validation-random-seed-001', help='Map name (default: %(default)s)') 20 | parser.add_argument('--device', type=str, default='cuda', help='Device to use: cuda, cpu, mps (default: %(default)s)') 21 | parser.add_argument('--max_episode_steps', type=int, default=128, 22 | help='Maximum episode steps (default: %(default)d)') 23 | parser.add_argument('--show_map_names', action='store_true', help='Shows names of all available maps') 24 | 25 | parser.add_argument('--model', type=str, choices=['2M', '6M', '85M', '2M-DDG'], default='2M-DDG', 26 | help='Model to use: 2M, 6M, 85M, 2M-DDG (default: %(default)s)') 27 | 28 | # loading maps from eval folders 29 | for maps_file in Path("eval_configs").rglob('maps.yaml'): 30 | with open(maps_file, 'r') as f: 31 | maps = yaml.safe_load(f) 32 | ToolboxRegistry.register_maps(maps) 33 | 34 | args = parser.parse_args() 35 | 36 | if args.show_map_names: 37 | for map_ in ToolboxRegistry.get_maps(): 38 | print(map_) 39 | return 40 | 41 | env_cfg = Environment( 42 | with_animation=args.animation, 43 | observation_type="MAPF", 44 | on_target="nothing", 45 | map_name=args.map_name, 46 | max_episode_steps=args.max_episode_steps, 47 | num_agents=args.num_agents, 48 | seed=args.seed, 49 | obs_radius=5, 50 | collision_system="soft", 51 | ) 52 | 53 | # pytorch seeding 54 | torch_seed = 42 55 | torch.manual_seed(torch_seed) 56 | if torch.cuda.is_available(): 57 | torch.cuda.manual_seed(torch_seed) 58 | torch.backends.mps.is_available() 59 | torch.backends.cudnn.deterministic = True 60 | 61 | env = create_eval_env(env_cfg) 62 | algo = MAPFGPTInference(MAPFGPTInferenceConfig(path_to_weights=f'hf_weights/model-{args.model}.pt', device=args.device)) 63 | algo.reset_states() 64 | results = run_episode(env, algo) 65 | 66 | svg_path = f"svg/{args.map_name}-{args.model}-seed-{args.seed}.svg" 67 | env.save_animation(svg_path) 68 | ToolboxRegistry.info(f'Saved animation to: {svg_path}') 69 | 70 | ToolboxRegistry.success(results) 71 | 72 | 73 | if __name__ == "__main__": 74 | main() -------------------------------------------------------------------------------- /lacam/lacam3/src/metrics.cpp: -------------------------------------------------------------------------------- 1 | #include "../include/metrics.hpp" 2 | 3 | int get_makespan(const Solution &solution) 4 | { 5 | if (solution.empty()) return 0; 6 | return solution.size() - 1; 7 | } 8 | 9 | int get_makespan_paths(const std::vector &solution) 10 | { 11 | auto c = 0; 12 | for (auto &&path : solution) { 13 | c = std::max(c, (int)path.size() - 1); 14 | } 15 | return c; 16 | } 17 | 18 | int get_path_cost(const Solution &solution, int i) 19 | { 20 | const auto makespan = solution.size(); 21 | const auto g = solution.back()[i]; 22 | auto c = makespan; 23 | while (c > 0 && solution[c - 1][i] == g) --c; 24 | return c; 25 | } 26 | 27 | int get_path_cost(const Path &path) 28 | { 29 | const auto g = path.back(); 30 | auto c = path.size(); 31 | while (c > 0 && path[c - 1] == g) --c; 32 | return c; 33 | } 34 | 35 | int get_sum_of_costs(const Solution &solution) 36 | { 37 | if (solution.empty()) return 0; 38 | int c = 0; 39 | const auto N = solution.front().size(); 40 | for (size_t i = 0; i < N; ++i) c += get_path_cost(solution, i); 41 | return c; 42 | } 43 | 44 | int get_sum_of_costs_paths(const std::vector &solution) 45 | { 46 | int c = 0; 47 | for (auto &&path : solution) c += get_path_cost(path); 48 | return c; 49 | } 50 | 51 | int get_path_loss(const Path &path) 52 | { 53 | const auto g = path.back(); 54 | const auto T = path.size(); 55 | auto c = 0; 56 | for (size_t t = 1; t < T; ++t) { 57 | if (path[t - 1] != g || path[t] != g) ++c; 58 | } 59 | return c; 60 | } 61 | 62 | int get_sum_of_loss(const Solution &solution, std::vector &agents_subset) 63 | { 64 | if (solution.empty()) return 0; 65 | int c = 0; 66 | const auto T = solution.size(); 67 | for (const auto i : agents_subset) { 68 | auto g = solution.back()[i]; 69 | for (size_t t = 1; t < T; ++t) { 70 | if (solution[t - 1][i] != g || solution[t][i] != g) ++c; 71 | } 72 | } 73 | return c; 74 | } 75 | 76 | int get_sum_of_loss(const Solution &solution) 77 | { 78 | if (solution.empty()) return 0; 79 | int c = 0; 80 | const auto N = solution.front().size(); 81 | const auto T = solution.size(); 82 | for (size_t i = 0; i < N; ++i) { 83 | auto g = solution.back()[i]; 84 | for (size_t t = 1; t < T; ++t) { 85 | if (solution[t - 1][i] != g || solution[t][i] != g) ++c; 86 | } 87 | } 88 | return c; 89 | } 90 | 91 | int get_sum_of_loss_paths(const std::vector &solution) 92 | { 93 | auto c = 0; 94 | for (auto &&path : solution) c += get_path_loss(path); 95 | return c; 96 | } 97 | 98 | int get_makespan_lower_bound(const Instance &ins, DistTable &dist_table) 99 | { 100 | int c = 0; 101 | for (size_t i = 0; i < ins.N; ++i) { 102 | c = std::max(c, dist_table.get(i, ins.starts[i])); 103 | } 104 | return c; 105 | } 106 | 107 | int get_sum_of_costs_lower_bound(const Instance &ins, DistTable &dist_table) 108 | { 109 | int c = 0; 110 | for (size_t i = 0; i < ins.N; ++i) { 111 | c += dist_table.get(i, ins.starts[i]); 112 | } 113 | return c; 114 | } 115 | -------------------------------------------------------------------------------- /million_agents_run.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import random 3 | import time 4 | from tqdm.auto import tqdm 5 | 6 | from pogema_toolbox.create_env import Environment 7 | 8 | from create_env import create_eval_env 9 | from gpt.inference import MAPFGPTInference, MAPFGPTInferenceConfig 10 | from pogema_toolbox.results_holder import ResultsHolder 11 | 12 | def run_episode(env, algo): 13 | """ 14 | Runs an episode in the environment using the given algorithm. 15 | 16 | Args: 17 | env: The environment to run the episode in. 18 | algo: The algorithm used for action selection. 19 | 20 | Returns: 21 | ResultsHolder: Object containing the results of the episode. 22 | """ 23 | algo.reset_states() 24 | results_holder = ResultsHolder() 25 | 26 | obs, _ = env.reset() 27 | for _ in tqdm(range(env.grid.config.max_episode_steps), desc="Running episode"): 28 | obs, rew, terminated, truncated, infos = env.step(algo.act(obs)) 29 | results_holder.after_step(infos) 30 | 31 | if all(terminated) or all(truncated): 32 | break 33 | return results_holder.get_final() 34 | 35 | def main(): 36 | random.seed(42) 37 | start_xy = set() 38 | size = 2048 39 | while len(start_xy) < 2**20: 40 | start_xy.add((random.randint(0, size-1), random.randint(0, size-1))) 41 | start_xy = list(start_xy) 42 | 43 | delta = 64 # max distance between start and goal 44 | goal_xy = [] 45 | used_goals = set() # Track already used goal positions 46 | 47 | for sx, sy in start_xy: 48 | while True: 49 | dx = random.randint(-delta, delta) 50 | dy = random.randint(-(delta-abs(dx)), delta-abs(dx)) 51 | 52 | gx, gy = sx + dx, sy + dy 53 | if 0 <= gx < size and 0 <= gy < size and (gx, gy) not in used_goals: 54 | goal_xy.append([gx, gy]) 55 | used_goals.add((gx, gy)) # Mark this goal as used 56 | break 57 | 58 | for s in [32,64,128,256,512,1024,2048,4096,8192,16384,32768,65536,131072,262144,524288,1048576]: 59 | start_time = time.time() 60 | env_cfg = Environment( 61 | with_animation=False, 62 | observation_type="MAPF", 63 | on_target="nothing", 64 | size=size, 65 | density=0, 66 | agents_xy=start_xy[:s], 67 | targets_xy=goal_xy[:s], 68 | max_episode_steps=512, 69 | seed=0, 70 | obs_radius=5, 71 | collision_system="soft", 72 | ) 73 | 74 | env = create_eval_env(env_cfg) 75 | create_time = time.time() 76 | algo = MAPFGPTInference(MAPFGPTInferenceConfig(path_to_weights=f'hf_weights/model-2M-DDG.pt', device='cuda')) 77 | algo.reset_states() 78 | results = run_episode(env, algo) 79 | end_time = time.time() 80 | env_time = end_time - create_time - results['runtime'] 81 | total_time = end_time - start_time 82 | results['env_time'] = env_time 83 | results['total_time'] = total_time 84 | print(s, results) #, "Time to create env: ", create_time-start_time, "Time to run env: ", env_time, "Total time: ", total_time) 85 | 86 | 87 | 88 | if __name__ == "__main__": 89 | main() 90 | -------------------------------------------------------------------------------- /finetuning/data_aggregation_generator.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from multiprocessing import Pool 3 | 4 | from gpt.inference import MAPFGPTInference, MAPFGPTInferenceConfig 5 | from lacam.inference import LacamInference, LacamInferenceConfig 6 | #from rhcr_cpp.rhcr import RHCRInference, RHCRConfig 7 | from pogema_toolbox.run_episode import run_episode 8 | from pogema_toolbox.registry import ToolboxRegistry 9 | from pydantic import BaseModel 10 | 11 | from finetuning.filter_data import filter_data 12 | 13 | from utils.data_collection import fill_actions_with_solver 14 | from finetuning.scenario_generators import make_pogema_maze_instance 15 | 16 | from utils.wrappers import UnrollWrapper 17 | from pogema.wrappers.metrics import RuntimeMetricWrapper 18 | 19 | 20 | class DataAggregationConfig(BaseModel): 21 | steps_delta: int = 8 22 | steps_saved: int = 8 23 | 24 | def run_solver(env, unroll_steps, steps_saved): 25 | env = deepcopy(env) 26 | solver = LacamInference(LacamInferenceConfig(time_limit=10, timeouts=[10])) 27 | env.set_unroll_steps(unroll_steps) 28 | chosen_agents = list(range(env.grid_config.num_agents)) 29 | ToolboxRegistry.debug(f'Collecting data from step {unroll_steps} to {unroll_steps + steps_saved}') 30 | input, gt_action, metrics = fill_actions_with_solver(env, unroll_steps, steps_saved, chosen_agents, solver) 31 | return input, gt_action, metrics 32 | 33 | def data_aggregation(env, learnable_algo, cfg: DataAggregationConfig): 34 | env = UnrollWrapper(env) 35 | env = RuntimeMetricWrapper(env) 36 | 37 | inputs = [] 38 | gt_actions = [] 39 | gpt_results = run_episode(env, learnable_algo) 40 | logs = {'map_name': env.grid.config.map_name, 'gpt_results': gpt_results, 'expert_results': []} 41 | episode_length = gpt_results['ep_length'] 42 | 43 | with Pool(processes=8) as pool: # limit the number of workers to avoid overloading the CPU 44 | unroll_steps_list = range(0, episode_length, cfg.steps_delta) 45 | results = pool.starmap(run_solver, [(env, unroll_steps, cfg.steps_saved, cfg.on_target) for unroll_steps in unroll_steps_list]) 46 | 47 | for input, gt_action, metrics in results: 48 | if input is not None: 49 | inputs.extend(input) 50 | gt_actions.extend(gt_action) 51 | if metrics is not None: 52 | logs['expert_results'].append(metrics) 53 | return filter_data(inputs, gt_actions), logs 54 | 55 | 56 | def main(): 57 | ToolboxRegistry.setup_logger('DEBUG') 58 | 59 | learnable_algo = MAPFGPTInference(MAPFGPTInferenceConfig(device='cuda', path_to_weights='../weights/model-2M.pt')) 60 | slow_time_limit = 10 61 | lacam_lib_path = "../lacam/liblacam.so" 62 | solver = LacamInference( 63 | LacamInferenceConfig(time_limit=slow_time_limit, timeouts=[slow_time_limit], lacam_lib_path=lacam_lib_path)) 64 | 65 | env = make_pogema_maze_instance(num_agents=32, 66 | max_episode_steps=128, 67 | map_seed=45, 68 | scenario_seed=45) 69 | 70 | data_aggregation(env=env, learnable_algo=learnable_algo, solver=solver, 71 | cfg=DataAggregationConfig()) 72 | 73 | 74 | if __name__ == '__main__': 75 | main() 76 | -------------------------------------------------------------------------------- /eval_configs/03-warehouse/03-warehouse.yaml: -------------------------------------------------------------------------------- 1 | environment: 2 | name: Environment 3 | with_animation: False 4 | on_target: nothing 5 | max_episode_steps: 128 6 | observation_type: MAPF 7 | collision_system: soft 8 | seed: 9 | grid_search: 10 | [ 11 | 0, 12 | 1, 13 | 2, 14 | 3, 15 | 4, 16 | 5, 17 | 6, 18 | 7, 19 | 8, 20 | 9, 21 | 10, 22 | 11, 23 | 12, 24 | 13, 25 | 14, 26 | 15, 27 | 16, 28 | 17, 29 | 18, 30 | 19, 31 | 20, 32 | 21, 33 | 22, 34 | 23, 35 | 24, 36 | 25, 37 | 26, 38 | 27, 39 | 28, 40 | 29, 41 | 30, 42 | 31, 43 | 32, 44 | 33, 45 | 34, 46 | 35, 47 | 36, 48 | 37, 49 | 38, 50 | 39, 51 | 40, 52 | 41, 53 | 42, 54 | 43, 55 | 44, 56 | 45, 57 | 46, 58 | 47, 59 | 48, 60 | 49, 61 | 50, 62 | 51, 63 | 52, 64 | 53, 65 | 54, 66 | 55, 67 | 56, 68 | 57, 69 | 58, 70 | 59, 71 | 60, 72 | 61, 73 | 62, 74 | 63, 75 | 64, 76 | 65, 77 | 66, 78 | 67, 79 | 68, 80 | 69, 81 | 70, 82 | 71, 83 | 72, 84 | 73, 85 | 74, 86 | 75, 87 | 76, 88 | 77, 89 | 78, 90 | 79, 91 | 80, 92 | 81, 93 | 82, 94 | 83, 95 | 84, 96 | 85, 97 | 86, 98 | 87, 99 | 88, 100 | 89, 101 | 90, 102 | 91, 103 | 92, 104 | 93, 105 | 94, 106 | 95, 107 | 96, 108 | 97, 109 | 98, 110 | 99, 111 | 100, 112 | 101, 113 | 102, 114 | 103, 115 | 104, 116 | 105, 117 | 106, 118 | 107, 119 | 108, 120 | 109, 121 | 110, 122 | 111, 123 | 112, 124 | 113, 125 | 114, 126 | 115, 127 | 116, 128 | 117, 129 | 118, 130 | 119, 131 | 120, 132 | 121, 133 | 122, 134 | 123, 135 | 124, 136 | 125, 137 | 126, 138 | 127, 139 | ] 140 | num_agents: 141 | grid_search: [32, 64, 96, 128, 160, 192] 142 | map_name: wfi_warehouse 143 | 144 | algorithms: 145 | MAPF-GPT-2M: 146 | name: MAPF-GPT 147 | parallel_backend: balanced_dask 148 | num_process: 4 149 | path_to_weights: weights/model-2M.pt 150 | MAPF-GPT-6M: 151 | name: MAPF-GPT 152 | parallel_backend: balanced_dask 153 | num_process: 4 154 | path_to_weights: weights/model-6M.pt 155 | 156 | results_views: 157 | TabularView1: 158 | type: tabular 159 | drop_keys: [seed] 160 | print_results: True 161 | round_digits: 3 162 | 03-warehouse-runtime: 163 | name: Warehouse 164 | type: plot 165 | x: num_agents 166 | y: runtime 167 | width: 2.5 168 | height: 2.5 169 | line_width: 2 170 | use_log_scale_x: False 171 | use_log_scale_y: True 172 | legend_font_size: 8 173 | font_size: 8 174 | ticks: [32, 64, 96, 128, 160, 192] 175 | 176 | TabularCongestion: 177 | type: tabular 178 | drop_keys: [seed, map_name, ISR, CSR, ep_length, SoC, makespan, runtime] 179 | print_results: True 180 | round_digits: 3 181 | -------------------------------------------------------------------------------- /lacam/lacam3/src/instance.cpp: -------------------------------------------------------------------------------- 1 | #include "../include/instance.hpp" 2 | #include 3 | 4 | Instance::~Instance() 5 | { 6 | if (delete_graph_after_used) delete G; 7 | } 8 | 9 | Instance::Instance(Graph *_G, const Config &_starts, const Config &_goals, 10 | uint _N) 11 | : G(_G), starts(_starts), goals(_goals), N(_N) 12 | { 13 | } 14 | 15 | Instance::Instance(const std::string &map_content, 16 | const std::vector &start_indexes, 17 | const std::vector &goal_indexes) 18 | : G(new Graph(map_content)), 19 | starts(Config()), 20 | goals(Config()), 21 | N(start_indexes.size()), 22 | delete_graph_after_used(true) 23 | { 24 | for (auto k : start_indexes) starts.push_back(G->U[k]); 25 | for (auto k : goal_indexes) goals.push_back(G->U[k]); 26 | } 27 | 28 | // for load instance 29 | static const std::regex r_instance = 30 | std::regex(R"(\d+\t.+\.map\t\d+\t\d+\t(\d+)\t(\d+)\t(\d+)\t(\d+)\t.+)"); 31 | 32 | Instance::Instance(const std::string &scen_content, const std::string &map_content, const int _N) 33 | : G(new Graph(map_content)), 34 | starts(Config()), 35 | goals(Config()), 36 | N(_N), 37 | delete_graph_after_used(true) 38 | { 39 | // load start-goal pairs 40 | std::istringstream scen_stream(scen_content); 41 | std::string line; 42 | std::smatch results; 43 | 44 | while (getline(scen_stream, line)) { 45 | // for CRLF coding 46 | if (!line.empty() && line.back() == '\r') { 47 | line.pop_back(); 48 | } 49 | 50 | if (std::regex_match(line, results, r_instance)) { 51 | auto x_s = std::stoi(results[1].str()); 52 | auto y_s = std::stoi(results[2].str()); 53 | auto x_g = std::stoi(results[3].str()); 54 | auto y_g = std::stoi(results[4].str()); 55 | if (x_s < 0 || G->width <= x_s || x_g < 0 || G->width <= x_g) continue; 56 | if (y_s < 0 || G->height <= y_s || y_g < 0 || G->height <= y_g) continue; 57 | auto s = G->U[G->width * y_s + x_s]; 58 | auto g = G->U[G->width * y_g + x_g]; 59 | if (s == nullptr || g == nullptr) continue; 60 | starts.push_back(s); 61 | goals.push_back(g); 62 | } 63 | 64 | if (starts.size() == N) break; 65 | } 66 | } 67 | 68 | Instance::Instance(const std::string &map_content, const int _N, 69 | const int seed) 70 | : G(new Graph(map_content)), 71 | starts(Config()), 72 | goals(Config()), 73 | N(_N), 74 | delete_graph_after_used(true) 75 | { 76 | auto MT = std::mt19937(seed); 77 | // random assignment 78 | const auto K = G->size(); 79 | 80 | // set starts 81 | auto s_indexes = std::vector(K); 82 | std::iota(s_indexes.begin(), s_indexes.end(), 0); 83 | std::shuffle(s_indexes.begin(), s_indexes.end(), MT); 84 | int i = 0; 85 | while (true) { 86 | if (i >= K) return; 87 | starts.push_back(G->V[s_indexes[i]]); 88 | if (starts.size() == N) break; 89 | ++i; 90 | } 91 | 92 | // set goals 93 | auto g_indexes = std::vector(K); 94 | std::iota(g_indexes.begin(), g_indexes.end(), 0); 95 | std::shuffle(g_indexes.begin(), g_indexes.end(), MT); 96 | int j = 0; 97 | while (true) { 98 | if (j >= K) return; 99 | goals.push_back(G->V[g_indexes[j]]); 100 | if (goals.size() == N) break; 101 | ++j; 102 | } 103 | } 104 | 105 | bool Instance::is_valid(const int verbose) const 106 | { 107 | if (N != starts.size() || N != goals.size()) { 108 | info(1, verbose, "invalid N, check instance"); 109 | return false; 110 | } 111 | return true; 112 | } 113 | -------------------------------------------------------------------------------- /lacam/lacam3/include/planner.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Implementation of LaCAM* 3 | * 4 | * references: 5 | * LaCAM: Search-Based Algorithm for Quick Multi-Agent Pathfinding. 6 | * Keisuke Okumura. 7 | * Proc. AAAI Conf. on Artificial Intelligence (AAAI). 2023. 8 | * 9 | * Improving LaCAM for Scalable Eventually Optimal Multi-Agent Pathfinding. 10 | * Keisuke Okumura. 11 | * Proc. Int. Joint Conf. on Artificial Intelligence (IJCAI). 2023. 12 | * 13 | * Engineering LaCAM*: Towards Real-Time, Large-Scale, and Near-Optimal 14 | * Multi-Agent Pathfinding. Keisuke Okumura. Proc. Int. Conf. on Autonomous 15 | * Agents and Multiagent Systems. 2024. 16 | */ 17 | #pragma once 18 | 19 | #include "dist_table.hpp" 20 | #include "graph.hpp" 21 | #include "heuristic.hpp" 22 | #include "hnode.hpp" 23 | #include "instance.hpp" 24 | #include "pibt.hpp" 25 | #include "refiner.hpp" 26 | #include "scatter.hpp" 27 | #include "translator.hpp" 28 | #include "utils.hpp" 29 | 30 | struct Planner { 31 | const Instance *ins; 32 | const Deadline *deadline; 33 | const int seed; 34 | std::mt19937 MT; 35 | const int verbose; 36 | const int depth; 37 | 38 | // solver utils 39 | const int N; // number of agents 40 | const int V_size; 41 | DistTable *D; 42 | bool delete_dist_table_after_used; 43 | 44 | // heuristic 45 | Heuristic *heuristic; 46 | 47 | // scatter (SUO) 48 | Scatter *scatter; 49 | 50 | // configuration generator 51 | std::vector pibts; 52 | 53 | // for refiner 54 | int seed_refiner; 55 | std::list> refiner_pool; 56 | 57 | // for search utils 58 | std::deque OPEN; 59 | std::unordered_map EXPLORED; 60 | HNode *H_init; // start node 61 | HNode *H_goal; // goal node 62 | 63 | // parameters 64 | static bool FLG_SWAP; // whether to use swap technique in PIBT 65 | static bool 66 | FLG_STAR; // whether to refine solutions after initial solution discovery 67 | static bool FLG_MULTI_THREAD; 68 | static int SCATTER_MARGIN; // used in SUO 69 | static int PIBT_NUM; // number of PIBT run, i.e., Monte-Carlo configuration 70 | // generator 71 | static bool FLG_REFINER; // whether to use refiners 72 | static int REFINER_NUM; // number of refiners 73 | static bool 74 | FLG_SCATTER; // whether to use space utilization optimization (SUO) 75 | static float RANDOM_INSERT_PROB1; // inserting the start node 76 | static float RANDOM_INSERT_PROB2; // inserting a node after finding the goal 77 | static bool FLG_RANDOM_INSERT_INIT_NODE; 78 | static float RECURSIVE_RATE; 79 | static double RECURSIVE_TIME_LIMIT; 80 | 81 | // for logging 82 | static int CHECKPOINTS_DURATION; 83 | static std::string MSG; 84 | 85 | int search_iter; 86 | int time_initial_solution; 87 | int cost_initial_solution; 88 | std::vector checkpoints; 89 | 90 | Planner(const Instance *_ins, int _verbose = 0, 91 | const Deadline *_deadline = nullptr, int _seed = 0, 92 | int _depth = 0, // used in recursive LaCAM 93 | DistTable *_D = nullptr // used in recursive LaCAM 94 | ); 95 | ~Planner(); 96 | Solution solve(); 97 | bool set_new_config(HNode *S, LNode *M, Config &Q_to); 98 | HNode *create_highlevel_node(const Config &Q, HNode *parent); 99 | void rewrite(HNode *H_from, HNode *H_to); 100 | int get_edge_cost(const Config &C1, const Config &C2); 101 | Solution backtrack(HNode *H); 102 | void apply_new_solution(const Solution &plan); 103 | void set_scatter(); 104 | void set_pibt(); 105 | void set_refiner(); 106 | Solution get_refined_plan(const Solution &plan_origin); 107 | void update_checkpoints(); 108 | void logging(); 109 | }; 110 | -------------------------------------------------------------------------------- /macro_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gymnasium import Wrapper 3 | 4 | class MAPFGPTObservationWrapper(Wrapper): 5 | def __init__(self, env, observation_generator): 6 | super().__init__(env) 7 | self.observation_generator = observation_generator 8 | 9 | 10 | def reset(self): 11 | observations, infos = self.env.reset() 12 | self.observation_generator.create_agents([o["global_xy"] for o in observations], [o["global_target_xy"] for o in observations]) 13 | return self.observation_generator.generate_observations(), infos 14 | 15 | def step(self, actions): 16 | desired_actions = actions.copy() 17 | observations, rewards, terminated, truncated, infos = self.env.step(actions) 18 | self.observation_generator.update_agents([o["global_xy"] for o in observations], [o["global_target_xy"] for o in observations], desired_actions) 19 | observations = self.observation_generator.generate_observations() 20 | return observations, rewards, terminated, truncated, infos 21 | 22 | def get_inner_env(self): 23 | return self.env 24 | 25 | class PogemaMacroEnvironment: 26 | def __init__(self, environments): 27 | self.environments = environments 28 | self.num_agents_per_env = None 29 | self.active_status = [True] * len(environments) # Tracks which environments are still active 30 | self.last_observations = [None] * len(environments) # Stores the last observations for inactive environments 31 | self.metrics_info = [{} for _ in environments] 32 | 33 | 34 | def step(self, actions): 35 | observations, rewards, terminated, truncated, infos = [], [], [], [], [] 36 | 37 | start_idx = 0 38 | for i, (env, num_agents) in enumerate(zip(self.environments, self.num_agents_per_env)): 39 | if self.active_status[i]: # Process only active environments 40 | env_actions = actions[start_idx:start_idx + num_agents] 41 | obs, reward, term, trunc, info = env.step(env_actions) 42 | self.active_status[i] = not (all(term) or all(trunc)) # Mark inactive if terminated or truncated 43 | if all(term) or all(trunc): 44 | info[0]['metrics']['map_name'] = env.grid_config.map_name 45 | self.metrics_info[i] = info 46 | self.last_observations[i] = obs # Store the observation for inactive reuse 47 | else: 48 | # Use last observation and set reward, terminated, and truncated to default values 49 | obs = self.last_observations[i] 50 | reward = np.zeros(num_agents) 51 | term = [True] * num_agents 52 | trunc = [True] * num_agents 53 | info = self.metrics_info[i] 54 | 55 | 56 | start_idx += num_agents 57 | 58 | observations.append(obs) 59 | rewards.append(reward) 60 | terminated.append(term) 61 | truncated.append(trunc) 62 | infos.append(info) 63 | 64 | return ( 65 | np.concatenate(observations), 66 | np.concatenate(rewards), 67 | np.concatenate(terminated), 68 | np.concatenate(truncated), 69 | infos, 70 | ) 71 | 72 | def reset(self): 73 | observations = [] 74 | self.active_status = [True] * len(self.environments) # Reset all environments to active 75 | for i, env in enumerate(self.environments): 76 | obs, info = env.reset() 77 | self.last_observations[i] = obs # Store the initial observation for reuse 78 | observations.append(obs) 79 | self.num_agents_per_env = [env.grid.config.num_agents for env in self.environments] 80 | 81 | return np.concatenate(observations), {} -------------------------------------------------------------------------------- /lacam/lacam3/src/graph.cpp: -------------------------------------------------------------------------------- 1 | #include "../include/graph.hpp" 2 | #include 3 | 4 | Vertex::Vertex(int _id, int _index, int _x, int _y) 5 | : id(_id), index(_index), x(_x), y(_y), neighbor() 6 | { 7 | } 8 | 9 | Graph::Graph() : V(Vertices()), width(0), height(0) {} 10 | 11 | Graph::~Graph() 12 | { 13 | for (auto &v : V) 14 | if (v != nullptr) delete v; 15 | V.clear(); 16 | } 17 | 18 | // to load graph 19 | static const std::regex r_height = std::regex(R"(height\s(\d+))"); 20 | static const std::regex r_width = std::regex(R"(width\s(\d+))"); 21 | static const std::regex r_map = std::regex(R"(map)"); 22 | 23 | Graph::Graph(const std::string &map_data) : V(Vertices()), width(0), height(0) 24 | { 25 | std::istringstream map_stream(map_data); 26 | std::string line; 27 | std::smatch results; 28 | 29 | // read fundamental graph parameters 30 | while (getline(map_stream, line)) { 31 | // for CRLF coding 32 | if (!line.empty() && line.back() == '\r') { 33 | line.pop_back(); 34 | } 35 | 36 | if (std::regex_match(line, results, r_height)) { 37 | height = std::stoi(results[1].str()); 38 | } 39 | if (std::regex_match(line, results, r_width)) { 40 | width = std::stoi(results[1].str()); 41 | } 42 | if (std::regex_match(line, results, r_map)) break; 43 | } 44 | 45 | U = Vertices(width * height, nullptr); 46 | 47 | // create vertices 48 | int y = 0; 49 | while (getline(map_stream, line)) { 50 | // for CRLF coding 51 | if (*(line.end() - 1) == 0x0d) line.pop_back(); 52 | for (int x = 0; x < width; ++x) { 53 | char s = line[x]; 54 | if (s == 'T' or s == '@') continue; // object 55 | auto index = width * y + x; 56 | auto v = new Vertex(V.size(), index, x, y); 57 | V.push_back(v); 58 | U[index] = v; 59 | } 60 | ++y; 61 | } 62 | 63 | // create edges 64 | for (int y = 0; y < height; ++y) { 65 | for (int x = 0; x < width; ++x) { 66 | auto v = U[width * y + x]; 67 | if (v == nullptr) continue; 68 | // left 69 | if (x > 0) { 70 | auto u = U[width * y + (x - 1)]; 71 | if (u != nullptr) v->neighbor.push_back(u); 72 | } 73 | // right 74 | if (x < width - 1) { 75 | auto u = U[width * y + (x + 1)]; 76 | if (u != nullptr) v->neighbor.push_back(u); 77 | } 78 | // up 79 | if (y < height - 1) { 80 | auto u = U[width * (y + 1) + x]; 81 | if (u != nullptr) v->neighbor.push_back(u); 82 | } 83 | // down 84 | if (y > 0) { 85 | auto u = U[width * (y - 1) + x]; 86 | if (u != nullptr) v->neighbor.push_back(u); 87 | } 88 | } 89 | } 90 | } 91 | 92 | int Graph::size() const { return V.size(); } 93 | 94 | bool is_same_config(const Config &C1, const Config &C2) 95 | { 96 | const auto N = C1.size(); 97 | for (size_t i = 0; i < N; ++i) { 98 | if (C1[i]->id != C2[i]->id) return false; 99 | } 100 | return true; 101 | } 102 | 103 | uint ConfigHasher::operator()(const Config &C) const 104 | { 105 | uint hash = C.size(); 106 | for (auto &v : C) { 107 | hash ^= v->id + 0x9e3779b9 + (hash << 6) + (hash >> 2); 108 | } 109 | return hash; 110 | } 111 | 112 | std::ostream &operator<<(std::ostream &os, const Vertex *v) 113 | { 114 | os << v->index; 115 | return os; 116 | } 117 | 118 | std::ostream &operator<<(std::ostream &os, const Config &Q) 119 | { 120 | for (auto v : Q) os << v << ","; 121 | return os; 122 | } 123 | 124 | std::ostream &operator<<(std::ostream &os, const Paths &paths) 125 | { 126 | for (auto i = 0; i < paths.size(); ++i) { 127 | os << i << ":"; 128 | for (auto &v : paths[i]) { 129 | os << std::setw(4) << v << "->"; 130 | } 131 | std::cout << std::endl; 132 | } 133 | return os; 134 | } 135 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /tokenizer/cost2go.cpp: -------------------------------------------------------------------------------- 1 | // cppimport 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | std::vector> get_cost_matrix(const std::vector> &grid, int si, int sj) 9 | { 10 | std::vector> moves = {{0, 0}, {-1, 0}, {1, 0}, {0, -1}, {0, 1}}; 11 | std::queue> fringe; 12 | fringe.push({si, sj}); 13 | auto result = std::vector>(grid.size(), std::vector(grid[0].size(), -1)); 14 | result[si][sj] = 0; 15 | while (!fringe.empty()) 16 | { 17 | auto pos = fringe.front(); 18 | fringe.pop(); 19 | for (const auto &move : moves) 20 | { 21 | int new_i(pos.first + move.first), new_j(pos.second + move.second); 22 | if(new_i >=0 && new_j >= 0 && new_i < grid.size() && new_j < grid.front().size()) 23 | if (grid[new_i][new_j] == 0 && result[new_i][new_j] < 0) 24 | { 25 | result[new_i][new_j] = result[pos.first][pos.second] + 1; 26 | fringe.push(std::make_pair(new_i, new_j)); 27 | } 28 | } 29 | } 30 | return result; 31 | } 32 | 33 | std::map, std::vector>> precompute_cost2go(const std::vector> &grid, int obs_radius) 34 | { 35 | std::map, std::vector>> cost2go; 36 | for (size_t i = obs_radius; i < grid.size() - obs_radius; i++) 37 | for (size_t j = obs_radius; j < grid[0].size() - obs_radius; j++) 38 | if (grid[i][j] == 0) 39 | cost2go[std::make_pair(i, j)] = get_cost_matrix(grid, i, j); 40 | return cost2go; 41 | } 42 | 43 | std::vector> generate_cost2go_obs(const std::vector> &cost2go, const std::pair &pos, int offset, int limit, bool only_obstacles) 44 | { 45 | if (offset == 0) 46 | return {}; 47 | int x = pos.first - offset; 48 | int y = pos.second - offset; 49 | 50 | std::vector> observation(2 * offset + 1, std::vector(2 * offset + 1)); 51 | if (only_obstacles) 52 | { 53 | for (int i = 0; i <= offset * 2; i++) 54 | for (int j = 0; j <= offset * 2; j++) 55 | { 56 | int nx = x + i; 57 | int ny = y + j; 58 | observation[i][j] = bool(cost2go[nx][ny] < 0); 59 | } 60 | return observation; 61 | } 62 | for (int i = 0; i <= offset * 2; i++) 63 | for (int j = 0; j <= offset * 2; j++) 64 | { 65 | int nx = x + i; 66 | int ny = y + j; 67 | observation[i][j] = cost2go[nx][ny]; 68 | } 69 | 70 | int middle_value = observation[offset][offset]; 71 | for (int i = 0; i < observation.size(); i++) 72 | for (int j = 0; j < observation[i].size(); j++) 73 | { 74 | if (observation[i][j] >= 0) 75 | { 76 | observation[i][j] -= middle_value; 77 | if (observation[i][j] > limit) 78 | observation[i][j] = limit * 2; 79 | else if (observation[i][j] < -limit) 80 | observation[i][j] = -limit * 2; 81 | } 82 | else 83 | observation[i][j] = -limit * 4; 84 | } 85 | 86 | return observation; 87 | } 88 | 89 | namespace py = pybind11; 90 | 91 | PYBIND11_MODULE(cost2go, m) 92 | { 93 | m.def("precompute_cost2go", &precompute_cost2go, "Precompute cost-to-go matrices for the grid", 94 | py::arg("grid"), py::arg("obs_radius")); 95 | m.def("get_cost_matrix", &get_cost_matrix, "Compute cost matrix from a starting position", 96 | py::arg("grid"), py::arg("si"), py::arg("sj")); 97 | m.def("generate_cost2go_obs", &generate_cost2go_obs, "Generate cost-to-go observations", 98 | py::arg("cost2go"), py::arg("pos"), py::arg("offset"), py::arg("limit"), py::arg("only_obstacles")); 99 | } 100 | 101 | <% 102 | cfg['extra_compile_args'] = ['-std=c++17'] 103 | setup_pybind11(cfg) 104 | %> -------------------------------------------------------------------------------- /gpt/fast_data_loader.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import time 4 | 5 | import numpy as np 6 | import pyarrow as pa 7 | import torch 8 | import random 9 | from loguru import logger 10 | from torch.utils.data import Dataset 11 | 12 | 13 | class MapfArrowDataset(torch.utils.data.Dataset): 14 | def __init__(self, folder_path, device, batch_size): 15 | self.all_data_files = self.file_paths = sorted(glob.glob(os.path.join(folder_path, "*.arrow"))) 16 | self.device = device 17 | self.batch_size = batch_size 18 | self.dtype = torch.int8 19 | 20 | ddp_local_rank = int(device.split(':')[-1]) 21 | ddp_world_size = os.environ.get("WORLD_SIZE") 22 | random.shuffle(self.file_paths) 23 | if "dagger" in folder_path or "ddg" in folder_path: 24 | self.file_paths = sorted(self.file_paths, 25 | key=lambda x: int(os.path.splitext(os.path.basename(x))[0].split('_')[-1]), 26 | reverse=True) 27 | if "validation" not in folder_path and ddp_local_rank is not None and ddp_world_size is not None: 28 | self.file_paths = self.file_paths[int(ddp_local_rank)::int(ddp_world_size)] 29 | 30 | # pre-allocate memory for the input and target tensors (same file size) 31 | sample_input_tensors, sample_gt_actions = self._get_data_from_file(self.file_paths[0]) 32 | 33 | self.input_tensors = torch.empty(sample_input_tensors.shape, dtype=self.dtype, device=self.device) 34 | self.target_tensors = torch.full(sample_input_tensors.shape, -1, dtype=self.dtype, device=self.device) 35 | 36 | logger.info( 37 | f"Single file tensor size: {self.input_tensors.numel() * self.input_tensors.element_size() / 1e9:.4f} GB") 38 | 39 | @staticmethod 40 | def _get_data_from_file(file_path): 41 | with pa.memory_map(file_path) as source: 42 | table = pa.ipc.open_file(source).read_all() 43 | input_tensors = table["input_tensors"].to_numpy() 44 | gt_actions = table["gt_actions"].to_numpy() 45 | 46 | # shuffle data within the current file 47 | indices = np.random.permutation(len(input_tensors)) 48 | input_tensors = np.stack(input_tensors[indices]) 49 | gt_actions = gt_actions[indices] 50 | 51 | return input_tensors, gt_actions 52 | 53 | def load_and_transfer_data_file(self, filename): 54 | start_time = time.monotonic() 55 | 56 | input_tensors, gt_actions = self._get_data_from_file(filename) 57 | 58 | self.input_tensors.copy_(torch.tensor(input_tensors, dtype=self.dtype), non_blocking=True) 59 | self.target_tensors[:, -1].copy_(torch.tensor(gt_actions, dtype=self.dtype), non_blocking=True) 60 | finish_time = time.monotonic() - start_time 61 | logger.debug(f'Data from {filename} for {self.device} device prepared in ~{round(finish_time, 5)}s') 62 | 63 | def __iter__(self): 64 | while True: 65 | for file_path in self.file_paths: 66 | self.load_and_transfer_data_file(file_path) 67 | num_samples = len(self.input_tensors) 68 | if num_samples < self.batch_size: 69 | raise KeyError('The dataset is too small to sample a single batch.') 70 | for i in range(0, num_samples - num_samples % self.batch_size, self.batch_size): 71 | yield self.input_tensors[i:i + self.batch_size], self.target_tensors[i:i + self.batch_size] 72 | 73 | def get_shard_size(self): 74 | return len(self.input_tensors) * len(self.file_paths) 75 | 76 | def get_full_dataset_size(self): 77 | return len(self.input_tensors) * len(self.all_data_files) 78 | 79 | 80 | def main(): 81 | # folder_path = "../dataset/validation" 82 | folder_path = "../dataset/train" 83 | dataset = MapfArrowDataset(folder_path, device='cuda:0', batch_size=3 * 256) 84 | data = iter(dataset) 85 | x = 0 86 | logger.info(dataset.get_full_dataset_size()) 87 | logger.info(dataset.get_shard_size()) 88 | 89 | while True: 90 | x += 1 91 | qx, qy = next(data) 92 | logger.info(str(qx.shape) + ' ' + str(qy.shape)) 93 | 94 | 95 | if __name__ == "__main__": 96 | main() 97 | -------------------------------------------------------------------------------- /finetuning/scenario_generators.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from pogema import pogema_v0 4 | from pogema_toolbox.create_env import Environment 5 | from pogema_toolbox.generators.maze_generator import MazeGenerator, MazeRangeSettings 6 | from pogema_toolbox.generators.random_generator import MapRangeSettings as RandomRangeSettings, generate_map 7 | from create_env import ProvideFutureTargetsWrapper 8 | 9 | def make_pogema_maze_instance(num_agents, max_episode_steps=256, size_min=17, size_max=21, wall_components_min=4, 10 | wall_components_max=8, on_target='nothing', map_seed=None, scenario_seed=None): 11 | rng = np.random.default_rng() 12 | 13 | settings_gen = MazeRangeSettings(width_min=size_min, width_max=size_max, 14 | height_min=size_max, height_max=size_max, 15 | wall_components_min=wall_components_min, 16 | wall_components_max=wall_components_max, 17 | ) 18 | if map_seed is None: 19 | map_seed = rng.integers(np.iinfo(np.int64).max) 20 | maze = MazeGenerator.generate_maze(**settings_gen.sample(seed=map_seed)) 21 | if scenario_seed is None: 22 | scenario_seed = rng.integers(np.iinfo(np.int64).max) 23 | env_cfg = Environment( 24 | num_agents=num_agents, 25 | observation_type="MAPF", 26 | max_episode_steps=max_episode_steps, 27 | map=maze, 28 | with_animation=False, 29 | on_target=on_target, 30 | seed=scenario_seed, 31 | collision_system='soft' 32 | ) 33 | 34 | env = pogema_v0(env_cfg) 35 | if on_target == 'restart': 36 | env = ProvideFutureTargetsWrapper(env) 37 | env_cfg.map_name = f'maze-seed-{str(map_seed)}-scenario-{str(scenario_seed)}' 38 | return env 39 | 40 | 41 | def make_pogema_random_instance(num_agents, max_episode_steps=256, size_min=17, size_max=21, on_target='nothing', map_seed=None, 42 | scenario_seed=None): 43 | rng = np.random.default_rng() 44 | settings_gen = RandomRangeSettings(width_min=size_min, width_max=size_max, height_min=size_max, height_max=size_max) 45 | if map_seed is None: 46 | map_seed = rng.integers(np.iinfo(np.int64).max) 47 | maze = generate_map(settings_gen.sample(map_seed)) 48 | if scenario_seed is None: 49 | scenario_seed = rng.integers(np.iinfo(np.int64).max) 50 | env_cfg = Environment( 51 | num_agents=num_agents, 52 | observation_type="MAPF", 53 | max_episode_steps=max_episode_steps, 54 | map=maze, 55 | with_animation=False, 56 | on_target=on_target, 57 | seed=scenario_seed, 58 | collision_system='soft' 59 | ) 60 | 61 | env = pogema_v0(env_cfg) 62 | if on_target == 'restart': 63 | env = ProvideFutureTargetsWrapper(env) 64 | env_cfg.map_name = f'random-seed-{str(map_seed)}-scenario-{str(scenario_seed)}' 65 | return env 66 | 67 | def make_pogema_map_instance(num_agents, map, max_episode_steps=256, on_target='nothing', scenario_seed=None): 68 | rng = np.random.default_rng() 69 | if scenario_seed is None: 70 | scenario_seed = rng.integers(np.iinfo(np.int64).max) 71 | env_cfg = Environment( 72 | num_agents=num_agents, 73 | observation_type="MAPF", 74 | max_episode_steps=max_episode_steps, 75 | map=map, 76 | with_animation=False, 77 | on_target=on_target, 78 | seed=scenario_seed, 79 | collision_system='soft' 80 | ) 81 | env = pogema_v0(env_cfg) 82 | if on_target == 'restart': 83 | env = ProvideFutureTargetsWrapper(env) 84 | env_cfg.map_name = f'warehouse-scenario-{str(scenario_seed)}' 85 | return env 86 | 87 | 88 | def main(): 89 | maze_env = make_pogema_maze_instance(num_agents=32, map_seed=42, scenario_seed=42) 90 | maze_env.reset() 91 | maze_env.render() 92 | 93 | random_env = make_pogema_random_instance(num_agents=32, map_seed=42, scenario_seed=42) 94 | random_env.reset() 95 | random_env.render() 96 | 97 | grid = ''' 98 | !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! 99 | !@@!@@!$$$$$$$$$$!$$$$$$$$$$!$$$$$$$$$$!@@!@@! 100 | !@@!@@!##########!##########!##########!@@!@@! 101 | !@@!@@!$$$$$$$$$$!$$$$$$$$$$!$$$$$$$$$$!@@!@@! 102 | !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! 103 | ''' 104 | warehouse_env = make_pogema_map_instance(num_agents=8, map=grid, scenario_seed=42) 105 | warehouse_env.reset() 106 | warehouse_env.render() 107 | 108 | if __name__ == '__main__': 109 | main() 110 | -------------------------------------------------------------------------------- /lacam/main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | extern "C" { 6 | const char* run_lacam(const char* map_content_cstr, 7 | const char* scene_content_cstr, int N, 8 | float time_limit_sec); 9 | } 10 | 11 | const char* run_lacam(const char* map_content_cstr, 12 | const char* scene_content_cstr, int N, 13 | float time_limit_sec) 14 | { 15 | std::string map_content(map_content_cstr); 16 | std::string scene_content(scene_content_cstr); 17 | 18 | const int seed = 0; 19 | const int verbose = 0; 20 | const bool log_short = false; 21 | 22 | // Solver parameters 23 | const bool flg_no_all = false; 24 | const bool flg_no_star = false; 25 | const bool flg_no_swap = false; 26 | const bool flg_no_multi_thread = true; 27 | const int pibt_num = 10; 28 | const bool flg_no_refiner = false; 29 | const int refiner_num = 4; 30 | const bool flg_no_scatter = false; 31 | const int scatter_margin = 10; 32 | const float random_insert_prob1 = 0.001f; 33 | const float random_insert_prob2 = 0.01f; 34 | const bool flg_random_insert_init_node = false; 35 | const float recursive_rate = 0.2f; 36 | const int recursive_time_limit = 1000; 37 | const int checkpoints_duration = 5000; 38 | 39 | // setup instance 40 | const auto ins = Instance(scene_content, map_content, N); 41 | if (!ins.is_valid(1)) return "ERROR_SCENE"; 42 | 43 | // solver parameters 44 | Planner::FLG_SWAP = !flg_no_swap && !flg_no_all; 45 | Planner::FLG_STAR = !flg_no_star && !flg_no_all; 46 | Planner::FLG_MULTI_THREAD = !flg_no_multi_thread && !flg_no_all; 47 | Planner::PIBT_NUM = flg_no_all ? 1 : pibt_num; 48 | Planner::FLG_REFINER = !flg_no_refiner && !flg_no_all; 49 | Planner::REFINER_NUM = refiner_num; 50 | Planner::FLG_SCATTER = !flg_no_scatter && !flg_no_all; 51 | Planner::SCATTER_MARGIN = scatter_margin; 52 | Planner::RANDOM_INSERT_PROB1 = flg_no_all ? 0 : random_insert_prob1; 53 | Planner::RANDOM_INSERT_PROB2 = flg_no_all ? 0 : random_insert_prob2; 54 | Planner::FLG_RANDOM_INSERT_INIT_NODE = 55 | flg_random_insert_init_node && !flg_no_all; 56 | Planner::RECURSIVE_RATE = flg_no_all ? 0 : recursive_rate; 57 | Planner::RECURSIVE_TIME_LIMIT = flg_no_all ? 0 : recursive_time_limit; 58 | Planner::CHECKPOINTS_DURATION = checkpoints_duration; 59 | 60 | // solve 61 | const auto deadline = Deadline(time_limit_sec * 1000); 62 | const auto solution = solve(ins, verbose - 1, &deadline, seed); 63 | const auto comp_time_ms = deadline.elapsed_ms(); 64 | 65 | // failure 66 | if (solution.empty()) { 67 | info(1, verbose, &deadline, "failed to solve"); 68 | return "ERROR_EMPTY"; 69 | } 70 | 71 | // check feasibility 72 | if (!is_feasible_solution(ins, solution, verbose)) { 73 | info(0, verbose, &deadline, "invalid solution"); 74 | return "ERROR_SOLUTION"; 75 | } 76 | 77 | // post processing 78 | print_stats(verbose, &deadline, ins, solution, comp_time_ms); 79 | make_log(ins, solution, "lacam_log.txt", comp_time_ms, "tmp_map", seed, 80 | log_short); 81 | 82 | auto get_x = [&](int k) { return k % ins.G->width; }; 83 | auto get_y = [&](int k) { return k / ins.G->width; }; 84 | 85 | std::ostringstream result_string; 86 | for (size_t t = 0; t < solution.size(); ++t) { 87 | auto C = solution[t]; 88 | for (auto v : C) { 89 | result_string << get_x(v->index) << "," << get_y(v->index) << "|"; 90 | } 91 | result_string << "\n"; 92 | } 93 | 94 | static std::string result; 95 | result = result_string.str(); 96 | 97 | return result.c_str(); 98 | } 99 | 100 | int main(int argc, char* argv[]) 101 | { 102 | const char* map_name = "tmp.map"; 103 | const char* scene_name = "tmp.scene"; 104 | const int N = 1; 105 | const float time_limit_sec = 1.0; 106 | 107 | std::string map_content; 108 | std::string scene_content; 109 | 110 | std::ifstream map_file(map_name); 111 | if (map_file) { 112 | std::stringstream buffer; 113 | buffer << map_file.rdbuf(); 114 | map_content = buffer.str(); 115 | map_file.close(); 116 | } else { 117 | std::cerr << "Failed to open map_name: " << map_name << std::endl; 118 | return 1; 119 | } 120 | 121 | std::ifstream scen_file(scene_name); 122 | if (scen_file) { 123 | std::stringstream buffer; 124 | buffer << scen_file.rdbuf(); 125 | scene_content = buffer.str(); 126 | scen_file.close(); 127 | } else { 128 | std::cerr << "Failed to open scene_name: " << scene_name << std::endl; 129 | return 1; 130 | } 131 | 132 | const char* map_content_cstr = map_content.c_str(); 133 | const char* scene_content_cstr = scene_content.c_str(); 134 | 135 | const char* result = 136 | run_lacam(map_content_cstr, scene_content_cstr, N, time_limit_sec); 137 | std::cout << result << std::endl; 138 | return 0; 139 | } 140 | -------------------------------------------------------------------------------- /lacam/lacam3/src/scatter.cpp: -------------------------------------------------------------------------------- 1 | #include "../include/scatter.hpp" 2 | 3 | #include "../include/metrics.hpp" 4 | 5 | Scatter::Scatter(const Instance *_ins, DistTable *_D, const Deadline *_deadline, 6 | const int seed, int _verbose, int _cost_margin) 7 | : ins(_ins), 8 | deadline(_deadline), 9 | MT(std::mt19937(seed)), 10 | verbose(_verbose), 11 | N(ins->N), 12 | V_size(ins->G->size()), 13 | T(get_makespan_lower_bound(*ins, *_D) + _cost_margin), 14 | D(_D), 15 | cost_margin(_cost_margin), 16 | sum_of_path_length(0), 17 | paths(N), 18 | scatter_data(N), 19 | CT(ins) 20 | { 21 | } 22 | 23 | void Scatter::construct() 24 | { 25 | info(0, verbose, deadline, "scatter", "\tinvoked"); 26 | 27 | // define path finding utilities 28 | // vertex, cost-to-come, cost-to-go, collision, parent 29 | using ScatterNode = std::tuple; 30 | auto cmp = [&](ScatterNode &a, ScatterNode &b) { 31 | // collision 32 | if (std::get<3>(a) != std::get<3>(b)) 33 | return std::get<3>(a) > std::get<3>(b); 34 | auto f_a = std::get<1>(a) + std::get<2>(a); 35 | auto f_b = std::get<1>(b) + std::get<2>(b); 36 | if (f_a != f_b) return f_a > f_b; 37 | return std::get<0>(a)->id < std::get<0>(b)->id; 38 | }; 39 | auto CLOSED = std::vector(V_size, nullptr); // parent 40 | 41 | // metrics 42 | auto collision_cnt_last = 0; 43 | auto paths_prev = std::vector(); 44 | 45 | // main loop 46 | auto loop = 0; 47 | while (loop < 2 || CT.collision_cnt < collision_cnt_last) { 48 | ++loop; 49 | collision_cnt_last = CT.collision_cnt; 50 | 51 | // randomize planning order 52 | auto order = std::vector(N, 0); 53 | std::iota(order.begin(), order.end(), 0); 54 | std::shuffle(order.begin(), order.end(), MT); 55 | 56 | // single-agent path finding for agent-i 57 | for (int _i = 0; _i < N; ++_i) { 58 | if (is_expired(deadline)) break; 59 | 60 | const auto i = order[_i]; 61 | const auto cost_ub = D->get(i, ins->starts[i]) + cost_margin; 62 | 63 | if (!paths[i].empty()) sum_of_path_length -= (paths[i].size() - 1); 64 | 65 | // clear cache 66 | CT.clearPath(i, paths[i]); 67 | 68 | // setup A* 69 | auto OPEN = std::priority_queue, 70 | decltype(cmp)>(cmp); 71 | // used with CLOSED, vertex-id list 72 | const auto s_i = ins->starts[i]; 73 | OPEN.push(std::make_tuple(s_i, 0, D->get(i, s_i), 0, nullptr)); 74 | auto USED = std::vector(); 75 | 76 | // A* 77 | auto solved = false; 78 | while (!OPEN.empty() && !is_expired(deadline)) { 79 | // pop 80 | auto node = OPEN.top(); 81 | OPEN.pop(); 82 | 83 | // check CLOSED list 84 | const auto v = std::get<0>(node); 85 | const auto g_v = std::get<1>(node); // cost-to-come 86 | const auto c_v = std::get<3>(node); // collision 87 | if (CLOSED[v->id] != nullptr) continue; 88 | CLOSED[v->id] = std::get<4>(node); // parent 89 | USED.push_back(v->id); 90 | 91 | // check goal condition 92 | if (v == ins->goals[i]) { 93 | solved = true; 94 | break; 95 | } 96 | 97 | // expand 98 | for (auto u : v->neighbor) { 99 | auto d_u = D->get(i, u); 100 | if (u != s_i && CLOSED[u->id] == nullptr && 101 | d_u + g_v + 1 <= cost_ub) { 102 | // insert new node 103 | OPEN.push(std::make_tuple(u, g_v + 1, d_u, 104 | CT.getCollisionCost(v, u, g_v) + c_v, v)); 105 | } 106 | } 107 | } 108 | 109 | // backtrack 110 | if (solved) { 111 | paths[i].clear(); 112 | auto v = ins->goals[i]; 113 | while (v != nullptr) { 114 | paths[i].push_back(v); 115 | v = CLOSED[v->id]; 116 | } 117 | std::reverse(paths[i].begin(), paths[i].end()); 118 | } 119 | 120 | // register to CT & update collision count 121 | CT.enrollPath(i, paths[i]); 122 | sum_of_path_length += paths[i].size() - 1; 123 | 124 | // memory management 125 | for (auto k : USED) CLOSED[k] = nullptr; 126 | } 127 | 128 | paths_prev = paths; 129 | info(1, verbose, deadline, "scatter", "\titer:", loop, 130 | "\tcollision_cnt:", CT.collision_cnt); 131 | 132 | if (CT.collision_cnt == 0) break; 133 | if (is_expired(deadline)) break; 134 | } 135 | 136 | paths = paths_prev; 137 | 138 | // set scatter data 139 | for (auto i = 0; i < N; ++i) { 140 | if (paths[i].empty()) continue; 141 | for (auto t = 0; t < paths[i].size() - 1; ++t) { 142 | scatter_data[i][paths[i][t]->id] = paths[i][t + 1]; 143 | } 144 | } 145 | 146 | info(0, verbose, deadline, "scatter", "\tcompleted"); 147 | } 148 | -------------------------------------------------------------------------------- /create_env.py: -------------------------------------------------------------------------------- 1 | import enum 2 | import time 3 | from copy import deepcopy 4 | 5 | import numpy as np 6 | from gymnasium import Wrapper 7 | from loguru import logger 8 | from pogema import AnimationConfig, AnimationMonitor, pogema_v0 9 | from pogema.generator import generate_from_possible_targets, generate_new_target 10 | from pogema.wrappers.metrics import AgentsDensityWrapper, RuntimeMetricWrapper 11 | from pogema_toolbox.create_env import MultiMapWrapper 12 | 13 | 14 | class ProvideFutureTargetsWrapper(Wrapper): 15 | def _get_lifelong_global_targets_xy(self): 16 | all_goals = [] 17 | cur_goals = self.grid.get_targets_xy() 18 | generators = deepcopy(self.random_generators) 19 | for agent_idx in range(self.grid_config.num_agents): 20 | distance = 0 21 | cur_goal = cur_goals[agent_idx] 22 | goals = [cur_goal] 23 | while distance < self.grid_config.max_episode_steps + 100: 24 | if self.grid_config.possible_targets_xy is None: 25 | new_goal = generate_new_target( 26 | generators[agent_idx], 27 | self.grid.point_to_component, 28 | self.grid.component_to_points, 29 | cur_goal, 30 | ) 31 | else: 32 | new_goal = generate_from_possible_targets( 33 | generators[agent_idx], 34 | self.grid_config.possible_targets_xy, 35 | cur_goal, 36 | ) 37 | new_goal = ( 38 | new_goal[0] + self.grid_config.obs_radius, 39 | new_goal[1] + self.grid_config.obs_radius, 40 | ) 41 | distance += abs(cur_goal[0] - new_goal[0]) + abs( 42 | cur_goal[1] - new_goal[1] 43 | ) 44 | cur_goal = new_goal 45 | goals.append(cur_goal) 46 | all_goals.append(goals) 47 | return all_goals 48 | 49 | def reset(self, **kwargs): 50 | observations, infos = self.env.reset(seed=self.env.grid_config.seed) 51 | observations[0]["after_reset"] = True 52 | observations[0]["max_episode_steps"] = self.env.grid_config.max_episode_steps 53 | if self.env.grid_config.on_target == "restart": 54 | global_lifelong_targets_xy = self._get_lifelong_global_targets_xy() 55 | for idx, obs in enumerate(observations): 56 | obs["global_lifelong_targets_xy"] = global_lifelong_targets_xy[idx] 57 | return observations, infos 58 | 59 | 60 | class LogActions(Wrapper): 61 | def __init__(self, env): 62 | super().__init__(env) 63 | self.made_actions = None 64 | self.init_positions = None 65 | 66 | def step(self, actions): 67 | observations, rewards, terminated, truncated, infos = self.env.step(actions) 68 | for i, action in enumerate(actions): 69 | self.made_actions[i].append(action) 70 | if all(terminated) or all(truncated): 71 | infos[0]["metrics"]["made_actions"] = self.made_actions 72 | infos[0]["metrics"]["init_positions"] = self.init_positions 73 | if self.env.grid_config.on_target == "restart": 74 | infos[0]["metrics"][ 75 | "global_lifelong_targets_xy" 76 | ] = self.global_lifelong_targets_xy 77 | return observations, rewards, terminated, truncated, infos 78 | 79 | def reset(self, **kwargs): 80 | observations, info = self.env.reset(**kwargs) 81 | self.made_actions = [[] for _ in observations] 82 | self.init_positions = [obs["global_xy"] for obs in observations] 83 | if self.env.grid_config.on_target == "restart": 84 | self.global_lifelong_targets_xy = [ 85 | [[int(x), int(y)] for x, y in obs["global_lifelong_targets_xy"]] 86 | for obs in observations 87 | ] 88 | return observations, info 89 | 90 | 91 | def create_eval_env(config): 92 | env = pogema_v0(grid_config=config) 93 | env = AgentsDensityWrapper(env) 94 | env = MultiMapWrapper(env) 95 | env = ProvideFutureTargetsWrapper(env) 96 | if config.with_animation: 97 | logger.debug("Wrapping environment with AnimationMonitor") 98 | env = AnimationMonitor(env, AnimationConfig(save_every_idx_episode=None)) 99 | env = RuntimeMetricWrapper(env) 100 | return env 101 | 102 | 103 | def create_logging_env(config): 104 | env = pogema_v0(grid_config=config) 105 | env = AgentsDensityWrapper(env) 106 | env = ProvideFutureTargetsWrapper(env) 107 | env = MultiMapWrapper(env) 108 | env = LogActions(env) 109 | if config.with_animation: 110 | logger.debug("Wrapping environment with AnimationMonitor") 111 | env = AnimationMonitor(env, AnimationConfig(save_every_idx_episode=None)) 112 | 113 | # Adding runtime metrics 114 | env = RuntimeMetricWrapper(env) 115 | 116 | return env 117 | -------------------------------------------------------------------------------- /lacam/lacam3/src/post_processing.cpp: -------------------------------------------------------------------------------- 1 | #include "../include/post_processing.hpp" 2 | 3 | #include "../include/dist_table.hpp" 4 | #include "../include/planner.hpp" 5 | 6 | bool is_feasible_solution(const Instance &ins, const Solution &solution, 7 | const int verbose) 8 | { 9 | if (solution.empty()) return true; 10 | 11 | // check start locations 12 | if (!is_same_config(solution.front(), ins.starts)) { 13 | info(1, verbose, "invalid starts"); 14 | return false; 15 | } 16 | 17 | // check goal locations 18 | if (!is_same_config(solution.back(), ins.goals)) { 19 | info(1, verbose, "invalid goals"); 20 | return false; 21 | } 22 | 23 | for (size_t t = 1; t < solution.size(); ++t) { 24 | for (size_t i = 0; i < ins.N; ++i) { 25 | auto v_i_from = solution[t - 1][i]; 26 | auto v_i_to = solution[t][i]; 27 | // check connectivity 28 | if (v_i_from != v_i_to && 29 | std::find(v_i_to->neighbor.begin(), v_i_to->neighbor.end(), 30 | v_i_from) == v_i_to->neighbor.end()) { 31 | info(1, verbose, "invalid move"); 32 | return false; 33 | } 34 | 35 | // check conflicts 36 | for (size_t j = i + 1; j < ins.N; ++j) { 37 | auto v_j_from = solution[t - 1][j]; 38 | auto v_j_to = solution[t][j]; 39 | // vertex conflicts 40 | if (v_j_to == v_i_to) { 41 | info(1, verbose, "vertex conflict between agent-", i, " and agent-", 42 | j, " at vertex-", v_i_to->id, " at timestep ", t); 43 | return false; 44 | } 45 | // swap conflicts 46 | if (v_j_to == v_i_from && v_j_from == v_i_to) { 47 | info(1, verbose, "edge conflict"); 48 | return false; 49 | } 50 | } 51 | } 52 | } 53 | 54 | return true; 55 | } 56 | 57 | void print_stats(const int verbose, const Deadline *deadline, 58 | const Instance &ins, const Solution &solution, 59 | const double comp_time_ms) 60 | { 61 | auto ceil = [](float x) { return std::ceil(x * 100) / 100; }; 62 | auto dist_table = DistTable(ins); 63 | const auto makespan = get_makespan(solution); 64 | const auto makespan_lb = get_makespan_lower_bound(ins, dist_table); 65 | const auto sum_of_costs = get_sum_of_costs(solution); 66 | const auto sum_of_costs_lb = get_sum_of_costs_lower_bound(ins, dist_table); 67 | const auto sum_of_loss = get_sum_of_loss(solution); 68 | info(1, verbose, deadline, "solved", "\tmakespan: ", makespan, 69 | " (lb=", makespan_lb, ", ub=", ceil((float)makespan / makespan_lb), ")", 70 | "\tsum_of_costs: ", sum_of_costs, " (lb=", sum_of_costs_lb, 71 | ", ub=", ceil((float)sum_of_costs / sum_of_costs_lb), ")", 72 | "\tsum_of_loss: ", sum_of_loss, " (lb=", sum_of_costs_lb, 73 | ", ub=", ceil((float)sum_of_loss / sum_of_costs_lb), ")"); 74 | } 75 | 76 | // for log of map_name 77 | static const std::regex r_map_name = std::regex(R"(.+/(.+))"); 78 | 79 | void make_log(const Instance &ins, const Solution &solution, 80 | const std::string &output_name, const double comp_time_ms, 81 | const std::string &map_name, const int seed, const bool log_short) 82 | { 83 | // map name 84 | std::smatch results; 85 | const auto map_recorded_name = "tmp_map"; 86 | 87 | // for instance-specific values 88 | auto dist_table = DistTable(ins); 89 | 90 | // log for visualizer 91 | auto get_x = [&](int k) { return k % ins.G->width; }; 92 | auto get_y = [&](int k) { return k / ins.G->width; }; 93 | std::ofstream log; 94 | log.open(output_name, std::ios::out); 95 | log << "agents=" << ins.N << "\n"; 96 | log << "map_file=" << map_recorded_name << "\n"; 97 | log << "solver=planner\n"; 98 | log << "solved=" << !solution.empty() << "\n"; 99 | log << "soc=" << get_sum_of_costs(solution) << "\n"; 100 | log << "soc_lb=" << get_sum_of_costs_lower_bound(ins, dist_table) << "\n"; 101 | log << "makespan=" << get_makespan(solution) << "\n"; 102 | log << "makespan_lb=" << get_makespan_lower_bound(ins, dist_table) << "\n"; 103 | log << "sum_of_loss=" << get_sum_of_loss(solution) << "\n"; 104 | log << "sum_of_loss_lb=" << get_sum_of_costs_lower_bound(ins, dist_table) 105 | << "\n"; 106 | log << "comp_time=" << comp_time_ms << "\n"; 107 | log << "seed=" << seed << "\n"; 108 | // log << Planner::MSG << "\n"; 109 | if (log_short) return; 110 | log << "starts="; 111 | for (size_t i = 0; i < ins.N; ++i) { 112 | auto k = ins.starts[i]->index; 113 | log << "(" << get_x(k) << "," << get_y(k) << "),"; 114 | } 115 | log << "\ngoals="; 116 | for (size_t i = 0; i < ins.N; ++i) { 117 | auto k = ins.goals[i]->index; 118 | log << "(" << get_x(k) << "," << get_y(k) << "),"; 119 | } 120 | log << "\nsolution=\n"; 121 | for (size_t t = 0; t < solution.size(); ++t) { 122 | log << t << ":"; 123 | auto C = solution[t]; 124 | for (auto v : C) { 125 | log << "(" << get_x(v->index) << "," << get_y(v->index) << "),"; 126 | } 127 | log << "\n"; 128 | } 129 | log.close(); 130 | } 131 | -------------------------------------------------------------------------------- /eval_configs/04-movingai/04-movingai.yaml: -------------------------------------------------------------------------------- 1 | environment: 2 | name: Environment 3 | with_animation: False 4 | on_target: nothing 5 | max_episode_steps: 256 6 | observation_type: MAPF 7 | collision_system: soft 8 | seed: 0 9 | num_agents: 10 | grid_search: [64, 128, 192, 256] 11 | map_name: 12 | grid_search: 13 | [ 14 | Berlin_1_256_00, 15 | Berlin_1_256_01, 16 | Berlin_1_256_02, 17 | Berlin_1_256_03, 18 | Berlin_1_256_04, 19 | Berlin_1_256_05, 20 | Berlin_1_256_06, 21 | Berlin_1_256_07, 22 | Berlin_1_256_08, 23 | Berlin_1_256_09, 24 | Berlin_1_256_10, 25 | Berlin_1_256_11, 26 | Berlin_1_256_12, 27 | Berlin_1_256_13, 28 | Berlin_1_256_14, 29 | Berlin_1_256_15, 30 | Boston_0_256_00, 31 | Boston_0_256_01, 32 | Boston_0_256_02, 33 | Boston_0_256_03, 34 | Boston_0_256_04, 35 | Boston_0_256_05, 36 | Boston_0_256_06, 37 | Boston_0_256_07, 38 | Boston_0_256_08, 39 | Boston_0_256_09, 40 | Boston_0_256_10, 41 | Boston_0_256_11, 42 | Boston_0_256_12, 43 | Boston_0_256_13, 44 | Boston_0_256_14, 45 | Boston_0_256_15, 46 | London_2_256_00, 47 | London_2_256_01, 48 | London_2_256_02, 49 | London_2_256_03, 50 | London_2_256_04, 51 | London_2_256_05, 52 | London_2_256_06, 53 | London_2_256_07, 54 | London_2_256_08, 55 | London_2_256_09, 56 | London_2_256_10, 57 | London_2_256_11, 58 | London_2_256_12, 59 | London_2_256_13, 60 | London_2_256_14, 61 | London_2_256_15, 62 | Milan_0_256_00, 63 | Milan_0_256_01, 64 | Milan_0_256_02, 65 | Milan_0_256_03, 66 | Milan_0_256_04, 67 | Milan_0_256_05, 68 | Milan_0_256_06, 69 | Milan_0_256_07, 70 | Milan_0_256_08, 71 | Milan_0_256_09, 72 | Milan_0_256_10, 73 | Milan_0_256_11, 74 | Milan_0_256_12, 75 | Milan_0_256_13, 76 | Milan_0_256_14, 77 | Milan_0_256_15, 78 | Moscow_0_256_00, 79 | Moscow_0_256_01, 80 | Moscow_0_256_02, 81 | Moscow_0_256_03, 82 | Moscow_0_256_04, 83 | Moscow_0_256_05, 84 | Moscow_0_256_06, 85 | Moscow_0_256_07, 86 | Moscow_0_256_08, 87 | Moscow_0_256_09, 88 | Moscow_0_256_10, 89 | Moscow_0_256_11, 90 | Moscow_0_256_12, 91 | Moscow_0_256_13, 92 | Moscow_0_256_14, 93 | Moscow_0_256_15, 94 | NewYork_1_256_00, 95 | NewYork_1_256_01, 96 | NewYork_1_256_02, 97 | NewYork_1_256_03, 98 | NewYork_1_256_04, 99 | NewYork_1_256_05, 100 | NewYork_1_256_06, 101 | NewYork_1_256_07, 102 | NewYork_1_256_08, 103 | NewYork_1_256_09, 104 | NewYork_1_256_10, 105 | NewYork_1_256_11, 106 | NewYork_1_256_12, 107 | NewYork_1_256_13, 108 | NewYork_1_256_14, 109 | NewYork_1_256_15, 110 | Paris_1_256_00, 111 | Paris_1_256_01, 112 | Paris_1_256_02, 113 | Paris_1_256_03, 114 | Paris_1_256_04, 115 | Paris_1_256_05, 116 | Paris_1_256_06, 117 | Paris_1_256_07, 118 | Paris_1_256_08, 119 | Paris_1_256_09, 120 | Paris_1_256_10, 121 | Paris_1_256_11, 122 | Paris_1_256_12, 123 | Paris_1_256_13, 124 | Paris_1_256_14, 125 | Paris_1_256_15, 126 | Paris_2_256_00, 127 | Paris_2_256_01, 128 | Paris_2_256_02, 129 | Paris_2_256_03, 130 | Paris_2_256_04, 131 | Paris_2_256_05, 132 | Paris_2_256_06, 133 | Paris_2_256_07, 134 | Paris_2_256_08, 135 | Paris_2_256_09, 136 | Paris_2_256_10, 137 | Paris_2_256_11, 138 | Paris_2_256_12, 139 | Paris_2_256_13, 140 | Paris_2_256_14, 141 | Paris_2_256_15, 142 | ] 143 | 144 | algorithms: 145 | MAPF-GPT-2M: 146 | name: MAPF-GPT 147 | parallel_backend: balanced_dask 148 | num_process: 4 149 | path_to_weights: weights/model-2M.pt 150 | MAPF-GPT-6M: 151 | name: MAPF-GPT 152 | parallel_backend: balanced_dask 153 | num_process: 4 154 | path_to_weights: weights/model-6M.pt 155 | 156 | results_views: 157 | TabularView1: 158 | type: tabular 159 | drop_keys: [seed, map_name] 160 | print_results: True 161 | 162 | 04-movingai-mapf-soc: 163 | type: plot 164 | x: num_agents 165 | y: SoC 166 | width: 3 167 | height: 2.5 168 | line_width: 2 169 | use_log_scale_x: False 170 | legend_font_size: 8 171 | font_size: 8 172 | name: Out-of-Distribution 173 | ticks: [64, 128, 192, 256] 174 | 175 | 04-movingai-mapf-csr: 176 | type: plot 177 | x: num_agents 178 | y: CSR 179 | width: 3 180 | height: 2.5 181 | line_width: 2 182 | use_log_scale_x: False 183 | legend_font_size: 8 184 | font_size: 8 185 | name: Out-of-Distribution 186 | ticks: [64, 128, 192, 256] 187 | -------------------------------------------------------------------------------- /gpt/inference.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Literal, Optional 3 | 4 | import cppimport.import_hook 5 | import torch 6 | from huggingface_hub import hf_hub_download 7 | from pogema_toolbox.algorithm_config import AlgoBase 8 | from pogema_toolbox.registry import ToolboxRegistry 9 | from pydantic import Extra 10 | 11 | from gpt.model import GPT, GPTConfig 12 | from gpt.observation_generator import ObservationGenerator, InputParameters 13 | 14 | 15 | class MAPFGPTInferenceConfig(AlgoBase, extra=Extra.forbid): 16 | name: Literal["MAPF-GPT"] = "MAPF-GPT" 17 | num_agents: int = 13 18 | num_previous_actions: int = 5 19 | cost2go_value_limit: int = 20 20 | agents_radius: int = 5 21 | cost2go_radius: int = 5 22 | path_to_weights: Optional[str] = "hf_weights/model-2M-DDG.pt" 23 | device: str = "cuda" 24 | context_size: int = 256 25 | mask_actions_history: bool = False 26 | mask_goal: bool = False 27 | mask_cost2go: bool = False 28 | mask_greed_action: bool = False 29 | repo_id: str = 'aandreychuk/MAPF-GPT' 30 | grid_step: int = 64 31 | save_cost2go: bool = False 32 | batch_size: int = 2048 33 | num_process: int = 8 34 | 35 | def strip_prefix_from_state_dict(state_dict, prefix="_orig_mod."): 36 | """ 37 | strips the given prefix from the keys in the state dictionary 38 | """ 39 | new_state_dict = {} 40 | for k, v in state_dict.items(): 41 | if k.startswith(prefix): 42 | new_key = k[len(prefix):] 43 | new_state_dict[new_key] = v 44 | else: 45 | new_state_dict[k] = v 46 | return new_state_dict 47 | 48 | 49 | class MAPFGPTInference: 50 | def __init__(self, cfg: MAPFGPTInferenceConfig, net=None): 51 | self.cfg: MAPFGPTInferenceConfig = cfg 52 | self.input_parameters = InputParameters( 53 | cfg.cost2go_value_limit, 54 | cfg.num_agents, 55 | cfg.num_previous_actions, 56 | cfg.context_size, 57 | cfg.cost2go_radius, 58 | cfg.agents_radius, 59 | cfg.grid_step, 60 | cfg.save_cost2go 61 | ) 62 | self.observation_generator = None 63 | self.last_actions = None 64 | 65 | path_to_weights = Path(self.cfg.path_to_weights) 66 | if "hf_weights" in self.cfg.path_to_weights: 67 | hf_hub_download(repo_id=self.cfg.repo_id, filename=path_to_weights.name, local_dir=path_to_weights.parent) 68 | ToolboxRegistry.info(f'Using weights loaded from huggingface: {path_to_weights}') 69 | 70 | if ('cuda' in self.cfg.device and not torch.cuda.is_available()) or (self.cfg.device == 'mps' and not torch.backends.mps.is_available()): 71 | ToolboxRegistry.warning(f'{self.cfg.device} is not available, using cpu instead!') 72 | self.cfg.device = 'cpu' 73 | 74 | self.torch_generator = torch.Generator(device=self.cfg.device) 75 | self.torch_generator.manual_seed(0) 76 | 77 | checkpoint = torch.load( 78 | Path(self.cfg.path_to_weights), map_location=self.cfg.device 79 | ) 80 | 81 | model_state_dict = strip_prefix_from_state_dict(checkpoint["model"]) 82 | config_dict = checkpoint.get("model_args") 83 | gpt_config = GPTConfig(**config_dict) 84 | if net is not None: 85 | self.net = net 86 | else: 87 | self.net = GPT(gpt_config) 88 | self.net.load_state_dict(model_state_dict, strict=False) 89 | self.net.to(self.cfg.device) 90 | self.net.eval() 91 | 92 | def act(self, observations): 93 | if isinstance(observations[0], dict): 94 | positions = [obs["global_xy"] for obs in observations] 95 | goals = [obs["global_target_xy"] for obs in observations] 96 | if self.observation_generator is None: 97 | self.observation_generator = ObservationGenerator(observations[0]["global_obstacles"].copy().astype(int).tolist(), self.input_parameters) 98 | self.observation_generator.create_agents(positions, goals) 99 | self.last_actions = [-1 for _ in range(len(observations))] 100 | self.observation_generator.update_agents(positions, goals, self.last_actions) 101 | inputs = self.observation_generator.generate_observations() 102 | else: 103 | inputs = observations 104 | if len(inputs) > self.cfg.batch_size: 105 | actions = [] 106 | for i in range(0, len(inputs), self.cfg.batch_size): 107 | batch_inputs = inputs[i:i + self.cfg.batch_size] 108 | tensor_obs = torch.tensor(batch_inputs, dtype=torch.long, device=self.cfg.device) 109 | batch_actions = torch.squeeze(self.net.act(tensor_obs, generator=self.torch_generator)).tolist() 110 | actions.extend(batch_actions) 111 | else: 112 | tensor_obs = torch.tensor(inputs, dtype=torch.long, device=self.cfg.device) 113 | actions = torch.squeeze(self.net.act(tensor_obs, generator=self.torch_generator)).tolist() 114 | if not isinstance(actions, list): 115 | actions = [actions] 116 | self.last_actions = actions.copy() 117 | return actions 118 | 119 | def reset_states(self): 120 | self.observation_generator = None 121 | self.torch_generator.manual_seed(0) 122 | -------------------------------------------------------------------------------- /lacam/lacam3/src/sipp.cpp: -------------------------------------------------------------------------------- 1 | #include "../include/sipp.hpp" 2 | 3 | SITable::SITable(CollisionTable *_CT) : CT(_CT) {} 4 | 5 | SITable::~SITable() {} 6 | 7 | SIs &SITable::get(Vertex *v) 8 | { 9 | auto &b_v = body[v->id]; 10 | if (!b_v.empty()) return b_v; 11 | auto &entry = CT->body[v->id]; 12 | auto &entry_last = CT->body_last[v->id]; 13 | auto t_last = entry_last.empty() 14 | ? INT_MAX 15 | : *std::min_element(entry_last.begin(), entry_last.end()); 16 | 17 | // insert safe interval 18 | auto time_start = 0; 19 | for (auto t = 0; t < entry.size(); ++t) { 20 | if (entry[t].empty()) continue; 21 | auto time_end = t - 1; 22 | if (time_start <= time_end) { 23 | b_v.push_back(std::make_pair(time_start, time_end)); 24 | } 25 | time_start = t + 1; 26 | if (t_last == t) break; 27 | } 28 | // add last safe interval 29 | if (t_last == INT_MAX) { 30 | b_v.push_back(std::make_pair(time_start, INT_MAX - 1)); 31 | } 32 | return b_v; 33 | } 34 | 35 | SINode::SINode(const int _uuid, const SI &si, Vertex *_v, int _t, int _g, 36 | int _f, SINode *_parent) 37 | : uuid(_uuid), 38 | time_start(si.first), 39 | time_end(si.second), 40 | v(_v), 41 | t(_t), 42 | g(_g), 43 | f(_f), 44 | parent(_parent) 45 | { 46 | } 47 | 48 | bool SINode::operator==(const SINode &other) const 49 | { 50 | return (other.v->id == v->id && other.time_start == time_start && 51 | other.time_end == time_end); 52 | } 53 | 54 | uint SINodeHasher::operator()(const SINode &n) const 55 | { 56 | uint hash = n.v->id; 57 | hash ^= n.time_start + 0x9e3779b9 + (hash << 6) + (hash >> 2); 58 | hash ^= n.time_end + 0x9e3779b9 + (hash << 6) + (hash >> 2); 59 | return hash; 60 | } 61 | 62 | // minimizing path-loss - not cost! 63 | Path sipp(const int i, Vertex *s_i, Vertex *g_i, DistTable *D, 64 | CollisionTable *CT, const Deadline *deadline, const int f_upper_bound) 65 | { 66 | auto solution_path = Path(); 67 | auto ST = SITable(CT); // safe interval table 68 | 69 | // setup goal 70 | auto &intervals_goal = ST.get(g_i); 71 | if (intervals_goal.empty()) return solution_path; 72 | const auto t_goal_after = intervals_goal.back().first - 1; 73 | 74 | // setup OPEN lists 75 | auto cmpNodes = [&](SINode *a, SINode *b) { 76 | if (a->f != b->f) return a->f > b->f; 77 | if (a->g != b->g) return a->g < b->g; 78 | if (a->time_start != b->time_start) return a->time_start > b->time_start; 79 | return a->uuid < b->uuid; 80 | }; 81 | 82 | int node_id = 0; 83 | auto OPEN = 84 | std::priority_queue(cmpNodes); 85 | std::unordered_map EXPLORED; 86 | OPEN.push(new SINode(++node_id, ST.get(s_i)[0], s_i, 0, 0, D->get(i, s_i), 87 | nullptr)); 88 | 89 | // main loop 90 | while (!OPEN.empty() && !is_expired(deadline)) { 91 | auto n = OPEN.top(); 92 | OPEN.pop(); 93 | 94 | // check known node 95 | auto itr_e = EXPLORED.find(*n); 96 | if (itr_e != EXPLORED.end() && itr_e->second->g <= n->g) { 97 | delete n; 98 | continue; 99 | } 100 | EXPLORED[*n] = n; 101 | 102 | // goal check 103 | if (n->v == g_i && n->t > t_goal_after) { 104 | // backtrack 105 | auto t = n->t; 106 | while (t >= 0) { 107 | solution_path.push_back(n->v); 108 | if (t == n->t) n = n->parent; 109 | --t; 110 | } 111 | std::reverse(solution_path.begin(), solution_path.end()); 112 | break; 113 | } 114 | 115 | // expand neighbors 116 | for (auto &u : n->v->neighbor) { 117 | for (auto &si : ST.get(u)) { 118 | // invalid transition 119 | if (si.first > n->time_end + 1) break; 120 | if (si.second <= n->time_start) continue; 121 | 122 | // check existence of t 123 | auto t_earliest = INT_MAX; 124 | if (n->v != g_i) { 125 | for (auto t = std::max(n->t, si.first - 1); 126 | t <= std::min(n->time_end, si.second - 1); ++t) { 127 | if (CT->getCollisionCost(n->v, u, t) == 0) { 128 | t_earliest = t + 1; 129 | break; 130 | } 131 | } 132 | } else { 133 | // for goal node -> reverse 134 | for (auto t = std::min(n->time_end, si.second - 1); 135 | t >= std::max(n->t, si.first - 1); --t) { 136 | if (CT->getCollisionCost(n->v, u, t) == 0) { 137 | t_earliest = t + 1; 138 | break; 139 | } 140 | } 141 | } 142 | if (t_earliest >= INT_MAX) continue; 143 | 144 | // valid neighbor 145 | auto g_val = n->g + (n->v != g_i ? t_earliest - n->t : 1); 146 | auto f_val = g_val + D->get(i, u); 147 | auto n_new = new SINode(++node_id, si, u, t_earliest, g_val, f_val, n); 148 | 149 | auto itr = EXPLORED.find(*n_new); 150 | if (f_val > f_upper_bound || 151 | (itr != EXPLORED.end() && g_val >= itr->second->g)) { 152 | delete n_new; 153 | } else { 154 | OPEN.push(n_new); 155 | } 156 | } 157 | } 158 | } 159 | 160 | // memory management 161 | while (!OPEN.empty()) { 162 | delete OPEN.top(); 163 | OPEN.pop(); 164 | } 165 | for (auto iter : EXPLORED) delete iter.second; 166 | return solution_path; 167 | } 168 | 169 | std::ostream &operator<<(std::ostream &os, const SINode *n) 170 | { 171 | os << "f=" << std::setw(4) << n->f << ", v=" << std::setw(6) << n->v 172 | << ", t=" << std::setw(4) << n->t << ", si: [" << std::setw(4) 173 | << n->time_start << ", " << std::setw(4) 174 | << ((n->time_end < INT_MAX - 1) ? std::to_string(n->time_end) : "inf") 175 | << "]"; 176 | return os; 177 | } 178 | -------------------------------------------------------------------------------- /gpt/observation_generator.h: -------------------------------------------------------------------------------- 1 | // cppimport 2 | #pragma once 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #define PYBIND11_MODULE 17 | #ifdef PYBIND11_MODULE 18 | #include 19 | #include 20 | #include 21 | #endif 22 | struct InputParameters 23 | { 24 | InputParameters(int cvl = 20, int na = 13, int npa = 5, int cs = 256, int obsr = 5, int ar = 5, int gs = 64, bool sc = false) : cost2go_value_limit(cvl), 25 | num_agents(na), 26 | num_previous_actions(npa), 27 | context_size(cs), 28 | obs_radius(obsr), 29 | agents_radius(ar), 30 | grid_step(gs), 31 | save_cost2go(sc) {} 32 | int cost2go_value_limit; 33 | int num_agents; 34 | int num_previous_actions; 35 | int context_size; 36 | int obs_radius; 37 | int agents_radius; 38 | int grid_step; 39 | bool save_cost2go; 40 | }; 41 | 42 | struct HashPair 43 | { 44 | uint64_t operator()(const std::pair& p) const { 45 | return (uint64_t(p.first) << 32) | uint64_t(p.second); 46 | } 47 | }; 48 | 49 | struct AgentsInfo 50 | { 51 | AgentsInfo(std::pair rp, std::pair rg, std::deque pa, std::string na) : relative_pos(rp), relative_goal(rg), previous_actions(pa), next_action(na) {} 52 | AgentsInfo() {} 53 | std::pair relative_pos; 54 | std::pair relative_goal; 55 | std::deque previous_actions; 56 | std::string next_action; 57 | }; 58 | 59 | struct Agent 60 | { 61 | std::pair pos; 62 | std::pair goal; 63 | std::deque action_history; 64 | std::string next_action; 65 | }; 66 | 67 | struct Cost2GoPartial 68 | { 69 | std::pair goal; 70 | int left_border; 71 | int right_border; 72 | int top_border; 73 | int bottom_border; 74 | std::vector> cost2go; 75 | Cost2GoPartial(const std::pair &goal = std::make_pair(-1, -1), 76 | int left_border = -1, 77 | int right_border = -1, 78 | int top_border = -1, 79 | int bottom_border = -1) : 80 | goal(goal), left_border(left_border), right_border(right_border), top_border(top_border), bottom_border(bottom_border) 81 | {} 82 | }; 83 | 84 | class Encoder 85 | { 86 | public: 87 | InputParameters cfg; 88 | std::vector coord_range; 89 | std::vector actions_range; 90 | std::vector next_action_range; 91 | std::unordered_map str_vocab; 92 | std::unordered_map int_vocab; 93 | std::unordered_map inverse_int_vocab; 94 | std::unordered_map inverse_str_vocab; 95 | Encoder(const InputParameters &cfg); 96 | std::vector encode(const std::vector &agents, const std::vector> &cost2go); 97 | }; 98 | 99 | class ObservationGenerator 100 | { 101 | public: 102 | std::vector agents; 103 | InputParameters cfg; 104 | Encoder encoder; 105 | std::vector> agents_locations; 106 | std::vector>> cost2go_obs_buffer; // Buffer for each agent 107 | std::vector> grid; 108 | std::vector> components; 109 | std::vector cost2go_partials; 110 | std::vector> precomputed_cost2go; 111 | std::unordered_map, int, HashPair> precomputed_cells_map; 112 | ObservationGenerator(const std::vector> &grid, const InputParameters &cfg) 113 | : grid(grid), cfg(cfg), encoder(cfg) 114 | { 115 | omp_set_num_threads(1); 116 | agents_locations = std::vector>(grid.size(), std::vector(grid[0].size(), -1)); 117 | mark_components(); 118 | precompute_cost2go(); 119 | } 120 | ~ObservationGenerator() {} 121 | void mark_components(); 122 | void compute_cost2go_partial(int agent_idx); 123 | void generate_cost2go_obs(int agent_idx, bool only_obstacles, std::vector> &buffer); 124 | int get_distance(int agent_idx, const std::pair &pos); 125 | void precompute_cost2go(); 126 | std::pair>, std::vector>> get_goal_border_and_cost2go(const std::pair &goal); 127 | std::vector> get_cells_on_border(const std::pair ¢er); 128 | void create_agents(const std::vector> &positions, const std::vector> &goals); 129 | void update_next_action(int agent_idx); 130 | void update_agents(const std::vector> &positions, const std::vector> &goals, const std::vector &actions); 131 | std::vector get_agents_info(int agent_idx); 132 | std::vector> generate_observations(); 133 | }; -------------------------------------------------------------------------------- /utils/svg_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | from copy import deepcopy 3 | from itertools import cycle 4 | 5 | import numpy as np 6 | from pogema import AnimationConfig, GridConfig 7 | from pogema.svg_animation.animation_drawer import SvgSettings, GridHolder, AnimationDrawer, Drawing 8 | from pogema.wrappers.persistence import AgentState 9 | 10 | 11 | def create_multi_animation(obstacles, histories: list[list[list[AgentState]]], grid_config: GridConfig, 12 | name='render.svg', 13 | animation_config: AnimationConfig = AnimationConfig()): 14 | working_radius = grid_config.obs_radius - 1 15 | wr = working_radius 16 | cut_obstacles = np.concatenate([obstacles for _ in range(len(histories))], axis=1) 17 | cut_obstacles = cut_obstacles[wr:-wr, wr:-wr] 18 | 19 | global_num_agents = sum([len(x) for x in histories]) 20 | history = [] 21 | offset_x = obstacles.shape[1] 22 | current_offset = 0 23 | for data in histories: 24 | history += get_moved_history(data, dy=current_offset, dx=0) 25 | current_offset += offset_x 26 | 27 | svg_settings = SvgSettings(time_scale=0.4) 28 | 29 | global_idx = 0 30 | agents_colors = {} 31 | for num_agents in [len(x) for x in histories]: 32 | colors_cycle = cycle(svg_settings.colors) 33 | cur_colors = {index + global_idx: next(colors_cycle) for index in range(num_agents)} 34 | agents_colors = {**agents_colors, **cur_colors} 35 | 36 | global_idx += num_agents 37 | 38 | episode_sizes = [len(q[0]) for q in histories] 39 | episode_length = max(episode_sizes) 40 | for agent_idx in range(global_num_agents): 41 | while len(history[agent_idx]) <= episode_length: 42 | q = history[agent_idx][-1] 43 | inactive = AgentState(q.x, q.y, q.tx, q.ty, q.step, False) 44 | history[agent_idx].append(inactive) 45 | 46 | grid_holder = GridHolder( 47 | width=len(cut_obstacles), height=len(cut_obstacles[0]), 48 | obstacles=cut_obstacles, 49 | episode_length=episode_length, 50 | history=history, 51 | obs_radius=grid_config.obs_radius, 52 | on_target=grid_config.on_target, 53 | colors=agents_colors, 54 | config=animation_config, 55 | svg_settings=svg_settings 56 | ) 57 | 58 | animation = CustomAnimationDrawer().create_animation(grid_holder) 59 | with open(name, "w") as f: 60 | f.write(animation.render()) 61 | 62 | 63 | def get_moved_history(history: list[list[AgentState]], dx=0, dy=0): 64 | results = [] 65 | for agents in history: 66 | result_for_agent = [] 67 | for state in agents: 68 | moved_state = AgentState(state.x + dx, state.y + dy, state.tx + dx, state.ty + dy, state.step, state.active) 69 | result_for_agent.append(moved_state) 70 | results.append(result_for_agent) 71 | return results 72 | 73 | 74 | def cut_history(history, start, finish): 75 | history = deepcopy(history) 76 | for idx, agents_history in enumerate(history): 77 | history[idx] = agents_history[start:finish] 78 | return history 79 | 80 | 81 | class CustomAnimationDrawer(AnimationDrawer): 82 | def create_animation(self, grid_holder: GridHolder): 83 | gh = grid_holder 84 | render_width = gh.height * gh.svg_settings.scale_size + gh.svg_settings.scale_size 85 | render_height = gh.width * gh.svg_settings.scale_size + gh.svg_settings.scale_size 86 | drawing = CustomDrawing(width=render_width, height=render_height, svg_settings=SvgSettings()) 87 | obstacles = self.create_obstacles(gh) 88 | 89 | agents = [] 90 | targets = [] 91 | 92 | if gh.config.show_agents: 93 | agents = self.create_agents(gh) 94 | targets = self.create_targets(gh) 95 | 96 | if not gh.config.static: 97 | self.animate_agents(agents, gh) 98 | self.animate_targets(targets, gh) 99 | if gh.config.show_grid_lines: 100 | grid_lines = self.create_grid_lines(gh, render_width, render_height) 101 | for line in grid_lines: 102 | drawing.add_element(line) 103 | for obj in [*obstacles, *agents, *targets]: 104 | drawing.add_element(obj) 105 | 106 | if gh.config.egocentric_idx is not None: 107 | field_of_view = self.create_field_of_view(grid_holder=gh) 108 | if not gh.config.static: 109 | self.animate_obstacles(obstacles=obstacles, grid_holder=gh) 110 | self.animate_field_of_view(field_of_view, gh) 111 | drawing.add_element(field_of_view) 112 | 113 | return drawing 114 | 115 | 116 | class CustomDrawing(Drawing): 117 | 118 | def __init__(self, height, width, svg_settings): 119 | super().__init__(height, width, svg_settings) 120 | 121 | def render(self): 122 | scale = max(self.height, self.width) / 1024 123 | scaled_width = math.ceil(self.width / scale) 124 | scaled_height = math.ceil(self.height / scale) 125 | 126 | dx, dy = self.origin 127 | view_box = (dx, dy - self.height, self.width, self.height) 128 | 129 | svg_header = f''' 130 | ''' 132 | 133 | definitions = f''' 134 | 135 | 140 | ''' 141 | 142 | elements_svg = [svg_header, '', definitions, '\n'] 143 | elements_svg.extend(element.render() for element in self.elements) 144 | elements_svg.append('') 145 | return "\n".join(elements_svg) 146 | -------------------------------------------------------------------------------- /eval_configs/02-mazes/02-mazes.yaml: -------------------------------------------------------------------------------- 1 | environment: 2 | name: Environment 3 | with_animation: False 4 | on_target: nothing 5 | max_episode_steps: 128 6 | observation_type: MAPF 7 | collision_system: soft 8 | seed: 0 9 | num_agents: 10 | grid_search: [8, 16, 24, 32, 48, 64] 11 | map_name: 12 | grid_search: 13 | [ 14 | validation-mazes-seed-000, 15 | validation-mazes-seed-001, 16 | validation-mazes-seed-002, 17 | validation-mazes-seed-003, 18 | validation-mazes-seed-004, 19 | validation-mazes-seed-005, 20 | validation-mazes-seed-006, 21 | validation-mazes-seed-007, 22 | validation-mazes-seed-008, 23 | validation-mazes-seed-009, 24 | validation-mazes-seed-010, 25 | validation-mazes-seed-011, 26 | validation-mazes-seed-012, 27 | validation-mazes-seed-013, 28 | validation-mazes-seed-014, 29 | validation-mazes-seed-015, 30 | validation-mazes-seed-016, 31 | validation-mazes-seed-017, 32 | validation-mazes-seed-018, 33 | validation-mazes-seed-019, 34 | validation-mazes-seed-020, 35 | validation-mazes-seed-021, 36 | validation-mazes-seed-022, 37 | validation-mazes-seed-023, 38 | validation-mazes-seed-024, 39 | validation-mazes-seed-025, 40 | validation-mazes-seed-026, 41 | validation-mazes-seed-027, 42 | validation-mazes-seed-028, 43 | validation-mazes-seed-029, 44 | validation-mazes-seed-030, 45 | validation-mazes-seed-031, 46 | validation-mazes-seed-032, 47 | validation-mazes-seed-033, 48 | validation-mazes-seed-034, 49 | validation-mazes-seed-035, 50 | validation-mazes-seed-036, 51 | validation-mazes-seed-037, 52 | validation-mazes-seed-038, 53 | validation-mazes-seed-039, 54 | validation-mazes-seed-040, 55 | validation-mazes-seed-041, 56 | validation-mazes-seed-042, 57 | validation-mazes-seed-043, 58 | validation-mazes-seed-044, 59 | validation-mazes-seed-045, 60 | validation-mazes-seed-046, 61 | validation-mazes-seed-047, 62 | validation-mazes-seed-048, 63 | validation-mazes-seed-049, 64 | validation-mazes-seed-050, 65 | validation-mazes-seed-051, 66 | validation-mazes-seed-052, 67 | validation-mazes-seed-053, 68 | validation-mazes-seed-054, 69 | validation-mazes-seed-055, 70 | validation-mazes-seed-056, 71 | validation-mazes-seed-057, 72 | validation-mazes-seed-058, 73 | validation-mazes-seed-059, 74 | validation-mazes-seed-060, 75 | validation-mazes-seed-061, 76 | validation-mazes-seed-062, 77 | validation-mazes-seed-063, 78 | validation-mazes-seed-064, 79 | validation-mazes-seed-065, 80 | validation-mazes-seed-066, 81 | validation-mazes-seed-067, 82 | validation-mazes-seed-068, 83 | validation-mazes-seed-069, 84 | validation-mazes-seed-070, 85 | validation-mazes-seed-071, 86 | validation-mazes-seed-072, 87 | validation-mazes-seed-073, 88 | validation-mazes-seed-074, 89 | validation-mazes-seed-075, 90 | validation-mazes-seed-076, 91 | validation-mazes-seed-077, 92 | validation-mazes-seed-078, 93 | validation-mazes-seed-079, 94 | validation-mazes-seed-080, 95 | validation-mazes-seed-081, 96 | validation-mazes-seed-082, 97 | validation-mazes-seed-083, 98 | validation-mazes-seed-084, 99 | validation-mazes-seed-085, 100 | validation-mazes-seed-086, 101 | validation-mazes-seed-087, 102 | validation-mazes-seed-088, 103 | validation-mazes-seed-089, 104 | validation-mazes-seed-090, 105 | validation-mazes-seed-091, 106 | validation-mazes-seed-092, 107 | validation-mazes-seed-093, 108 | validation-mazes-seed-094, 109 | validation-mazes-seed-095, 110 | validation-mazes-seed-096, 111 | validation-mazes-seed-097, 112 | validation-mazes-seed-098, 113 | validation-mazes-seed-099, 114 | validation-mazes-seed-100, 115 | validation-mazes-seed-101, 116 | validation-mazes-seed-102, 117 | validation-mazes-seed-103, 118 | validation-mazes-seed-104, 119 | validation-mazes-seed-105, 120 | validation-mazes-seed-106, 121 | validation-mazes-seed-107, 122 | validation-mazes-seed-108, 123 | validation-mazes-seed-109, 124 | validation-mazes-seed-110, 125 | validation-mazes-seed-111, 126 | validation-mazes-seed-112, 127 | validation-mazes-seed-113, 128 | validation-mazes-seed-114, 129 | validation-mazes-seed-115, 130 | validation-mazes-seed-116, 131 | validation-mazes-seed-117, 132 | validation-mazes-seed-118, 133 | validation-mazes-seed-119, 134 | validation-mazes-seed-120, 135 | validation-mazes-seed-121, 136 | validation-mazes-seed-122, 137 | validation-mazes-seed-123, 138 | validation-mazes-seed-124, 139 | validation-mazes-seed-125, 140 | validation-mazes-seed-126, 141 | validation-mazes-seed-127, 142 | ] 143 | 144 | algorithms: 145 | MAPF-GPT-2M: 146 | name: MAPF-GPT 147 | parallel_backend: balanced_dask 148 | num_process: 4 149 | path_to_weights: weights/model-2M.pt 150 | MAPF-GPT-DDG-2M: 151 | name: MAPF-GPT 152 | parallel_backend: balanced_dask 153 | num_process: 4 154 | path_to_weights: weights/model-2M-DDG.pt 155 | 156 | results_views: 157 | TabularView1: 158 | type: tabular 159 | drop_keys: [seed, map_name] 160 | print_results: True 161 | # 02-mazes-SoC: 162 | # name: Mazes $20\times20$ 163 | # type: plot 164 | # x: num_agents 165 | # y: SoC 166 | # width: 2.5 167 | # height: 2.5 168 | # line_width: 2 169 | # use_log_scale_x: True 170 | # legend_font_size: 8 171 | # font_size: 8 172 | # 173 | # 02-mazes-CSR: 174 | # name: Mazes $20\times20$ 175 | # type: plot 176 | # x: num_agents 177 | # y: CSR 178 | # width: 2.5 179 | # height: 2.5 180 | # line_width: 2 181 | # use_log_scale_x: True 182 | # legend_font_size: 8 183 | # font_size: 8 184 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Advancing Learnable Multi-Agent Pathfinding Solvers with Active Fine-Tuning 2 | 3 | 4 | 5 |
6 | 7 | --- 8 | [![arXiv](https://img.shields.io/badge/arXiv-2506.23793-b31b1b.svg)](https://arxiv.org/abs/2506.23793) 9 | [![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](https://github.com/CognitiveAISystems/MAPF-GPT-DDG/blob/main/LICENSE) 10 | [![Hugging Face](https://img.shields.io/badge/Weights-MAPF--GPT-blue?logo=huggingface)](https://huggingface.co/aandreychuk/MAPF-GPT/tree/main) 11 | [![Hugging Face](https://img.shields.io/badge/Dataset-MAPF--GPT-blue?logo=huggingface)](https://huggingface.co/datasets/aandreychuk/MAPF-GPT/tree/main) 12 | [![Hugging Face](https://img.shields.io/badge/Dataset-MAPF--GPT--DDG-blue?logo=huggingface)](https://huggingface.co/datasets/aandreychuk/MAPF-GPT-DDG/tree/main) 13 |
14 | 15 | The repository is based on the repository of original [MAPF-GPT](https://github.com/CognitiveAISystems/MAPF-GPT). It consists of the following crucial parts: 16 | 17 | - `example.py` - an example of code to run the MAPF-GPT-DDG approach. 18 | - `benchmark.py` - a script that launches the evaluation of the MAPF-GPT-DDG model on the POGEMA benchmark set of maps. 19 | - `download_dataset.py` - a script that downloads 1B training dataset and 1M validation one. The dataset is uploaded to Hugging Face. 20 | - `train.py` - a script that launches the training of the MAPF-GPT-DDG model. 21 | - `eval_configs` - a folder that contains configs from the POGEMA benchmark. Required by the `benchmark.py` script. 22 | - `ckpt_configs` - a folder that contains configs used for validation of intermidiate checkpoints used in ablation study. 23 | 24 | ## Installation 25 | 26 | It's recommended to utilize Docker to build the environment compatible with MAPF-GPT code. The `docker` folder contains both `Dockerfile` and `requirements.txt` files to successfully build an appropriate container. 27 | 28 | ``` 29 | cd docker & sh build.sh 30 | ``` 31 | 32 | ## Running an example 33 | 34 | To test MAPF-GPT-DDG, you can simply run the `example.py` script. By default, it uses the MAPF-GPT-DDG-2M model, but this can be adjusted. 35 | Additionally, there is a list of optional arguments: `--map_name`, `--device`, `--num_agents`, `--seed`, `--max_episode_steps`, `--model`, `--show_map_names`. The `--map_name` argument allows you to select a map from those available in the `eval_configs` folder. To list all available maps, you can provide the `--show_map_names` option or look inside `eval_config` folder. Here are a few examples from each set: `validation-random-seed-000`, `validation-mazes-seed-000`, `wfi_warehouse`, `Berlin_1_256_00`, `puzzle-00`. 36 | 37 | It is recommended to use GPU-accelerated setups; however, smaller models can be run on a CPU. For Apple Silicon machines, it's recommended to use `--device mps`, which significantly speeds up inference. 38 | By default MAPF-GPT-DDG-2M model is used. However, the code is compitable with original MAPF-GPT weights as well, as the architecture of the model and the input/output stayed unchanged. 39 | Thus, you can additionally choose from `2M`, `6M`, and `85M` model sizes of the original MAPF-GPT, which will be automatically downloaded from Hugging Face. Be aware that the 85M model requires 1 GB of disk space. 40 | 41 | 42 | Here is an example of running MAPF-GPT-2M on a maze map: 43 | ``` 44 | python3 example.py --map_name validation-mazes-seed-000 --model 2M-DDG --num_agents 32 45 | ``` 46 | 47 | 48 | Here is an example of running MAPF-GPT-85M on `wfi_warehouse` map: 49 | ``` 50 | python3 example.py --map_name wfi_warehouse --model 85M --num_agents 192 51 | ``` 52 | 53 | In addition to statistics about SoC, success rate, etc., you will also get an SVG file that animates the solution found by MAPF-GPT, which will be saved to the `svg/` folder. 54 | 55 | 56 | ## Running evaluation 57 | 58 | You can run the `benchmark.py` script, which will run both MAPF-GPT-2M and MAPF-GPT-DDG-2M models on all the scenarios from the POGEMA benchmark. 59 | You can also run the MAPF-GPT-85M model by setting `path_to_weights` to `hf_weights/model-85M.pt`. The weights for all models will be downloaded automatically. 60 | 61 | ``` 62 | python3 benchmark.py 63 | ``` 64 | 65 | The results will be stored in the `eval_configs` folder near the corresponding configs. They can also be logged into wandb. The tables with average success rates will be displayed directly in the console. 66 | You can also find the results (raw data and scripts to build plots) presented in the paper in the [metrics](https://github.com/Cognitive-AI-Systems/MAPF-GPT-DDG/tree/metrics) branch. 67 | 68 | ## Dataset 69 | 70 | To train MAPF-GPT-DDG, we utilized the 1B dataset generated to train MAPF-GPT. It can be downloaded from Hugging Face via `download_dataset.py` script. During training phase we generate additional data, detecting hard cases, that cannot be efficiently solved by MAPF-GPT. Solving these hard cases by expert and adding new observation-action pairs to the training dataset allows to boost the performance of MAPF-GPT. In contrast to 1B dataset, it cannot be preliminary generated/downloaded as it requires to run MAPF-GPT on the instances to detect hard cases for the current checkpoint. More details about the generation of additional data and its usage during training are provided in the paper. 71 | 72 | 73 | ## Running training of MAPF-GPT with DDG 74 | 75 | To train MAPF-GPT from scratch or fine-tune the existing weights, you can use the `train.py` script. By providing it a config, you can adjust the parameters of the model and training setup. The script utilizes DDP, which allows training the model on multiple GPUs simultaneously. By adjusting the `nproc_per_node` value, you can choose the number of GPUs that are used for training. 76 | Adjusting the value of `dagger_type` parameter you can choose the way of how the model is trained: 77 | - `standard` - default MAPF-GPT training setup is utilized, without any additional data collection phases 78 | - `ddg` - during training an addiitonal dataset will be collected following the logic of the proposed DDG method 79 | - `dagger` - during training an addiitonal dataset will be collected following the logic of the classic DAgger method 80 | ``` 81 | torchrun --standalone --nproc_per_node=1 train.py gpt/config-2M-DDG.py 82 | ``` 83 | 84 | ## Citation: 85 | 86 | ```bibtex 87 | @article{andreychuk2025advancing, 88 | title={Advancing Learnable Multi-Agent Pathfinding Solvers with Active Fine-Tuning}, 89 | author={Anton Andreychuk and Konstantin Yakovlev and Aleksandr Panov and Alexey Skrynnik}, 90 | journal={arXiv preprint arXiv:2506.23793}, 91 | year={2025}, 92 | url={https://arxiv.org/abs/2506.23793} 93 | } 94 | ``` 95 | -------------------------------------------------------------------------------- /eval_configs/01-random/01-random.yaml: -------------------------------------------------------------------------------- 1 | environment: 2 | name: Environment 3 | with_animation: False 4 | on_target: nothing 5 | max_episode_steps: 128 6 | observation_type: MAPF 7 | collision_system: soft 8 | seed: 0 9 | num_agents: 10 | grid_search: [8, 16, 24, 32, 48, 64] 11 | map_name: 12 | grid_search: 13 | [ 14 | validation-random-seed-000, 15 | validation-random-seed-001, 16 | validation-random-seed-002, 17 | validation-random-seed-003, 18 | validation-random-seed-004, 19 | validation-random-seed-005, 20 | validation-random-seed-006, 21 | validation-random-seed-007, 22 | validation-random-seed-008, 23 | validation-random-seed-009, 24 | validation-random-seed-010, 25 | validation-random-seed-011, 26 | validation-random-seed-012, 27 | validation-random-seed-013, 28 | validation-random-seed-014, 29 | validation-random-seed-015, 30 | validation-random-seed-016, 31 | validation-random-seed-017, 32 | validation-random-seed-018, 33 | validation-random-seed-019, 34 | validation-random-seed-020, 35 | validation-random-seed-021, 36 | validation-random-seed-022, 37 | validation-random-seed-023, 38 | validation-random-seed-024, 39 | validation-random-seed-025, 40 | validation-random-seed-026, 41 | validation-random-seed-027, 42 | validation-random-seed-028, 43 | validation-random-seed-029, 44 | validation-random-seed-030, 45 | validation-random-seed-031, 46 | validation-random-seed-032, 47 | validation-random-seed-033, 48 | validation-random-seed-034, 49 | validation-random-seed-035, 50 | validation-random-seed-036, 51 | validation-random-seed-037, 52 | validation-random-seed-038, 53 | validation-random-seed-039, 54 | validation-random-seed-040, 55 | validation-random-seed-041, 56 | validation-random-seed-042, 57 | validation-random-seed-043, 58 | validation-random-seed-044, 59 | validation-random-seed-045, 60 | validation-random-seed-046, 61 | validation-random-seed-047, 62 | validation-random-seed-048, 63 | validation-random-seed-049, 64 | validation-random-seed-050, 65 | validation-random-seed-051, 66 | validation-random-seed-052, 67 | validation-random-seed-053, 68 | validation-random-seed-054, 69 | validation-random-seed-055, 70 | validation-random-seed-056, 71 | validation-random-seed-057, 72 | validation-random-seed-058, 73 | validation-random-seed-059, 74 | validation-random-seed-060, 75 | validation-random-seed-061, 76 | validation-random-seed-062, 77 | validation-random-seed-063, 78 | validation-random-seed-064, 79 | validation-random-seed-065, 80 | validation-random-seed-066, 81 | validation-random-seed-067, 82 | validation-random-seed-068, 83 | validation-random-seed-069, 84 | validation-random-seed-070, 85 | validation-random-seed-071, 86 | validation-random-seed-072, 87 | validation-random-seed-073, 88 | validation-random-seed-074, 89 | validation-random-seed-075, 90 | validation-random-seed-076, 91 | validation-random-seed-077, 92 | validation-random-seed-078, 93 | validation-random-seed-079, 94 | validation-random-seed-080, 95 | validation-random-seed-081, 96 | validation-random-seed-082, 97 | validation-random-seed-083, 98 | validation-random-seed-084, 99 | validation-random-seed-085, 100 | validation-random-seed-086, 101 | validation-random-seed-087, 102 | validation-random-seed-088, 103 | validation-random-seed-089, 104 | validation-random-seed-090, 105 | validation-random-seed-091, 106 | validation-random-seed-092, 107 | validation-random-seed-093, 108 | validation-random-seed-094, 109 | validation-random-seed-095, 110 | validation-random-seed-096, 111 | validation-random-seed-097, 112 | validation-random-seed-098, 113 | validation-random-seed-099, 114 | validation-random-seed-100, 115 | validation-random-seed-101, 116 | validation-random-seed-102, 117 | validation-random-seed-103, 118 | validation-random-seed-104, 119 | validation-random-seed-105, 120 | validation-random-seed-106, 121 | validation-random-seed-107, 122 | validation-random-seed-108, 123 | validation-random-seed-109, 124 | validation-random-seed-110, 125 | validation-random-seed-111, 126 | validation-random-seed-112, 127 | validation-random-seed-113, 128 | validation-random-seed-114, 129 | validation-random-seed-115, 130 | validation-random-seed-116, 131 | validation-random-seed-117, 132 | validation-random-seed-118, 133 | validation-random-seed-119, 134 | validation-random-seed-120, 135 | validation-random-seed-121, 136 | validation-random-seed-122, 137 | validation-random-seed-123, 138 | validation-random-seed-124, 139 | validation-random-seed-125, 140 | validation-random-seed-126, 141 | validation-random-seed-127, 142 | ] 143 | 144 | algorithms: 145 | MAPF-GPT-2M: 146 | name: MAPF-GPT 147 | parallel_backend: balanced_dask 148 | num_process: 4 149 | path_to_weights: weights/model-2M.pt 150 | MAPF-GPT-DDG-2M: 151 | name: MAPF-GPT 152 | parallel_backend: balanced_dask 153 | num_process: 4 154 | path_to_weights: weights/model-2M-DDG.pt 155 | 156 | results_views: 157 | TabularView1: 158 | type: tabular 159 | drop_keys: [seed, map_name] 160 | print_results: True 161 | 162 | 01-random-mazes-SoC: 163 | type: plot 164 | x: num_agents 165 | y: SoC 166 | width: 3.0 167 | height: 2.5 168 | line_width: 2 169 | use_log_scale_x: True 170 | legend_font_size: 8 171 | font_size: 8 172 | name: Random / Mazes 173 | ticks: [8, 16, 24, 32, 48, 64] 174 | 175 | 01-random-mazes-CSR: 176 | type: plot 177 | x: num_agents 178 | y: CSR 179 | width: 3.0 180 | height: 2.5 181 | line_width: 2 182 | use_log_scale_x: True 183 | legend_font_size: 8 184 | font_size: 8 185 | name: Random / Mazes 186 | ticks: [8, 16, 24, 32, 48, 64] 187 | # 188 | # TabularAll: 189 | # type: tabular 190 | # drop_keys: [ seed ] 191 | # print_results: True 192 | -------------------------------------------------------------------------------- /tokenizer/encoder.cpp: -------------------------------------------------------------------------------- 1 | // cppimport 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | namespace py = pybind11; 12 | 13 | struct InputParameters { 14 | InputParameters(int cvl, int na, int npa, int cs): 15 | cost2go_value_limit(cvl), num_agents(na), num_previous_actions(npa), context_size(cs) {} 16 | int cost2go_value_limit; 17 | int num_agents; 18 | int num_previous_actions; 19 | int context_size = 256; 20 | }; 21 | 22 | struct AgentsInfo 23 | { 24 | AgentsInfo(std::pair rp, std::pair rg, std::vector pa, std::string na): 25 | relative_pos(rp), relative_goal(rg), previous_actions(pa), next_action(na) {} 26 | std::pair relative_pos; 27 | std::pair relative_goal; 28 | std::vector previous_actions; 29 | std::string next_action; 30 | }; 31 | 32 | std::string to_repr(const AgentsInfo& self) { 33 | std::ostringstream oss; 34 | oss << "(relative_pos=(" << self.relative_pos.first << ", " << self.relative_pos.second 35 | << "), relative_goal=(" << self.relative_goal.first << ", " << self.relative_goal.second 36 | << "), previous_actions=["; 37 | 38 | for (size_t i = 0; i < self.previous_actions.size(); ++i) { 39 | oss << self.previous_actions[i]; 40 | if (i < self.previous_actions.size() - 1) { 41 | oss << ", "; 42 | } 43 | } 44 | oss << "], next_action=" << self.next_action << ")"; 45 | 46 | return oss.str(); 47 | } 48 | 49 | class Encoder { 50 | public: 51 | Encoder(const InputParameters& cfg) 52 | : cfg(cfg) { 53 | for (int i = -cfg.cost2go_value_limit; i <= cfg.cost2go_value_limit; ++i) { 54 | coord_range.push_back(i); 55 | } 56 | coord_range.push_back(-cfg.cost2go_value_limit * 4); 57 | coord_range.push_back(-cfg.cost2go_value_limit * 2); 58 | coord_range.push_back(cfg.cost2go_value_limit * 2); 59 | 60 | actions_range = {'n', 'w', 'u', 'd', 'l', 'r'}; 61 | for (int i = 0; i < 16; ++i) { 62 | std::stringstream ss; 63 | ss << std::bitset<4>(i); 64 | next_action_range.push_back(ss.str()); 65 | } 66 | 67 | int idx = 0; 68 | for (auto& token : coord_range) { 69 | int_vocab[token] = idx++; 70 | } 71 | for (auto& token : actions_range) { 72 | str_vocab[std::string(1, token)] = idx++; 73 | } 74 | for (auto& token : next_action_range) { 75 | str_vocab[token] = idx++; 76 | } 77 | str_vocab["!"] = idx; 78 | 79 | for (auto& [token, idx] : int_vocab) { 80 | inverse_int_vocab[idx] = token; 81 | } 82 | for (auto& [token, idx] : str_vocab) { 83 | inverse_str_vocab[idx] = token; 84 | } 85 | } 86 | 87 | std::vector encode(const std::vector& agents, const std::vector> &cost2go) { 88 | std::vector agents_indices; 89 | for (const auto& agent : agents) { 90 | std::vector coord_indices = { 91 | int_vocab.at(agent.relative_pos.first), 92 | int_vocab.at(agent.relative_pos.second), 93 | int_vocab.at(agent.relative_goal.first), 94 | int_vocab.at(agent.relative_goal.second) 95 | }; 96 | 97 | std::vector actions_indices; 98 | for (const auto& action : agent.previous_actions) { 99 | actions_indices.push_back(str_vocab.at(action)); 100 | } 101 | std::vector next_action_indices = {str_vocab.at(agent.next_action)}; 102 | 103 | agents_indices.insert(agents_indices.end(), coord_indices.begin(), coord_indices.end()); 104 | agents_indices.insert(agents_indices.end(), actions_indices.begin(), actions_indices.end()); 105 | agents_indices.insert(agents_indices.end(), next_action_indices.begin(), next_action_indices.end()); 106 | } 107 | 108 | if (agents.size() < cfg.num_agents) 109 | agents_indices.insert(agents_indices.end(), (cfg.num_agents - agents.size()) * (5 + cfg.num_previous_actions), str_vocab["!"]); 110 | 111 | std::vector cost2go_indices; 112 | for (const auto& row : cost2go) 113 | for (int value : row) 114 | cost2go_indices.push_back(int_vocab.at(value)); 115 | 116 | std::vector result; 117 | result.insert(result.end(), cost2go_indices.begin(), cost2go_indices.end()); 118 | result.insert(result.end(), agents_indices.begin(), agents_indices.end()); 119 | while(result.size() < 256) 120 | result.push_back(str_vocab["!"]); 121 | return result; 122 | } 123 | 124 | private: 125 | InputParameters cfg; 126 | std::vector coord_range; 127 | std::vector actions_range; 128 | std::vector next_action_range; 129 | std::unordered_map str_vocab; 130 | std::unordered_map int_vocab; 131 | std::unordered_map inverse_int_vocab; 132 | std::unordered_map inverse_str_vocab; 133 | 134 | std::string join(const std::vector& vec, const std::string& delim) { 135 | std::ostringstream res; 136 | copy(vec.begin(), vec.end(), std::ostream_iterator(res, delim.c_str())); 137 | return res.str().substr(0, res.str().length() - delim.length()); 138 | } 139 | }; 140 | 141 | PYBIND11_MODULE(encoder, m) { 142 | py::class_(m, "InputParameters") 143 | .def(py::init()) 144 | .def_readwrite("cost2go_value_limit", &InputParameters::cost2go_value_limit) 145 | .def_readwrite("num_agents", &InputParameters::num_agents) 146 | .def_readwrite("num_previous_actions", &InputParameters::num_previous_actions) 147 | ; 148 | 149 | py::class_(m, "AgentsInfo") 150 | .def(py::init, std::pair, std::vector, std::string>()) 151 | .def_readwrite("relative_pos", &AgentsInfo::relative_pos) 152 | .def_readwrite("relative_goal", &AgentsInfo::relative_goal) 153 | .def_readwrite("previous_actions", &AgentsInfo::previous_actions) 154 | .def_readwrite("next_action", &AgentsInfo::next_action) 155 | .def("__repr__", &to_repr); 156 | ; 157 | 158 | py::class_(m, "Encoder") 159 | .def(py::init()) 160 | .def("encode", &Encoder::encode) 161 | ; 162 | } 163 | 164 | <% 165 | cfg['extra_compile_args'] = ['-std=c++17'] 166 | setup_pybind11(cfg) 167 | %> 168 | -------------------------------------------------------------------------------- /tokenizer/tokenizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from tokenizer.parameters import InputParameters 5 | 6 | 7 | class Tokenizer: 8 | def __init__(self, cfg: InputParameters) -> None: 9 | self.encoder = Encoder(cfg) 10 | 11 | def __call__(self, obs, return_tensors="pt"): 12 | assert return_tensors == "pt", "Only pt (PyTorch) encoded tensor is supported" 13 | idx = self.encoder.encode(obs) 14 | out = torch.tensor(idx, dtype=torch.int8) 15 | return out 16 | 17 | def encode(self, obs): 18 | idx = self.encoder.encode(obs) 19 | out = np.array(idx, dtype=np.int8) 20 | return out 21 | 22 | def decode(self, idx): 23 | assert idx.ndim == 1 24 | obs = self.encoder.decode(idx.tolist()) 25 | return obs 26 | 27 | 28 | class Encoder: 29 | def __init__(self, cfg: InputParameters): 30 | self.cfg = cfg 31 | self.coord_range = list( 32 | range(-cfg.cost2go_value_limit, cfg.cost2go_value_limit + 1) 33 | ) + [ 34 | -cfg.cost2go_value_limit * 4, 35 | -cfg.cost2go_value_limit * 2, 36 | cfg.cost2go_value_limit * 2, 37 | ] 38 | self.actions_range = ["n", "w", "u", "d", "l", "r"] 39 | self.next_action_range = [format(i, "04b") for i in range(16)] # 0000 to 1111 40 | 41 | self.vocab = { 42 | token: idx 43 | for idx, token in enumerate( 44 | self.coord_range + self.actions_range + self.next_action_range + ["!"] 45 | ) 46 | } # '!' is a trash symbol 47 | self.inverse_vocab = {idx: token for token, idx in self.vocab.items()} 48 | 49 | def encode(self, observation): 50 | agents_indices = [] 51 | 52 | def clamp_value(value, max_abs_value=20): 53 | return max(-max_abs_value, min(max_abs_value, value)) 54 | 55 | for agent in observation["agents"]: 56 | coord_indices = [ 57 | self.vocab[clamp_value(agent["relative_pos"][0])], 58 | self.vocab[clamp_value(agent["relative_pos"][1])], 59 | self.vocab[clamp_value(agent["relative_goal"][0])], 60 | self.vocab[clamp_value(agent["relative_goal"][1])], 61 | ] 62 | actions_indices = [ 63 | self.vocab[action] for action in agent["previous_actions"] 64 | ] 65 | next_action_indices = [self.vocab[agent["next_action"]]] 66 | 67 | agent_obs = coord_indices + actions_indices + next_action_indices 68 | agents_indices.extend(agent_obs) 69 | if len(observation["agents"]) < self.cfg.num_agents: 70 | agents_indices.extend( 71 | [ 72 | self.vocab["!"] 73 | for _ in range( 74 | (self.cfg.num_agents - len(observation["agents"])) 75 | * (5 + self.cfg.num_previous_actions) 76 | ) 77 | ] 78 | ) 79 | cost2go_indices = [ 80 | self.vocab[v] for v in np.array(observation["cost2go"]).flatten() 81 | ] 82 | 83 | result = ( 84 | cost2go_indices 85 | + agents_indices 86 | + [ 87 | self.vocab["!"] 88 | for _ in range( 89 | self.cfg.context_size - len(cost2go_indices) - len(agents_indices) 90 | ) 91 | ] 92 | ) 93 | if any( 94 | [ 95 | self.cfg.mask_actions_history, 96 | self.cfg.mask_cost2go, 97 | self.cfg.mask_goal, 98 | self.cfg.mask_greed_action, 99 | ] 100 | ): 101 | result = self.mask(result) 102 | return result 103 | 104 | def mask(self, input): 105 | cost2go_size = (self.cfg.cost2go_radius * 2 + 1) ** 2 106 | if self.cfg.mask_actions_history: 107 | for i in range(self.cfg.num_agents): 108 | input[ 109 | cost2go_size 110 | + i * (5 + self.cfg.num_previous_actions) 111 | + 4 : cost2go_size 112 | + i * (5 + self.cfg.num_previous_actions) 113 | + 4 114 | + self.cfg.num_previous_actions 115 | ] = [self.vocab["!"] for _ in range(self.cfg.num_previous_actions)] 116 | if self.cfg.mask_cost2go: 117 | traversable_cell = self.vocab[0] 118 | blocked_cell = self.vocab[-self.cfg.cost2go_value_limit * 4] 119 | for i in range(cost2go_size): 120 | if input[i] != blocked_cell: 121 | input[i] = traversable_cell 122 | if self.cfg.mask_goal: 123 | for i in range(self.cfg.num_agents): 124 | input[cost2go_size + i * (5 + self.cfg.num_previous_actions) + 2] = ( 125 | self.vocab["!"] 126 | ) 127 | input[cost2go_size + i * (5 + self.cfg.num_previous_actions) + 3] = ( 128 | self.vocab["!"] 129 | ) 130 | if self.cfg.mask_greed_action: 131 | for i in range(self.cfg.num_agents): 132 | input[ 133 | cost2go_size 134 | + i * (5 + self.cfg.num_previous_actions) 135 | + 4 136 | + self.cfg.num_previous_actions 137 | ] = self.vocab["!"] 138 | return input 139 | 140 | def decode(self, idx): 141 | if any( 142 | [ 143 | self.cfg.mask_actions_history, 144 | self.cfg.mask_cost2go, 145 | self.cfg.mask_goal, 146 | self.cfg.mask_greed_action, 147 | ] 148 | ): 149 | idx = self.mask(idx) 150 | agents_info_size = 4 + self.cfg.num_previous_actions + 1 151 | cost2go_size = (self.cfg.cost2go_radius * 2 + 1) ** 2 152 | agents = [] 153 | for i in range(self.cfg.num_agents): 154 | agent_indices = idx[ 155 | cost2go_size 156 | + i * agents_info_size : cost2go_size 157 | + (i + 1) * agents_info_size 158 | ] 159 | 160 | relative_pos = ( 161 | self.inverse_vocab[agent_indices[0]], 162 | self.inverse_vocab[agent_indices[1]], 163 | ) 164 | relative_goal = ( 165 | self.inverse_vocab[agent_indices[2]], 166 | self.inverse_vocab[agent_indices[3]], 167 | ) 168 | previous_actions = [self.inverse_vocab[a] for a in agent_indices[4:-1]] 169 | next_action = self.inverse_vocab[agent_indices[-1]] 170 | 171 | agent = { 172 | "relative_pos": relative_pos, 173 | "relative_goal": relative_goal, 174 | "previous_actions": previous_actions, 175 | "next_action": next_action, 176 | } 177 | agents.append(agent) 178 | 179 | cost2go_indices = idx[:cost2go_size] 180 | cost2go = [self.inverse_vocab[v] for v in cost2go_indices] 181 | cost2go_size = self.cfg.cost2go_radius * 2 + 1 182 | cost2go = np.array(cost2go).reshape(cost2go_size, cost2go_size) 183 | observation = {"agents": agents, "cost2go": cost2go} 184 | 185 | return observation 186 | -------------------------------------------------------------------------------- /lacam/lacam3/src/pibt.cpp: -------------------------------------------------------------------------------- 1 | #include "../include/pibt.hpp" 2 | 3 | PIBT::PIBT(const Instance *_ins, DistTable *_D, int seed, bool _flg_swap, 4 | Scatter *_scatter) 5 | : ins(_ins), 6 | MT(std::mt19937(seed)), 7 | N(ins->N), 8 | V_size(ins->G->size()), 9 | D(_D), 10 | NO_AGENT(N), 11 | occupied_now(V_size, NO_AGENT), 12 | occupied_next(V_size, NO_AGENT), 13 | C_next(N, std::array()), 14 | tie_breakers(V_size, 0), 15 | flg_swap(_flg_swap), 16 | scatter(_scatter) 17 | { 18 | } 19 | 20 | PIBT::~PIBT() {} 21 | 22 | bool PIBT::set_new_config(const Config &Q_from, Config &Q_to, 23 | const std::vector &order) 24 | { 25 | bool success = true; 26 | // setup cache & constraints check 27 | for (auto i = 0; i < N; ++i) { 28 | // set occupied now 29 | occupied_now[Q_from[i]->id] = i; 30 | 31 | // set occupied next 32 | if (Q_to[i] != nullptr) { 33 | // vertex collision 34 | if (occupied_next[Q_to[i]->id] != NO_AGENT) { 35 | success = false; 36 | break; 37 | } 38 | // swap collision 39 | auto j = occupied_now[Q_to[i]->id]; 40 | if (j != NO_AGENT && j != i && Q_to[j] == Q_from[i]) { 41 | success = false; 42 | break; 43 | } 44 | occupied_next[Q_to[i]->id] = i; 45 | } 46 | } 47 | 48 | if (success) { 49 | for (auto i : order) { 50 | if (Q_to[i] == nullptr && !funcPIBT(i, Q_from, Q_to)) { 51 | success = false; 52 | break; 53 | } 54 | } 55 | } 56 | 57 | // cleanup 58 | for (auto i = 0; i < N; ++i) { 59 | occupied_now[Q_from[i]->id] = NO_AGENT; 60 | if (Q_to[i] != nullptr) occupied_next[Q_to[i]->id] = NO_AGENT; 61 | } 62 | 63 | return success; 64 | } 65 | 66 | bool PIBT::funcPIBT(const int i, const Config &Q_from, Config &Q_to) 67 | { 68 | const auto K = Q_from[i]->neighbor.size(); 69 | 70 | // exploit scatter data 71 | Vertex *prioritized_vertex = nullptr; 72 | if (scatter != nullptr) { 73 | auto itr_s = scatter->scatter_data[i].find(Q_from[i]->id); 74 | if (itr_s != scatter->scatter_data[i].end()) { 75 | prioritized_vertex = itr_s->second; 76 | } 77 | } 78 | 79 | // set C_next 80 | for (size_t k = 0; k < K; ++k) { 81 | auto u = Q_from[i]->neighbor[k]; 82 | C_next[i][k] = u; 83 | tie_breakers[u->id] = get_random_float(MT); // set tie-breaker 84 | } 85 | C_next[i][K] = Q_from[i]; 86 | 87 | // sort, note: K + 1 is sufficient 88 | std::sort(C_next[i].begin(), C_next[i].begin() + K + 1, 89 | [&](Vertex *const v, Vertex *const u) { 90 | if (v == prioritized_vertex) return true; 91 | if (u == prioritized_vertex) return false; 92 | return D->get(i, v) + tie_breakers[v->id] < 93 | D->get(i, u) + tie_breakers[u->id]; 94 | }); 95 | 96 | // emulate swap 97 | auto swap_agent = NO_AGENT; 98 | if (flg_swap) { 99 | swap_agent = is_swap_required_and_possible(i, Q_from, Q_to); 100 | if (swap_agent != NO_AGENT) { 101 | // reverse vertex scoring 102 | std::reverse(C_next[i].begin(), C_next[i].begin() + K + 1); 103 | } 104 | } 105 | 106 | auto swap_operation = [&]() { 107 | if (swap_agent != NO_AGENT && // swap_agent exists 108 | Q_to[swap_agent] == nullptr && // not decided 109 | occupied_next[Q_from[i]->id] == NO_AGENT // free 110 | ) { 111 | // pull swap_agent 112 | occupied_next[Q_from[i]->id] = swap_agent; 113 | Q_to[swap_agent] = Q_from[i]; 114 | } 115 | }; 116 | 117 | // main loop 118 | for (size_t k = 0; k < K + 1; ++k) { 119 | auto u = C_next[i][k]; 120 | 121 | // avoid vertex conflicts 122 | if (occupied_next[u->id] != NO_AGENT) continue; 123 | 124 | const auto j = occupied_now[u->id]; 125 | 126 | // avoid swap conflicts with constraints 127 | if (j != NO_AGENT && Q_to[j] == Q_from[i]) continue; 128 | 129 | // reserve next location 130 | occupied_next[u->id] = i; 131 | Q_to[i] = u; 132 | 133 | // priority inheritance 134 | if (j != NO_AGENT && u != Q_from[i] && Q_to[j] == nullptr && 135 | !funcPIBT(j, Q_from, Q_to)) 136 | continue; 137 | 138 | // success to plan next one step 139 | if (flg_swap && k == 0) swap_operation(); 140 | return true; 141 | } 142 | 143 | // failed to secure node 144 | occupied_next[Q_from[i]->id] = i; 145 | Q_to[i] = Q_from[i]; 146 | return false; 147 | } 148 | 149 | int PIBT::is_swap_required_and_possible(const int i, const Config &Q_from, 150 | Config &Q_to) 151 | { 152 | // agent-j occupying the desired vertex for agent-i 153 | const auto j = occupied_now[C_next[i][0]->id]; 154 | if (j != NO_AGENT && j != i && // j exists 155 | Q_to[j] == nullptr && // j does not decide next location 156 | is_swap_required(i, j, Q_from[i], Q_from[j]) && // swap required 157 | is_swap_possible(Q_from[j], Q_from[i]) // swap possible 158 | ) { 159 | return j; 160 | } 161 | 162 | // for clear operation, c.f., push & swap 163 | if (C_next[i][0] != Q_from[i]) { 164 | for (auto u : Q_from[i]->neighbor) { 165 | const auto k = occupied_now[u->id]; 166 | if (k != NO_AGENT && // k exists 167 | C_next[i][0] != Q_from[k] && // this is for clear operation 168 | is_swap_required(k, i, Q_from[i], 169 | C_next[i][0]) && // emulating from one step ahead 170 | is_swap_possible(C_next[i][0], Q_from[i])) { 171 | return k; 172 | } 173 | } 174 | } 175 | return NO_AGENT; 176 | } 177 | 178 | bool PIBT::is_swap_required(const int pusher, const int puller, 179 | Vertex *v_pusher_origin, Vertex *v_puller_origin) 180 | { 181 | auto v_pusher = v_pusher_origin; 182 | auto v_puller = v_puller_origin; 183 | Vertex *tmp = nullptr; 184 | while (D->get(pusher, v_puller) < D->get(pusher, v_pusher)) { 185 | auto n = v_puller->neighbor.size(); 186 | // remove agents who need not to move 187 | for (auto u : v_puller->neighbor) { 188 | const auto i = occupied_now[u->id]; 189 | if (u == v_pusher || 190 | (u->neighbor.size() == 1 && i != NO_AGENT && ins->goals[i] == u)) { 191 | --n; 192 | } else { 193 | tmp = u; 194 | } 195 | } 196 | if (n >= 2) return false; // able to swap at v_l 197 | if (n <= 0) break; 198 | v_pusher = v_puller; 199 | v_puller = tmp; 200 | } 201 | 202 | return (D->get(puller, v_pusher) < D->get(puller, v_puller)) && 203 | (D->get(pusher, v_pusher) == 0 || 204 | D->get(pusher, v_puller) < D->get(pusher, v_pusher)); 205 | } 206 | 207 | bool PIBT::is_swap_possible(Vertex *v_pusher_origin, Vertex *v_puller_origin) 208 | { 209 | // simulate pull 210 | auto v_pusher = v_pusher_origin; 211 | auto v_puller = v_puller_origin; 212 | Vertex *tmp = nullptr; 213 | while (v_puller != v_pusher_origin) { // avoid loop 214 | auto n = v_puller->neighbor.size(); 215 | for (auto u : v_puller->neighbor) { 216 | const auto i = occupied_now[u->id]; 217 | if (u == v_pusher || 218 | (u->neighbor.size() == 1 && i != NO_AGENT && ins->goals[i] == u)) { 219 | --n; 220 | } else { 221 | tmp = u; 222 | } 223 | } 224 | if (n >= 2) return true; // able to swap at v_next 225 | if (n <= 0) return false; 226 | v_pusher = v_puller; 227 | v_puller = tmp; 228 | } 229 | return false; 230 | } 231 | -------------------------------------------------------------------------------- /finetuning/delta_data_generator.py: -------------------------------------------------------------------------------- 1 | from gpt.inference import MAPFGPTInference, MAPFGPTInferenceConfig 2 | 3 | from copy import deepcopy 4 | from pathlib import Path 5 | 6 | from pogema import AnimationMonitor, AnimationConfig 7 | 8 | from pogema_toolbox.run_episode import run_episode 9 | from pogema_toolbox.registry import ToolboxRegistry 10 | from pydantic import BaseModel 11 | 12 | from finetuning.filter_data import filter_data 13 | 14 | from utils.data_collection import fill_actions_with_solver 15 | from finetuning.scenario_generators import make_pogema_maze_instance 16 | 17 | from utils.svg_utils import cut_history, create_multi_animation 18 | from utils.wrappers import UnrollWrapper 19 | 20 | from multiprocessing import Pool 21 | from lacam.inference import LacamInference, LacamInferenceConfig 22 | from pogema.wrappers.metrics import RuntimeMetricWrapper 23 | from macro_env import PogemaMacroEnvironment, MAPFGPTObservationWrapper 24 | from gpt.observation_generator import ObservationGenerator, InputParameters 25 | 26 | class FastSolverDeltaConfig(BaseModel): 27 | steps_delta: int = 16 28 | steps_saved: int = 32 29 | save_debug_svg: bool = False 30 | diff_threshold = 3 31 | 32 | 33 | def run_solver(env, unroll_steps, time_limit): 34 | env = deepcopy(env) 35 | solver = LacamInference(LacamInferenceConfig(time_limit=time_limit, timeouts=[time_limit])) 36 | env.set_unroll_steps(unroll_steps) 37 | results = run_episode(env, solver) 38 | results['step'] = unroll_steps 39 | results['map_name'] = env.grid.config.map_name 40 | return results 41 | 42 | def run_episode_macro(env, algo): 43 | algo.reset_states() 44 | obs, _ = env.reset() 45 | while True: 46 | obs, rew, terminated, truncated, infos = env.step(algo.act(obs)) 47 | if all(terminated) or all(truncated): 48 | break 49 | return [info[0]['metrics'] for info in infos] 50 | 51 | def run_expert(env, unroll_steps, steps_saved, chosen_agents, time_limit): 52 | env = deepcopy(env) 53 | solver = LacamInference(LacamInferenceConfig(time_limit=time_limit, timeouts=[time_limit])) 54 | input, gt_action, metrics = fill_actions_with_solver(env, unroll_steps, steps_saved, chosen_agents, solver) 55 | if metrics is not None: 56 | metrics['step'] = unroll_steps 57 | metrics['map_name'] = env.grid.config.map_name 58 | return input, gt_action, metrics 59 | 60 | def fast_solver_delta(envs, learnable_algo, fast_solver, solver, cfg: FastSolverDeltaConfig): 61 | 62 | def create_svg(env, unroll_steps): 63 | obstacles = env.get_obstacles(ignore_borders=False) 64 | algo_history = env.get_full_history() 65 | fast_env = deepcopy(env) 66 | fast_env.set_unroll_steps(unroll_steps) 67 | run_episode(fast_env, fast_solver) 68 | fast_solver_history = fast_env.get_full_history() 69 | oracle_env = deepcopy(env) 70 | oracle_env.set_unroll_steps(unroll_steps) 71 | run_episode(oracle_env, solver) 72 | oracle_history = oracle_env.get_full_history() 73 | histories = [algo_history, fast_solver_history, oracle_history] 74 | ToolboxRegistry.debug('Histories sizes: ' + str([len(x[0]) for x in histories])) 75 | cut_histories = [cut_history(x, start=unroll_steps, finish=unroll_steps + cfg.steps_saved) for x in histories] 76 | ToolboxRegistry.debug('Cut histories sizes: ' + str([len(x[0]) for x in cut_histories])) 77 | 78 | svg_path = f'renders/seed-{env.grid.config.map_name}-step-{unroll_steps}.svg' 79 | Path(svg_path).parent.mkdir(exist_ok=True) 80 | create_multi_animation(obstacles, cut_histories, env.grid.config, name=svg_path) 81 | ToolboxRegistry.debug(f'Saved svg to: {svg_path}') 82 | 83 | inputs = [] 84 | gt_actions = [] 85 | gpt_envs = [] 86 | for env in envs: 87 | env = RuntimeMetricWrapper(env) 88 | if cfg.save_debug_svg: 89 | env = AnimationMonitor(env, AnimationConfig(save_every_idx_episode=None)) 90 | obs, _ = env.reset(seed=env.grid_config.seed) 91 | obs_generator = ObservationGenerator(obs[0]["global_obstacles"].copy().astype(int).tolist(), 92 | InputParameters(20, 13, 5, 256, 5, 5, 64, False)) 93 | obs_generator.create_agents([o["global_xy"] for o in obs], [o["global_target_xy"] for o in obs]) 94 | env = UnrollWrapper(env) 95 | env = MAPFGPTObservationWrapper(env, obs_generator) 96 | gpt_envs.append(env) 97 | macro_env = PogemaMacroEnvironment(gpt_envs) 98 | gpt_results = run_episode_macro(macro_env, learnable_algo) 99 | 100 | envs = [env.get_inner_env() for env in macro_env.environments] 101 | 102 | unroll_steps_lists = [] 103 | for gpt_result in gpt_results: 104 | unroll_steps_list = range(0, gpt_result['ep_length'], cfg.steps_delta) 105 | unroll_steps_lists.append(unroll_steps_list) 106 | 107 | with Pool(processes=8) as pool: 108 | fast_solver_results = pool.starmap(run_solver, 109 | [(env, unroll_steps, 2) for env, unroll_steps_list in zip(envs, unroll_steps_lists) for unroll_steps in unroll_steps_list]) 110 | 111 | fast_solver_results_by_map = {} 112 | for result in fast_solver_results: 113 | if result['map_name'] not in fast_solver_results_by_map: 114 | fast_solver_results_by_map[result['map_name']] = {} 115 | fast_solver_results_by_map[result['map_name']][result['step']] = result 116 | 117 | 118 | diffs_by_map = {} 119 | for map_name, results in fast_solver_results_by_map.items(): 120 | unroll_steps = sorted(results.keys()) 121 | diffs = [] 122 | for i in range(1, len(unroll_steps)): 123 | prev_step = unroll_steps[i - 1] 124 | curr_step = unroll_steps[i] 125 | diff = results[curr_step]['makespan'] - results[prev_step]['makespan'] 126 | diffs.append(diff) 127 | diffs_by_map[map_name] = diffs 128 | 129 | max_diff_indices = {map_name: diffs.index(max(diffs)) for map_name, diffs in diffs_by_map.items()} 130 | 131 | envs_with_positive_diffs = [] 132 | for env in envs: 133 | if diffs_by_map[env.grid.config.map_name][max_diff_indices[env.grid.config.map_name]] > cfg.diff_threshold: 134 | env.set_unroll_steps(cfg.steps_delta*max_diff_indices[env.grid.config.map_name]) 135 | envs_with_positive_diffs.append((env, cfg.steps_delta*max_diff_indices[env.grid.config.map_name])) 136 | chosen_agents = list(range(env.grid.config.num_agents)) 137 | ToolboxRegistry.debug(f'Makespan difference: {diffs_by_map}') 138 | with Pool(processes=8) as pool: 139 | expert_results = pool.starmap(run_expert, 140 | [(env, unroll_steps, cfg.steps_saved, chosen_agents, 10) for env, unroll_steps in envs_with_positive_diffs]) 141 | 142 | inputs = [] 143 | gt_actions = [] 144 | expert_logs = {} 145 | for result in expert_results: 146 | if result[0] is not None: 147 | filtered_data = filter_data(result[0], result[1]) 148 | inputs.extend(filtered_data['inputs']) 149 | gt_actions.extend(filtered_data['gt_actions']) 150 | expert_logs[result[2]['map_name']] = result[2] 151 | else: 152 | ToolboxRegistry.debug('No expert results for env', env.grid.config.map_name) 153 | if cfg.save_debug_svg: 154 | for env, unroll_steps in envs_with_positive_diffs: 155 | create_svg(env, unroll_steps) 156 | logs = [{'map_name': envs[i].grid.config.map_name, 157 | 'gpt_results': gpt_results[i], 158 | 'fast_expert_results': fast_solver_results_by_map[envs[i].grid.config.map_name], 159 | 'expert_results': expert_logs[envs[i].grid.config.map_name] if envs[i].grid.config.map_name in expert_logs else "Diff threshold not reached"} for i in range(len(envs))] 160 | return {'inputs': inputs, 'gt_actions': gt_actions}, logs 161 | 162 | 163 | def main(): 164 | ToolboxRegistry.setup_logger('DEBUG') 165 | 166 | learnable_algo = MAPFGPTInference(MAPFGPTInferenceConfig(device='cuda', path_to_weights='../weights/model-2M.pt')) 167 | fast_time_limit = 2 168 | slow_time_limit = 10 169 | lacam_lib_path = "../lacam/liblacam.so" 170 | fast_solver = LacamInference( 171 | LacamInferenceConfig(time_limit=fast_time_limit, timeouts=[fast_time_limit], lacam_lib_path=lacam_lib_path), ) 172 | solver = LacamInference( 173 | LacamInferenceConfig(time_limit=slow_time_limit, timeouts=[slow_time_limit], lacam_lib_path=lacam_lib_path)) 174 | 175 | env = make_pogema_maze_instance(num_agents=32, 176 | max_episode_steps=256, 177 | map_seed=45, 178 | scenario_seed=45) 179 | 180 | fast_solver_delta(env=env, learnable_algo=learnable_algo, fast_solver=fast_solver, solver=solver, 181 | cfg=FastSolverDeltaConfig(save_debug_svg=True)) 182 | 183 | 184 | if __name__ == '__main__': 185 | main() 186 | -------------------------------------------------------------------------------- /lacam/inference.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | from typing import Literal 6 | from pydantic import Extra 7 | from pogema_toolbox.algorithm_config import AlgoBase 8 | 9 | from pogema import GridConfig 10 | 11 | import subprocess 12 | 13 | file_folder = Path(__file__).resolve().parent 14 | lib_path = file_folder / 'liblacam.so' 15 | 16 | if not lib_path.exists(): 17 | subprocess.run(['cmake', '.'], check=True, cwd=file_folder) 18 | subprocess.run(['make', '-j8'], check=True, cwd=file_folder) 19 | 20 | 21 | class LacamLib: 22 | def __init__(self, lib_path): 23 | self.lib_path = lib_path 24 | self.load_library() 25 | 26 | def load_library(self): 27 | self._lacam_lib = ctypes.CDLL(self.lib_path) 28 | 29 | self._lacam_lib.run_lacam.argtypes = [ 30 | ctypes.c_char_p, # map_name 31 | ctypes.c_char_p, # scene_name 32 | ctypes.c_int, # N 33 | ctypes.c_float # time_limit_sec 34 | ] 35 | self._lacam_lib.run_lacam.restype = ctypes.c_char_p 36 | 37 | def run_lacam(self, map_file_content, scene_file_content, num_agents, lacam_timeouts): 38 | map_file_bytes = map_file_content.encode('utf-8') 39 | scenario_file_bytes = scene_file_content.encode('utf-8') 40 | 41 | num_agents_int = ctypes.c_int(num_agents) 42 | for time_limit_sec in lacam_timeouts: 43 | result = self._lacam_lib.run_lacam( 44 | map_file_bytes, 45 | scenario_file_bytes, 46 | num_agents_int, 47 | time_limit_sec 48 | ) 49 | 50 | try: 51 | result_str = result.decode('utf-8') 52 | except Exception as e: 53 | print(f'Exception occured while running Lacam: {e}') 54 | raise e 55 | 56 | if "ERROR" in result_str: 57 | print(f'Lacam failed to find path with time_limit_sec={time_limit_sec} | {result_str}') 58 | else: 59 | return True, result_str 60 | 61 | return False, None 62 | 63 | class LacamAgent: 64 | def __init__(self, idx): 65 | self._moves = GridConfig().MOVES 66 | self._reverse_actions = {tuple(self._moves[i]): i for i in range(len(self._moves))} 67 | 68 | self.idx = idx 69 | self.previous_goal = None 70 | self.path = [] 71 | 72 | def is_new_goal(self, new_goal): 73 | return not self.previous_goal == new_goal 74 | 75 | def set_new_goal(self, new_goal): 76 | self.previous_goal = new_goal 77 | 78 | def set_path(self, new_path): 79 | self.path = new_path[::-1] 80 | 81 | def format_task_string(self, start_xy, target_xy, map_shape): 82 | task_file_content = f"{self.idx} tmp.map {map_shape[0]} {map_shape[1]} " 83 | task_file_content += f"{start_xy[1]} {start_xy[0]} {target_xy[1]} {target_xy[0]} 1\n" 84 | return task_file_content 85 | 86 | def get_action(self): 87 | action = 0 88 | if len(self.path) > 1: 89 | x, y = self.path[-1] 90 | tx, ty = self.path[-2] 91 | action = self._reverse_actions[tx - x, ty - y] 92 | self.path.pop() 93 | return action 94 | 95 | def clear_state(self): 96 | self.previous_goal = None 97 | self.path = [] 98 | 99 | 100 | class LacamInferenceConfig(AlgoBase, extra=Extra.forbid): 101 | name: Literal['LaCAM'] = 'LaCAM' 102 | time_limit: float = 60 103 | timeouts: list = [1.0, 5.0, 10.0, 60.0] 104 | lacam_lib_path: str = "lacam/liblacam.so" 105 | 106 | 107 | class LacamInference: 108 | def __init__(self, cfg: LacamInferenceConfig): 109 | self.cfg = cfg 110 | self.lacam_agents = None 111 | self.lacam_lib = LacamLib(cfg.lacam_lib_path) 112 | self.solved = None 113 | 114 | def _parse_data(self, data): 115 | if data is None: 116 | return None 117 | lines = data.strip().split('\n') 118 | columns = None 119 | 120 | for line in lines: 121 | tuples = [tuple(map(int, item.split(','))) for item in line.strip().split('|') if item] 122 | if len(tuples) == 0: 123 | return None 124 | if columns is None: 125 | columns = [[] for _ in range(len(tuples))] 126 | for i, t in enumerate(tuples): 127 | columns[i].append(t[::-1]) 128 | 129 | return columns 130 | 131 | def _find_near_goal(self, start_xy, target_xy, map_array, processed_targets): 132 | for radius in range(1, 3): 133 | offset_list = [] 134 | for x_offset in range(-radius, radius+1): 135 | for y_offset in range(-radius, radius+1): 136 | if x_offset == 0 and y_offset == 0: 137 | continue 138 | offset_list.append((x_offset, y_offset)) 139 | 140 | offset_list.sort(key=lambda xy_off: (abs(xy_off[0]) + abs(xy_off[1]))**0.5) 141 | 142 | for (x_offset, y_offset) in offset_list: 143 | near_target_x = target_xy[0] + x_offset 144 | near_target_y = target_xy[1] + y_offset 145 | is_obstacle = map_array[near_target_x, near_target_y] 146 | assert map_array[target_xy[0], target_xy[1]] == 0 147 | if not is_obstacle and (near_target_x, near_target_y) not in processed_targets \ 148 | and start_xy != (near_target_x, near_target_y): 149 | return (near_target_x, near_target_y) 150 | 151 | def act(self, observations, rewards=None, dones=None, info=None, skip_agents=None): 152 | map_array = np.array(observations[0]['global_obstacles']) 153 | agent_starts_xy = [obs['global_xy'] for obs in observations] 154 | agent_targets_xy = [obs['global_target_xy'] for obs in observations] 155 | 156 | has_new_tasks = False 157 | 158 | processed_starts = set() 159 | processed_targets = set() 160 | if self.lacam_agents is None: 161 | self.lacam_agents = [LacamAgent(idx) for idx in range(len(observations))] 162 | # Process old tasks 163 | agent_tasks_dict = {} 164 | for idx, (start_xy, target_xy) in enumerate(zip(agent_starts_xy, agent_targets_xy)): 165 | if self.lacam_agents[idx].is_new_goal(target_xy): 166 | continue 167 | if start_xy == target_xy or target_xy in processed_targets: 168 | near_target_xy = self._find_near_goal(start_xy, target_xy, map_array, processed_targets) 169 | target_xy = near_target_xy 170 | 171 | processed_starts.add(start_xy) 172 | processed_targets.add(target_xy) 173 | 174 | agent_task = self.lacam_agents[idx].format_task_string(start_xy, target_xy, map_shape=map_array.shape) 175 | agent_tasks_dict[idx] = agent_task 176 | 177 | # Process new tasks 178 | for idx, (start_xy, target_xy) in enumerate(zip(agent_starts_xy, agent_targets_xy)): 179 | if not self.lacam_agents[idx].is_new_goal(target_xy): 180 | continue 181 | if target_xy in processed_targets: 182 | near_target_xy = self._find_near_goal(start_xy, target_xy, map_array, processed_targets) 183 | target_xy = near_target_xy 184 | self.lacam_agents[idx].set_new_goal(target_xy) 185 | has_new_tasks = True 186 | 187 | processed_starts.add(start_xy) 188 | processed_targets.add(target_xy) 189 | 190 | agent_task = self.lacam_agents[idx].format_task_string(start_xy, target_xy, map_shape=map_array.shape) 191 | agent_tasks_dict[idx] = agent_task 192 | 193 | task_file_content = "version 1\n" 194 | for idx in range(len(self.lacam_agents)): 195 | task_file_content += agent_tasks_dict[idx] 196 | 197 | if has_new_tasks: 198 | map_row = lambda row: ''.join('@' if x else '.' for x in row) 199 | map_content = '\n'.join(map_row(row) for row in map_array) 200 | map_file_content = f"type octile\nheight {map_array.shape[0]}\nwidth {map_array.shape[1]}\nmap\n{map_content}" 201 | self.solved, lacam_results = self.lacam_lib.run_lacam(map_file_content, task_file_content, len(self.lacam_agents), self.cfg.timeouts) 202 | if self.solved: 203 | agent_paths = self._parse_data(lacam_results) 204 | else: 205 | agent_paths = [[agent_starts_xy[i] for _ in range(256)] for i in range(len(agent_starts_xy))] # if failed - agents just wait in start locations 206 | if agent_paths is not None: 207 | for idx, agent_path in enumerate(agent_paths): 208 | self.lacam_agents[idx].set_path(agent_path) 209 | 210 | return [agent.get_action() for agent in self.lacam_agents] 211 | 212 | def after_step(self, dones): 213 | pass 214 | 215 | def reset_states(self): 216 | self.lacam_agents = None 217 | 218 | def after_reset(self): 219 | pass 220 | 221 | def get_additional_info(self): 222 | addinfo = {"rl_used": 0.0} 223 | return addinfo --------------------------------------------------------------------------------