├── py.typed ├── py_factor_graph ├── io │ ├── __init__.py │ ├── tests │ │ ├── __init__.py │ │ ├── test_io.py │ │ └── data │ │ │ ├── pyfg_text_se2_test_data.txt │ │ │ └── pyfg_text_se3_test_data.txt │ ├── pickle_file.py │ ├── standard_graph.py │ ├── efg_file.py │ ├── plaza_experiments.py │ └── g2o_file.py ├── calibrations │ ├── __init__.py │ ├── odom_measurement_calibration.py │ └── range_measurement_calibration.py ├── __init__.py ├── utils │ ├── __init__.py │ ├── logging_utils.py │ ├── name_utils.py │ ├── attrib_utils.py │ ├── solver_utils.py │ └── plot_utils.py ├── priors.py ├── variables.py └── measurements.py ├── pyproject.toml ├── setup.py ├── docs ├── index.html └── py_factor_graph │ ├── utils │ └── index.html │ ├── index.html │ └── parse_factor_graph.html ├── .pre-commit-hooks.yaml ├── .pre-commit-config.yaml ├── setup.cfg ├── .gitignore └── README.md /py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /py_factor_graph/io/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /py_factor_graph/calibrations/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /py_factor_graph/__init__.py: -------------------------------------------------------------------------------- 1 | """Initialize factor graph package.""" 2 | -------------------------------------------------------------------------------- /py_factor_graph/io/tests/__init__.py: -------------------------------------------------------------------------------- 1 | """Directory for io tests""" 2 | -------------------------------------------------------------------------------- /py_factor_graph/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Directory for utils files """ 2 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # Pre-commit hooks related config. 2 | [tool.pycln] 3 | all = true -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), "factor_graph")) 5 | 6 | from setuptools import setup 7 | 8 | setup() 9 | -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.pre-commit-hooks.yaml: -------------------------------------------------------------------------------- 1 | - id: pdoc 2 | name: pdoc 3 | description: 'pdoc: Auto-generate API documentation for Python projects' 4 | entry: pdoc 5 | language: python 6 | language_version: python3 7 | require_serial: true 8 | types: [python] -------------------------------------------------------------------------------- /py_factor_graph/utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import coloredlogs 3 | 4 | logger = logging.getLogger(__name__) 5 | field_styles = { 6 | "filename": {"color": "green"}, 7 | "levelname": {"bold": True, "color": "black"}, 8 | "name": {"color": "blue"}, 9 | } 10 | coloredlogs.install( 11 | level="INFO", 12 | fmt="[%(filename)s:%(lineno)d] %(name)s %(levelname)s - %(message)s", 13 | field_styles=field_styles, 14 | ) 15 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v2.3.0 4 | hooks: 5 | - id: check-yaml 6 | - id: trailing-whitespace 7 | exclude_types: [html] 8 | - repo: https://github.com/psf/black 9 | rev: 22.3.0 10 | hooks: 11 | - id: black 12 | - repo: https://github.com/pre-commit/mirrors-mypy 13 | rev: v0.910 14 | hooks: 15 | - id: mypy 16 | additional_dependencies: [attrs==22.1.0] -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [options] 2 | # install_requires reads from requirements.txt 3 | install_requires= 4 | attrs 5 | numpy 6 | scipy 7 | evo 8 | PyQt5 9 | coloredlogs 10 | tqdm 11 | 12 | packages = find: 13 | 14 | [metadata] 15 | name = py_factor_graph 16 | version = 0.1.0 17 | description = A factor graph object with parsing and save functionality for SLAM problems 18 | author = Alan Papalia 19 | author_email = apapalia@mit.edu 20 | url = https://github.com/MarineRoboticsGroup/PyFactorGraph 21 | classifiers = 22 | Typing :: Typed 23 | 24 | [options.packages.find] 25 | exclude = examples* 26 | 27 | [options.package_data] 28 | py_factor_graph = py.typed 29 | -------------------------------------------------------------------------------- /py_factor_graph/io/pickle_file.py: -------------------------------------------------------------------------------- 1 | from os.path import isfile 2 | import pickle 3 | 4 | from py_factor_graph.factor_graph import ( 5 | FactorGraphData, 6 | ) 7 | from py_factor_graph.utils.logging_utils import logger 8 | 9 | 10 | def parse_pickle_file(filepath: str) -> FactorGraphData: 11 | """ 12 | Retrieve a pickled FactorGraphData object. Requires that the 13 | file ends with .pickle (e.g. "my_file.pickle"). 14 | 15 | Args: 16 | filepath: The path to the factor graph file. 17 | 18 | Returns: 19 | FactorGraphData: The factor graph data. 20 | """ 21 | assert isfile(filepath), f"{filepath} is not a file" 22 | assert filepath.endswith(".pickle") or filepath.endswith( 23 | ".pkl" 24 | ), f"{filepath} is not a pickle file" 25 | 26 | with open(filepath, "rb") as f: 27 | data = pickle.load(f) 28 | logger.debug(f"Loaded factor graph data from {filepath}") 29 | return data 30 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info 2 | 3 | # OS generated files # 4 | ###################### 5 | .DS_Store 6 | .DS_Store? 7 | ._* 8 | .Spotlight-V100 9 | .Trashes 10 | ehthumbs.db 11 | [Tt]humbs.db 12 | 13 | # mypy 14 | *.sqlite3 15 | .mypy_cache 16 | 17 | # profiling 18 | *.log 19 | *.svg 20 | 21 | # Jupyter Notebooks # 22 | ##################### 23 | .ipynb_checkpoints 24 | 25 | # Binaries # 26 | ############ 27 | *.pyc 28 | 29 | # Latex # 30 | ######### 31 | *.aux 32 | *.fdb_latexmk 33 | *.fls 34 | *.log 35 | *.out 36 | *.synctex.gz 37 | *.bbl 38 | *.blg 39 | *.bcf 40 | *.run.xml 41 | *.toc 42 | *.dvi 43 | *.lof 44 | *.lot 45 | main.pdf 46 | 47 | # VSCode # 48 | ############# 49 | .history/ 50 | .vscode/ 51 | .history/* 52 | .vscode/* 53 | .markdownlint.json 54 | *.code-workspace 55 | 56 | # spacemacs # 57 | ############# 58 | auto/ 59 | .#* 60 | 61 | # Python # 62 | ########## 63 | __pycache__/ 64 | test.py 65 | 66 | # Cached Files from Python # 67 | ########## 68 | cached_planning/*.txt 69 | figures/*.png 70 | trajs/*.txt 71 | *scrap* 72 | rigidity_dicts/*.json 73 | 74 | # PyTorch # 75 | ########### 76 | *.pt 77 | 78 | # Executables/C++ # 79 | ################### 80 | build/ 81 | 82 | # Tensorflow # 83 | ############## 84 | *.tfrecord 85 | *.pb 86 | 87 | # Temporary Files # 88 | ################### 89 | *.*~ 90 | *.swp 91 | 92 | # Large Files # 93 | ############### 94 | *.png 95 | *.mp4 96 | *.jpg 97 | *.xlsx 98 | 99 | # Blender # 100 | ########### 101 | *.blend 102 | 103 | # MATLAB # 104 | ########## 105 | *.m~ 106 | 107 | # ROS # 108 | ####### 109 | *.bag 110 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyFactorGraph 2 | 3 | The primary purpose of this is so that people working in Python can avoid having 4 | to re-write different data providers for each project and data set they work on. 5 | The factor graph object and functions provided here allow for {reading, writing, 6 | accessing} of measurements and variables. 7 | 8 | We provide examples for working with the Extended Factor Graph format (`.fg`) 9 | and pickled `FactorGraphData` objects (`.pickle`) in the `/examples` directory. 10 | 11 | Also potentially useful, auto-documentation of this code can be found in the 12 | `/docs` directory. 13 | 14 | ## Getting Started 15 | 16 | Installing this package is quick and easy! 17 | 18 | ```bash 19 | pip install git+https://github.com/MarineRoboticsGroup/PyFactorGraph 20 | ``` 21 | 22 | Ta-da you should be ready to go! 23 | 24 | ## Contributing 25 | 26 | If you want to contribute a new feature to this package please read this brief section. 27 | 28 | ### Code Standards 29 | 30 | Any necessary coding standards are enforced through `pre-commit`. This will run 31 | a series of `hooks` when attempting to commit code to this repo. Additionally, 32 | we run a `pre-commit` hook to auto-generate the documentation of this library to 33 | make sure it is always up to date. 34 | 35 | To set up `pre-commit` 36 | 37 | ```bash 38 | cd ~/PyFactorGraph 39 | pip3 install pre-commit 40 | pre-commit install 41 | ``` 42 | 43 | ### Testing 44 | 45 | If you want to develop this package and test it from an external package you can 46 | also install via 47 | 48 | ```bash 49 | cd ~/PyFactorGraph 50 | pip install -e . 51 | ``` 52 | 53 | The `-e` flag will make sure that any changes you make here are automatically 54 | translated to the external library you are working in. 55 | -------------------------------------------------------------------------------- /py_factor_graph/io/tests/test_io.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from py_factor_graph.io.pyfg_text import read_from_pyfg_text, save_to_pyfg_text 4 | 5 | # get current directory and directory containing data 6 | cur_dir = os.path.dirname(os.path.realpath(__file__)) 7 | data_dir = os.path.join(cur_dir, "data") 8 | 9 | # create temporary folder for saving factor graph to file 10 | tmp_dir = os.path.join(cur_dir, "tmp") 11 | if not os.path.exists(tmp_dir): 12 | os.makedirs(tmp_dir) 13 | 14 | # read two text files and check if each line in both files is identical 15 | def _check_file_equality(file1, file2): 16 | with open(file1, "r") as f1, open(file2, "r") as f2: 17 | for line1, line2 in zip(f1, f2): 18 | if line1.strip() != line2.strip(): 19 | return False 20 | return True 21 | 22 | 23 | def test_pyfg_se3_file() -> None: 24 | # read factor graph data 25 | data_file = os.path.join(data_dir, "pyfg_text_se3_test_data.txt") 26 | factor_graph = read_from_pyfg_text(data_file) 27 | 28 | # write factor graph data 29 | write_file = os.path.join(tmp_dir, "pyfg_text_se3_test_tmp.txt") 30 | save_to_pyfg_text(factor_graph, write_file) 31 | 32 | # assert read and write files are equal 33 | assert _check_file_equality(data_file, write_file) 34 | 35 | # remove temporary file 36 | os.remove(write_file) 37 | 38 | 39 | def test_pyfg_se2_file() -> None: 40 | # read factor graph data 41 | data_file = os.path.join(data_dir, "pyfg_text_se2_test_data.txt") 42 | factor_graph = read_from_pyfg_text(data_file) 43 | 44 | # write factor graph data 45 | write_file = os.path.join(tmp_dir, "pyfg_text_se2_test_tmp.txt") 46 | save_to_pyfg_text(factor_graph, write_file) 47 | 48 | # assert read and write files are equal 49 | assert _check_file_equality(data_file, write_file) 50 | 51 | # remove temporary file 52 | os.remove(write_file) 53 | -------------------------------------------------------------------------------- /py_factor_graph/utils/name_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | robot_char_list = [chr(ord("A") + i) for i in range(26)] 4 | robot_char_list.remove("L") 5 | 6 | 7 | def get_robot_char_from_number(robot_number: int) -> str: 8 | """ 9 | Get the robot character from the given robot number. 10 | """ 11 | assert robot_number >= 0 and robot_number < len( 12 | robot_char_list 13 | ), f"Cannot have more than {len(robot_char_list)} robots" 14 | char = robot_char_list[robot_number] 15 | assert char != "L", "Character L is reserved for landmarks" 16 | return char 17 | 18 | 19 | def get_robot_char_from_frame_name(frame: str) -> str: 20 | """ 21 | Get the robot character from the given frame. 22 | """ 23 | check_is_valid_frame_name(frame) 24 | robot_char = frame[0] 25 | return robot_char 26 | 27 | 28 | def get_robot_idx_from_char(robot_char: str) -> int: 29 | """ 30 | Get the robot index from the given robot character. 31 | """ 32 | assert robot_char in robot_char_list, f"Invalid robot character: {robot_char}" 33 | return robot_char_list.index(robot_char) 34 | 35 | 36 | def get_robot_idx_from_frame_name(frame: str) -> int: 37 | """ 38 | Get the robot index from the given frame name. 39 | """ 40 | check_is_valid_frame_name(frame) 41 | robot_char = get_robot_char_from_frame_name(frame) 42 | return get_robot_idx_from_char(robot_char) 43 | 44 | 45 | def get_time_idx_from_frame_name(frame: str) -> int: 46 | """ 47 | Get the time index from the given frame name. 48 | """ 49 | check_is_valid_frame_name(frame) 50 | return int(re.search(r"[\d+]+", frame).group(0)) # type: ignore 51 | 52 | 53 | def check_is_valid_frame_name(frame: str): 54 | """ 55 | Runs assertions if the given frame name is valid. 56 | """ 57 | assert isinstance( 58 | frame, str 59 | ), f"Frame name must be a string, not {type(frame)}: {frame}" 60 | assert len(re.findall(r"[a-zA-Z][\d+]+", frame)) == 1, ( 61 | "Frame name must identify robot and pose number. " 62 | "E.g. A0 or B12 are both accaptable. " 63 | f"Received {frame}" 64 | ) 65 | assert ( 66 | len(re.findall(r"[A-Z]", frame)) == 1 67 | ), "Only allowing single character robot names" 68 | -------------------------------------------------------------------------------- /py_factor_graph/io/tests/data/pyfg_text_se2_test_data.txt: -------------------------------------------------------------------------------- 1 | VERTEX_SE2 0.123456789 A0 0.000000 0.000000 0.0000000 2 | VERTEX_SE2 0.123456789 A1 1.033099 0.093536 0.4499639 3 | VERTEX_SE2 0.123456789 A2 1.864103 -0.068827 3.0373118 4 | VERTEX_SE2 0.123456789 A3 2.778843 0.043020 3.0867882 5 | VERTEX_SE2 0.123456789 A4 3.740591 0.018251 -1.1179039 6 | VERTEX_SE2 0.123456789 A5 4.033220 0.677269 -2.9115926 7 | VERTEX_SE2 0.123456789 A6 3.213011 0.899245 1.7964227 8 | VERTEX_SE2 0.123456789 A7 2.367769 0.848256 1.7722049 9 | VERTEX_SE2 0.123456789 A8 1.754363 0.732940 1.1026314 10 | VERTEX_XY L0 2.500000 0.500000 11 | VERTEX_XY L1 2.500000 -0.500000 12 | VERTEX_SE2:PRIOR 0.123456789 A0 0.000000 0.000000 0.0000000 0.010000 0.000000 0.000000 0.010000 0.000000 0.040000 13 | VERTEX_XY:PRIOR 0.123456789 L0 2.500000 0.500000 0.010000 0.000000 0.010000 14 | VERTEX_XY:PRIOR 0.123456789 L1 2.500000 -0.500000 0.010000 0.000000 0.010000 15 | EDGE_SE2 0.123456789 A0 A1 1.033099 0.093536 0.4499639 0.010000 0.000000 0.000000 0.010000 0.000000 0.040000 16 | EDGE_SE2 1.123456789 A1 A2 0.589385 -0.557830 -3.1097626 0.010000 0.000000 0.000000 0.010000 0.000000 0.040000 17 | EDGE_SE2 2.123456789 A2 A3 -0.344795 -0.326952 -0.6128210 0.010000 0.000000 0.000000 0.010000 0.000000 0.040000 18 | EDGE_SE2 3.123456789 A3 A4 -0.862188 0.468834 1.7684509 0.010000 0.000000 0.000000 0.010000 0.000000 0.040000 19 | EDGE_SE2 4.123456789 A4 A5 -0.412970 0.367591 -2.1054829 0.010000 0.000000 0.000000 0.010000 0.000000 0.040000 20 | EDGE_SE2 6.123456789 A5 A6 1.021990 0.090432 -1.0532952 0.010000 0.000000 0.000000 0.010000 0.000000 0.040000 21 | EDGE_SE2 7.123456789 A6 A7 0.459972 0.672496 -0.8577842 0.010000 0.000000 0.000000 0.010000 0.000000 0.040000 22 | EDGE_SE2 8.123456789 A7 A8 0.620421 0.552512 -0.6361325 0.010000 0.000000 0.000000 0.010000 0.000000 0.040000 23 | EDGE_SE2 9.123456789 A1 A8 -0.062404 0.790626 1.2602358 0.010000 0.000000 0.000000 0.010000 0.000000 0.040000 24 | EDGE_SE2 10.123456789 A3 A6 -0.586019 0.114638 -1.6478037 0.010000 0.000000 0.000000 0.010000 0.000000 0.040000 25 | EDGE_SE2 11.123456789 A7 A2 -0.693071 0.663893 2.1813801 0.010000 0.000000 0.000000 0.010000 0.000000 0.040000 26 | EDGE_SE2_XY 2.123456789 A2 L0 0.635897 0.568827 0.010000 0.000000 0.010000 27 | EDGE_SE2_XY 6.123456789 A6 L0 -0.713011 -0.399245 0.010000 0.000000 0.010000 28 | EDGE_SE2_XY 2.123456789 A2 L1 0.635897 0.568827 0.010000 0.000000 0.010000 29 | EDGE_SE2_XY 6.123456789 A6 L1 -0.713011 -0.399245 0.010000 0.000000 0.010000 30 | EDGE_RANGE 2.123456789 A2 L0 0.853187 0.010000 31 | EDGE_RANGE 6.123456789 A6 L0 0.817178 0.010000 32 | EDGE_RANGE 2.123456789 A2 L1 0.853187 0.010000 33 | EDGE_RANGE 6.123456789 A6 L1 0.817178 0.010000 34 | EDGE_RANGE 0.123456789 L0 L1 1.000000 0.010000 35 | -------------------------------------------------------------------------------- /py_factor_graph/io/tests/data/pyfg_text_se3_test_data.txt: -------------------------------------------------------------------------------- 1 | VERTEX_SE3:QUAT 0.123456789 A0 0.000000 0.000000 0.000000 0.0000000 0.0000000 0.0000000 1.0000000 2 | VERTEX_SE3:QUAT 0.123456789 A1 1.033099 0.093536 -0.037961 0.3171845 -0.2366641 0.1427899 0.9071908 3 | VERTEX_SE3:QUAT 0.123456789 A2 1.864103 -0.068827 -0.224420 0.3990360 -0.1862907 -0.8967650 0.0433426 4 | VERTEX_SE3:QUAT 0.123456789 A3 2.778843 0.043020 -0.654026 -0.0946935 0.8516455 -0.5040938 0.1078076 5 | VERTEX_SE3:QUAT 0.123456789 A4 3.740591 0.018251 -1.258278 -0.2025126 0.0306155 -0.5368945 0.8184104 6 | VERTEX_SE3:QUAT 0.123456789 A5 4.033220 0.677269 -0.953695 0.2648076 0.3972635 0.8786534 0.0051805 7 | VERTEX_SE3:QUAT 0.123456789 A6 3.213011 0.899245 -0.372356 0.6065701 -0.6659431 0.4047869 0.1572895 8 | VERTEX_SE3:QUAT 0.123456789 A7 2.367769 0.848256 -0.024307 0.5611840 -0.2302828 0.7226670 0.3313529 9 | VERTEX_SE3:QUAT 0.123456789 A8 1.754363 0.732940 0.550029 0.7067708 -0.4274800 0.3028011 0.4754444 10 | VERTEX_XYZ L0 2.500000 0.500000 1.000000 11 | VERTEX_XYZ L1 2.500000 -0.500000 -1.000000 12 | VERTEX_SE3:QUAT:PRIOR 0.123456789 A0 0.000000 0.000000 0.000000 0.0000000 0.0000000 0.0000000 1.0000000 0.010000 0.000000 0.000000 0.000000 0.000000 0.000000 0.010000 0.000000 0.000000 0.000000 0.000000 0.010000 0.000000 0.000000 0.000000 0.040000 0.000000 0.000000 0.040000 0.000000 0.040000 13 | VERTEX_XYZ:PRIOR 0.123456789 L0 2.500000 0.500000 1.000000 0.010000 0.000000 0.000000 0.010000 0.000000 0.010000 14 | VERTEX_XYZ:PRIOR 0.123456789 L1 2.500000 -0.500000 -1.000000 0.010000 0.000000 0.000000 0.010000 0.000000 0.010000 15 | EDGE_SE3:QUAT 0.123456789 A0 A1 1.033099 0.093536 -0.037961 0.3171845 -0.2366641 0.1427899 0.9071908 0.010000 0.000000 0.000000 0.000000 0.000000 0.000000 0.010000 0.000000 0.000000 0.000000 0.000000 0.010000 0.000000 0.000000 0.000000 0.040000 0.000000 0.000000 0.040000 0.000000 0.040000 16 | EDGE_SE3:QUAT 1.123456789 A1 A2 0.589385 -0.557830 -0.305201 0.1094217 -0.5001618 -0.8550748 0.0819273 0.010000 0.000000 0.000000 0.000000 0.000000 0.000000 0.010000 0.000000 0.000000 0.000000 0.000000 0.010000 0.000000 0.000000 0.000000 0.040000 0.000000 0.000000 0.040000 0.000000 0.040000 17 | EDGE_SE3:QUAT 2.123456789 A2 A3 -0.344795 -0.326952 -0.898909 -0.9047572 -0.2290733 -0.2473674 0.2602866 0.010000 0.000000 0.000000 0.000000 0.000000 0.000000 0.010000 0.000000 0.000000 0.000000 0.000000 0.010000 0.000000 0.000000 0.000000 0.040000 0.000000 0.000000 0.040000 0.000000 0.040000 18 | EDGE_SE3:QUAT 3.123456789 A3 A4 -0.862188 0.468834 0.572294 0.4974765 -0.7449399 0.1851044 0.4041262 0.010000 0.000000 0.000000 0.000000 0.000000 0.000000 0.010000 0.000000 0.000000 0.000000 0.000000 0.010000 0.000000 0.000000 0.000000 0.040000 0.000000 0.000000 0.040000 0.000000 0.040000 19 | EDGE_SE3:QUAT 4.123456789 A4 A5 -0.412970 0.367591 0.554111 0.0224185 -0.2892013 -0.8104386 0.5089689 0.010000 0.000000 0.000000 0.000000 0.000000 0.000000 0.010000 0.000000 0.000000 0.000000 0.000000 0.010000 0.000000 0.000000 0.000000 0.040000 0.000000 0.000000 0.040000 0.000000 0.040000 20 | EDGE_SE3:QUAT 6.123456789 A5 A6 1.021990 0.090432 0.085613 -0.7844494 -0.4917095 0.2812090 0.2525518 0.010000 0.000000 0.000000 0.000000 0.000000 0.000000 0.010000 0.000000 0.000000 0.000000 0.000000 0.010000 0.000000 0.000000 0.000000 0.040000 0.000000 0.000000 0.040000 0.000000 0.040000 21 | EDGE_SE3:QUAT 7.123456789 A6 A7 0.459972 0.672496 -0.417547 0.2753192 0.3956294 -0.2544933 0.8383972 0.010000 0.000000 0.000000 0.000000 0.000000 0.000000 0.010000 0.000000 0.000000 0.000000 0.000000 0.010000 0.000000 0.000000 0.000000 0.040000 0.000000 0.000000 0.040000 0.000000 0.040000 22 | EDGE_SE3:QUAT 8.123456789 A7 A8 0.620421 0.552512 -0.170980 -0.2718170 -0.3729929 -0.1661162 0.8714340 0.010000 0.000000 0.000000 0.000000 0.000000 0.000000 0.010000 0.000000 0.000000 0.000000 0.000000 0.010000 0.000000 0.000000 0.000000 0.040000 0.000000 0.000000 0.040000 0.000000 0.040000 23 | EDGE_SE3:QUAT 9.123456789 A1 A8 -0.062404 0.790626 -0.703394 0.4615956 0.1481179 0.6142114 0.6226836 0.010000 0.000000 0.000000 0.000000 0.000000 0.000000 0.010000 0.000000 0.000000 0.000000 0.000000 0.010000 0.000000 0.000000 0.000000 0.040000 0.000000 0.000000 0.040000 0.000000 0.040000 24 | EDGE_SE3:QUAT 10.123456789 A3 A6 -0.586019 0.114638 -0.982900 -0.0650692 0.0574151 -0.7311567 0.6766678 0.010000 0.000000 0.000000 0.000000 0.000000 0.000000 0.010000 0.000000 0.000000 0.000000 0.000000 0.010000 0.000000 0.000000 0.000000 0.040000 0.000000 0.000000 0.040000 0.000000 0.040000 25 | EDGE_SE3:QUAT 11.123456789 A7 A2 -0.693071 0.663893 -0.264779 -0.0751329 0.7634717 0.2365160 0.5962602 0.010000 0.000000 0.000000 0.000000 0.000000 0.000000 0.010000 0.000000 0.000000 0.000000 0.000000 0.010000 0.000000 0.000000 0.000000 0.040000 0.000000 0.000000 0.040000 0.000000 0.040000 26 | EDGE_SE3_XYZ 2.123456789 A2 L0 0.635897 0.568827 1.224420 0.010000 0.000000 0.000000 0.010000 0.000000 0.010000 27 | EDGE_SE3_XYZ 6.123456789 A6 L0 -0.713011 -0.399245 1.372356 0.010000 0.000000 0.000000 0.010000 0.000000 0.010000 28 | EDGE_SE3_XYZ 2.123456789 A2 L1 0.635897 -0.431173 -0.775580 0.010000 0.000000 0.000000 0.010000 0.000000 0.010000 29 | EDGE_SE3_XYZ 6.123456789 A6 L1 -0.713011 -1.399245 -0.627644 0.010000 0.000000 0.000000 0.010000 0.000000 0.010000 30 | EDGE_RANGE 2.123456789 A2 L0 1.492358 0.010000 31 | EDGE_RANGE 6.123456789 A6 L0 1.597229 0.010000 32 | EDGE_RANGE 2.123456789 A2 L1 1.091696 0.010000 33 | EDGE_RANGE 6.123456789 A6 L1 1.691215 0.010000 34 | EDGE_RANGE 0.123456789 L0 L1 2.236067 0.010000 35 | -------------------------------------------------------------------------------- /py_factor_graph/io/standard_graph.py: -------------------------------------------------------------------------------- 1 | """This will write the factor graph to a series of fairly basic graph files. 2 | 3 | These files will largely throw away most of the information in the factor graph, 4 | but will retain the structure of the graph. This is useful for: 5 | - visualizing the graph using standard graph visualization tools 6 | - projects that only need the structure of the graph 7 | """ 8 | 9 | import numpy as np 10 | import os 11 | import math 12 | from py_factor_graph.factor_graph import FactorGraphData 13 | from py_factor_graph.utils.name_utils import ( 14 | get_robot_idx_from_frame_name, 15 | get_time_idx_from_frame_name, 16 | ) 17 | from py_factor_graph.utils.matrix_utils import ( 18 | get_rotation_matrix_from_quat, 19 | get_symmetric_matrix_from_list_column_major, 20 | get_list_column_major_from_symmetric_matrix, 21 | get_measurement_precisions_from_covariance_matrix, 22 | ) 23 | from py_factor_graph.variables import ( 24 | PoseVariable2D, 25 | PoseVariable3D, 26 | LandmarkVariable2D, 27 | LandmarkVariable3D, 28 | POSE_VARIABLE_TYPES, 29 | LANDMARK_VARIABLE_TYPES, 30 | ) 31 | from py_factor_graph.priors import ( 32 | PosePrior2D, 33 | PosePrior3D, 34 | LandmarkPrior2D, 35 | LandmarkPrior3D, 36 | POSE_PRIOR_TYPES, 37 | LANDMARK_PRIOR_TYPES, 38 | ) 39 | from py_factor_graph.measurements import ( 40 | PoseMeasurement2D, 41 | PoseMeasurement3D, 42 | PoseToLandmarkMeasurement2D, 43 | PoseToLandmarkMeasurement3D, 44 | FGRangeMeasurement, 45 | POSE_MEASUREMENT_TYPES, 46 | POSE_LANDMARK_MEASUREMENT_TYPES, 47 | ) 48 | from py_factor_graph.utils.logging_utils import logger 49 | 50 | from itertools import chain 51 | 52 | 53 | def write_edge_list(pyfg: FactorGraphData, output_fpath: str) -> None: 54 | """Writes the factor graph to a simple edge list file. 55 | 56 | Args: 57 | pyfg (FactorGraphData): The factor graph to write. 58 | output_fpath (str): The path to the output file. 59 | """ 60 | logger.info(f"Writing edge list to {output_fpath}") 61 | 62 | variable_names = pyfg.all_variable_names 63 | name_to_idx = {name: idx for idx, name in enumerate(variable_names)} 64 | 65 | writer = open(output_fpath, "w") 66 | 67 | # get odometry measurements and flatten them 68 | odom_measures = chain.from_iterable(pyfg.odom_measurements) 69 | for odo_measure in odom_measures: 70 | i, j = name_to_idx[odo_measure.base_pose], name_to_idx[odo_measure.to_pose] 71 | weight = odo_measure.rotation_precision 72 | writer.write(f"{i} {j} {weight}\n") 73 | 74 | # loop closures 75 | loop_closures = pyfg.loop_closure_measurements 76 | for loop_measure in loop_closures: 77 | i, j = name_to_idx[loop_measure.base_pose], name_to_idx[loop_measure.to_pose] 78 | weight = loop_measure.rotation_precision 79 | writer.write(f"{i} {j} {weight}\n") 80 | 81 | # get pose to landmark measurements 82 | pose_landmark_measures = pyfg.pose_landmark_measurements 83 | for pl_measure in pose_landmark_measures: 84 | i, j = name_to_idx[pl_measure.pose_name], name_to_idx[pl_measure.landmark_name] 85 | weight = pl_measure.translation_precision 86 | writer.write(f"{i} {j} {weight}\n") 87 | 88 | # range measurements 89 | range_measures = pyfg.range_measurements 90 | for r_measure in range_measures: 91 | i, j = name_to_idx[r_measure.first_key], name_to_idx[r_measure.second_key] 92 | weight = r_measure.precision 93 | writer.write(f"{i} {j} {weight}\n") 94 | 95 | writer.close() 96 | 97 | 98 | if __name__ == "__main__": 99 | import argparse 100 | 101 | parser = argparse.ArgumentParser( 102 | description="Writes a factor graph to a simple edge list file." 103 | ) 104 | parser.add_argument( 105 | "input_graph", 106 | type=str, 107 | help="Path to the input file containing the factor graph. Currently supports .g2o and .pyfg formats.", 108 | ) 109 | parser.add_argument( 110 | "output_edge_list", 111 | type=str, 112 | help="Path to the output edge list file (e.g., .edge).", 113 | ) 114 | 115 | args = parser.parse_args() 116 | 117 | from py_factor_graph.io.pyfg_text import read_from_pyfg_text 118 | from py_factor_graph.io.g2o_file import parse_2d_g2o_file, parse_3d_g2o_file 119 | 120 | # Load the factor graph 121 | if not os.path.exists(args.input_graph): 122 | raise FileNotFoundError(f"Input graph file {args.input_graph} does not exist.") 123 | 124 | fg = FactorGraphData(dimension=2) # Default to 2D, will be overwritten if needed 125 | if args.input_graph.endswith(".g2o"): 126 | # Attempt to parse as a g2o file 127 | try: 128 | fg = parse_2d_g2o_file(args.input_graph) 129 | logger.info("Parsed input graph as 2D g2o file.") 130 | except Exception as e_2d: 131 | try: 132 | fg = parse_3d_g2o_file(args.input_graph) 133 | logger.info("Parsed input graph as 3D g2o file.") 134 | except Exception as e_3d: 135 | raise ValueError( 136 | f"Failed to parse g2o file as 2D or 3D: {e_2d}, {e_3d}" 137 | ) 138 | elif args.input_graph.endswith(".pyfg"): 139 | fg = read_from_pyfg_text(args.input_graph) 140 | else: 141 | raise ValueError("Input graph file must be either .g2o or .pyfg format.") 142 | 143 | # Write the edge list 144 | write_edge_list(fg, args.output_edge_list) 145 | 146 | print(f"Wrote edge list to {args.output_edge_list}") 147 | -------------------------------------------------------------------------------- /py_factor_graph/priors.py: -------------------------------------------------------------------------------- 1 | import attr 2 | from typing import Optional, Tuple, Union 3 | import numpy as np 4 | from attrs import define, field 5 | import scipy.spatial as spatial 6 | from py_factor_graph.utils.attrib_utils import ( 7 | make_variable_name_validator, 8 | float_tuple_validator, 9 | make_rot_matrix_validator, 10 | positive_float_validator, 11 | ) 12 | from py_factor_graph.utils.matrix_utils import ( 13 | get_covariance_matrix_from_measurement_precisions, 14 | get_rotation_matrix_from_theta, 15 | get_quat_from_rotation_matrix, 16 | ) 17 | 18 | 19 | @define(frozen=False) 20 | class PosePrior2D: 21 | """A prior on the robot pose 22 | 23 | Arguments: 24 | name (str): the name of the pose variable 25 | position (Tuple[float, float]): the prior of the position 26 | theta (float): the prior of the theta 27 | covariance (np.ndarray): the covariance of the prior 28 | timestamp (float): seconds since epoch 29 | """ 30 | 31 | name: str = attr.ib() 32 | position: Tuple[float, float] = attr.ib() 33 | theta: float = attr.ib() 34 | translation_precision: float = attr.ib(validator=positive_float_validator) 35 | rotation_precision: float = attr.ib(validator=positive_float_validator) 36 | timestamp: Optional[float] = attr.ib(default=None) 37 | 38 | @property 39 | def x(self): 40 | return self.position[0] 41 | 42 | @property 43 | def y(self): 44 | return self.position[1] 45 | 46 | @property 47 | def translation_vector(self): 48 | return np.array([self.x, self.y]) 49 | 50 | @property 51 | def rotation_matrix(self): 52 | return get_rotation_matrix_from_theta(self.theta) 53 | 54 | @property 55 | def covariance(self): 56 | return get_covariance_matrix_from_measurement_precisions( 57 | self.translation_precision, self.rotation_precision, mat_dim=3 58 | ) 59 | 60 | 61 | @define(frozen=False) 62 | class PosePrior3D: 63 | name: str = attr.ib() 64 | position: Tuple[float, float, float] = attr.ib() 65 | rotation: np.ndarray = attr.ib(validator=make_rot_matrix_validator(3)) 66 | translation_precision: float = attr.ib(validator=positive_float_validator) 67 | rotation_precision: float = attr.ib(validator=positive_float_validator) 68 | timestamp: Optional[float] = attr.ib(default=None) 69 | 70 | @property 71 | def x(self) -> float: 72 | return self.position[0] 73 | 74 | @property 75 | def y(self) -> float: 76 | return self.position[1] 77 | 78 | @property 79 | def z(self) -> float: 80 | return self.position[2] 81 | 82 | @property 83 | def translation_vector(self) -> np.ndarray: 84 | return np.array(self.position) 85 | 86 | @property 87 | def rotation_matrix(self) -> np.ndarray: 88 | return self.rotation 89 | 90 | @property 91 | def yaw(self) -> float: 92 | return spatial.transform.Rotation.from_matrix(self.rotation).as_euler("zyx")[0] 93 | 94 | @property 95 | def covariance(self) -> np.ndarray: 96 | return get_covariance_matrix_from_measurement_precisions( 97 | self.translation_precision, self.rotation_precision, mat_dim=6 98 | ) 99 | 100 | @property 101 | def quat(self) -> np.ndarray: 102 | return get_quat_from_rotation_matrix(self.rotation) 103 | 104 | 105 | @define(frozen=False) 106 | class LandmarkPrior2D: 107 | """A prior on the landmark 108 | 109 | Arguments: 110 | name (str): the name of the landmark variable 111 | position (Tuple[float, float]): the prior of the position 112 | covariance (np.ndarray): the covariance of the prior 113 | """ 114 | 115 | name: str = field(validator=make_variable_name_validator("landmark")) 116 | position: Tuple[float, float] = field(validator=float_tuple_validator) 117 | translation_precision: float = field(validator=positive_float_validator) 118 | timestamp: Optional[float] = field(default=None) 119 | 120 | @property 121 | def x(self): 122 | return self.position[0] 123 | 124 | @property 125 | def y(self): 126 | return self.position[1] 127 | 128 | @property 129 | def translation_vector(self): 130 | return np.array(self.position) 131 | 132 | @property 133 | def covariance_matrix(self): 134 | return np.diag([1 / self.translation_precision] * 2) 135 | 136 | @property 137 | def covariance(self): 138 | return self.covariance_matrix 139 | 140 | 141 | @define(frozen=False) 142 | class LandmarkPrior3D: 143 | name: str = field(validator=make_variable_name_validator("landmark")) 144 | position: Tuple[float, float, float] = field(validator=float_tuple_validator) 145 | translation_precision: float = field(validator=positive_float_validator) 146 | timestamp: Optional[float] = field(default=None) 147 | 148 | @property 149 | def x(self) -> float: 150 | return self.position[0] 151 | 152 | @property 153 | def y(self) -> float: 154 | return self.position[1] 155 | 156 | @property 157 | def z(self) -> float: 158 | return self.position[2] 159 | 160 | @property 161 | def translation_vector(self): 162 | return np.array(self.position) 163 | 164 | @property 165 | def covariance_matrix(self): 166 | return np.diag([1 / self.translation_precision] * 3) 167 | 168 | @property 169 | def covariance(self): 170 | return self.covariance_matrix 171 | 172 | 173 | POSE_PRIOR_TYPES = Union[PosePrior2D, PosePrior3D] 174 | LANDMARK_PRIOR_TYPES = Union[LandmarkPrior2D, LandmarkPrior3D] 175 | -------------------------------------------------------------------------------- /docs/py_factor_graph/utils/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | py_factor_graph.utils API documentation 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 |
20 |
21 |
22 |

Module py_factor_graph.utils

23 |
24 |
25 |

Directory for utils files

26 |
27 | 28 | Expand source code 29 | 30 |
"""Directory for utils files """
31 |
32 |
33 |
34 |

Sub-modules

35 |
36 |
py_factor_graph.utils.data_utils
37 |
38 |
39 |
40 |
py_factor_graph.utils.name_utils
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 | 72 |
73 | 76 | 77 | -------------------------------------------------------------------------------- /py_factor_graph/io/efg_file.py: -------------------------------------------------------------------------------- 1 | from os.path import isfile 2 | 3 | from py_factor_graph.variables import PoseVariable2D, LandmarkVariable2D 4 | from py_factor_graph.measurements import ( 5 | PoseMeasurement2D, 6 | FGRangeMeasurement, 7 | ) 8 | from py_factor_graph.priors import ( 9 | PosePrior2D, 10 | ) 11 | from py_factor_graph.factor_graph import ( 12 | FactorGraphData, 13 | ) 14 | from py_factor_graph.utils.name_utils import ( 15 | get_robot_idx_from_frame_name, 16 | get_time_idx_from_frame_name, 17 | ) 18 | from py_factor_graph.utils.data_utils import get_covariance_matrix_from_list 19 | from py_factor_graph.utils.matrix_utils import ( 20 | get_measurement_precisions_from_covariance_matrix, 21 | ) 22 | 23 | 24 | def parse_efg_file(filepath: str) -> FactorGraphData: 25 | """ 26 | Parse a factor graph file to extract the factors and variables. Requires 27 | that the file ends with .fg (e.g. "my_file.fg"). 28 | 29 | Args: 30 | filepath: The path to the factor graph file. 31 | 32 | Returns: 33 | FactorGraphData: The factor graph data. 34 | 35 | Raises: 36 | ValueError: If the file does not end with .fg. 37 | ValueError: If the file does not exist. 38 | """ 39 | if not isfile(filepath): 40 | raise ValueError(f"File {filepath} does not exist.") 41 | if not filepath.endswith(".fg"): 42 | raise ValueError(f"File {filepath} does not end with .fg.") 43 | 44 | pose_var_header = "Variable Pose SE2" 45 | landmark_var_header = "Variable Landmark R2" 46 | pose_measure_header = "Factor SE2RelativeGaussianLikelihoodFactor" 47 | amb_measure_header = "Factor AmbiguousDataAssociationFactor" 48 | range_measure_header = "Factor SE2R2RangeGaussianLikelihoodFactor" 49 | pose_prior_header = "Factor UnarySE2ApproximateGaussianPriorFactor" 50 | landmark_prior_header = "Landmark" # don't have any of these yet 51 | 52 | new_fg_data = FactorGraphData(dimension=2) 53 | 54 | with open(filepath, "r") as f: 55 | for line in f: 56 | if line.startswith(pose_var_header): 57 | line_items = line.split() 58 | pose_name = line_items[3] 59 | x = float(line_items[4]) 60 | y = float(line_items[5]) 61 | theta = float(line_items[6]) 62 | pose_var = PoseVariable2D(pose_name, (x, y), theta) 63 | new_fg_data.add_pose_variable(pose_var) 64 | elif line.startswith(landmark_var_header): 65 | line_items = line.split() 66 | landmark_name = line_items[3] 67 | x = float(line_items[4]) 68 | y = float(line_items[5]) 69 | landmark_var = LandmarkVariable2D(landmark_name, (x, y)) 70 | new_fg_data.add_landmark_variable(landmark_var) 71 | elif line.startswith(pose_measure_header): 72 | line_items = line.split() 73 | base_pose = line_items[2] 74 | local_pose = line_items[3] 75 | delta_x = float(line_items[4]) 76 | delta_y = float(line_items[5]) 77 | delta_theta = float(line_items[6]) 78 | covar_list = [float(x) for x in line_items[8:]] 79 | covar = get_covariance_matrix_from_list(covar_list) 80 | ( 81 | trans_precision, 82 | rot_precision, 83 | ) = get_measurement_precisions_from_covariance_matrix( 84 | covar, matrix_dim=3 85 | ) 86 | measure = PoseMeasurement2D( 87 | base_pose, 88 | local_pose, 89 | delta_x, 90 | delta_y, 91 | delta_theta, 92 | trans_precision, 93 | rot_precision, 94 | ) 95 | 96 | base_pose_idx = get_robot_idx_from_frame_name(base_pose) 97 | local_pose_idx = get_robot_idx_from_frame_name(local_pose) 98 | base_time_idx = get_time_idx_from_frame_name(base_pose) 99 | local_time_idx = get_time_idx_from_frame_name(local_pose) 100 | 101 | # if either the robot indices are different or the time indices 102 | # are not sequential then it is a loop closure 103 | if ( 104 | base_pose_idx != local_pose_idx 105 | or local_time_idx != base_time_idx + 1 106 | ): 107 | new_fg_data.add_loop_closure(measure) 108 | 109 | # otherwise it is an odometry measurement 110 | else: 111 | new_fg_data.add_odom_measurement(base_pose_idx, measure) 112 | 113 | elif line.startswith(range_measure_header): 114 | line_items = line.split() 115 | var1 = line_items[2] 116 | var2 = line_items[3] 117 | dist = float(line_items[4]) 118 | stddev = float(line_items[5]) 119 | range_measure = FGRangeMeasurement((var1, var2), dist, stddev) 120 | new_fg_data.add_range_measurement(range_measure) 121 | 122 | elif line.startswith(pose_prior_header): 123 | line_items = line.split() 124 | pose_name = line_items[2] 125 | x = float(line_items[3]) 126 | y = float(line_items[4]) 127 | theta = float(line_items[5]) 128 | covar_list = [float(x) for x in line_items[7:]] 129 | covar = get_covariance_matrix_from_list(covar_list) 130 | ( 131 | translation_precision, 132 | rotation_precision, 133 | ) = get_measurement_precisions_from_covariance_matrix( 134 | covar, matrix_dim=3 135 | ) 136 | pose_prior = PosePrior2D( 137 | pose_name, (x, y), theta, translation_precision, rotation_precision 138 | ) 139 | new_fg_data.add_pose_prior(pose_prior) 140 | 141 | elif line.startswith(landmark_prior_header): 142 | raise NotImplementedError("Landmark priors not implemented yet") 143 | elif line.startswith(amb_measure_header): 144 | line_items = line.split() 145 | 146 | # if it is a range measurement then add to ambiguous range 147 | # measurements list 148 | if "SE2R2RangeGaussianLikelihoodFactor" in line: 149 | raise NotImplementedError( 150 | "Need to parse for ambiguous range measurements measurement" 151 | ) 152 | 153 | # if it is a pose measurement then add to ambiguous pose 154 | # measurements list 155 | elif "SE2RelativeGaussianLikelihoodFactor" in line: 156 | raise NotImplementedError( 157 | "Need to parse for ambiguous pose measurement" 158 | ) 159 | 160 | # this is a case that we haven't planned for yet 161 | else: 162 | raise NotImplementedError( 163 | f"Unknown measurement type in ambiguous measurement: {line}" 164 | ) 165 | 166 | return new_fg_data 167 | -------------------------------------------------------------------------------- /py_factor_graph/utils/attrib_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | import numpy as np 3 | 4 | 5 | def is_dimension(instance, attribute, value) -> None: 6 | """ 7 | Return validator for dimension. 8 | 9 | Args: 10 | value (int): value to validate 11 | 12 | Returns: 13 | None 14 | """ 15 | if not isinstance(value, int): 16 | raise ValueError(f"{value} is not an int") 17 | if not value in [2, 3]: 18 | raise ValueError(f"Value {value} is not 2 or 3") 19 | 20 | 21 | def range_validator(instance, attribute, value): 22 | if value < 0: 23 | raise ValueError(f"Value {value} should not be negative") 24 | 25 | 26 | def probability_validator(instance, attribute, value): 27 | """ 28 | Return validator for probability. 29 | 30 | Args: 31 | value (float): value to validate 32 | 33 | Returns: 34 | None 35 | """ 36 | if not isinstance(value, float): 37 | raise ValueError(f"{value} is not a float") 38 | if not 0 <= value <= 1: 39 | raise ValueError(f"Value {value} is not within range [0,1]") 40 | 41 | 42 | def positive_float_validator(instance, attribute, value): 43 | """ 44 | Return validator for positive float. 45 | 46 | Args: 47 | value (float): value to validate 48 | 49 | Returns: 50 | None 51 | """ 52 | if not isinstance(value, float) and not isinstance(value, int): 53 | raise ValueError(f"{value} is not a float (ints are also accepted)") 54 | if value < 0: 55 | raise ValueError(f"Value {value} is not positive") 56 | if np.isnan(value): 57 | raise ValueError(f"Value {value} is nan") 58 | 59 | 60 | def positive_int_validator(instance, attribute, value) -> None: 61 | """ 62 | Return validator for positive int. 63 | 64 | Args: 65 | value (int): value to validate 66 | 67 | Returns: 68 | None 69 | """ 70 | if not isinstance(value, int): 71 | raise ValueError(f"{value} is not an int") 72 | if value < 0: 73 | raise ValueError(f"Value {value} is not positive") 74 | 75 | 76 | def positive_int_tuple_validator(instance, attribute, value) -> None: 77 | """ 78 | Return validator for positive int. 79 | 80 | Args: 81 | value (int): value to validate 82 | 83 | Returns: 84 | None 85 | """ 86 | if not isinstance(value, tuple): 87 | raise ValueError(f"{value} is not a tuple") 88 | if not all(isinstance(x, int) for x in value): 89 | raise ValueError(f"At least one value in {value} is not an int") 90 | if not all(x >= 0 for x in value): 91 | raise ValueError(f"At least one value in {value} is negative") 92 | 93 | 94 | def float_tuple_validator(instance, attribute, value) -> None: 95 | """ 96 | Return validator for float tuple. 97 | 98 | Args: 99 | value (tuple): value to validate 100 | 101 | Returns: 102 | None 103 | """ 104 | if not isinstance(value, tuple): 105 | raise ValueError(f"{value} is not a tuple") 106 | 107 | if not all(isinstance(x, float) for x in value): 108 | raise ValueError(f"At least one value in {value} is not a float") 109 | 110 | 111 | def make_rot_matrix_validator(dimension: int) -> Callable: 112 | """ 113 | Return validator for rotation matrix. 114 | 115 | Args: 116 | value (np.ndarray): value to validate 117 | 118 | Returns: 119 | None 120 | """ 121 | 122 | def rot_matrix_validator(instance, attribute, value) -> None: 123 | if not isinstance(value, np.ndarray): 124 | raise ValueError(f"{value} is not a np.ndarray") 125 | 126 | if not value.shape[0] == value.shape[1]: 127 | raise ValueError(f"Rotation matrix is not square {value.shape}") 128 | 129 | if not value.shape[0] == dimension: 130 | raise ValueError(f"Rotation matrix is not {dimension}x{dimension}") 131 | 132 | if not np.allclose(value @ value.T, np.eye(dimension)): 133 | raise ValueError(f"{value} is not orthogonal") 134 | 135 | if not np.allclose(np.linalg.det(value), 1): 136 | raise ValueError(f"{value} has determinant {np.linalg.det(value)}") 137 | 138 | return rot_matrix_validator 139 | 140 | 141 | def rot_matrix_validator(instance, attribute, value) -> None: 142 | """ 143 | Return validator for rotation matrix. 144 | 145 | Args: 146 | value (np.ndarray): value to validate 147 | 148 | Returns: 149 | None 150 | """ 151 | if not isinstance(value, np.ndarray): 152 | raise ValueError(f"{value} is not a np.ndarray") 153 | 154 | if not value.shape[0] == value.shape[1]: 155 | raise ValueError(f"Rotation matrix is not square {value.shape}") 156 | 157 | dim = value.shape[0] 158 | if not np.allclose(value @ value.T, np.eye(dim)): 159 | raise ValueError(f"{value} is not orthogonal") 160 | 161 | if not np.allclose(np.linalg.det(value), 1): 162 | raise ValueError(f"{value} has determinant {np.linalg.det(value)}") 163 | 164 | 165 | def optional_float_validator(instance, attribute, value) -> None: 166 | """ 167 | Return validator for optional float. 168 | 169 | Args: 170 | value (float): value to validate 171 | 172 | Returns: 173 | None 174 | """ 175 | if value is not None: 176 | if not isinstance(value, float): 177 | raise ValueError(f"{value} is not a float") 178 | 179 | 180 | def make_variable_name_validator(type: str) -> Callable: 181 | """ 182 | Return validator for either pose or landmark names. Should be a string of 183 | form 184 | 185 | "" 186 | Poses should NOT start with 'L' : Poses = "A1", "B27", "C19" 187 | Landmarks should start with 'L' : Landmarks = "L1", "L2", "L3" 188 | 189 | Args: 190 | value (str): value to validate 191 | 192 | Returns: 193 | Callable: validator 194 | """ 195 | valid_types = ["pose", "landmark"] 196 | assert ( 197 | type in valid_types 198 | ), f"Type {type} is not valid, should be one of {valid_types}" 199 | 200 | def variable_name_validator(instance, attribute, value) -> None: 201 | if not isinstance(value, str): 202 | raise ValueError(f"{value} is not a string") 203 | 204 | # allow exception for declaring an "origin" pose 205 | if type == "pose" and value == "origin": 206 | return 207 | 208 | first_char = value[0] 209 | 210 | if type == "pose" and first_char == "L": 211 | raise ValueError(f"{value} starts with L - reserved for landmarks") 212 | elif type == "landmark" and first_char != "L": 213 | raise ValueError( 214 | f"{value} does not start with L - landmarks should start with L" 215 | ) 216 | 217 | if not first_char.isalpha() or first_char.islower(): 218 | raise ValueError(f"{value} does not start with a capital letter") 219 | 220 | if not value[1:].isdigit(): 221 | raise ValueError(f"{value} does not end with a number") 222 | 223 | return variable_name_validator 224 | 225 | 226 | def general_variable_name_validator(instance, attribute, value) -> None: 227 | """ 228 | Return validator for variable names. Should be a string of form 229 | 230 | "" 231 | Poses should NOT start with 'L' : Poses = "A1", "B27", "C19" 232 | Landmarks should start with 'L' : Landmarks = "L1", "L2", "L3" 233 | 234 | Args: 235 | value (str): value to validate 236 | 237 | Returns: 238 | Callable: validator 239 | """ 240 | if not isinstance(value, str): 241 | raise ValueError(f"{value} is not a string") 242 | 243 | first_char = value[0] 244 | if not first_char.isalpha() or first_char.islower(): 245 | raise ValueError(f"{value} does not start with a capital letter") 246 | 247 | if not value[1:].isdigit(): 248 | raise ValueError(f"{value} does not end with a number") 249 | -------------------------------------------------------------------------------- /docs/py_factor_graph/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | py_factor_graph API documentation 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 |
20 |
21 |
22 |

Package py_factor_graph

23 |
24 |
25 |

Initialize factor graph package.

26 |
27 | 28 | Expand source code 29 | 30 |
"""Initialize factor graph package."""
31 |
32 |
33 |
34 |

Sub-modules

35 |
36 |
py_factor_graph.factor_graph
37 |
38 |
39 |
40 |
py_factor_graph.measurements
41 |
42 |
43 |
44 |
py_factor_graph.parsing
45 |
46 |
47 |
48 |
py_factor_graph.priors
49 |
50 |
51 |
52 |
py_factor_graph.utils
53 |
54 |

Directory for utils files

55 |
56 |
py_factor_graph.variables
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 | 87 |
88 | 91 | 92 | -------------------------------------------------------------------------------- /py_factor_graph/variables.py: -------------------------------------------------------------------------------- 1 | import attr 2 | from typing import Tuple, Optional, Union 3 | import numpy as np 4 | from py_factor_graph.utils.matrix_utils import ( 5 | get_quat_from_rotation_matrix, 6 | _check_transformation_matrix, 7 | get_rotation_matrix_from_transformation_matrix, 8 | get_translation_from_transformation_matrix, 9 | get_theta_from_transformation_matrix, 10 | ) 11 | from py_factor_graph.utils.attrib_utils import ( 12 | optional_float_validator, 13 | make_rot_matrix_validator, 14 | make_variable_name_validator, 15 | ) 16 | import scipy.spatial as spatial 17 | 18 | 19 | @attr.s() 20 | class PoseVariable2D: 21 | """A variable which is a robot pose 22 | 23 | Args: 24 | name (str): the name of the variable (defines the frame) 25 | true_position (Tuple[float, float]): the true position of the robot 26 | true_theta (float): the true orientation of the robot 27 | timestamp (float): seconds since epoch 28 | """ 29 | 30 | name: str = attr.ib(validator=make_variable_name_validator("pose")) 31 | true_position: Tuple[float, float] = attr.ib() 32 | true_theta: float = attr.ib(validator=attr.validators.instance_of(float)) 33 | timestamp: Optional[float] = attr.ib(default=None) 34 | 35 | @true_position.validator 36 | def _check_true_position(self, attribute, value): 37 | if len(value) != 2: 38 | raise ValueError(f"true_position should be a tuple of length 2") 39 | assert all(isinstance(x, float) for x in value) 40 | 41 | @property 42 | def rotation_matrix(self) -> np.ndarray: 43 | """ 44 | Get the rotation matrix for the measurement 45 | """ 46 | return np.array( 47 | [ 48 | [np.cos(self.true_theta), -np.sin(self.true_theta)], 49 | [np.sin(self.true_theta), np.cos(self.true_theta)], 50 | ] 51 | ) 52 | 53 | @property 54 | def position_vector(self) -> np.ndarray: 55 | """ 56 | Get the position vector for the measurement 57 | """ 58 | return np.array(self.true_position) 59 | 60 | @property 61 | def true_x(self) -> float: 62 | return self.true_position[0] 63 | 64 | @property 65 | def true_y(self) -> float: 66 | return self.true_position[1] 67 | 68 | @property 69 | def true_z(self) -> float: 70 | return 0 71 | 72 | @property 73 | def true_quat(self) -> np.ndarray: 74 | quat = np.array( 75 | [0.0, 0.0, np.sin(self.true_theta / 2), np.cos(self.true_theta / 2)] 76 | ) 77 | return quat 78 | 79 | @property 80 | def transformation_matrix(self) -> np.ndarray: 81 | """Returns the transformation matrix representing the true latent pose 82 | of this variable 83 | 84 | Returns: 85 | np.ndarray: the transformation matrix 86 | """ 87 | T = np.eye(3) 88 | T[0:2, 0:2] = self.rotation_matrix 89 | T[0, 2] = self.true_x 90 | T[1, 2] = self.true_y 91 | return T 92 | 93 | def transform(self, T: np.ndarray) -> "PoseVariable2D": 94 | """Returns the transformation matrix representing the true latent pose 95 | of this variable 96 | 97 | Returns: 98 | PoseVariable2D: the transformed pose 99 | """ 100 | _check_transformation_matrix(T) 101 | assert T.shape == (3, 3) 102 | current_transformation = self.transformation_matrix 103 | new_transformation = current_transformation @ T 104 | new_position = get_translation_from_transformation_matrix(new_transformation) 105 | new_theta = get_theta_from_transformation_matrix(new_transformation) 106 | pos2d = (float(new_position[0]), float(new_position[1])) 107 | return PoseVariable2D(self.name, pos2d, new_theta, self.timestamp) 108 | 109 | 110 | @attr.s() 111 | class PoseVariable3D: 112 | """A variable which is a robot pose 113 | 114 | Args: 115 | name (str): the name of the variable (defines the frame) 116 | true_position (Tuple[float, float, float]): the true position of the robot 117 | true_rotation (np.ndarray): the true orientation of the robot 118 | timestamp (float): seconds since epoch 119 | """ 120 | 121 | name: str = attr.ib(validator=make_variable_name_validator("pose")) 122 | true_position: Tuple[float, float, float] = attr.ib() 123 | true_rotation: np.ndarray = attr.ib(validator=make_rot_matrix_validator(3)) 124 | timestamp: Optional[float] = attr.ib( 125 | default=None, validator=optional_float_validator 126 | ) 127 | 128 | @true_position.validator 129 | def _check_true_position(self, attribute, value): 130 | if len(value) != 3: 131 | raise ValueError(f"true_position should be a tuple of length 3: {value}") 132 | assert all(isinstance(x, float) for x in value) 133 | 134 | @property 135 | def dimension(self) -> int: 136 | return 3 137 | 138 | @property 139 | def rotation_matrix(self) -> np.ndarray: 140 | """ 141 | Get the rotation matrix for the measurement 142 | """ 143 | return self.true_rotation 144 | 145 | @property 146 | def position_vector(self) -> np.ndarray: 147 | """ 148 | Get the position vector for the measurement 149 | """ 150 | return np.array(self.true_position) 151 | 152 | @property 153 | def true_x(self) -> float: 154 | return self.true_position[0] 155 | 156 | @property 157 | def true_y(self) -> float: 158 | return self.true_position[1] 159 | 160 | @property 161 | def true_z(self) -> float: 162 | return self.true_position[2] 163 | 164 | @property 165 | def yaw(self) -> float: 166 | rot_mat = self.rotation_matrix 167 | yaw = spatial.transform.Rotation.from_matrix(rot_mat).as_euler("zyx")[0] 168 | return yaw 169 | 170 | @property 171 | def true_quat(self) -> np.ndarray: 172 | rot = self.rotation_matrix 173 | quat = get_quat_from_rotation_matrix(rot) 174 | return quat 175 | 176 | @property 177 | def transformation_matrix(self) -> np.ndarray: 178 | """Returns the transformation matrix representing the true latent pose 179 | of this variable 180 | 181 | Returns: 182 | np.ndarray: the transformation matrix 183 | """ 184 | T = np.eye(self.dimension + 1) 185 | T[: self.dimension, : self.dimension] = self.true_rotation 186 | T[: self.dimension, self.dimension] = self.true_position 187 | assert T.shape == (4, 4) 188 | return T 189 | 190 | def transform(self, T: np.ndarray) -> "PoseVariable3D": 191 | """Returns the transformation matrix representing the true latent pose 192 | of this variable 193 | 194 | Returns: 195 | PoseVariable3D: the transformed pose 196 | """ 197 | _check_transformation_matrix(T) 198 | assert T.shape == (4, 4) 199 | current_transformation = self.transformation_matrix 200 | new_transformation = current_transformation @ T 201 | new_position = get_translation_from_transformation_matrix(new_transformation) 202 | new_rotation = get_rotation_matrix_from_transformation_matrix( 203 | new_transformation 204 | ) 205 | pos3d = (float(new_position[0]), float(new_position[1]), float(new_position[2])) 206 | return PoseVariable3D(self.name, pos3d, new_rotation, self.timestamp) 207 | 208 | 209 | @attr.s() 210 | class LandmarkVariable2D: 211 | """A variable which is a landmark 212 | 213 | Arguments: 214 | name (str): the name of the variable 215 | true_position (Tuple[float, float]): the true position of the landmark 216 | """ 217 | 218 | name: str = attr.ib(validator=make_variable_name_validator("landmark")) 219 | true_position: Tuple[float, float] = attr.ib() 220 | 221 | @true_position.validator 222 | def _check_true_position(self, attribute, value): 223 | if len(value) != 2: 224 | raise ValueError(f"true_position should be a tuple of length 2") 225 | assert all(isinstance(x, float) for x in value) 226 | 227 | @property 228 | def true_x(self): 229 | return self.true_position[0] 230 | 231 | @property 232 | def true_y(self): 233 | return self.true_position[1] 234 | 235 | 236 | @attr.s() 237 | class LandmarkVariable3D: 238 | """A variable which is a landmark 239 | 240 | Arguments: 241 | name (str): the name of the variable 242 | true_position (Tuple[float, float, float]): the true position of the landmark 243 | """ 244 | 245 | name: str = attr.ib(validator=make_variable_name_validator("landmark")) 246 | true_position: Tuple[float, float, float] = attr.ib() 247 | 248 | @true_position.validator 249 | def _check_true_position(self, attribute, value): 250 | if len(value) != 3: 251 | raise ValueError(f"true_position should be a tuple of length 3") 252 | assert all(isinstance(x, float) for x in value) 253 | 254 | @property 255 | def true_x(self): 256 | return self.true_position[0] 257 | 258 | @property 259 | def true_y(self): 260 | return self.true_position[1] 261 | 262 | @property 263 | def true_z(self): 264 | return self.true_position[2] 265 | 266 | 267 | POSE_VARIABLE_TYPES = Union[PoseVariable2D, PoseVariable3D] 268 | LANDMARK_VARIABLE_TYPES = Union[LandmarkVariable2D, LandmarkVariable3D] 269 | 270 | 271 | def dist_between_variables( 272 | var1: Union[POSE_VARIABLE_TYPES, LANDMARK_VARIABLE_TYPES], 273 | var2: Union[POSE_VARIABLE_TYPES, LANDMARK_VARIABLE_TYPES], 274 | ) -> float: 275 | """Returns the distance between two variables""" 276 | if isinstance(var1, PoseVariable2D) or isinstance(var1, PoseVariable3D): 277 | pos1 = var1.position_vector 278 | elif isinstance(var1, LandmarkVariable2D) or isinstance(var1, LandmarkVariable3D): 279 | pos1 = np.array(var1.true_position) 280 | else: 281 | raise ValueError(f"Variable {var1} not supported") 282 | 283 | if isinstance(var2, PoseVariable2D) or isinstance(var2, PoseVariable3D): 284 | pos2 = var2.position_vector 285 | elif isinstance(var2, LandmarkVariable2D) or isinstance(var2, LandmarkVariable3D): 286 | pos2 = np.array(var2.true_position) 287 | else: 288 | raise ValueError(f"Variable {var2} not supported") 289 | 290 | dist = np.linalg.norm(pos1 - pos2).astype(float) 291 | return dist 292 | -------------------------------------------------------------------------------- /py_factor_graph/utils/solver_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, Tuple, List 2 | import pickle 3 | from os.path import isfile, dirname, isdir 4 | from os import makedirs 5 | import numpy as np 6 | import attr 7 | 8 | from py_factor_graph.utils.matrix_utils import ( 9 | get_rotation_matrix_from_transformation_matrix, 10 | get_theta_from_transformation_matrix, 11 | get_quat_from_rotation_matrix, 12 | get_translation_from_transformation_matrix, 13 | _check_transformation_matrix, 14 | ) 15 | from py_factor_graph.utils.logging_utils import logger 16 | 17 | 18 | @attr.s(frozen=True) 19 | class VariableValues: 20 | dim: int = attr.ib(validator=attr.validators.instance_of(int)) 21 | poses: Dict[str, np.ndarray] = attr.ib() 22 | landmarks: Dict[str, np.ndarray] = attr.ib() 23 | distances: Optional[Dict[Tuple[str, str], np.ndarray]] = attr.ib(default=None) 24 | pose_times: Optional[Dict[str, float]] = attr.ib(default=None) 25 | 26 | @dim.validator 27 | def _check_dim(self, attribute, value: int): 28 | assert value in (2, 3) 29 | 30 | @poses.validator 31 | def _check_poses(self, attribute, value: Dict[str, np.ndarray]): 32 | for pose in value.values(): 33 | _check_transformation_matrix(pose, dim=self.dim) 34 | 35 | @landmarks.validator 36 | def _check_landmarks(self, attribute, value: Dict[str, np.ndarray]): 37 | for landmark in value.values(): 38 | assert landmark.shape == (self.dim,) 39 | 40 | @distances.validator 41 | def _check_distances( 42 | self, attribute, value: Optional[Dict[Tuple[str, str], np.ndarray]] 43 | ): 44 | if value is not None: 45 | for distance in value.values(): 46 | assert distance.shape in [ 47 | (1,), 48 | (self.dim,), 49 | ], f"Expected shape ({self.dim},) or (1,) but got {distance.shape} for distance" 50 | 51 | @property 52 | def rotations_theta(self) -> Dict[str, float]: 53 | return { 54 | key: get_theta_from_transformation_matrix(value) 55 | for key, value in self.poses.items() 56 | } 57 | 58 | @property 59 | def rotations_matrix(self) -> Dict[str, np.ndarray]: 60 | return { 61 | key: get_rotation_matrix_from_transformation_matrix(value) 62 | for key, value in self.poses.items() 63 | } 64 | 65 | @property 66 | def rotations_quat(self) -> Dict[str, np.ndarray]: 67 | return { 68 | key: get_quat_from_rotation_matrix(value) 69 | for key, value in self.rotations_matrix.items() 70 | } 71 | 72 | @property 73 | def translations(self) -> Dict[str, np.ndarray]: 74 | trans_vals = { 75 | key: get_translation_from_transformation_matrix(value) 76 | for key, value in self.poses.items() 77 | } 78 | landmark_trans_vals = {key: value for key, value in self.landmarks.items()} 79 | trans_vals.update(landmark_trans_vals) 80 | return trans_vals 81 | 82 | @property 83 | def limits(self) -> Tuple[Tuple[float, float], Tuple[float, float]]: 84 | """ 85 | Returns the x/y limits of the poses and landmarks 86 | 87 | Returns: 88 | (xmin, xmax), (ymin, ymax) 89 | """ 90 | translations = self.translations.values() 91 | trans_as_array = np.array(list(translations)) 92 | nrows, ncols = trans_as_array.shape 93 | assert ncols == self.dim 94 | max_vals = np.max(trans_as_array, axis=0) 95 | min_vals = np.min(trans_as_array, axis=0) 96 | return (min_vals[0], max_vals[0]), (min_vals[1], max_vals[1]) 97 | 98 | 99 | @attr.s(frozen=True) 100 | class SolverResults: 101 | variables: VariableValues = attr.ib() 102 | total_time: float = attr.ib() 103 | solved: bool = attr.ib() 104 | pose_chain_names: Optional[list] = attr.ib(default=None) # Default [[str]] 105 | solver_cost: Optional[float] = attr.ib(default=None) 106 | 107 | @property 108 | def dim(self) -> int: 109 | return self.variables.dim 110 | 111 | @property 112 | def poses(self): 113 | return self.variables.poses 114 | 115 | @property 116 | def translations(self): 117 | return self.variables.translations 118 | 119 | @property 120 | def rotations_quat(self): 121 | return self.variables.rotations_quat 122 | 123 | @property 124 | def rotations_theta(self): 125 | return self.variables.rotations_theta 126 | 127 | @property 128 | def landmarks(self): 129 | return self.variables.landmarks 130 | 131 | @property 132 | def distances(self): 133 | return self.variables.distances 134 | 135 | @property 136 | def limits(self) -> Tuple[Tuple[float, float], Tuple[float, float]]: 137 | """ 138 | Returns the x/y limits of the poses and landmarks 139 | 140 | Returns: 141 | (xmin, xmax), (ymin, ymax) 142 | """ 143 | return self.variables.limits 144 | 145 | @property 146 | def pose_times(self): 147 | return self.variables.pose_times 148 | 149 | 150 | def save_results_to_file( 151 | solved_results: SolverResults, 152 | solved_cost: float, 153 | solve_success: bool, 154 | filepath: str, 155 | ): 156 | """ 157 | Saves the results to a file 158 | 159 | Args: 160 | solved_results: The results of the solver 161 | solved_cost: The cost of the solved results 162 | solve_success: Whether the solver was successful 163 | filepath: The path to save the results to 164 | """ 165 | data_dir = dirname(filepath) 166 | if not isdir(data_dir): 167 | makedirs(data_dir) 168 | 169 | if filepath.endswith(".pickle") or filepath.endswith(".pkl"): 170 | pickle_file = open(filepath, "wb") 171 | pickle.dump(solved_results, pickle_file) 172 | solve_info = { 173 | "success": solve_success, 174 | "optimal_cost": solved_cost, 175 | } 176 | pickle.dump(solve_info, pickle_file) 177 | pickle_file.close() 178 | 179 | elif filepath.endswith(".txt"): 180 | raise NotImplementedError( 181 | "Saving to txt not implemented yet since allowing for 3D" 182 | ) 183 | with open(filepath, "w") as f: 184 | translations = solved_results.translations 185 | rot_thetas = solved_results.rotations_theta 186 | for pose_key in translations.keys(): 187 | trans_solve = translations[pose_key] 188 | theta_solve = rot_thetas[pose_key] 189 | 190 | trans_string = np.array2string( 191 | trans_solve, precision=1, floatmode="fixed" 192 | ) 193 | status = ( 194 | f"State {pose_key}" 195 | + f" | Translation: {trans_string}" 196 | + f" | Rotation: {theta_solve:.2f}\n" 197 | ) 198 | f.write(status) 199 | 200 | landmarks = solved_results.landmarks 201 | for landmark_key in landmarks.keys(): 202 | landmark_solve = landmarks[landmark_key] 203 | 204 | landmark_string = np.array2string( 205 | landmark_solve, precision=1, floatmode="fixed" 206 | ) 207 | status = ( 208 | f"State {landmark_key}" + f" | Translation: {landmark_string}\n" 209 | ) 210 | f.write(status) 211 | 212 | f.write(f"Is optimization successful? {solve_success}\n") 213 | f.write(f"optimal cost: {solved_cost}") 214 | 215 | # Outputs each posechain as a separate file with timestamp in TUM format 216 | elif filepath.endswith(".tum"): 217 | save_to_tum(solved_results, filepath) 218 | else: 219 | raise ValueError( 220 | f"The file extension {filepath.split('.')[-1]} is not supported. " 221 | ) 222 | 223 | logger.debug(f"Results saved to: {filepath}\n") 224 | 225 | 226 | def save_to_tum( 227 | solved_results: SolverResults, 228 | filepath: str, 229 | strip_extension: bool = False, 230 | verbose: bool = False, 231 | ) -> List[str]: 232 | """Saves a given set of solver results to a number of TUM files, with one 233 | for each pose chain in the results. 234 | 235 | Args: 236 | solved_results (SolverResults): [description] 237 | filepath (str): the path to save the results to. The final files will 238 | have the pose chain letter appended to the end to indicate which pose chain. 239 | strip_extension (bool, optional): Whether to strip the file extension 240 | and replace with ".tum". This should be set to true if the file 241 | extension is not already ".tum". Defaults to False. 242 | 243 | Returns: 244 | List[str]: The list of filepaths that the results were saved to. 245 | """ 246 | assert ( 247 | solved_results.pose_chain_names is not None 248 | ), "Pose_chain_names must be provided for multi robot trajectories" 249 | acceptable_pose_chain_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ".replace("L", "") 250 | # TODO: Add support for exporting without pose_chain_names 251 | 252 | save_files = [] 253 | for pose_chain in solved_results.pose_chain_names: 254 | if len(pose_chain) == 0: 255 | continue 256 | pose_chain_letter = pose_chain[0][0] # Get first letter of first pose in chain 257 | assert ( 258 | pose_chain_letter in acceptable_pose_chain_letters 259 | ), "Pose chain letter must be uppercase letter and not L" 260 | 261 | # Removes extension from filepath to add tum extension 262 | if strip_extension: 263 | filepath = filepath.split(".")[0] + ".tum" 264 | 265 | assert filepath.endswith(".tum"), "File extension must be .tum" 266 | modified_path = filepath.replace(".tum", f"_{pose_chain_letter}.tum") 267 | 268 | # if file already exists we won't write over it 269 | if verbose and isfile(modified_path) and "/tmp/" not in modified_path: 270 | logger.warning(f"{modified_path} already exists, overwriting") 271 | 272 | if not isdir(dirname(modified_path)): 273 | makedirs(dirname(modified_path)) 274 | 275 | pose_times = solved_results.pose_times 276 | with open(modified_path, "w") as f: 277 | translations = solved_results.translations 278 | quats = solved_results.rotations_quat 279 | for pose_key in pose_chain: 280 | trans_solve = translations[pose_key] 281 | if len(trans_solve) == 2: 282 | tx, ty = trans_solve 283 | tz = 0.0 284 | elif len(trans_solve) == 3: 285 | tx, ty, tz = trans_solve 286 | else: 287 | raise ValueError( 288 | f"Solved for translation of wrong dimension {len(trans_solve)}" 289 | ) 290 | 291 | quat_solve = quats[pose_key] 292 | qx, qy, qz, qw = quat_solve 293 | i = pose_times[pose_key] 294 | f.write( 295 | f"{i:6f} {tx:.5f} {ty:.5f} {tz:.5f} {qx:.8f} {qy:.8f} {qz:.8f} {qw:.8f}\n" 296 | ) 297 | # f.write(f"{i} {tx} {ty} {tz} {qx} {qy} {qz} {qw}\n") 298 | 299 | if verbose and "/tmp/" not in modified_path: 300 | logger.info(f"Wrote: {modified_path}") 301 | save_files.append(modified_path) 302 | 303 | return save_files 304 | 305 | 306 | def load_custom_init_file(file_path: str) -> VariableValues: 307 | """Loads the custom init file. Is either a pickled VariableValues object 308 | or a pickled SolverResults object. 309 | 310 | Args: 311 | file_path (str): path to the custom init file 312 | """ 313 | 314 | assert isfile(file_path), f"File {file_path} does not exist" 315 | assert file_path.endswith(".pickle") or file_path.endswith( 316 | ".pkl" 317 | ), f"File {file_path} must end with '.pickle' or '.pkl'" 318 | 319 | logger.debug(f"Loading custom init file: {file_path}") 320 | with open(file_path, "rb") as f: 321 | init_dict = pickle.load(f) 322 | if isinstance(init_dict, SolverResults): 323 | return init_dict.variables 324 | elif isinstance(init_dict, VariableValues): 325 | return init_dict 326 | else: 327 | raise ValueError(f"Unknown type: {type(init_dict)}") 328 | 329 | 330 | def load_pickled_solution(pickled_solution_path: str) -> SolverResults: 331 | with open(pickled_solution_path, "rb") as f: 332 | return pickle.load(f) 333 | -------------------------------------------------------------------------------- /py_factor_graph/io/plaza_experiments.py: -------------------------------------------------------------------------------- 1 | """ 2 | For parsing the Plaza dataset 3 | 4 | paper: https://onlinelibrary.wiley.com/doi/pdf/10.1002/rob.20311 5 | dataset: https://infoscience.epfl.ch/record/283435 6 | """ 7 | from typing import List, Dict, Tuple, Optional 8 | import os 9 | import numpy as np 10 | import pandas as pd 11 | 12 | from py_factor_graph.variables import PoseVariable2D, LandmarkVariable2D 13 | from py_factor_graph.measurements import ( 14 | PoseMeasurement2D, 15 | FGRangeMeasurement, 16 | ) 17 | from py_factor_graph.calibrations.range_measurement_calibration import ( 18 | UncalibratedRangeMeasurement, 19 | get_inlier_set_of_range_measurements, 20 | get_linearly_calibrated_measurements, 21 | ) 22 | from py_factor_graph.factor_graph import ( 23 | FactorGraphData, 24 | ) 25 | from py_factor_graph.utils.matrix_utils import ( 26 | get_measurement_precisions_from_covariances, 27 | ) 28 | from attrs import define, field 29 | 30 | ODOM_EXTENSION = "_DR.csv" 31 | ODOM_PATH_EXTENSION = "_DRp.csv" 32 | GT_ROBOT_EXTENSION = "_GT.csv" 33 | DIST_MEASURE_EXTENSION = "_TD.csv" 34 | GT_LANDMARK_EXTENSION = "_TL.csv" 35 | 36 | import logging, coloredlogs 37 | 38 | logger = logging.getLogger(__name__) 39 | field_styles = { 40 | "filename": {"color": "green"}, 41 | "levelname": {"bold": True, "color": "black"}, 42 | "name": {"color": "blue"}, 43 | } 44 | coloredlogs.install( 45 | level="INFO", 46 | fmt="[%(filename)s:%(lineno)d] %(name)s %(levelname)s - %(message)s", 47 | field_styles=field_styles, 48 | ) 49 | 50 | 51 | def _get_file_with_extension(files: List[str], extension: str) -> str: 52 | candidate_files = [f for f in files if f.endswith(extension)] 53 | assert ( 54 | len(candidate_files) == 1 55 | ), f"Found {len(candidate_files)} files with extension {extension} but expected 1." 56 | assert os.path.isfile( 57 | candidate_files[0] 58 | ), f"Selected {candidate_files[0]} but it is not a valid file." 59 | return candidate_files[0] 60 | 61 | 62 | @define 63 | class PlazaDataFiles: 64 | dirpath: str = field() 65 | 66 | # have all of the files but construct after init 67 | odom_file: str = field(init=False) 68 | odom_path_file: str = field(init=False) 69 | gt_robot_file: str = field(init=False) 70 | dist_measure_file: str = field(init=False) 71 | gt_landmark_file: str = field(init=False) 72 | 73 | @dirpath.validator 74 | def _check_dirpath(self, attribute, value): 75 | if not os.path.isdir(value): 76 | raise ValueError(f"dirpath {value} is not a directory.") 77 | 78 | def __attrs_post_init__(self): 79 | all_files = [os.path.join(self.dirpath, f) for f in os.listdir(self.dirpath)] 80 | self.odom_file = _get_file_with_extension(all_files, ODOM_EXTENSION) 81 | self.odom_path_file = _get_file_with_extension(all_files, ODOM_PATH_EXTENSION) 82 | self.gt_robot_file = _get_file_with_extension(all_files, GT_ROBOT_EXTENSION) 83 | self.dist_measure_file = _get_file_with_extension( 84 | all_files, DIST_MEASURE_EXTENSION 85 | ) 86 | self.gt_landmark_file = _get_file_with_extension( 87 | all_files, GT_LANDMARK_EXTENSION 88 | ) 89 | 90 | def robot_gt_df(self) -> pd.DataFrame: 91 | headers = ["time", "x", "y", "theta"] 92 | return pd.read_csv(self.gt_robot_file, names=headers) 93 | 94 | def odom_df(self) -> pd.DataFrame: 95 | headers = ["time", "dx", "dtheta"] 96 | return pd.read_csv(self.odom_file, names=headers) 97 | 98 | def odom_path_df(self) -> pd.DataFrame: 99 | headers = ["time", "x", "y", "theta"] 100 | return pd.read_csv(self.odom_path_file, names=headers) 101 | 102 | def dist_measure_df(self) -> pd.DataFrame: 103 | headers = ["time", "robot_id", "beacon_id", "distance"] 104 | range_df = pd.read_csv(self.dist_measure_file, names=headers) 105 | assert ( 106 | len(range_df["robot_id"].unique()) == 1 107 | ), "Multiple robot ids found in range file." 108 | return range_df 109 | 110 | def landmark_gt_df(self) -> pd.DataFrame: 111 | headers = ["beacon_id", "x", "y"] 112 | return pd.read_csv(self.gt_landmark_file, names=headers) 113 | 114 | def get_beacon_id_to_idx_mapping(self) -> Dict[int, int]: 115 | beacon_id_to_idx: Dict[int, int] = {} 116 | with open(self.gt_landmark_file, "r") as f: 117 | for line in f.readlines(): 118 | beacon_id, x, y = line.split(",") 119 | beacon_id_to_idx[int(beacon_id)] = len(beacon_id_to_idx) 120 | return beacon_id_to_idx 121 | 122 | 123 | def _set_beacon_variables(fg: FactorGraphData, data_files: PlazaDataFiles): 124 | beacon_id_to_idx = data_files.get_beacon_id_to_idx_mapping() 125 | with open(data_files.gt_landmark_file, "r") as f: 126 | for line in f.readlines(): 127 | beacon_id, x, y = line.split(",") 128 | beacon_idx = beacon_id_to_idx[int(beacon_id)] 129 | beacon_var = LandmarkVariable2D( 130 | name=f"L{beacon_idx}", 131 | true_position=(float(x), float(y)), 132 | ) 133 | fg.add_landmark_variable(beacon_var) 134 | 135 | 136 | def _set_pose_variables(fg: FactorGraphData, data_files: PlazaDataFiles): 137 | gt_pose_df = data_files.robot_gt_df() 138 | if "plaza2" in data_files.dirpath.lower(): 139 | logger.warning("Plaza2 data detected. Adding pi offset to theta.") 140 | theta_offset = np.pi 141 | else: 142 | theta_offset = 0.0 143 | for idx, row in gt_pose_df.iterrows(): 144 | pose_var = PoseVariable2D( 145 | name=f"A{idx}", 146 | true_position=(row["x"], row["y"]), 147 | true_theta=row["theta"] + theta_offset, 148 | timestamp=row["time"], 149 | ) 150 | fg.add_pose_variable(pose_var) 151 | 152 | 153 | def _add_odometry_measurements(fg: FactorGraphData, data_files: PlazaDataFiles): 154 | odom_df = data_files.odom_df() 155 | 156 | translation_cov = (0.1) ** 2 157 | rot_cov = (0.01) ** 2 158 | trans_precision, rot_precision = get_measurement_precisions_from_covariances( 159 | translation_cov, rot_cov, mat_dim=3 160 | ) 161 | 162 | for idx, row in odom_df.iterrows(): 163 | odom_measure = PoseMeasurement2D( 164 | base_pose=f"A{idx}", 165 | to_pose=f"A{idx+1}", 166 | x=row["dx"], 167 | y=0.0, 168 | theta=row["dtheta"], 169 | translation_precision=trans_precision, 170 | rotation_precision=rot_precision, 171 | timestamp=row["time"], 172 | ) 173 | fg.add_odom_measurement(robot_idx=0, odom_meas=odom_measure) 174 | 175 | 176 | def _find_nearest_time_index( 177 | time_series: pd.Series, target_time: float, start_idx: int 178 | ) -> int: 179 | """ 180 | We know that time is sorted and that we will be iterating through the 181 | time_series in order. As a result, we can start our search from the 182 | previous index we found. 183 | """ 184 | for idx in range(start_idx, len(time_series)): 185 | if time_series[idx] >= target_time: 186 | prev_time = time_series[idx - 1] 187 | next_time = time_series[idx] 188 | prev_diff = abs(prev_time - target_time) 189 | next_diff = abs(next_time - target_time) 190 | return idx - 1 if prev_diff < next_diff else idx 191 | 192 | return len(time_series) - 1 193 | 194 | 195 | def _parse_uncalibrated_range_measures( 196 | data_files: PlazaDataFiles, 197 | ) -> List[UncalibratedRangeMeasurement]: 198 | beacon_id_to_idx = data_files.get_beacon_id_to_idx_mapping() 199 | gt_pose_df = data_files.robot_gt_df() 200 | range_df = data_files.dist_measure_df() 201 | range_df["beacon_id"] = range_df["beacon_id"].apply(lambda x: beacon_id_to_idx[x]) 202 | 203 | # collect a list of range measures for each robot-beacon pair - we will 204 | # average the measured distance and timestamps over these to get a single 205 | # range measurement for each robot-beacon pair 206 | range_measures: Dict[Tuple[str, str], List[Tuple[float, float]]] = {} 207 | most_recent_pose_idx = 0 208 | for _, row in range_df.iterrows(): 209 | range_measure_time = row["time"] 210 | nearest_robot_pose_idx = _find_nearest_time_index( 211 | gt_pose_df["time"], range_measure_time, most_recent_pose_idx 212 | ) 213 | most_recent_pose_idx = nearest_robot_pose_idx 214 | robot_pose_name = f"A{nearest_robot_pose_idx}" 215 | beacon_pose_name = f"L{int(row['beacon_id'])}" 216 | measured_distance = row["distance"] 217 | 218 | association = (robot_pose_name, beacon_pose_name) 219 | if association not in range_measures: 220 | range_measures[association] = [] 221 | range_measures[association].append((range_measure_time, measured_distance)) 222 | 223 | range_measure_list: List[UncalibratedRangeMeasurement] = [] 224 | for association, measures in range_measures.items(): 225 | robot_pose_name, beacon_pose_name = association 226 | avg_measured_distance = float(np.mean([x[1] for x in measures])) 227 | measured_timestamp = float(np.mean([x[0] for x in measures])) 228 | range_measure_list.append( 229 | UncalibratedRangeMeasurement( 230 | association=association, 231 | dist=avg_measured_distance, 232 | timestamp=measured_timestamp, 233 | ) 234 | ) 235 | 236 | return range_measure_list 237 | 238 | 239 | def _obtain_calibrated_measurements( 240 | data_files: PlazaDataFiles, 241 | uncalibrated_measures: List[UncalibratedRangeMeasurement], 242 | stddev: Optional[float] = None, 243 | ) -> List[FGRangeMeasurement]: 244 | gt_pose_df = data_files.robot_gt_df() 245 | beacon_idxs = data_files.get_beacon_id_to_idx_mapping().values() 246 | beacon_gt_df = data_files.landmark_gt_df() 247 | 248 | # group the range measures by beacon and add the true range (from GPS) to each 249 | calibration_pairs: Dict[int, List[UncalibratedRangeMeasurement]] = { 250 | x: [] for x in beacon_idxs 251 | } 252 | for uncal_measure in uncalibrated_measures: 253 | pose_name, beacon_name = uncal_measure.association 254 | robot_idx = int(pose_name[1:]) 255 | beacon_idx = int(beacon_name[1:]) 256 | 257 | true_robot_location = gt_pose_df.iloc[robot_idx][["x", "y"]].values 258 | true_beacon_location = beacon_gt_df.iloc[beacon_idx][["x", "y"]].values 259 | true_range = float(np.linalg.norm(true_robot_location - true_beacon_location)) 260 | 261 | uncal_measure.set_true_dist(true_range) 262 | calibration_pairs[beacon_idx].append(uncal_measure) 263 | 264 | inlier_measurements: Dict[int, List[UncalibratedRangeMeasurement]] = { 265 | beacon_idx: get_inlier_set_of_range_measurements(measures) 266 | for beacon_idx, measures in calibration_pairs.items() 267 | } 268 | 269 | all_calibrated_measurements: List[FGRangeMeasurement] = [] 270 | for beacon_idx, measures in inlier_measurements.items(): 271 | calibrated_measurements = get_linearly_calibrated_measurements(measures) 272 | all_calibrated_measurements.extend(calibrated_measurements) 273 | 274 | if stddev is not None: 275 | for measure in all_calibrated_measurements: 276 | measure.stddev = stddev 277 | 278 | return all_calibrated_measurements 279 | 280 | 281 | def _add_range_measurements( 282 | fg: FactorGraphData, 283 | data_files: PlazaDataFiles, 284 | range_stddev: Optional[float] = None, 285 | ): 286 | uncalibrated_range_measures = _parse_uncalibrated_range_measures(data_files) 287 | calibrated_ranges = _obtain_calibrated_measurements( 288 | data_files, uncalibrated_range_measures, stddev=range_stddev 289 | ) 290 | for range_measure in calibrated_ranges: 291 | fg.add_range_measurement(range_measure) 292 | 293 | 294 | def parse_plaza_files( 295 | dirpath: str, range_stddev: Optional[float] = None 296 | ) -> FactorGraphData: 297 | data_files = PlazaDataFiles(dirpath) 298 | if "gesling" in dirpath.lower(): 299 | raise NotImplementedError( 300 | """ 301 | Gesling data not yet supported. This data requires some 302 | additional calibration, as there are multiple radios attached to the robot 303 | (https://onlinelibrary.wiley.com/doi/pdf/10.1002/rob.20311) 304 | """ 305 | ) 306 | 307 | fg = FactorGraphData(dimension=2) 308 | _set_pose_variables(fg, data_files) 309 | _add_odometry_measurements(fg, data_files) 310 | _set_beacon_variables(fg, data_files) 311 | _add_range_measurements(fg, data_files, range_stddev=range_stddev) 312 | 313 | return fg 314 | 315 | 316 | if __name__ == "__main__": 317 | import os 318 | 319 | data_dir = os.path.expanduser("~/experimental_data/plaza/Plaza1") 320 | 321 | # parse and print summary 322 | fg = parse_plaza_files(data_dir) 323 | fg.print_summary() 324 | 325 | # animate if desired 326 | visualize = False 327 | if visualize: 328 | fg.animate_odometry( 329 | show_gt=True, 330 | pause_interval=0.01, 331 | draw_range_lines=True, 332 | draw_range_circles=False, 333 | num_timesteps_keep_ranges=1, 334 | ) 335 | 336 | # save the factor graph to file 337 | save_path = os.path.expanduser( 338 | "~/experimental_data/plaza/Plaza1/factor_graph.pickle" 339 | ) 340 | fg.save_to_file(save_path) 341 | -------------------------------------------------------------------------------- /py_factor_graph/utils/plot_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import os 4 | import matplotlib.pyplot as plt 5 | import matplotlib.patches as mpatches 6 | import matplotlib.lines as mlines 7 | import mpl_toolkits.mplot3d.art3d as art3d 8 | import mpl_toolkits.mplot3d.axes3d as axes3d 9 | 10 | from typing import Tuple, Union, Optional, List, Sequence 11 | from evo.tools import file_interface, plot as evoplot 12 | from py_factor_graph.utils.solver_utils import ( 13 | SolverResults, 14 | save_to_tum, 15 | ) 16 | from py_factor_graph.variables import PoseVariable2D, PoseVariable3D, LandmarkVariable2D, LandmarkVariable3D 17 | from py_factor_graph.measurements import FGRangeMeasurement 18 | from py_factor_graph.utils.matrix_utils import get_theta_from_rotation_matrix 19 | 20 | COLORS = ["blue", "red", "green", "yellow", "black", "cyan", "magenta"] 21 | 22 | 23 | def get_color(i: int) -> str: 24 | return COLORS[i % len(COLORS)] 25 | 26 | 27 | def draw_arrow( 28 | ax: plt.Axes, 29 | x: float, 30 | y: float, 31 | theta: float, 32 | quiver_length: float = 1, 33 | quiver_width: float = 0.1, 34 | color: str = "black", 35 | ) -> mpatches.FancyArrow: 36 | """Draws an arrow on the given axes 37 | 38 | Args: 39 | ax (plt.Axes): the axes to draw the arrow on 40 | x (float): the x position of the arrow 41 | y (float): the y position of the arrow 42 | theta (float): the angle of the arrow 43 | quiver_length (float, optional): the length of the arrow. Defaults to 1. 44 | quiver_width (float, optional): the width of the arrow. Defaults to 0.1. 45 | color (str, optional): color of the arrow. Defaults to "black". 46 | 47 | Returns: 48 | mpatches.FancyArrow: the arrow 49 | """ 50 | dx = quiver_length * math.cos(theta) 51 | dy = quiver_length * math.sin(theta) 52 | return ax.arrow( 53 | x, 54 | y, 55 | dx, 56 | dy, 57 | head_width=quiver_length, 58 | head_length=quiver_length, 59 | width=quiver_width, 60 | color=color, 61 | ) 62 | 63 | def draw_arrow_3d( 64 | ax: plt.Axes, 65 | x: float, 66 | y: float, 67 | z: float, 68 | dx: float, 69 | dy: float, 70 | dz: float, 71 | quiver_length: float = 1, 72 | color: str = "black", 73 | ): 74 | """Draws an arrow on the given axes 75 | 76 | Args: 77 | ax (plt.Axes): the axes to draw the arrow on 78 | x (float): the x position of the arrow 79 | y (float): the y position of the arrow 80 | z (float): the z position of the arrow 81 | dx (float): the x direction of the arrow 82 | dy (float): the y direction of the arrow 83 | dz (float): the z direction of the arrow 84 | quiver_length (float, optional): the length of the arrow. Defaults to 1. 85 | quiver_width (float, optional): the width of the arrow. Defaults to 0.1. 86 | color (str, optional): color of the arrow. Defaults to "black". 87 | 88 | Returns: 89 | art3d.Line3DCollection: the arrow 90 | """ 91 | return ax.quiver( 92 | x, 93 | y, 94 | z, 95 | dx, 96 | dy, 97 | dz, 98 | length = quiver_length * 5.0, 99 | color = color, 100 | normalize = True, 101 | ) 102 | 103 | def draw_line( 104 | ax: plt.Axes, 105 | x_start: float, 106 | y_start: float, 107 | x_end: float, 108 | y_end: float, 109 | color: str = "black", 110 | ) -> mlines.Line2D: 111 | """Draws a line on the given axes between the two points 112 | 113 | Args: 114 | ax (plt.Axes): the axes to draw the arrow on 115 | x_start (float): the x position of the start of the line 116 | y_start (float): the y position of the start of the line 117 | x_end (float): the x position of the end of the line 118 | y_end (float): the y position of the end of the line 119 | color (str, optional): color of the arrow. Defaults to "black". 120 | 121 | Returns: 122 | mpatches.FancyArrow: the arrow 123 | """ 124 | # if color is grey lets make the line dashed and reduce the line width 125 | if color == "grey": 126 | line = mlines.Line2D( 127 | [x_start, x_end], 128 | [y_start, y_end], 129 | color=color, 130 | linestyle="dashed", 131 | linewidth=0.5, 132 | ) 133 | else: 134 | line = mlines.Line2D([x_start, x_end], [y_start, y_end], color=color) 135 | 136 | ax.add_line(line) 137 | return line 138 | 139 | def draw_line_3d( 140 | ax: plt.Axes, 141 | x_start: float, 142 | y_start: float, 143 | z_start: float, 144 | x_end: float, 145 | y_end: float, 146 | z_end: float, 147 | color: str = "black", 148 | ) -> art3d.Line3D: 149 | """Draws a line on the given axes between the two points 150 | 151 | Args: 152 | ax (plt.Axes): the axes to draw the arrow on 153 | x_start (float): the x position of the start of the line 154 | y_start (float): the y position of the start of the line 155 | z_start (float): the z position of the start of the line 156 | x_end (float): the x position of the end of the line 157 | y_end (float): the y position of the end of the line 158 | z_end (float): the z position of the end of the line 159 | color (str, optional): color of the arrow. Defaults to "black". 160 | 161 | Returns: 162 | art3d.Line3D: the arrow 163 | """ 164 | # if color is grey lets make the line dashed and reduce the line width 165 | if color == "grey": 166 | line = art3d.Line3D( 167 | [x_start, x_end], 168 | [y_start, y_end], 169 | [z_start, z_end], 170 | color=color, 171 | linestyle="dashed", 172 | linewidth=0.5, 173 | ) 174 | else: 175 | line = art3d.Line3D([x_start, x_end], [y_start, y_end], [z_start, z_end],color=color) 176 | ax.add_line(line) 177 | return line 178 | 179 | def draw_circle(ax: plt.Axes, circle: np.ndarray, color="red") -> mpatches.Circle: 180 | assert circle.size == 3 181 | return ax.add_patch( 182 | mpatches.Circle(circle[0:2], circle[2], color=color, fill=False) 183 | ) 184 | 185 | 186 | def _get_pose_xytheta( 187 | pose: Union[np.ndarray, PoseVariable2D], 188 | ) -> Tuple[float, float, float]: 189 | assert isinstance(pose, np.ndarray) or isinstance(pose, PoseVariable2D) 190 | if isinstance(pose, PoseVariable2D): 191 | x = pose.true_x 192 | y = pose.true_y 193 | theta = pose.true_theta 194 | else: 195 | x = pose[0, 2] 196 | y = pose[1, 2] 197 | theta = get_theta_from_rotation_matrix(pose[0:2, 0:2]) 198 | return x, y, theta 199 | 200 | def _get_pose_3d( 201 | pose: Union[np.ndarray, PoseVariable3D], 202 | ) -> Tuple[float, float, float, float, float, float]: 203 | assert isinstance(pose, np.ndarray) or isinstance(pose, PoseVariable3D) 204 | if isinstance(pose, PoseVariable3D): 205 | x = pose.true_x 206 | y = pose.true_y 207 | z = pose.true_z 208 | dx = pose.true_rotation[0, 0] 209 | dy = pose.true_rotation[1, 0] 210 | dz = pose.true_rotation[2, 0] 211 | else: 212 | x = pose[0, 3] 213 | y = pose[1, 3] 214 | z = pose[2, 3] 215 | dx = pose[0, 0] 216 | dy = pose[1, 0] 217 | dz = pose[2, 0] 218 | 219 | length = np.sqrt(dx**2.0 + dy**2.0 + dz**2.0) 220 | return x, y, z, dx / length, dy / length, dz / length 221 | 222 | 223 | def draw_pose( 224 | ax: plt.Axes, 225 | pose: Union[np.ndarray, PoseVariable2D], 226 | color="blue", 227 | scale: float = 1, 228 | ) -> mpatches.FancyArrow: 229 | true_x, true_y, true_theta = _get_pose_xytheta(pose) 230 | return draw_arrow( 231 | ax, 232 | true_x, 233 | true_y, 234 | true_theta, 235 | color=color, 236 | quiver_length=scale, 237 | quiver_width=scale / 10, 238 | ) 239 | 240 | def draw_pose_3d( 241 | ax: plt.Axes, 242 | pose: Union[np.ndarray, PoseVariable3D], 243 | color="blue", 244 | scale: float = 1, 245 | ) -> art3d.Line3DCollection: 246 | true_x, true_y, true_z, dx, dy, dz = _get_pose_3d(pose) 247 | return draw_arrow_3d( 248 | ax, 249 | true_x, 250 | true_y, 251 | true_z, 252 | dx, 253 | dy, 254 | dz, 255 | color=color, 256 | quiver_length=scale, 257 | ) 258 | 259 | def update_pose_arrow( 260 | arrow: mpatches.FancyArrow, 261 | pose: Union[np.ndarray, PoseVariable2D], 262 | scale: float = 1, 263 | ): 264 | x, y, theta = _get_pose_xytheta(pose) 265 | quiver_length = scale 266 | dx = quiver_length * math.cos(theta) 267 | dy = quiver_length * math.sin(theta) 268 | arrow.set_data(x=x, y=y, dx=dx, dy=dy) 269 | 270 | def draw_traj( 271 | ax: plt.Axes, 272 | x_traj: Sequence[float], 273 | y_traj: Sequence[float], 274 | color: str = "black", 275 | ) -> mlines.Line2D: 276 | assert len(x_traj) == len(y_traj) 277 | line = mlines.Line2D(x_traj, y_traj, color=color) 278 | ax.add_line(line) 279 | return line 280 | 281 | def draw_traj_3d( 282 | ax: plt.Axes, 283 | x_traj: Sequence[float], 284 | y_traj: Sequence[float], 285 | z_traj: Sequence[float], 286 | color: str = "black", 287 | ) -> art3d.Line3D: 288 | assert len(x_traj) == len(y_traj) and len(x_traj) == len(z_traj) and len(y_traj) == len(z_traj) 289 | line = art3d.Line3D(x_traj, y_traj, z_traj, color=color) 290 | ax.add_line(line) 291 | return line 292 | 293 | def update_traj( 294 | line: mlines.Line2D, 295 | x_traj: Sequence[float], 296 | y_traj: Sequence[float], 297 | ): 298 | assert len(x_traj) == len(y_traj) 299 | line.set_xdata(x_traj) 300 | line.set_ydata(y_traj) 301 | 302 | 303 | def draw_landmark_variable(ax: plt.Axes, landmark: LandmarkVariable2D): 304 | true_x = landmark.true_x 305 | true_y = landmark.true_y 306 | ax.scatter(true_x, true_y, color="green", marker=(5, 2)) 307 | 308 | def draw_landmark_variable_3d(ax: plt.Axes, landmark: LandmarkVariable3D): 309 | true_x = landmark.true_x 310 | true_y = landmark.true_y 311 | true_z = landmark.true_z 312 | ax.scatter(true_x, true_y, true_z, c="green", marker = (5, 2)) 313 | 314 | def draw_loop_closure_measurement( 315 | ax: plt.Axes, base_loc: np.ndarray, to_pose: PoseVariable2D 316 | ) -> Tuple[mlines.Line2D, mpatches.FancyArrow]: 317 | assert base_loc.size == 2 318 | 319 | x_start = base_loc[0] 320 | y_start = base_loc[1] 321 | x_end = to_pose.true_x 322 | y_end = to_pose.true_y 323 | 324 | line = draw_line(ax, x_start, y_start, x_end, y_end, color="green") 325 | arrow = draw_pose(ax, to_pose) 326 | 327 | return line, arrow 328 | 329 | 330 | def draw_range_measurement( 331 | ax: plt.Axes, 332 | range_measure: FGRangeMeasurement, 333 | from_pose: PoseVariable2D, 334 | to_landmark: Union[LandmarkVariable2D, PoseVariable2D], 335 | add_line: bool = True, 336 | add_circle: bool = True, 337 | ) -> Tuple[Optional[mlines.Line2D], Optional[mpatches.Circle]]: 338 | base_loc = from_pose.true_x, from_pose.true_y 339 | to_loc = to_landmark.true_x, to_landmark.true_y 340 | 341 | x_start, y_start = base_loc 342 | landmark_idx = int(to_landmark.name[1:]) 343 | c = get_color(landmark_idx) 344 | c = "grey" 345 | 346 | if add_line: 347 | x_end, y_end = to_loc 348 | line = draw_line(ax, x_start, y_start, x_end, y_end, color=c) 349 | else: 350 | line = None 351 | if add_circle: 352 | dist = range_measure.dist 353 | circle = draw_circle(ax, np.array([x_start, y_start, dist]), color=c) 354 | else: 355 | circle = None 356 | 357 | return line, circle 358 | 359 | def draw_range_measurement_3d( 360 | ax: plt.Axes, 361 | range_measure: FGRangeMeasurement, 362 | from_pose: PoseVariable3D, 363 | to_landmark: Union[LandmarkVariable3D, PoseVariable3D], 364 | add_line: bool = True, 365 | ) -> Optional[mlines.Line2D]: 366 | base_loc = from_pose.true_x, from_pose.true_y, from_pose.true_z 367 | to_loc = to_landmark.true_x, to_landmark.true_y, to_landmark.true_z 368 | 369 | x_start, y_start, z_start = base_loc 370 | landmark_idx = int(to_landmark.name[1:]) 371 | c = get_color(landmark_idx) 372 | c = "grey" 373 | 374 | if add_line: 375 | x_end, y_end, z_end = to_loc 376 | line = draw_line_3d(ax, x_start, y_start, z_start, x_end, y_end, z_end, color=c) 377 | else: 378 | line = None 379 | 380 | return line 381 | 382 | def visualize_solution( 383 | solution: SolverResults, 384 | gt_files: Optional[List[str]] = None, 385 | name: str = "estimate", 386 | xlim: Optional[Tuple[float, float]] = None, 387 | ylim: Optional[Tuple[float, float]] = None, 388 | save_path: Optional[str] = None, 389 | show: bool = True, 390 | ) -> None: 391 | """Visualizes the solution. 392 | 393 | Args: 394 | solution (SolverResults): the solution. 395 | gt_traj (str): the path to the groundtruth trajectory. 396 | """ 397 | # save the solution to a temporary .tum file 398 | temp_file = f"/tmp/{name}.tum" 399 | soln_tum_files = save_to_tum(solution, temp_file) 400 | assert gt_files is None or len(gt_files) == len( 401 | soln_tum_files 402 | ), f"gt_files: {gt_files}, soln_tum_files: {soln_tum_files}" 403 | 404 | def _get_filename_without_extension(filepath: str) -> str: 405 | return os.path.splitext(os.path.basename(filepath))[0] 406 | 407 | traj_by_label = {} 408 | for file_idx in range(len(soln_tum_files)): 409 | file = soln_tum_files[file_idx] 410 | traj_est = file_interface.read_tum_trajectory_file(file) 411 | 412 | if gt_files is not None: 413 | gt_traj_path = gt_files[file_idx] 414 | gt_traj = file_interface.read_tum_trajectory_file(gt_traj_path) 415 | 416 | # get the traj label from the filename without extension 417 | traj_label = _get_filename_without_extension(gt_traj_path) 418 | traj_by_label[traj_label] = gt_traj 419 | 420 | # align the estimated trajectory to the groundtruth 421 | traj_est.align(gt_traj) 422 | 423 | label_letter = _get_filename_without_extension(file)[-1] 424 | label = f"{name}_{label_letter}" 425 | traj_by_label[label] = traj_est 426 | 427 | fig = plt.figure() 428 | plot_mode = evoplot.PlotMode.xy 429 | # turn off the background grid and legend 430 | evoplot.trajectories(fig, traj_by_label, plot_mode=plot_mode) 431 | if xlim is not None: 432 | plt.xlim(xlim) 433 | if ylim is not None: 434 | plt.ylim(ylim) 435 | 436 | # hide the legend 437 | plt.gca().get_legend().remove() 438 | 439 | # hide the grid 440 | plt.grid(False) 441 | 442 | # set the background to white 443 | plt.gca().set_facecolor("white") 444 | 445 | # set the background to transparent 446 | background_transparent = True 447 | 448 | if save_path is not None: 449 | if not os.path.isdir(os.path.dirname(save_path)): 450 | os.makedirs(os.path.dirname(save_path)) 451 | 452 | # save at a higher resolution 453 | plt.savefig( 454 | save_path, transparent=background_transparent, bbox_inches="tight", dpi=300 455 | ) 456 | print(f"Saved plot to {save_path}") 457 | 458 | if show: 459 | plt.show() 460 | 461 | plt.close(fig) 462 | -------------------------------------------------------------------------------- /py_factor_graph/io/g2o_file.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | from os.path import isfile 3 | import numpy as np 4 | 5 | from py_factor_graph.variables import PoseVariable3D, PoseVariable2D 6 | from py_factor_graph.measurements import PoseMeasurement3D, PoseMeasurement2D 7 | from py_factor_graph.factor_graph import ( 8 | FactorGraphData, 9 | ) 10 | from py_factor_graph.utils.matrix_utils import ( 11 | get_rotation_matrix_from_quat, 12 | get_measurement_precisions_from_info_matrix, 13 | get_symmetric_matrix_from_list_column_major, 14 | ) 15 | from py_factor_graph.utils.logging_utils import logger 16 | 17 | SE3_VARIABLE = "VERTEX_SE3:QUAT" 18 | SE2_VARIABLE = "VERTEX_SE2" 19 | EDGE_SE3 = "EDGE_SE3:QUAT" 20 | EDGE_SE2 = "EDGE_SE2" 21 | 22 | from attrs import define, field 23 | 24 | 25 | @define 26 | class Counter: 27 | 28 | count: int = field(default=0) 29 | 30 | def increment(self): 31 | self.count += 1 32 | 33 | 34 | C = Counter() 35 | 36 | 37 | def convert_se3_var_line_to_pose_variable( 38 | line_items: List[str], 39 | ) -> PoseVariable3D: 40 | """converts the g2o line items for a SE3 variable to a PoseVariable3D object. 41 | 42 | Args: 43 | line_items (List[str]): List of items in a line of a g2o file. 44 | 45 | Returns: 46 | PoseVariable3D: PoseVariable3D object corresponding to the line items. 47 | """ 48 | assert ( 49 | line_items[0] == SE3_VARIABLE 50 | ), f"Line type is not {SE3_VARIABLE}, it is {line_items[0]}" 51 | pose_num_idx = 1 52 | translation_idx_bounds = (2, 5) 53 | quat_idx_bounds = (5,) 54 | timestep = int(line_items[pose_num_idx]) 55 | pose_name = f"A{timestep}" 56 | 57 | # get the translation 58 | translation_vals = [ 59 | float(x) 60 | for x in line_items[translation_idx_bounds[0] : translation_idx_bounds[1]] 61 | ] 62 | assert len(translation_vals) == 3 63 | translation = (translation_vals[0], translation_vals[1], translation_vals[2]) 64 | 65 | # get the rotation 66 | quat_vals = [float(x) for x in line_items[quat_idx_bounds[0] :]] 67 | quat = np.array(quat_vals) 68 | rot = get_rotation_matrix_from_quat(quat) 69 | 70 | pose_var = PoseVariable3D(pose_name, translation, rot, float(timestep)) 71 | return pose_var 72 | 73 | 74 | def convert_se3_measurement_line_to_pose_measurement( 75 | line_items: List[str], 76 | ) -> Optional[PoseMeasurement3D]: 77 | """converts the g2o line items for a SE3 measurement to a PoseMeasurement3D 78 | object. 79 | 80 | Args: 81 | line_items (List[str]): List of items in a line of a g2o 82 | file. 83 | 84 | Returns: 85 | Optional[PoseMeasurement3D]: PoseMeasurement3D object corresponding to 86 | the line. Returns None if the measurement is not a movement (no trans or 87 | rotation) 88 | """ 89 | assert ( 90 | line_items[0] == EDGE_SE3 91 | ), f"Line type is not {EDGE_SE3}, it is {line_items[0]}" 92 | assert len(line_items) == 31, f"Line has {len(line_items)} items, not 31" 93 | 94 | # where the indices are in the line items 95 | pose_num_idx = 1 96 | pose_num_idx_2 = 2 97 | translation_idx_bounds = (3, 6) 98 | quat_idx_bounds = (6, 10) 99 | cov_idx_bounds = (10, len(line_items)) 100 | 101 | # get the pose names 102 | from_pose_name = f"A{line_items[pose_num_idx]}" 103 | to_pose_name = f"A{line_items[pose_num_idx_2]}" 104 | 105 | # timestamp is the same as the greatest pose number 106 | timestamp = max(int(line_items[pose_num_idx]), int(line_items[pose_num_idx_2])) 107 | 108 | # get the translation 109 | translation_vals = [ 110 | float(x) 111 | for x in line_items[translation_idx_bounds[0] : translation_idx_bounds[1]] 112 | ] 113 | translation = np.array(translation_vals) 114 | 115 | # get the rotation 116 | quat_vals = [float(x) for x in line_items[quat_idx_bounds[0] : quat_idx_bounds[1]]] 117 | quat = np.array(quat_vals) 118 | rot = get_rotation_matrix_from_quat(quat) 119 | 120 | no_translation = np.allclose(translation, np.zeros(3)) 121 | no_rotation = np.allclose(rot, np.eye(3)) 122 | if no_translation and no_rotation: 123 | pass 124 | # return None 125 | 126 | # parse information matrix 127 | info_mat_size = 6 128 | info_vals = [float(x) for x in line_items[cov_idx_bounds[0] : cov_idx_bounds[1]]] 129 | info_mat = get_symmetric_matrix_from_list_column_major(info_vals, info_mat_size) 130 | trans_precision, rot_precision = get_measurement_precisions_from_info_matrix( 131 | info_mat, matrix_dim=info_mat_size 132 | ) 133 | 134 | if trans_precision < 0.5 or rot_precision < 0.5: 135 | err = f"Low precisions! Trans: {trans_precision}, Rot: {rot_precision}" 136 | logger.warning(err + f" low-precision factor {C.count}") 137 | 138 | # form pose measurement 139 | pose_measurement = PoseMeasurement3D( 140 | from_pose_name, 141 | to_pose_name, 142 | translation, 143 | rot, 144 | trans_precision, 145 | rot_precision, 146 | float(timestamp), 147 | ) 148 | return pose_measurement 149 | 150 | 151 | def is_odom_measurement(line_items: List[str]) -> bool: 152 | """Determine if a line of a g2o file is an odometry measurement. 153 | 154 | Args: 155 | line_items: List of items in a line of a g2o file. 156 | 157 | Returns: 158 | True if the line is an odometry measurement, False otherwise. 159 | """ 160 | assert ( 161 | line_items[0] == EDGE_SE3 or line_items[0] == EDGE_SE2 162 | ), f"Line type is not {EDGE_SE3} or {EDGE_SE2}, it is {line_items[0]}" 163 | from_idx, to_idx = int(line_items[1]), int(line_items[2]) 164 | return from_idx == to_idx - 1 165 | 166 | 167 | def parse_3d_g2o_file(filepath: str): 168 | # read the file line-by-line 169 | logger.info(f"Parsing 3D g2o file: {filepath}") 170 | 171 | if not isfile(filepath): 172 | raise FileNotFoundError(f"File not found: {filepath}") 173 | 174 | fg = FactorGraphData(dimension=3) 175 | 176 | with open(filepath, "r") as f: 177 | lines = f.readlines() 178 | for line in lines: 179 | items = line.split() 180 | line_type = items[0] 181 | if line_type == SE3_VARIABLE: 182 | new_pose_var = convert_se3_var_line_to_pose_variable(items) 183 | fg.add_pose_variable(new_pose_var) 184 | elif line_type == EDGE_SE3: 185 | new_pose_measurement = convert_se3_measurement_line_to_pose_measurement( 186 | items 187 | ) 188 | if new_pose_measurement is None: 189 | pose_0 = items[1] 190 | pose_1 = items[2] 191 | logger.warning( 192 | f"Skipping measurement between: {pose_0} and {pose_1}" 193 | ) 194 | continue 195 | 196 | if is_odom_measurement(items): 197 | robot_idx = 0 # only 1 robot in g2o files 198 | fg.add_odom_measurement(robot_idx, new_pose_measurement) 199 | else: 200 | fg.add_loop_closure(new_pose_measurement) 201 | else: 202 | raise ValueError(f"Unsupported line type for 3D: {line_type}") 203 | 204 | logger.info(f"Finished parsing 3D g2o file: {filepath}") 205 | return fg 206 | 207 | 208 | def convert_se2_var_line_to_pose_variable( 209 | line_items: List[str], 210 | ) -> PoseVariable2D: 211 | """converts the g2o line items for a SE2 variable to a PoseVariable3D object. 212 | 213 | Args: 214 | line_items (List[str]): List of items in a line of a g2o file. 215 | 216 | Returns: 217 | PoseVariable3D: PoseVariable3D object corresponding to the line items. 218 | """ 219 | assert ( 220 | line_items[0] == SE2_VARIABLE 221 | ), f"Line type is not {SE2_VARIABLE}, it is {line_items[0]}" 222 | pose_num_idx = 1 223 | translation_idx_bounds = (2, 4) 224 | theta_idx = 4 225 | timestamp = int(line_items[pose_num_idx]) 226 | pose_name = f"A{timestamp}" 227 | 228 | # get the translation 229 | translation_vals = [ 230 | float(x) 231 | for x in line_items[translation_idx_bounds[0] : translation_idx_bounds[1]] 232 | ] 233 | assert len(translation_vals) == 2 234 | translation = (translation_vals[0], translation_vals[1]) 235 | 236 | # get the rotation 237 | theta = float(line_items[theta_idx]) 238 | 239 | pose_var = PoseVariable2D(pose_name, translation, theta, float(timestamp)) 240 | return pose_var 241 | 242 | 243 | def convert_se2_measurement_line_to_pose_measurement( 244 | line_items: List[str], 245 | ) -> Optional[PoseMeasurement2D]: 246 | """converts the g2o line items for a SE2 measurement to a PoseMeasurement2D 247 | object. 248 | 249 | Args: 250 | line_items (List[str]): List of items in a line of a g2o 251 | file. 252 | 253 | Returns: 254 | Optional[PoseMeasurement3D]: PoseMeasurement3D object corresponding to 255 | the line. Returns None if the measurement is 256 | not 257 | """ 258 | assert ( 259 | line_items[0] == EDGE_SE2 260 | ), f"Line type is not {EDGE_SE2}, it is {line_items[0]}" 261 | assert len(line_items) == 12, f"Line has {len(line_items)} items, not 12" 262 | 263 | # where the indices are in the line items 264 | pose_num_idx = 1 265 | pose_num_idx_2 = 2 266 | translation_idx_bounds = (3, 5) 267 | theta_idx = 5 268 | cov_idx_bounds = (6, len(line_items)) 269 | 270 | # get the pose names 271 | from_pose_name = f"A{line_items[pose_num_idx]}" 272 | to_pose_name = f"A{line_items[pose_num_idx_2]}" 273 | 274 | # timestamp is the same as the greatest pose number 275 | timestamp = max(int(line_items[pose_num_idx]), int(line_items[pose_num_idx_2])) 276 | 277 | # get the translation 278 | translation_vals = [ 279 | float(x) 280 | for x in line_items[translation_idx_bounds[0] : translation_idx_bounds[1]] 281 | ] 282 | translation = np.array(translation_vals) 283 | 284 | # get the rotation 285 | theta = float(line_items[theta_idx]) 286 | 287 | no_translation = np.allclose(translation, np.zeros(2)) 288 | no_rotation = np.allclose(theta, 0) 289 | if no_translation and no_rotation: 290 | pass 291 | # return None 292 | 293 | # parse information matrix 294 | info_mat_size = 3 295 | info_vals = [float(x) for x in line_items[cov_idx_bounds[0] : cov_idx_bounds[1]]] 296 | info_mat = get_symmetric_matrix_from_list_column_major(info_vals, info_mat_size) 297 | ( 298 | translation_precision, 299 | theta_precision, 300 | ) = get_measurement_precisions_from_info_matrix(info_mat, matrix_dim=info_mat_size) 301 | 302 | rpm = PoseMeasurement2D( 303 | from_pose_name, 304 | to_pose_name, 305 | x=translation[0], 306 | y=translation[1], 307 | theta=theta, 308 | translation_precision=translation_precision, 309 | rotation_precision=theta_precision, 310 | timestamp=float(timestamp), 311 | ) 312 | 313 | return rpm 314 | 315 | 316 | def parse_2d_g2o_file(filepath: str): 317 | # read the file line-by-line 318 | logger.info(f"Parsing 2D g2o file: {filepath}") 319 | 320 | if not isfile(filepath): 321 | raise FileNotFoundError(f"File not found: {filepath}") 322 | 323 | fg = FactorGraphData(dimension=2) 324 | 325 | with open(filepath, "r") as f: 326 | lines = f.readlines() 327 | for line in lines: 328 | items = line.split() 329 | line_type = items[0] 330 | if line_type == SE2_VARIABLE: 331 | new_pose_var = convert_se2_var_line_to_pose_variable(items) 332 | fg.add_pose_variable(new_pose_var) 333 | elif line_type == EDGE_SE2: 334 | new_pose_measurement = convert_se2_measurement_line_to_pose_measurement( 335 | items 336 | ) 337 | if new_pose_measurement is None: 338 | pose_0 = items[1] 339 | pose_1 = items[2] 340 | logger.warning( 341 | f"Skipping measurement between: {pose_0} and {pose_1}" 342 | ) 343 | continue 344 | 345 | if is_odom_measurement(items): 346 | robot_idx = 0 # only 1 robot in g2o files 347 | fg.add_odom_measurement(robot_idx, new_pose_measurement) 348 | else: 349 | fg.add_loop_closure(new_pose_measurement) 350 | else: 351 | raise ValueError(f"Unsupported line type for 2D: {line_type}") 352 | 353 | logger.info(f"Finished parsing 2D g2o file: {filepath}") 354 | return fg 355 | 356 | 357 | if __name__ == "__main__": 358 | from py_factor_graph.modifiers import convert_to_sensor_network_localization 359 | 360 | dirpath = "/home/alan/cora/build/bin/data" 361 | dirpath = "/home/alan/Downloads" 362 | # /home/alan/Downloads/grid3D.g2o /home/alan/Downloads/input_M3500_g2o.g2o /home/alan/Downloads/sphere_bignoise_vertex3.g2o 363 | # /home/alan/Downloads/input_INTEL_g2o.g2o /home/alan/Downloads/input_MITb_g2o.g2o /home/alan/Downloads/torus3D.g2o 364 | fname = "sphere_bignoise_vertex3.g2o" 365 | fname = "torus3D.g2o" 366 | # fname = "grid3D.g2o" 367 | fname = "input_M3500_g2o.g2o" 368 | fname = "input_INTEL_g2o.g2o" 369 | # fname = "input_MITb_g2o.g2o" 370 | 371 | fpath = f"{dirpath}/{fname}" 372 | 373 | # /home/alan/mrg-mac/data/ais2klinik.g2o 374 | # /home/alan/mrg-mac/data/intel_edges.g2o 375 | # /home/alan/mrg-mac/data/kitti_05.g2o 376 | # /home/alan/mrg-mac/data/sphere2500.g2o 377 | # /home/alan/mrg-mac/data/city10000.g2o 378 | # /home/alan/mrg-mac/data/intel.g2o 379 | # /home/alan/mrg-mac/data/kitti_02.g2o 380 | # /home/alan/mrg-mac/data/sphere2500_edges.g2o 381 | dirpath = "/home/alan/mrg-mac/data" 382 | fname = "ais2klinik.g2o" 383 | fname = "kitti_05.g2o" 384 | fname = "sphere2500.g2o" 385 | # fname = "city10000.g2o" 386 | # fname = "intel.g2o" 387 | # fname = "kitti_02.g2o" 388 | 389 | fpath = f"{dirpath}/{fname}" 390 | 391 | files_2d = [ 392 | "input_M3500_g2o.g2o", 393 | "input_INTEL_g2o.g2o", 394 | "input_MITb_g2o.g2o", 395 | "ais2klinik.g2o", 396 | "intel_edges.g2o", 397 | "kitti_05.g2o", 398 | "city10000.g2o", 399 | "intel.g2o", 400 | "kitti_02.g2o", 401 | ] 402 | 403 | if fname in files_2d: 404 | fg = parse_2d_g2o_file(fpath) 405 | # fg.animate_odometry(show_gt=True, draw_range_lines=True) 406 | else: 407 | fg = parse_3d_g2o_file(fpath) 408 | # fg.animate_odometry_3d(show_gt=True, draw_range_lines=True) 409 | 410 | new_fg = convert_to_sensor_network_localization(fg) 411 | 412 | from py_factor_graph.io.pyfg_text import save_to_pyfg_text 413 | 414 | new_fg_path = f"/home/alan/cora-plus-plus/build/bin/data/{fname}".replace( 415 | ".g2o", ".pyfg" 416 | ) 417 | new_fg_path = "/home/alan/cora-plus-plus/build/bin/data/test.pyfg" 418 | save_to_pyfg_text(new_fg, new_fg_path) 419 | 420 | # plot all of the landmarks and their ranges 421 | import matplotlib.pyplot as plt 422 | 423 | fig = plt.figure() 424 | ax = fig.add_subplot(111, projection="3d") 425 | 426 | num_vars_skip = 10 427 | for idx, landmark in enumerate(new_fg.landmark_variables): 428 | if idx % num_vars_skip != 0: 429 | continue 430 | continue 431 | position = landmark.true_position 432 | ax.scatter(position[0], position[1], position[2], color="r") 433 | 434 | name_to_var_map = new_fg.landmark_variables_dict 435 | num_measures_skip = 2 436 | for idx, range_measurement in enumerate(new_fg.range_measurements): 437 | if idx % num_measures_skip != 0: 438 | continue 439 | var1, var2 = range_measurement.association 440 | loc1 = name_to_var_map[var1].true_position 441 | loc2 = name_to_var_map[var2].true_position 442 | xs = [loc1[0], loc2[0]] 443 | ys = [loc1[1], loc2[1]] 444 | if len(loc1) == len(loc2) == 3: 445 | zs = [loc1[2], loc2[2]] # type: ignore 446 | else: 447 | zs = [0.0, 0.0] 448 | 449 | # if len(loc1) == 2 and len(loc2) == 2: 450 | # zs = [0.0, 0.0] 451 | # elif len(loc1) == 3 and len(loc2) == 3: 452 | # zs = [loc1[2], loc2[2]] 453 | # else: 454 | # raise ValueError("Landmark dimensions not 2 or 3") 455 | ax.plot(xs, ys, zs, color="b") 456 | 457 | plt.show() 458 | -------------------------------------------------------------------------------- /py_factor_graph/calibrations/odom_measurement_calibration.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Dict 2 | from attrs import define, field 3 | import numpy as np 4 | from numpy import linalg as la 5 | from scipy.stats import linregress 6 | import matplotlib.pyplot as plt 7 | import copy 8 | 9 | from py_factor_graph.measurements import PoseMeasurement2D 10 | from py_factor_graph.factor_graph import FactorGraphData 11 | from py_factor_graph.utils.logging_utils import logger 12 | 13 | 14 | def _validate_is_2d_pose_matrix(inst, attribute, value): 15 | if value.shape != (3, 3): 16 | raise ValueError( 17 | f"Expected {attribute} to be of shape (3, 3), got {value.shape}" 18 | ) 19 | 20 | # check that the last row is [0, 0, 1] 21 | if not np.allclose(value[2, :], np.array([0, 0, 1])): 22 | raise ValueError( 23 | f"Expected {attribute} to have last row [0, 0, 1], got {value[2, :]}" 24 | ) 25 | 26 | # check R.T @ R = I 27 | R = value[:2, :2] 28 | if not np.allclose(R.T @ R, np.eye(2)): 29 | raise ValueError(f"Expected {attribute} to have R.T @ R = I, got {R.T @ R}") 30 | 31 | 32 | def _get_theta_from_rot_matrix(R: np.ndarray) -> float: 33 | """ 34 | Get the angle of rotation from a 2D rotation matrix. 35 | """ 36 | assert R.shape == (2, 2) 37 | theta = np.arctan2(R[1, 0], R[0, 0]) 38 | return theta 39 | 40 | 41 | @define 42 | class RelPose2DResiduals: 43 | dx: float = field() 44 | dy: float = field() 45 | dtheta: float = field() 46 | 47 | 48 | @define 49 | class Uncalibrated2DRelPoseMeasurement: 50 | association: Tuple[str, str] = field() 51 | measured_rel_pose: np.ndarray = field(validator=_validate_is_2d_pose_matrix) 52 | true_rel_pose: np.ndarray = field(validator=_validate_is_2d_pose_matrix) 53 | timestamp: float = field() 54 | 55 | @property 56 | def measured_dx(self): 57 | return self.measured_rel_pose[0, 2] 58 | 59 | @property 60 | def measured_dy(self): 61 | return self.measured_rel_pose[1, 2] 62 | 63 | @property 64 | def measured_dtheta(self): 65 | return _get_theta_from_rot_matrix(self.measured_rel_pose[:2, :2]) 66 | 67 | @property 68 | def true_dx(self): 69 | return self.true_rel_pose[0, 2] 70 | 71 | @property 72 | def true_dy(self): 73 | return self.true_rel_pose[1, 2] 74 | 75 | @property 76 | def true_dtheta(self): 77 | return _get_theta_from_rot_matrix(self.true_rel_pose[:2, :2]) 78 | 79 | 80 | @define 81 | class Linear2DRelPoseCalibrationModel: 82 | dx_slope: float = field() 83 | dx_intercept: float = field() 84 | dy_slope: float = field() 85 | dy_intercept: float = field() 86 | dtheta_slope: float = field() 87 | dtheta_intercept: float = field() 88 | 89 | def __call__( 90 | self, x: List[Uncalibrated2DRelPoseMeasurement] 91 | ) -> List[PoseMeasurement2D]: 92 | assert isinstance(x, list) 93 | assert all([isinstance(x, Uncalibrated2DRelPoseMeasurement) for x in x]) 94 | residuals = self.get_calibrated_residuals(x) 95 | residuals_arr = np.array([[x.dx, x.dy, x.dtheta] for x in residuals]) 96 | calibrated_stddevs = np.std(residuals_arr, axis=0) 97 | assert calibrated_stddevs.shape == (3,) 98 | trans_cov = 0.5 * (calibrated_stddevs[0] ** 2 + calibrated_stddevs[1] ** 2) 99 | trans_precision = 1 / trans_cov 100 | theta_precision = 1 / calibrated_stddevs[2] ** 2 101 | logger.info(f"Calibrated trans stddev: {np.sqrt(trans_cov):.3f}") 102 | logger.info(f"Calibrated theta stddev: {calibrated_stddevs[2]:.3f}") 103 | calibrated_vals = self.apply_linear_calibration(x) 104 | calibrated_measurements = [ 105 | PoseMeasurement2D( 106 | base_pose=x.association[0], 107 | to_pose=x.association[1], 108 | x=calibrated_val[0], 109 | y=calibrated_val[1], 110 | theta=calibrated_val[2], 111 | translation_precision=trans_precision, 112 | rotation_precision=theta_precision, 113 | timestamp=x.timestamp, 114 | ) 115 | for x, calibrated_val in zip(x, calibrated_vals) 116 | ] 117 | return calibrated_measurements 118 | 119 | def apply_linear_calibration( 120 | self, uncalibrated_measurements: List[Uncalibrated2DRelPoseMeasurement] 121 | ) -> np.ndarray: 122 | measured_vals = np.array( 123 | [ 124 | [x.measured_dx, x.measured_dy, x.measured_dtheta] 125 | for x in uncalibrated_measurements 126 | ] 127 | ) 128 | slopes = np.array([self.dx_slope, self.dy_slope, self.dtheta_slope]) 129 | intercepts = np.array( 130 | [self.dx_intercept, self.dy_intercept, self.dtheta_intercept] 131 | ) 132 | predicted_vals = measured_vals * slopes + intercepts 133 | assert predicted_vals.shape == measured_vals.shape 134 | 135 | return predicted_vals 136 | 137 | def get_calibrated_residuals( 138 | self, 139 | uncalibrated_measurements: List[Uncalibrated2DRelPoseMeasurement], 140 | ) -> List[RelPose2DResiduals]: 141 | predicted_vals = self.apply_linear_calibration(uncalibrated_measurements) 142 | true_vals = np.array( 143 | [[x.true_dx, x.true_dy, x.true_dtheta] for x in uncalibrated_measurements] 144 | ) 145 | 146 | residuals = true_vals - predicted_vals 147 | 148 | residual_triplets = [ 149 | RelPose2DResiduals( 150 | dx=residual[0], 151 | dy=residual[1], 152 | dtheta=residual[2], 153 | ) 154 | for residual in residuals 155 | ] 156 | assert len(residual_triplets) == len(uncalibrated_measurements) 157 | 158 | return residual_triplets 159 | 160 | 161 | def fit_linear_calibration_model( 162 | uncalibrated_measurements: List[Uncalibrated2DRelPoseMeasurement], 163 | ) -> Linear2DRelPoseCalibrationModel: 164 | assert all( 165 | [ 166 | isinstance(x, Uncalibrated2DRelPoseMeasurement) 167 | for x in uncalibrated_measurements 168 | ] 169 | ) 170 | measured_dx = np.array([x.measured_dx for x in uncalibrated_measurements]) 171 | measured_dy = np.array([x.measured_dy for x in uncalibrated_measurements]) 172 | measured_dtheta = np.array([x.measured_dtheta for x in uncalibrated_measurements]) 173 | 174 | true_dx = np.array([x.true_dx for x in uncalibrated_measurements]) 175 | true_dy = np.array([x.true_dy for x in uncalibrated_measurements]) 176 | true_dtheta = np.array([x.true_dtheta for x in uncalibrated_measurements]) 177 | 178 | dx_slope, dx_intercept, _, _, _ = linregress(measured_dx, true_dx) 179 | dy_slope, dy_intercept, _, _, _ = linregress(measured_dy, true_dy) 180 | dtheta_slope, dtheta_intercept, _, _, _ = linregress(measured_dtheta, true_dtheta) 181 | 182 | calibration_model = Linear2DRelPoseCalibrationModel( 183 | dx_slope=dx_slope, 184 | dx_intercept=dx_intercept, 185 | dy_slope=dy_slope, 186 | dy_intercept=dy_intercept, 187 | dtheta_slope=dtheta_slope, 188 | dtheta_intercept=dtheta_intercept, 189 | ) 190 | 191 | return calibration_model 192 | 193 | 194 | def inspect_for_inliers_and_outliers( 195 | uncalibrated_measurements: List[Uncalibrated2DRelPoseMeasurement], 196 | inlier_stddev_threshold: float = 3.0, 197 | ) -> List[Uncalibrated2DRelPoseMeasurement]: 198 | """ 199 | We will fit a linear model to the range measurements and remove outliers. W 200 | """ 201 | assert all( 202 | [ 203 | isinstance(x, Uncalibrated2DRelPoseMeasurement) 204 | for x in uncalibrated_measurements 205 | ] 206 | ) 207 | assert len(uncalibrated_measurements) > 0 208 | assert inlier_stddev_threshold > 0.0 209 | 210 | def _plot_inliers_and_outliers( 211 | measured_vals: np.ndarray, 212 | true_vals: np.ndarray, 213 | outlier_mask: np.ndarray, 214 | slope: float, 215 | intercept: float, 216 | title: str, 217 | ): 218 | # inliers = [x for idx, x in enumerate(measurements) if idx not in outlier_mask] 219 | # outliers = [x for idx, x in enumerate(measurements) if idx in outlier_mask] 220 | inlier_measured_vals = measured_vals[~outlier_mask] 221 | inlier_true_vals = true_vals[~outlier_mask] 222 | outlier_measured_vals = measured_vals[outlier_mask] 223 | outlier_true_vals = true_vals[outlier_mask] 224 | 225 | plt.scatter( 226 | inlier_measured_vals, inlier_true_vals, color="blue", label="inliers" 227 | ) 228 | plt.scatter( 229 | outlier_measured_vals, outlier_true_vals, color="red", label="outliers" 230 | ) 231 | plt.title(title) 232 | plt.legend() 233 | plt.xlabel("Measured distance (m)") 234 | plt.ylabel("True distance (m)") 235 | 236 | # draw the linear model up to the largest measured distance 237 | x = np.linspace(0, np.max(inlier_measured_vals), 100) 238 | y = slope * x + intercept 239 | plt.plot(x, y, color="black", label="linear model") 240 | 241 | # make sure axis is square 242 | plt.gca().set_aspect("equal", adjustable="box") 243 | 244 | plt.show(block=True) 245 | 246 | inlier_measurements = copy.deepcopy(uncalibrated_measurements) 247 | # fit a linear model to the range measurements 248 | linear_calibration = fit_linear_calibration_model(inlier_measurements) 249 | 250 | measured_calibrated_vals = linear_calibration(inlier_measurements) 251 | measured_vals = np.array([[x.x, x.y, x.theta] for x in measured_calibrated_vals]) 252 | true_vals = np.array( 253 | [[x.true_dx, x.true_dy, x.true_dtheta] for x in inlier_measurements] 254 | ) 255 | 256 | # compute the residuals and use them to find outliers 257 | residuals = true_vals - measured_vals 258 | 259 | # compute the standard deviation of each residual independently 260 | res_stddev = np.std(residuals, axis=0) 261 | assert res_stddev.shape == (3,) 262 | 263 | # titles for the plots 264 | titles = [ 265 | f"dx: {res_stddev[0]:.3f}", 266 | f"dy: {res_stddev[1]:.3f}", 267 | f"dtheta: {res_stddev[2]:.3f}", 268 | ] 269 | 270 | # calibration params 271 | calibration_slopes = [ 272 | linear_calibration.dx_slope, 273 | linear_calibration.dy_slope, 274 | linear_calibration.dtheta_slope, 275 | ] 276 | 277 | calibration_intercepts = [ 278 | linear_calibration.dx_intercept, 279 | linear_calibration.dy_intercept, 280 | linear_calibration.dtheta_intercept, 281 | ] 282 | 283 | for measured, true, residual, res_stddev, cl_slope, cl_intercept, title in zip( 284 | measured_vals.T, 285 | true_vals.T, 286 | residuals.T, 287 | res_stddev, 288 | calibration_slopes, 289 | calibration_intercepts, 290 | titles, 291 | ): 292 | # find the outliers 293 | outlier_mask = np.where( 294 | np.abs(residual) > inlier_stddev_threshold * res_stddev 295 | )[0] 296 | 297 | # visualize the inliers and outliers 298 | _plot_inliers_and_outliers( 299 | measured_vals=measured, 300 | true_vals=true, 301 | outlier_mask=outlier_mask, 302 | slope=cl_slope, 303 | intercept=cl_intercept, 304 | title=title, 305 | ) 306 | 307 | print( 308 | f"Found {len(outlier_mask)} outliers out of {len(inlier_measurements)} measurements" 309 | ) 310 | 311 | # if everything is an outlier, then some nonsense is going on 312 | if len(outlier_mask) == len(inlier_measurements): 313 | logger.warning( 314 | f"Everything is an outlier. This is probably a bug. Returning empty list." 315 | ) 316 | return [] 317 | 318 | logger.warning("Not currently rejecting outliers -- just useful for inspection") 319 | return inlier_measurements 320 | 321 | 322 | def get_linearly_calibrated_measurements( 323 | uncalibrated_measurements: List[Uncalibrated2DRelPoseMeasurement], 324 | ) -> List[PoseMeasurement2D]: 325 | """ 326 | We will fit a linear model to the range measurements and remove outliers. W 327 | """ 328 | linear_calibration = fit_linear_calibration_model(uncalibrated_measurements) 329 | calibrated_measurements = linear_calibration(uncalibrated_measurements) 330 | return calibrated_measurements 331 | 332 | 333 | def calibrate_odom_measures( 334 | pyfg: FactorGraphData, 335 | ) -> FactorGraphData: 336 | """ 337 | We will fit a linear model to the range measurements and remove outliers. W 338 | """ 339 | uncalibrated_measurements = [] 340 | for odom_chain in pyfg.odom_measurements: 341 | uncalibrated_measurements.extend(odom_chain) 342 | # uncalibrated_measurements += pyfg.loop_closure_measurements 343 | true_poses = pyfg.pose_variables_dict 344 | 345 | # group the range measurements by association 346 | # e.g., 347 | # - (A1, L1) and (A15, L1) will be grouped together as (A, L1) 348 | # - (B23, L10) and (B138, L10) will be grouped together as (B, L10) 349 | # - (A5, L1) and (B23, L10) will not be grouped together 350 | 351 | # get valid variables 352 | valid_variable_groups = set() 353 | for variable_name in true_poses.keys(): 354 | if "L" in variable_name: 355 | valid_variable_groups.add(variable_name) 356 | else: 357 | valid_variable_groups.add(variable_name[0]) 358 | 359 | # get valid associations as the cross product of valid variables 360 | valid_associations = set() 361 | for var1 in valid_variable_groups: 362 | for var2 in valid_variable_groups: 363 | # pair should be alphabetically sorted, except "L" should always be second 364 | if "L" in var1: 365 | valid_associations.add((var2, var1)) 366 | elif "L" in var2: 367 | valid_associations.add((var1, var2)) 368 | elif var1 < var2: 369 | valid_associations.add((var1, var2)) 370 | else: 371 | valid_associations.add((var2, var1)) 372 | 373 | def get_association_grouping(association: Tuple[str, str]) -> Tuple[str, str]: 374 | # assert at most one "L" in the association and must be second 375 | a1, a2 = association 376 | if "L" in a1 and "L" in a2: 377 | raise ValueError(f"Invalid association: {association}") 378 | elif "L" in a1: 379 | raise ValueError(f"Invalid association: {association}") 380 | 381 | if "L" in association[1]: 382 | return a1[0], a2 383 | 384 | a1_group = a1[0] 385 | a2_group = a2[0] 386 | if a1_group < a2_group: 387 | return a1_group, a2_group 388 | else: 389 | return a2_group, a1_group 390 | 391 | # group the range measurements by association 392 | association_to_measurements: Dict[ 393 | Tuple[str, str], List[Uncalibrated2DRelPoseMeasurement] 394 | ] = {pair: [] for pair in valid_associations} 395 | for measurement in uncalibrated_measurements: 396 | assert isinstance( 397 | measurement, PoseMeasurement2D 398 | ), f"Expected PoseMeasurement2D, got {type(measurement)}" 399 | association = (measurement.base_pose, measurement.to_pose) 400 | association_group = get_association_grouping(association) 401 | if association_group not in valid_associations: 402 | raise ValueError(f"Invalid association: {association}") 403 | 404 | v1_true_pose = true_poses[measurement.base_pose].transformation_matrix 405 | v2_true_pose = true_poses[measurement.to_pose].transformation_matrix 406 | true_rel_pose = la.inv(v1_true_pose) @ v2_true_pose 407 | 408 | assert measurement.timestamp is not None 409 | uncalibrated_measurement = Uncalibrated2DRelPoseMeasurement( 410 | association=association, 411 | measured_rel_pose=measurement.transformation_matrix, 412 | true_rel_pose=true_rel_pose, 413 | timestamp=measurement.timestamp, 414 | ) 415 | 416 | association_to_measurements[association_group].append(uncalibrated_measurement) 417 | 418 | # remove any associations that don't have any measurements 419 | for association, measurements in list(association_to_measurements.items()): 420 | if len(measurements) == 0: 421 | del association_to_measurements[association] 422 | 423 | # inspect the measurements for outliers 424 | for association, measurements in association_to_measurements.items(): 425 | inspect_for_inliers_and_outliers(measurements) 426 | 427 | # calibrate the measurements 428 | # calibrated_measurements = get_linearly_calibrated_measurements(inlier_measurements) 429 | logger.warning("Odometry: not currently rejecting outliers or calibrating") 430 | 431 | # update the factor graph data 432 | return pyfg 433 | -------------------------------------------------------------------------------- /py_factor_graph/measurements.py: -------------------------------------------------------------------------------- 1 | import attr 2 | from typing import Optional, Tuple, Union 3 | import numpy as np 4 | from py_factor_graph.utils.attrib_utils import ( 5 | positive_float_validator, 6 | make_variable_name_validator, 7 | make_rot_matrix_validator, 8 | optional_float_validator, 9 | ) 10 | from py_factor_graph.utils.matrix_utils import ( 11 | get_covariance_matrix_from_measurement_precisions, 12 | get_quat_from_rotation_matrix, 13 | ) 14 | import scipy.spatial as spatial 15 | 16 | 17 | @attr.s(frozen=False) 18 | class PoseMeasurement2D: 19 | """ 20 | An pose measurement 21 | 22 | Args: 23 | base_pose (str): the pose which the measurement is in the reference frame of 24 | to_pose (str): the name of the pose the measurement is to 25 | x (float): the measured change in x coordinate 26 | y (float): the measured change in y coordinate 27 | theta (float): the measured change in theta 28 | covariance (np.ndarray): a 3x3 covariance matrix from the measurement model 29 | timestamp (float): seconds since epoch 30 | """ 31 | 32 | base_pose: str = attr.ib(validator=make_variable_name_validator("pose")) 33 | to_pose: str = attr.ib(validator=make_variable_name_validator("pose")) 34 | x: float = attr.ib(validator=attr.validators.instance_of(float)) 35 | y: float = attr.ib(validator=attr.validators.instance_of(float)) 36 | theta: float = attr.ib(validator=attr.validators.instance_of(float)) 37 | translation_precision: float = attr.ib(validator=positive_float_validator) 38 | rotation_precision: float = attr.ib(validator=positive_float_validator) 39 | timestamp: Optional[float] = attr.ib( 40 | default=None, validator=optional_float_validator 41 | ) 42 | 43 | @property 44 | def rotation_matrix(self) -> np.ndarray: 45 | """ 46 | Get the rotation matrix for the measurement 47 | """ 48 | return np.array( 49 | [ 50 | [np.cos(self.theta), -np.sin(self.theta)], 51 | [np.sin(self.theta), np.cos(self.theta)], 52 | ] 53 | ) 54 | 55 | @property 56 | def transformation_matrix(self) -> np.ndarray: 57 | """ 58 | Get the transformation matrix 59 | """ 60 | return np.array( 61 | [ 62 | [np.cos(self.theta), -np.sin(self.theta), self.x], 63 | [np.sin(self.theta), np.cos(self.theta), self.y], 64 | [0, 0, 1], 65 | ] 66 | ) 67 | 68 | @property 69 | def translation_vector(self) -> np.ndarray: 70 | """ 71 | Get the translation vector for the measurement 72 | """ 73 | return np.array([self.x, self.y]) 74 | 75 | @property 76 | def covariance(self) -> np.ndarray: 77 | """ 78 | Get the covariance matrix 79 | """ 80 | return get_covariance_matrix_from_measurement_precisions( 81 | self.translation_precision, self.rotation_precision, mat_dim=3 82 | ) 83 | 84 | 85 | @attr.s(frozen=False) 86 | class PoseToLandmarkMeasurement2D: 87 | pose_name: str = attr.ib(validator=make_variable_name_validator("pose")) 88 | landmark_name: str = attr.ib(validator=make_variable_name_validator("landmark")) 89 | x: float = attr.ib(validator=attr.validators.instance_of(float)) 90 | y: float = attr.ib(validator=attr.validators.instance_of(float)) 91 | translation_precision: float = attr.ib(validator=positive_float_validator) 92 | timestamp: Optional[float] = attr.ib( 93 | default=None, validator=optional_float_validator 94 | ) 95 | 96 | @property 97 | def translation_vector(self) -> np.ndarray: 98 | """ 99 | Get the translation vector for the measurement 100 | """ 101 | return np.array([self.x, self.y]) 102 | 103 | @property 104 | def covariance(self) -> np.ndarray: 105 | """ 106 | Get the covariance matrix 107 | """ 108 | return np.diag([1 / self.translation_precision] * 2) 109 | 110 | 111 | @attr.s(frozen=False) 112 | class PoseToLandmarkMeasurement3D: 113 | pose_name: str = attr.ib(validator=make_variable_name_validator("pose")) 114 | landmark_name: str = attr.ib(validator=make_variable_name_validator("landmark")) 115 | x: float = attr.ib(validator=attr.validators.instance_of(float)) 116 | y: float = attr.ib(validator=attr.validators.instance_of(float)) 117 | z: float = attr.ib(validator=attr.validators.instance_of(float)) 118 | translation_precision: float = attr.ib(validator=positive_float_validator) 119 | timestamp: Optional[float] = attr.ib( 120 | default=None, validator=optional_float_validator 121 | ) 122 | 123 | @property 124 | def translation_vector(self) -> np.ndarray: 125 | """ 126 | Get the translation vector for the measurement 127 | """ 128 | return np.array([self.x, self.y, self.z]) 129 | 130 | @property 131 | def covariance(self) -> np.ndarray: 132 | """ 133 | Get the covariance matrix 134 | """ 135 | return np.diag([1 / self.translation_precision] * 3) 136 | 137 | 138 | @attr.s(frozen=False) 139 | class PoseMeasurement3D: 140 | """ 141 | An pose measurement 142 | 143 | Args: 144 | base_pose (str): the pose which the measurement is in the reference frame of 145 | to_pose (str): the name of the pose the measurement is to 146 | translation (np.ndarray): the measured change in x, y, z coordinates 147 | rotation (np.ndarray): the measured change in rotation 148 | translation_precision (float): the weight of the translation measurement 149 | rotation_precision (float): the weight of the rotation measurement 150 | timestamp (float): seconds since epoch 151 | """ 152 | 153 | base_pose: str = attr.ib(validator=make_variable_name_validator("pose")) 154 | to_pose: str = attr.ib(validator=make_variable_name_validator("pose")) 155 | translation: np.ndarray = attr.ib(validator=attr.validators.instance_of(np.ndarray)) 156 | rotation: np.ndarray = attr.ib(validator=make_rot_matrix_validator(3)) 157 | translation_precision: float = attr.ib(validator=positive_float_validator) 158 | rotation_precision: float = attr.ib(validator=positive_float_validator) 159 | timestamp: Optional[float] = attr.ib( 160 | default=None, validator=optional_float_validator 161 | ) 162 | 163 | def __attrs_post_init__(self): 164 | if self.base_pose == self.to_pose: 165 | raise ValueError( 166 | f"base_pose and to_pose cannot be the same: base: {self.base_pose}, to: {self.to_pose}" 167 | ) 168 | 169 | @property 170 | def rotation_matrix(self) -> np.ndarray: 171 | """ 172 | Get the rotation matrix for the measurement 173 | 174 | Returns: 175 | np.ndarray: the 3x3 rotation matrix 176 | """ 177 | return self.rotation 178 | 179 | @property 180 | def transformation_matrix(self) -> np.ndarray: 181 | """ 182 | Get the transformation matrix 183 | 184 | Returns: 185 | np.ndarray: the 4x4 transformation matrix 186 | """ 187 | T = np.eye(4) 188 | T[:3, :3] = self.rotation 189 | T[:3, 3] = self.translation 190 | return T 191 | 192 | @property 193 | def translation_vector(self) -> np.ndarray: 194 | """ 195 | Get the translation vector for the measurement 196 | 197 | Returns: 198 | np.ndarray: the 3x1 translation vector 199 | """ 200 | return self.translation 201 | 202 | @property 203 | def x(self) -> float: 204 | """ 205 | Get the x translation 206 | 207 | Returns: 208 | float: the x translation 209 | """ 210 | return self.translation[0] 211 | 212 | @property 213 | def y(self) -> float: 214 | """ 215 | Get the y translation 216 | 217 | Returns: 218 | float: the y translation 219 | """ 220 | return self.translation[1] 221 | 222 | @property 223 | def z(self) -> float: 224 | """ 225 | Get the z translation 226 | 227 | Returns: 228 | float: the z translation 229 | """ 230 | return self.translation[2] 231 | 232 | @property 233 | def quat(self) -> np.ndarray: 234 | """ 235 | Get the quaternion in the form [x, y, z, w] 236 | 237 | Returns: 238 | np.ndarray: the 4x1 quaternion 239 | """ 240 | return get_quat_from_rotation_matrix(self.rotation) 241 | 242 | @property 243 | def yaw(self) -> float: 244 | """ 245 | Get the yaw angle 246 | 247 | Returns: 248 | float: the yaw angle 249 | """ 250 | rot = spatial.transform.Rotation.from_matrix(self.rotation) 251 | return rot.as_euler("zyx")[0] 252 | 253 | @property 254 | def covariance(self): 255 | """ 256 | Get the 6x6 covariance matrix. Right now uses isotropic covariance 257 | for the translation and rotation respectively 258 | 259 | Returns: 260 | np.ndarray: the 6x6 covariance matrix 261 | """ 262 | return get_covariance_matrix_from_measurement_precisions( 263 | self.translation_precision, self.rotation_precision, mat_dim=6 264 | ) 265 | 266 | 267 | @attr.s(frozen=True) 268 | class AmbiguousPoseMeasurement2D: 269 | """ 270 | An ambiguous odom measurement 271 | 272 | base_pose (str): the name of the base pose which the measurement is in the 273 | reference frame of 274 | measured_to_pose (str): the name of the pose the measurement thinks it is to 275 | true_to_pose (str): the name of the pose the measurement is to 276 | x (float): the change in x 277 | y (float): the change in y 278 | theta (float): the change in theta 279 | covariance (np.ndarray): a 3x3 covariance matrix 280 | timestamp (float): seconds since epoch 281 | """ 282 | 283 | base_pose: str = attr.ib(validator=make_variable_name_validator("pose")) 284 | measured_to_pose: str = attr.ib(validator=make_variable_name_validator("pose")) 285 | true_to_pose: str = attr.ib(validator=make_variable_name_validator("pose")) 286 | x: float = attr.ib(validator=attr.validators.instance_of(float)) 287 | y: float = attr.ib(validator=attr.validators.instance_of(float)) 288 | theta: float = attr.ib(validator=attr.validators.instance_of(float)) 289 | translation_precision: float = attr.ib(validator=positive_float_validator) 290 | rotation_precision: float = attr.ib(validator=positive_float_validator) 291 | timestamp: Optional[float] = attr.ib( 292 | default=None, validator=optional_float_validator 293 | ) 294 | 295 | @property 296 | def rotation_matrix(self): 297 | """ 298 | Get the rotation matrix for the measurement 299 | """ 300 | return np.array( 301 | [ 302 | [np.cos(self.theta), -np.sin(self.theta)], 303 | [np.sin(self.theta), np.cos(self.theta)], 304 | ] 305 | ) 306 | 307 | @property 308 | def transformation_matrix(self): 309 | """ 310 | Get the transformation matrix 311 | """ 312 | return np.array( 313 | [ 314 | [np.cos(self.theta), -np.sin(self.theta), self.x], 315 | [np.sin(self.theta), np.cos(self.theta), self.y], 316 | [0, 0, 1], 317 | ] 318 | ) 319 | 320 | @property 321 | def translation_vector(self): 322 | """ 323 | Get the translation vector for the measurement 324 | """ 325 | return np.array([self.x, self.y]) 326 | 327 | @property 328 | def covariance(self): 329 | """ 330 | Get the covariance matrix 331 | """ 332 | return get_covariance_matrix_from_measurement_precisions( 333 | self.translation_precision, self.rotation_precision, mat_dim=3 334 | ) 335 | 336 | 337 | @attr.s(frozen=False) 338 | class FGRangeMeasurement: 339 | """A range measurement 340 | 341 | Arguments: 342 | association (Tuple[str, str]): the data associations of the measurement. 343 | dist (float): The measured range 344 | stddev (float): The standard deviation 345 | timestamp (float): seconds since epoch 346 | """ 347 | 348 | association: Tuple[str, str] = attr.ib() 349 | dist: float = attr.ib(validator=positive_float_validator) 350 | stddev: float = attr.ib(validator=positive_float_validator) 351 | timestamp: Optional[float] = attr.ib(default=None) 352 | 353 | @association.validator 354 | def check_association(self, attribute, value: Tuple[str, str]): 355 | """Validates the association attribute 356 | 357 | Args: 358 | attribute ([type]): [description] 359 | value (Tuple[str, str]): the true_association attribute 360 | 361 | Raises: 362 | ValueError: is not a 2-tuple 363 | ValueError: the associations are identical 364 | ValueError: the associations are not valid pose or landmark keys 365 | """ 366 | assert all(isinstance(x, str) for x in value) 367 | if len(value) != 2: 368 | raise ValueError( 369 | "Range measurements must have exactly two variables associated with." 370 | ) 371 | if value[0] == value[1]: 372 | raise ValueError(f"Range measurements must have unique variables: {value}") 373 | 374 | association_1_is_uppercase_letter = ( 375 | value[0][0].isalpha() and value[0][0].isupper() 376 | ) 377 | association_1_ends_in_number = value[0][1:].isnumeric() 378 | if (not association_1_is_uppercase_letter) or ( 379 | not association_1_ends_in_number 380 | ): 381 | raise ValueError(f"First association is not a valid variable: {value[0]}") 382 | 383 | association_2_is_uppercase_letter = ( 384 | value[1][0].isalpha() and value[1][0].isupper() 385 | ) 386 | association_2_ends_in_number = value[1][1:].isnumeric() 387 | if (not association_2_is_uppercase_letter) or ( 388 | not association_2_ends_in_number 389 | ): 390 | raise ValueError(f"Second association is not a valid variable: {value[1]}") 391 | 392 | @property 393 | def weight(self) -> float: 394 | """ 395 | Get the weight of the measurement 396 | """ 397 | return 1 / (self.stddev**2) 398 | 399 | @property 400 | def first_key(self) -> str: 401 | """ 402 | Get the first key from the association 403 | """ 404 | return self.association[0] 405 | 406 | @property 407 | def second_key(self) -> str: 408 | """ 409 | Get the second key from the association 410 | """ 411 | return self.association[1] 412 | 413 | @property 414 | def variance(self) -> float: 415 | """ 416 | Get the variance of the measurement 417 | """ 418 | return self.stddev**2 419 | 420 | @property 421 | def precision(self) -> float: 422 | """ 423 | Get the precision of the measurement 424 | """ 425 | return 1 / self.variance 426 | 427 | 428 | @attr.s(frozen=True) 429 | class AmbiguousFGRangeMeasurement: 430 | """A range measurement 431 | 432 | Arguments: 433 | var1 (str): one variable the measurement is associated with 434 | var2 (str): the other variable the measurement is associated with 435 | dist (float): The measured range 436 | stddev (float): The standard deviation 437 | timestamp (float): seconds since epoch 438 | """ 439 | 440 | true_association: Tuple[str, str] = attr.ib() 441 | measured_association: Tuple[str, str] = attr.ib() 442 | dist: float = attr.ib() 443 | stddev: float = attr.ib() 444 | timestamp: Optional[float] = attr.ib(default=None) 445 | 446 | @true_association.validator 447 | def check_true_association(self, attribute, value: Tuple[str, str]): 448 | """Validates the true_association attribute 449 | 450 | Args: 451 | attribute ([type]): [description] 452 | value (Tuple[str, str]): the true_association attribute 453 | 454 | Raises: 455 | ValueError: is not a 2-tuple 456 | ValueError: the associations are identical 457 | """ 458 | if len(value) != 2: 459 | raise ValueError( 460 | "Range measurements must have exactly two variables associated with." 461 | ) 462 | if len(value) != len(set(value)): 463 | raise ValueError("Range measurements must have unique variables.") 464 | 465 | @measured_association.validator 466 | def check_measured_association(self, attribute, value): 467 | if len(value) != 2: 468 | raise ValueError( 469 | "Range measurements must have exactly two variables associated with." 470 | ) 471 | if len(value) != len(set(value)): 472 | raise ValueError("Range measurements must have unique variables.") 473 | 474 | @property 475 | def weight(self): 476 | """ 477 | Get the weight of the measurement 478 | """ 479 | return 1 / (self.stddev**2) 480 | 481 | 482 | POSE_MEASUREMENT_TYPES = Union[PoseMeasurement2D, PoseMeasurement3D] 483 | POSE_LANDMARK_MEASUREMENT_TYPES = Union[ 484 | PoseToLandmarkMeasurement2D, PoseToLandmarkMeasurement3D 485 | ] 486 | -------------------------------------------------------------------------------- /py_factor_graph/calibrations/range_measurement_calibration.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Optional, Union, overload, Dict 2 | from attrs import define, field 3 | import numpy as np 4 | from scipy.stats import linregress # type: ignore 5 | from scipy.signal import savgol_filter # type: ignore 6 | from sklearn import linear_model # type: ignore 7 | import matplotlib.pyplot as plt 8 | 9 | from py_factor_graph.measurements import FGRangeMeasurement 10 | from py_factor_graph.factor_graph import FactorGraphData 11 | from py_factor_graph.utils.logging_utils import logger 12 | from py_factor_graph.variables import dist_between_variables 13 | 14 | 15 | @define 16 | class UncalibratedRangeMeasurement: 17 | association: Tuple[str, str] = field() 18 | dist: float = field() 19 | timestamp: float = field() 20 | true_dist: Optional[float] = field(default=None) 21 | 22 | def set_true_dist(self, true_dist: float): 23 | self.true_dist = true_dist 24 | 25 | 26 | @define 27 | class LinearCalibrationModel: 28 | slope: float = field() 29 | intercept: float = field() 30 | 31 | @overload 32 | def __call__(self, x: float) -> float: 33 | ... 34 | 35 | @overload 36 | def __call__( 37 | self, x: List[UncalibratedRangeMeasurement] 38 | ) -> List[FGRangeMeasurement]: 39 | ... 40 | 41 | @overload 42 | def __call__( 43 | self, x: np.ndarray[np.dtype[np.float64]] # type: ignore 44 | ) -> np.ndarray[np.dtype[np.float64]]: # type: ignore 45 | ... 46 | 47 | def __call__( 48 | self, x: Union[float, np.ndarray, List[UncalibratedRangeMeasurement]] 49 | ) -> Union[float, np.ndarray, List[FGRangeMeasurement]]: 50 | if isinstance(x, float): 51 | return self.slope * x + self.intercept 52 | elif isinstance(x, np.ndarray): 53 | return self.slope * x + self.intercept 54 | elif isinstance(x, list): 55 | assert all([isinstance(x, UncalibratedRangeMeasurement) for x in x]) 56 | residuals = self.get_calibrated_residuals(x) 57 | calibrated_stddev = np.std(residuals) 58 | logger.info(f"Calibrated stddev: {calibrated_stddev}") 59 | calibrated_dists = self(np.array([x.dist for x in x])) 60 | calibrated_measurements = [ 61 | FGRangeMeasurement( 62 | x.association, 63 | dist=calibrated_dist, 64 | stddev=calibrated_stddev, 65 | timestamp=x.timestamp, 66 | ) 67 | for x, calibrated_dist in zip(x, calibrated_dists) 68 | ] 69 | return calibrated_measurements 70 | else: 71 | raise NotImplementedError(f"Unsupported type: {type(x)}") 72 | 73 | def get_calibrated_residuals( 74 | self, 75 | uncalibrated_measurements: List[UncalibratedRangeMeasurement], 76 | ) -> np.ndarray: 77 | """ 78 | We will fit a linear model to the range measurements and remove outliers. 79 | """ 80 | # make sure that all true distances are set 81 | assert all([x.true_dist is not None for x in uncalibrated_measurements]) 82 | 83 | measured_distances = np.array([x.dist for x in uncalibrated_measurements]) 84 | true_distances = np.array([x.true_dist for x in uncalibrated_measurements]) 85 | predicted_true_distances = self(measured_distances) 86 | residuals = true_distances - predicted_true_distances 87 | return residuals 88 | 89 | 90 | def fit_linear_calibration_model( 91 | uncalibrated_measurements: List[UncalibratedRangeMeasurement], 92 | ) -> LinearCalibrationModel: 93 | """ 94 | We will fit a linear model to the range measurements and remove outliers. 95 | """ 96 | measured_dists = np.array([x.dist for x in uncalibrated_measurements]) 97 | true_dists = np.array([x.true_dist for x in uncalibrated_measurements]) 98 | slope, intercept, r_value, p_value, std_err = linregress(measured_dists, true_dists) 99 | return LinearCalibrationModel(slope=slope, intercept=intercept) 100 | 101 | 102 | def get_inlier_set_of_range_measurements( 103 | uncalibrated_measurements: List[UncalibratedRangeMeasurement], 104 | show_outlier_rejection: bool = False, 105 | ) -> List[UncalibratedRangeMeasurement]: 106 | """ 107 | We will fit a linear model to the range measurements and remove outliers. W 108 | """ 109 | if len(uncalibrated_measurements) == 0: 110 | return [] 111 | 112 | if len(uncalibrated_measurements) < 5: 113 | logger.warning( 114 | f"Only {len(uncalibrated_measurements)} range measurements. Discarding" 115 | ) 116 | return [] 117 | 118 | association = uncalibrated_measurements[0].association 119 | if "L" in association[1]: 120 | data_set_name = f"Robot {association[0][0]} - Landmark {association[1]}" 121 | else: 122 | first_char = association[0][0] 123 | second_char = association[1][0] 124 | assert first_char != second_char, f"Invalid association: {association}" 125 | assert ( 126 | first_char != "L" and second_char != "L" 127 | ), f"Invalid association: {association}" 128 | if first_char > second_char: 129 | first_char, second_char = second_char, first_char 130 | data_set_name = f"Range calibration: Robot {first_char} - Robot {second_char}" 131 | 132 | if len(uncalibrated_measurements) == 0: 133 | return [] 134 | 135 | def _plot_inliers_and_outliers( 136 | inliers: List[UncalibratedRangeMeasurement], 137 | outliers: List[UncalibratedRangeMeasurement], 138 | ransac: linear_model.RANSACRegressor, 139 | ): 140 | inlier_measured_dists = np.array([x.dist for x in inliers]) 141 | inlier_true_dists = np.array([x.true_dist for x in inliers]) 142 | outlier_measured_dists = np.array([x.dist for x in outliers]) 143 | outlier_true_dists = np.array([x.true_dist for x in outliers]) 144 | 145 | inlier_calibrated_dists = ( 146 | ransac.predict(inlier_measured_dists.reshape(-1, 1)) 147 | if len(inlier_measured_dists) > 0 148 | else np.array([[]]).reshape(-1, 1) 149 | ) 150 | outlier_calibrated_dists = ( 151 | ransac.predict(outlier_measured_dists.reshape(-1, 1)) 152 | if len(outlier_measured_dists) > 0 153 | else np.array([[]]).reshape(-1, 1) 154 | ) 155 | 156 | # two subplots, on left and one on the right 157 | fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6)) 158 | 159 | # title over both subplots 160 | slope = ransac.estimator_.coef_[0][0] 161 | intercept = ransac.estimator_.intercept_[0] 162 | fig.suptitle(f"{data_set_name}: slope={slope:.2f}, intercept={intercept:.2f}") 163 | 164 | # on the left plot show the measured vs true distances 165 | ax1.scatter( 166 | inlier_measured_dists, 167 | inlier_true_dists, 168 | color="blue", 169 | label="inliers", 170 | ) 171 | ax1.scatter( 172 | outlier_measured_dists, 173 | outlier_true_dists, 174 | color="red", 175 | label="outliers", 176 | ) 177 | all_measured_dists = np.concatenate( 178 | [inlier_measured_dists, outlier_measured_dists] 179 | ) 180 | xmin = min(np.min(all_measured_dists), 0.0) 181 | xmax = np.max(all_measured_dists) 182 | x = np.linspace(xmin, xmax, 5) 183 | y = ransac.predict(x.reshape(-1, 1)) 184 | ax1.plot(x, y, color="black", label="linear model") 185 | ax1.set_xlabel("Measured distance (m)") 186 | ax1.set_ylabel("True distance (m)") 187 | ax1.legend() 188 | 189 | # on the right plot show the calibrated vs true distances 190 | ax2.scatter( 191 | inlier_calibrated_dists, 192 | inlier_true_dists, 193 | color="blue", 194 | label="inliers", 195 | ) 196 | ax2.scatter( 197 | outlier_calibrated_dists, 198 | outlier_true_dists, 199 | color="red", 200 | label="outliers", 201 | ) 202 | ax2.set_xlabel("Calibrated distance (m)") 203 | ax2.set_ylabel("True distance (m)") 204 | 205 | # draw the linear model up to the largest measured distance 206 | all_calibrated_dists = np.concatenate( 207 | [inlier_calibrated_dists, outlier_calibrated_dists] 208 | ) 209 | xmin = min(np.min(all_calibrated_dists), 0.0) 210 | xmax = np.max(all_calibrated_dists) 211 | x = np.linspace(xmin, xmax, 5) 212 | y = ransac.predict(x.reshape(-1, 1)) 213 | # plt.plot(x, y, color="black", label="linear model") 214 | 215 | # draw a line along the y-axis to indicate the left-half plane 216 | all_true_dists = np.concatenate([inlier_true_dists, outlier_true_dists]) 217 | ymin = 0.0 218 | ymax = np.max(all_true_dists) 219 | ax2.vlines(0.0, ymin, ymax, color="black", linestyle="--") 220 | 221 | # draw a 1-1 line 222 | ax2.plot(x, x, color="green", label="1-1 line") 223 | 224 | # make sure axis is square 225 | # plt.gca().set_aspect("equal", adjustable="box") 226 | ax1.set_aspect("equal", adjustable="box") 227 | ax2.set_aspect("equal", adjustable="box") 228 | 229 | # show labels 230 | plt.legend() 231 | 232 | plt.show(block=True) 233 | 234 | measured_dists = np.array([x.dist for x in uncalibrated_measurements]) 235 | true_dists = np.array([x.true_dist for x in uncalibrated_measurements]) 236 | 237 | # get a quick linear fit 238 | slope, intercept, _, _, _ = linregress(measured_dists, true_dists) 239 | residuals = true_dists - (slope * measured_dists + intercept) 240 | residuals_stddev = np.std(residuals) 241 | assert not np.isnan(residuals_stddev), "Residuals stddev is NaN" 242 | 243 | # ransac model is invalid if slope is too far from 1 244 | def is_model_valid(model: linear_model.LinearRegression, x, y): 245 | slope = model.coef_[0][0] 246 | return abs(slope - 1) < 0.3 247 | 248 | min_sample_ratio = 0.35 249 | ransac = linear_model.RANSACRegressor( 250 | residual_threshold=2 * residuals_stddev, 251 | min_samples=min_sample_ratio, 252 | is_model_valid=is_model_valid, 253 | max_trials=1000, 254 | ) 255 | 256 | try: 257 | ransac.fit(measured_dists.reshape(-1, 1), true_dists.reshape(-1, 1)) 258 | except ValueError as e: 259 | logger.error( 260 | f"{data_set_name}: Discarding all {len(uncalibrated_measurements)} measurements.\n{e}" 261 | ) 262 | _plot_inliers_and_outliers([], [], ransac) 263 | return [] 264 | 265 | slope = ransac.estimator_.coef_[0][0] 266 | if abs(slope - 1) > 0.1: 267 | logger.warning( 268 | f"{data_set_name}: {len(uncalibrated_measurements)} measurements. Calibration slope of {slope:.2f} detected. This may be due to errors in the data." 269 | ) 270 | 271 | inlier_mask = ransac.inlier_mask_ 272 | inlier_measurements = [] 273 | outlier_measurements = [] 274 | for measurement, is_inlier in zip(uncalibrated_measurements, inlier_mask): 275 | if is_inlier: 276 | inlier_measurements.append(measurement) 277 | else: 278 | outlier_measurements.append(measurement) 279 | 280 | if show_outlier_rejection: 281 | _plot_inliers_and_outliers(inlier_measurements, outlier_measurements, ransac) 282 | 283 | logger.debug( 284 | f"{data_set_name}: {len(inlier_measurements)} inliers, {len(outlier_measurements)} outliers" 285 | ) 286 | 287 | return inlier_measurements 288 | 289 | 290 | def get_linearly_calibrated_measurements( 291 | uncalibrated_measurements: List[UncalibratedRangeMeasurement], 292 | ) -> List[FGRangeMeasurement]: 293 | """ 294 | We will fit a linear model to the range measurements and remove outliers. W 295 | """ 296 | linear_calibration = fit_linear_calibration_model(uncalibrated_measurements) 297 | calibrated_measurements = linear_calibration(uncalibrated_measurements) 298 | return calibrated_measurements 299 | 300 | 301 | def get_range_measurements_by_association( 302 | pyfg: FactorGraphData, 303 | ) -> Dict[Tuple[str, str], List[FGRangeMeasurement]]: 304 | uncalibrated_measurements = pyfg.range_measurements 305 | true_variable_positions = pyfg.variable_true_positions_dict 306 | 307 | # group the range measurements by association 308 | # e.g., 309 | # - (A1, L1) and (A15, L1) will be grouped together as (A, L1) 310 | # - (B23, L10) and (B138, L10) will be grouped together as (B, L10) 311 | # - (A5, L1) and (B23, L10) will not be grouped together 312 | 313 | # get valid variables 314 | valid_variable_groups = set() 315 | for variable_name in true_variable_positions.keys(): 316 | if "L" in variable_name: 317 | valid_variable_groups.add(variable_name) 318 | else: 319 | valid_variable_groups.add(variable_name[0]) 320 | 321 | # get valid associations as the cross product of valid variables 322 | valid_associations = set() 323 | for var1 in valid_variable_groups: 324 | for var2 in valid_variable_groups: 325 | # pair should be alphabetically sorted, except "L" should always be second 326 | if var1 == var2 or "L" in var1 and "L" in var2: 327 | continue 328 | elif "L" in var1: 329 | valid_associations.add((var2, var1)) 330 | elif "L" in var2: 331 | valid_associations.add((var1, var2)) 332 | elif var1 < var2: 333 | valid_associations.add((var1, var2)) 334 | else: 335 | valid_associations.add((var2, var1)) 336 | 337 | def get_association_grouping(association: Tuple[str, str]) -> Tuple[str, str]: 338 | # assert at most one "L" in the association and must be second 339 | a1, a2 = association 340 | if "L" in a1 and "L" in a2: 341 | raise ValueError(f"Invalid association: {association}") 342 | elif "L" in a1: 343 | raise ValueError(f"Invalid association: {association}") 344 | 345 | if "L" in association[1]: 346 | return a1[0], a2 347 | 348 | a1_group = a1[0] 349 | a2_group = a2[0] 350 | if a1_group == a2_group: 351 | raise ValueError(f"Invalid measurement association: {association}") 352 | elif a1_group < a2_group: 353 | return a1_group, a2_group 354 | else: 355 | return a2_group, a1_group 356 | 357 | # group the range measurements by association 358 | association_to_measurements: Dict[Tuple[str, str], List[FGRangeMeasurement]] = { 359 | pair: [] for pair in valid_associations 360 | } 361 | for measurement in uncalibrated_measurements: 362 | association = measurement.association 363 | association_group = get_association_grouping(association) 364 | if association_group not in valid_associations: 365 | raise ValueError(f"Invalid association: {association}") 366 | 367 | association_to_measurements[association_group].append(measurement) 368 | 369 | # if any association has no measurements, remove it 370 | association_to_measurements = { 371 | association: measurements 372 | for association, measurements in association_to_measurements.items() 373 | if len(measurements) > 0 374 | } 375 | 376 | # sort the measurements by timestamp 377 | for association, measurements in association_to_measurements.items(): 378 | association_to_measurements[association] = sorted( 379 | measurements, key=lambda x: x.timestamp 380 | ) 381 | 382 | return association_to_measurements 383 | 384 | 385 | def calibrate_range_measures( 386 | pyfg: FactorGraphData, 387 | show_outlier_rejection: bool = False, 388 | ) -> FactorGraphData: 389 | """ 390 | We will fit a linear model to the range measurements and remove outliers. W 391 | """ 392 | 393 | measurements_by_association = get_range_measurements_by_association(pyfg) 394 | variables_by_name = pyfg.pose_and_landmark_variables_dict 395 | uncalibrated_measures_by_association: Dict[ 396 | Tuple[str, str], List[UncalibratedRangeMeasurement] 397 | ] = {association: [] for association in measurements_by_association.keys()} 398 | for radio_association, measurements in measurements_by_association.items(): 399 | for measure in measurements: 400 | assert measure.timestamp is not None, "Timestamp must be set." 401 | 402 | var1, var2 = ( 403 | variables_by_name[measure.association[0]], 404 | variables_by_name[measure.association[1]], 405 | ) 406 | assert var1 is not None, f"Variable {measure.association[0]} not found" 407 | assert var2 is not None, f"Variable {measure.association[1]} not found" 408 | 409 | uncalibrated_measures_by_association[radio_association].append( 410 | UncalibratedRangeMeasurement( 411 | association=measure.association, 412 | dist=measure.dist, 413 | timestamp=measure.timestamp, 414 | true_dist=dist_between_variables(var1, var2), 415 | ) 416 | ) 417 | 418 | # for each measurement group, get the inlier set 419 | inlier_measurements = [] 420 | for _, uncalibrated_measurements in uncalibrated_measures_by_association.items(): 421 | inlier_measurements += get_inlier_set_of_range_measurements( 422 | uncalibrated_measurements, show_outlier_rejection=show_outlier_rejection 423 | ) 424 | 425 | # get the calibrated measurements 426 | calibrated_measurements = get_linearly_calibrated_measurements(inlier_measurements) 427 | 428 | # update the factor graph data 429 | pyfg.range_measurements = calibrated_measurements 430 | 431 | return pyfg 432 | 433 | 434 | def reject_measurements_based_on_temporal_consistency( 435 | pyfg: FactorGraphData, show_outlier_rejection: bool = False 436 | ) -> FactorGraphData: 437 | """The idea here is that range measurements that are close to each other in 438 | time should have similar distances. If they don't, then we should discard 439 | them. 440 | 441 | Args: 442 | pyfg (FactorGraphData): the original data 443 | 444 | Returns: 445 | FactorGraphData: the updated data 446 | """ 447 | measures_by_association = get_range_measurements_by_association(pyfg) 448 | inlier_measures = [] 449 | for association, measures in measures_by_association.items(): 450 | if len(measures) == 0: 451 | raise ValueError(f"No measurements for association: {association}") 452 | 453 | filtered_measures = apply_savgol_outlier_rejection( 454 | measures, 455 | plot_title=str(association), 456 | show_outlier_rejection=show_outlier_rejection, 457 | ) 458 | inlier_measures += filtered_measures 459 | 460 | pyfg.range_measurements = inlier_measures 461 | return pyfg 462 | 463 | 464 | def apply_savgol_outlier_rejection( 465 | original_measurements: List[FGRangeMeasurement], 466 | plot_title: Optional[str] = None, 467 | show_outlier_rejection: bool = False, 468 | ) -> List[FGRangeMeasurement]: 469 | """Use the Savitzky-Golay filter to smooth the data and remove outliers. 470 | 471 | Args: 472 | List[FGRangeMeasurement]: the range measurements 473 | 474 | Returns: 475 | List[FGRangeMeasurements]: the filtered range measurements 476 | """ 477 | if plot_title is None: 478 | logger.debug( 479 | f"Applying Savitzky-Golay outlier rejection to {len(original_measurements)} measurements" 480 | ) 481 | else: 482 | logger.debug( 483 | f"Applying Savitzky-Golay outlier rejection to {len(original_measurements)} measurements: {plot_title}" 484 | ) 485 | 486 | distances = np.array([x.dist for x in original_measurements]) 487 | timestamps_ns = np.array([x.timestamp for x in original_measurements]) 488 | 489 | # convert timestamps to seconds and subtract the first timestamp 490 | timestamps = (timestamps_ns - timestamps_ns[0]) / 1e9 491 | 492 | # get the update frequency on the data 493 | update_freq_hz = np.median(np.diff(timestamps)) 494 | window_size_seconds = 2.0 495 | window_size_samples = int(window_size_seconds / update_freq_hz) 496 | 497 | poly_degree = 2 498 | if window_size_samples <= poly_degree: 499 | logger.warning( 500 | f"Window size of {window_size_samples} samples is too small for polynomial degree of {poly_degree}. Rejecting all measurements." 501 | ) 502 | return [] 503 | 504 | smoothed_distances = savgol_filter(distances, window_size_samples, poly_degree) 505 | 506 | # Calculate residuals (difference between original and smoothed values) 507 | residuals = np.abs(distances - smoothed_distances) 508 | 509 | # Compute the threshold based on the median and MAD of residuals 510 | mad_residuals = np.median(np.abs(residuals - np.median(residuals))) 511 | median_abs_deviation_threshold = 5.0 512 | threshold_value = median_abs_deviation_threshold * mad_residuals 513 | 514 | # Detect outliers 515 | outliers = residuals > threshold_value 516 | inliers = ~outliers 517 | 518 | # plot the data, with outliers in red. Size of the point is 1 519 | if show_outlier_rejection: 520 | plt.scatter(timestamps[inliers], distances[inliers], label="Inliers", s=1) 521 | plt.scatter( 522 | timestamps[outliers], 523 | distances[outliers], 524 | label="Outliers", 525 | s=1, 526 | color="red", 527 | ) 528 | 529 | # draw the smoothed data 530 | plt.plot( 531 | timestamps, 532 | smoothed_distances, 533 | label="Smoothed Distances", 534 | color="orange", 535 | linewidth=1.0, 536 | linestyle="--", 537 | ) 538 | 539 | plt.xlabel("Timestamp") 540 | plt.ylabel("Distance") 541 | if plot_title is not None: 542 | plt.title(plot_title) 543 | plt.legend() 544 | plt.show() 545 | 546 | # return the inliers, with the smoothed distances in place of the original 547 | smoothed_inlier_ranges = [ 548 | FGRangeMeasurement( 549 | x.association, 550 | dist=smoothed_dist, 551 | stddev=x.stddev, 552 | timestamp=x.timestamp, 553 | ) 554 | for x, smoothed_dist, is_inlier in zip( 555 | original_measurements, smoothed_distances, inliers 556 | ) 557 | if is_inlier 558 | ] 559 | return smoothed_inlier_ranges 560 | 561 | # return [x for x, is_inlier in zip(original_measurements, inliers) if is_inlier] 562 | -------------------------------------------------------------------------------- /docs/py_factor_graph/parse_factor_graph.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | py_factor_graph.parsing API documentation 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 |
20 |
21 |
22 |

Module py_factor_graph.parsing

23 |
24 |
25 |
26 | 27 | Expand source code 28 | 29 |
from typing import List
 30 | from os.path import isfile
 31 | import numpy as np
 32 | import pickle
 33 | 
 34 | from py_factor_graph.variables import PoseVariable, LandmarkVariable
 35 | from py_factor_graph.measurements import (
 36 |     PoseMeasurement2D,
 37 |     AmbiguousPoseMeasurement2D,
 38 |     FGRangeMeasurement,
 39 |     AmbiguousFGRangeMeasurement,
 40 | )
 41 | from py_factor_graph.priors import PosePrior, LandmarkPrior
 42 | from py_factor_graph.factor_graph import (
 43 |     FactorGraphData,
 44 | )
 45 | from py_factor_graph.utils.name_utils import (
 46 |     get_robot_idx_from_frame_name,
 47 |     get_time_idx_from_frame_name,
 48 | )
 49 | from py_factor_graph.utils.data_utils import get_covariance_matrix_from_list
 50 | 
 51 | 
 52 | def parse_efg_file(filepath: str) -> FactorGraphData:
 53 |     """
 54 |     Parse a factor graph file to extract the factors and variables. Requires
 55 |     that the file ends with .fg (e.g. "my_file.fg").
 56 | 
 57 |     Args:
 58 |         filepath: The path to the factor graph file.
 59 | 
 60 |     Returns:
 61 |         FactorGraphData: The factor graph data.
 62 |     """
 63 |     assert isfile(filepath), f"{filepath} is not a file"
 64 |     assert filepath.endswith(".fg"), f"{filepath} is not an efg file"
 65 | 
 66 |     pose_var_header = "Variable Pose SE2"
 67 |     landmark_var_header = "Variable Landmark R2"
 68 |     pose_measure_header = "Factor SE2RelativeGaussianLikelihoodFactor"
 69 |     amb_measure_header = "Factor AmbiguousDataAssociationFactor"
 70 |     range_measure_header = "Factor SE2R2RangeGaussianLikelihoodFactor"
 71 |     pose_prior_header = "Factor UnarySE2ApproximateGaussianPriorFactor"
 72 |     landmark_prior_header = "Landmark"  # don't have any of these yet
 73 | 
 74 |     new_fg_data = FactorGraphData(dimension=2)
 75 | 
 76 |     with open(filepath, "r") as f:
 77 |         for line in f:
 78 |             if line.startswith(pose_var_header):
 79 |                 line_items = line.split()
 80 |                 pose_name = line_items[3]
 81 |                 x = float(line_items[4])
 82 |                 y = float(line_items[5])
 83 |                 theta = float(line_items[6])
 84 |                 pose_var = PoseVariable(pose_name, (x, y), theta)
 85 |                 new_fg_data.add_pose_variable(pose_var)
 86 |             elif line.startswith(landmark_var_header):
 87 |                 line_items = line.split()
 88 |                 landmark_name = line_items[3]
 89 |                 x = float(line_items[4])
 90 |                 y = float(line_items[5])
 91 |                 landmark_var = LandmarkVariable(landmark_name, (x, y))
 92 |                 new_fg_data.add_landmark_variable(landmark_var)
 93 |             elif line.startswith(pose_measure_header):
 94 |                 line_items = line.split()
 95 |                 base_pose = line_items[2]
 96 |                 local_pose = line_items[3]
 97 |                 delta_x = float(line_items[4])
 98 |                 delta_y = float(line_items[5])
 99 |                 delta_theta = float(line_items[6])
100 |                 covar_list = [float(x) for x in line_items[8:]]
101 |                 covar = get_covariance_matrix_from_list(covar_list)
102 |                 # assert covar[0, 0] == covar[1, 1]
103 |                 trans_weight = 1 / (covar[0, 0])
104 |                 rot_weight = 1 / (covar[2, 2])
105 |                 measure = PoseMeasurement2D(
106 |                     base_pose,
107 |                     local_pose,
108 |                     delta_x,
109 |                     delta_y,
110 |                     delta_theta,
111 |                     trans_weight,
112 |                     rot_weight,
113 |                 )
114 | 
115 |                 base_pose_idx = get_robot_idx_from_frame_name(base_pose)
116 |                 local_pose_idx = get_robot_idx_from_frame_name(local_pose)
117 |                 base_time_idx = get_time_idx_from_frame_name(base_pose)
118 |                 local_time_idx = get_time_idx_from_frame_name(local_pose)
119 | 
120 |                 # if either the robot indices are different or the time indices
121 |                 # are not sequential then it is a loop closure
122 |                 if (
123 |                     base_pose_idx != local_pose_idx
124 |                     or local_time_idx != base_time_idx + 1
125 |                 ):
126 |                     new_fg_data.add_loop_closure(measure)
127 | 
128 |                 # otherwise it is an odometry measurement
129 |                 else:
130 |                     new_fg_data.add_odom_measurement(base_pose_idx, measure)
131 | 
132 |             elif line.startswith(range_measure_header):
133 |                 line_items = line.split()
134 |                 var1 = line_items[2]
135 |                 var2 = line_items[3]
136 |                 dist = float(line_items[4])
137 |                 stddev = float(line_items[5])
138 |                 range_measure = FGRangeMeasurement((var1, var2), dist, stddev)
139 |                 new_fg_data.add_range_measurement(range_measure)
140 | 
141 |             elif line.startswith(pose_prior_header):
142 |                 line_items = line.split()
143 |                 pose_name = line_items[2]
144 |                 x = float(line_items[3])
145 |                 y = float(line_items[4])
146 |                 theta = float(line_items[5])
147 |                 covar_list = [float(x) for x in line_items[7:]]
148 |                 covar = get_covariance_matrix_from_list(covar_list)
149 |                 pose_prior = PosePrior(pose_name, (x, y), theta, covar)
150 |                 new_fg_data.add_pose_prior(pose_prior)
151 | 
152 |             elif line.startswith(landmark_prior_header):
153 |                 raise NotImplementedError("Landmark priors not implemented yet")
154 |             elif line.startswith(amb_measure_header):
155 |                 line_items = line.split()
156 | 
157 |                 # if it is a range measurement then add to ambiguous range
158 |                 # measurements list
159 |                 if "SE2R2RangeGaussianLikelihoodFactor" in line:
160 |                     raise NotImplementedError(
161 |                         "Need to parse for ambiguous range measurements measurement"
162 |                     )
163 | 
164 |                 # if it is a pose measurement then add to ambiguous pose
165 |                 # measurements list
166 |                 elif "SE2RelativeGaussianLikelihoodFactor" in line:
167 |                     raise NotImplementedError(
168 |                         "Need to parse for ambiguous pose measurement"
169 |                     )
170 | 
171 |                 # this is a case that we haven't planned for yet
172 |                 else:
173 |                     raise NotImplementedError(
174 |                         f"Unknown measurement type in ambiguous measurement: {line}"
175 |                     )
176 | 
177 |     return new_fg_data
178 | 
179 | 
180 | def parse_pickle_file(filepath: str) -> FactorGraphData:
181 |     """
182 |     Retrieve a pickled FactorGraphData object. Requires that the
183 |     file ends with .pickle (e.g. "my_file.pickle").
184 | 
185 |     Args:
186 |         filepath: The path to the factor graph file.
187 | 
188 |     Returns:
189 |         FactorGraphData: The factor graph data.
190 |     """
191 |     assert isfile(filepath), f"{filepath} is not a file"
192 |     assert filepath.endswith(".pickle"), f"{filepath} is not a pickle file"
193 | 
194 |     with open(filepath, "rb") as f:
195 |         data = pickle.load(f)
196 |         return data
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |

Functions

205 |
206 |
207 | def parse_efg_file(filepath: str) ‑> FactorGraphData 208 |
209 |
210 |

Parse a factor graph file to extract the factors and variables. Requires 211 | that the file ends with .fg (e.g. "my_file.fg").

212 |

Args

213 |
214 |
filepath
215 |
The path to the factor graph file.
216 |
217 |

Returns

218 |
219 |
FactorGraphData
220 |
The factor graph data.
221 |
222 |
223 | 224 | Expand source code 225 | 226 |
def parse_efg_file(filepath: str) -> FactorGraphData:
227 |     """
228 |     Parse a factor graph file to extract the factors and variables. Requires
229 |     that the file ends with .fg (e.g. "my_file.fg").
230 | 
231 |     Args:
232 |         filepath: The path to the factor graph file.
233 | 
234 |     Returns:
235 |         FactorGraphData: The factor graph data.
236 |     """
237 |     assert isfile(filepath), f"{filepath} is not a file"
238 |     assert filepath.endswith(".fg"), f"{filepath} is not an efg file"
239 | 
240 |     pose_var_header = "Variable Pose SE2"
241 |     landmark_var_header = "Variable Landmark R2"
242 |     pose_measure_header = "Factor SE2RelativeGaussianLikelihoodFactor"
243 |     amb_measure_header = "Factor AmbiguousDataAssociationFactor"
244 |     range_measure_header = "Factor SE2R2RangeGaussianLikelihoodFactor"
245 |     pose_prior_header = "Factor UnarySE2ApproximateGaussianPriorFactor"
246 |     landmark_prior_header = "Landmark"  # don't have any of these yet
247 | 
248 |     new_fg_data = FactorGraphData(dimension=2)
249 | 
250 |     with open(filepath, "r") as f:
251 |         for line in f:
252 |             if line.startswith(pose_var_header):
253 |                 line_items = line.split()
254 |                 pose_name = line_items[3]
255 |                 x = float(line_items[4])
256 |                 y = float(line_items[5])
257 |                 theta = float(line_items[6])
258 |                 pose_var = PoseVariable(pose_name, (x, y), theta)
259 |                 new_fg_data.add_pose_variable(pose_var)
260 |             elif line.startswith(landmark_var_header):
261 |                 line_items = line.split()
262 |                 landmark_name = line_items[3]
263 |                 x = float(line_items[4])
264 |                 y = float(line_items[5])
265 |                 landmark_var = LandmarkVariable(landmark_name, (x, y))
266 |                 new_fg_data.add_landmark_variable(landmark_var)
267 |             elif line.startswith(pose_measure_header):
268 |                 line_items = line.split()
269 |                 base_pose = line_items[2]
270 |                 local_pose = line_items[3]
271 |                 delta_x = float(line_items[4])
272 |                 delta_y = float(line_items[5])
273 |                 delta_theta = float(line_items[6])
274 |                 covar_list = [float(x) for x in line_items[8:]]
275 |                 covar = get_covariance_matrix_from_list(covar_list)
276 |                 # assert covar[0, 0] == covar[1, 1]
277 |                 trans_weight = 1 / (covar[0, 0])
278 |                 rot_weight = 1 / (covar[2, 2])
279 |                 measure = PoseMeasurement2D(
280 |                     base_pose,
281 |                     local_pose,
282 |                     delta_x,
283 |                     delta_y,
284 |                     delta_theta,
285 |                     trans_weight,
286 |                     rot_weight,
287 |                 )
288 | 
289 |                 base_pose_idx = get_robot_idx_from_frame_name(base_pose)
290 |                 local_pose_idx = get_robot_idx_from_frame_name(local_pose)
291 |                 base_time_idx = get_time_idx_from_frame_name(base_pose)
292 |                 local_time_idx = get_time_idx_from_frame_name(local_pose)
293 | 
294 |                 # if either the robot indices are different or the time indices
295 |                 # are not sequential then it is a loop closure
296 |                 if (
297 |                     base_pose_idx != local_pose_idx
298 |                     or local_time_idx != base_time_idx + 1
299 |                 ):
300 |                     new_fg_data.add_loop_closure(measure)
301 | 
302 |                 # otherwise it is an odometry measurement
303 |                 else:
304 |                     new_fg_data.add_odom_measurement(base_pose_idx, measure)
305 | 
306 |             elif line.startswith(range_measure_header):
307 |                 line_items = line.split()
308 |                 var1 = line_items[2]
309 |                 var2 = line_items[3]
310 |                 dist = float(line_items[4])
311 |                 stddev = float(line_items[5])
312 |                 range_measure = FGRangeMeasurement((var1, var2), dist, stddev)
313 |                 new_fg_data.add_range_measurement(range_measure)
314 | 
315 |             elif line.startswith(pose_prior_header):
316 |                 line_items = line.split()
317 |                 pose_name = line_items[2]
318 |                 x = float(line_items[3])
319 |                 y = float(line_items[4])
320 |                 theta = float(line_items[5])
321 |                 covar_list = [float(x) for x in line_items[7:]]
322 |                 covar = get_covariance_matrix_from_list(covar_list)
323 |                 pose_prior = PosePrior(pose_name, (x, y), theta, covar)
324 |                 new_fg_data.add_pose_prior(pose_prior)
325 | 
326 |             elif line.startswith(landmark_prior_header):
327 |                 raise NotImplementedError("Landmark priors not implemented yet")
328 |             elif line.startswith(amb_measure_header):
329 |                 line_items = line.split()
330 | 
331 |                 # if it is a range measurement then add to ambiguous range
332 |                 # measurements list
333 |                 if "SE2R2RangeGaussianLikelihoodFactor" in line:
334 |                     raise NotImplementedError(
335 |                         "Need to parse for ambiguous range measurements measurement"
336 |                     )
337 | 
338 |                 # if it is a pose measurement then add to ambiguous pose
339 |                 # measurements list
340 |                 elif "SE2RelativeGaussianLikelihoodFactor" in line:
341 |                     raise NotImplementedError(
342 |                         "Need to parse for ambiguous pose measurement"
343 |                     )
344 | 
345 |                 # this is a case that we haven't planned for yet
346 |                 else:
347 |                     raise NotImplementedError(
348 |                         f"Unknown measurement type in ambiguous measurement: {line}"
349 |                     )
350 | 
351 |     return new_fg_data
352 |
353 |
354 |
355 | def parse_pickle_file(filepath: str) ‑> FactorGraphData 356 |
357 |
358 |

Retrieve a pickled FactorGraphData object. Requires that the 359 | file ends with .pickle (e.g. "my_file.pickle").

360 |

Args

361 |
362 |
filepath
363 |
The path to the factor graph file.
364 |
365 |

Returns

366 |
367 |
FactorGraphData
368 |
The factor graph data.
369 |
370 |
371 | 372 | Expand source code 373 | 374 |
def parse_pickle_file(filepath: str) -> FactorGraphData:
375 |     """
376 |     Retrieve a pickled FactorGraphData object. Requires that the
377 |     file ends with .pickle (e.g. "my_file.pickle").
378 | 
379 |     Args:
380 |         filepath: The path to the factor graph file.
381 | 
382 |     Returns:
383 |         FactorGraphData: The factor graph data.
384 |     """
385 |     assert isfile(filepath), f"{filepath} is not a file"
386 |     assert filepath.endswith(".pickle"), f"{filepath} is not a pickle file"
387 | 
388 |     with open(filepath, "rb") as f:
389 |         data = pickle.load(f)
390 |         return data
391 |
392 |
393 |
394 |
395 |
396 |
397 |
398 | 417 |
418 | 421 | 422 | --------------------------------------------------------------------------------