├── transformermpc ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-311.pyc │ │ ├── dataset.cpython-311.pyc │ │ └── qp_generator.cpython-311.pyc │ ├── qp_generator.py │ └── dataset.py ├── demo │ ├── __init__.py │ └── demo.py ├── models │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-311.pyc │ │ ├── constraint_predictor.cpython-311.pyc │ │ └── warm_start_predictor.cpython-311.pyc │ ├── constraint_predictor.py │ └── warm_start_predictor.py ├── utils │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-311.pyc │ │ ├── osqp_wrapper.cpython-311.pyc │ │ └── visualization.cpython-311.pyc │ ├── metrics.py │ ├── osqp_wrapper.py │ └── visualization.py ├── training │ ├── __init__.py │ └── __pycache__ │ │ ├── __init__.cpython-311.pyc │ │ └── trainer.cpython-311.pyc ├── __pycache__ │ └── __init__.cpython-311.pyc └── __init__.py ├── transformermpc.egg-info ├── not-zip-safe ├── dependency_links.txt ├── top_level.txt ├── entry_points.txt ├── requires.txt ├── SOURCES.txt └── PKG-INFO ├── summary_plot.png ├── training_curves.png ├── boxplot_comparison.png ├── demo_results ├── models │ ├── best_model.pt │ └── constraint_predictor.pt └── results │ ├── speedup_barchart.png │ ├── solve_time_boxplot.png │ └── solve_time_violinplot.png ├── dist ├── transformermpc-0.1.6.tar.gz └── transformermpc-0.1.6-py3-none-any.whl ├── tests ├── dist │ ├── transformermpc-0.1.6.tar.gz │ └── transformermpc-0.1.6-py3-none-any.whl ├── demo_results │ ├── models │ │ ├── best_model.pt │ │ └── constraint_predictor.pt │ └── results │ │ ├── speedup_barchart.png │ │ ├── solve_time_boxplot.png │ │ └── solve_time_violinplot.png ├── test_import.py ├── scripts │ ├── run_demo.py │ ├── verify_package.py │ ├── verify_structure.py │ └── boxplot_demo.py ├── test_package_structure.py ├── test_solve.py ├── test_files.py ├── test_new_models.py ├── examples │ └── example_usage.py └── test_benchmark.py ├── setup.cfg ├── requirements.txt ├── MANIFEST.in ├── scripts ├── run_demo.py ├── verify_package.py ├── verify_structure.py ├── boxplot_demo.py └── simple_demo.py ├── LICENSE ├── setup.py ├── pyproject.toml ├── examples └── example_usage.py └── README.md /transformermpc/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /transformermpc/demo/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /transformermpc/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /transformermpc/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /transformermpc/training/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /transformermpc.egg-info/not-zip-safe: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /transformermpc.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /transformermpc.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | transformermpc 2 | -------------------------------------------------------------------------------- /summary_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/summary_plot.png -------------------------------------------------------------------------------- /training_curves.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/training_curves.png -------------------------------------------------------------------------------- /boxplot_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/boxplot_comparison.png -------------------------------------------------------------------------------- /transformermpc.egg-info/entry_points.txt: -------------------------------------------------------------------------------- 1 | [console_scripts] 2 | transformermpc-demo = transformermpc.demo:main 3 | -------------------------------------------------------------------------------- /demo_results/models/best_model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/demo_results/models/best_model.pt -------------------------------------------------------------------------------- /dist/transformermpc-0.1.6.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/dist/transformermpc-0.1.6.tar.gz -------------------------------------------------------------------------------- /tests/dist/transformermpc-0.1.6.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/tests/dist/transformermpc-0.1.6.tar.gz -------------------------------------------------------------------------------- /demo_results/results/speedup_barchart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/demo_results/results/speedup_barchart.png -------------------------------------------------------------------------------- /tests/demo_results/models/best_model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/tests/demo_results/models/best_model.pt -------------------------------------------------------------------------------- /demo_results/models/constraint_predictor.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/demo_results/models/constraint_predictor.pt -------------------------------------------------------------------------------- /demo_results/results/solve_time_boxplot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/demo_results/results/solve_time_boxplot.png -------------------------------------------------------------------------------- /dist/transformermpc-0.1.6-py3-none-any.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/dist/transformermpc-0.1.6-py3-none-any.whl -------------------------------------------------------------------------------- /demo_results/results/solve_time_violinplot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/demo_results/results/solve_time_violinplot.png -------------------------------------------------------------------------------- /tests/demo_results/results/speedup_barchart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/tests/demo_results/results/speedup_barchart.png -------------------------------------------------------------------------------- /tests/dist/transformermpc-0.1.6-py3-none-any.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/tests/dist/transformermpc-0.1.6-py3-none-any.whl -------------------------------------------------------------------------------- /tests/demo_results/models/constraint_predictor.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/tests/demo_results/models/constraint_predictor.pt -------------------------------------------------------------------------------- /tests/demo_results/results/solve_time_boxplot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/tests/demo_results/results/solve_time_boxplot.png -------------------------------------------------------------------------------- /transformermpc/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/transformermpc/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | license_file = LICENSE 3 | 4 | [options] 5 | zip_safe = False 6 | include_package_data = True 7 | 8 | [options.package_data] 9 | * = *.py -------------------------------------------------------------------------------- /tests/demo_results/results/solve_time_violinplot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/tests/demo_results/results/solve_time_violinplot.png -------------------------------------------------------------------------------- /transformermpc/data/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/transformermpc/data/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /transformermpc/data/__pycache__/dataset.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/transformermpc/data/__pycache__/dataset.cpython-311.pyc -------------------------------------------------------------------------------- /transformermpc/models/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/transformermpc/models/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /transformermpc/utils/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/transformermpc/utils/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /transformermpc/data/__pycache__/qp_generator.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/transformermpc/data/__pycache__/qp_generator.cpython-311.pyc -------------------------------------------------------------------------------- /transformermpc/training/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/transformermpc/training/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /transformermpc/training/__pycache__/trainer.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/transformermpc/training/__pycache__/trainer.cpython-311.pyc -------------------------------------------------------------------------------- /transformermpc/utils/__pycache__/osqp_wrapper.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/transformermpc/utils/__pycache__/osqp_wrapper.cpython-311.pyc -------------------------------------------------------------------------------- /transformermpc/utils/__pycache__/visualization.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/transformermpc/utils/__pycache__/visualization.cpython-311.pyc -------------------------------------------------------------------------------- /transformermpc/models/__pycache__/constraint_predictor.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/transformermpc/models/__pycache__/constraint_predictor.cpython-311.pyc -------------------------------------------------------------------------------- /transformermpc/models/__pycache__/warm_start_predictor.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/transformermpc/models/__pycache__/warm_start_predictor.cpython-311.pyc -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.20.0,<2.0.0 2 | scipy>=1.7.0 3 | torch>=1.9.0 4 | osqp>=0.6.2 5 | matplotlib>=3.4.0 6 | pandas>=1.3.0 7 | tqdm>=4.62.0 8 | scikit-learn>=0.24.0 9 | tensorboard>=2.7.0 10 | quadprog>=0.1.11 11 | -------------------------------------------------------------------------------- /transformermpc.egg-info/requires.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.20.0 2 | scipy>=1.7.0 3 | torch>=1.9.0 4 | osqp>=0.6.2 5 | matplotlib>=3.4.0 6 | pandas>=1.3.0 7 | tqdm>=4.62.0 8 | scikit-learn>=0.24.0 9 | tensorboard>=2.7.0 10 | quadprog>=0.1.11 11 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include README.md 3 | include requirements.txt 4 | recursive-include transformermpc *.py 5 | recursive-include transformermpc/data *.pkl *.npy 6 | recursive-include transformermpc/models *.pth *.pt 7 | recursive-include transformermpc/demo *.png *.jpg -------------------------------------------------------------------------------- /tests/test_import.py: -------------------------------------------------------------------------------- 1 | import transformermpc 2 | print("Package imported successfully!") 3 | 4 | from transformermpc.models.constraint_predictor import ConstraintPredictor 5 | from transformermpc.models.warm_start_predictor import WarmStartPredictor 6 | print("Models imported successfully!") 7 | 8 | from transformermpc.demo.demo import parse_args 9 | print("Demo imported successfully!") -------------------------------------------------------------------------------- /scripts/run_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import argparse 4 | 5 | # Add the current directory to Python path 6 | sys.path.insert(0, os.path.abspath(os.path.dirname(__file__))) 7 | 8 | def main(): 9 | # Import the demo module 10 | from transformermpc.transformermpc.demo.demo import main as demo_main 11 | 12 | # Run the demo 13 | demo_main() 14 | 15 | if __name__ == "__main__": 16 | main() -------------------------------------------------------------------------------- /tests/scripts/run_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import argparse 4 | 5 | # Add the current directory to Python path 6 | sys.path.insert(0, os.path.abspath(os.path.dirname(__file__))) 7 | 8 | def main(): 9 | # Import the demo module 10 | from transformermpc.transformermpc.demo.demo import main as demo_main 11 | 12 | # Run the demo 13 | demo_main() 14 | 15 | if __name__ == "__main__": 16 | main() -------------------------------------------------------------------------------- /scripts/verify_package.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | 5 | def list_py_files(start_dir): 6 | print(f"Scanning directory: {start_dir}") 7 | 8 | for root, dirs, files in os.walk(start_dir): 9 | py_files = [f for f in files if f.endswith('.py')] 10 | if py_files: 11 | rel_path = os.path.relpath(root, start_dir) 12 | if rel_path == '.': 13 | rel_path = '' 14 | print(f"\nIn {rel_path or 'root'}:") 15 | for py_file in sorted(py_files): 16 | print(f" - {py_file}") 17 | 18 | if __name__ == "__main__": 19 | list_py_files("transformermpc") -------------------------------------------------------------------------------- /tests/scripts/verify_package.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | 5 | def list_py_files(start_dir): 6 | print(f"Scanning directory: {start_dir}") 7 | 8 | for root, dirs, files in os.walk(start_dir): 9 | py_files = [f for f in files if f.endswith('.py')] 10 | if py_files: 11 | rel_path = os.path.relpath(root, start_dir) 12 | if rel_path == '.': 13 | rel_path = '' 14 | print(f"\nIn {rel_path or 'root'}:") 15 | for py_file in sorted(py_files): 16 | print(f" - {py_file}") 17 | 18 | if __name__ == "__main__": 19 | list_py_files("transformermpc") -------------------------------------------------------------------------------- /transformermpc.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | LICENSE 2 | MANIFEST.in 3 | README.md 4 | pyproject.toml 5 | requirements.txt 6 | setup.cfg 7 | setup.py 8 | tests/test_benchmark.py 9 | tests/test_files.py 10 | tests/test_import.py 11 | tests/test_new_models.py 12 | tests/test_package_structure.py 13 | tests/test_solve.py 14 | transformermpc/__init__.py 15 | transformermpc.egg-info/PKG-INFO 16 | transformermpc.egg-info/SOURCES.txt 17 | transformermpc.egg-info/dependency_links.txt 18 | transformermpc.egg-info/entry_points.txt 19 | transformermpc.egg-info/not-zip-safe 20 | transformermpc.egg-info/requires.txt 21 | transformermpc.egg-info/top_level.txt 22 | transformermpc/data/__init__.py 23 | transformermpc/data/dataset.py 24 | transformermpc/data/qp_generator.py 25 | transformermpc/demo/__init__.py 26 | transformermpc/demo/demo.py 27 | transformermpc/models/__init__.py 28 | transformermpc/models/constraint_predictor.py 29 | transformermpc/models/warm_start_predictor.py 30 | transformermpc/training/__init__.py 31 | transformermpc/training/trainer.py 32 | transformermpc/utils/__init__.py 33 | transformermpc/utils/metrics.py 34 | transformermpc/utils/osqp_wrapper.py 35 | transformermpc/utils/visualization.py -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Vrushabh Zinage, Ahmed Khalil, Efstathios Bakolas 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open("README.md", "r", encoding="utf-8") as fh: 4 | long_description = fh.read() 5 | 6 | with open("requirements.txt", "r", encoding="utf-8") as fh: 7 | requirements = fh.read().splitlines() 8 | 9 | setup( 10 | name="transformermpc", 11 | version="0.1.6", 12 | author="Vrushabh Zinage, Ahmed Khalil, Efstathios Bakolas", 13 | author_email="vrushabh.zinage@gmail.com", 14 | description="Accelerating Model Predictive Control via Neural Networks", 15 | long_description=long_description, 16 | long_description_content_type="text/markdown", 17 | url="https://github.com/Vrushabh27/transformermpc", 18 | packages=[ 19 | 'transformermpc', 20 | 'transformermpc.data', 21 | 'transformermpc.models', 22 | 'transformermpc.utils', 23 | 'transformermpc.training', 24 | 'transformermpc.demo', 25 | ], 26 | package_dir={'transformermpc': 'transformermpc'}, 27 | classifiers=[ 28 | "Programming Language :: Python :: 3", 29 | "License :: OSI Approved :: MIT License", 30 | "Operating System :: OS Independent", 31 | ], 32 | python_requires=">=3.7", 33 | install_requires=requirements, 34 | include_package_data=True, 35 | ) -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=42", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "transformermpc" 7 | version = "0.1.6" 8 | description = "Accelerating Model Predictive Control via Neural Networks" 9 | readme = "README.md" 10 | requires-python = ">=3.7" 11 | license = "MIT" 12 | authors = [ 13 | {name = "Vrushabh Zinage", email = "vrushabh.zinage@e.com"}, 14 | {name = "Ahmed Khalil"}, 15 | {name = "Efstathios Bakolas"} 16 | ] 17 | maintainers = [ 18 | {name = "Vrushabh Zinage", email = "vrushabh.zinage@e.com"} 19 | ] 20 | keywords = ["machine learning", "model predictive control", "quadratic programming"] 21 | classifiers = [ 22 | "Programming Language :: Python :: 3", 23 | "Operating System :: OS Independent", 24 | ] 25 | dependencies = [ 26 | "numpy>=1.20.0", 27 | "scipy>=1.7.0", 28 | "torch>=1.9.0", 29 | "osqp>=0.6.2", 30 | "matplotlib>=3.4.0", 31 | "pandas>=1.3.0", 32 | "tqdm>=4.62.0", 33 | "scikit-learn>=0.24.0", 34 | "tensorboard>=2.7.0", 35 | "quadprog>=0.1.11", 36 | ] 37 | 38 | [project.urls] 39 | "Homepage" = "https://github.com/vrushabh/transformermpc" 40 | "Bug Tracker" = "https://github.com/vrushabh/transformermpc/issues" 41 | 42 | [project.scripts] 43 | transformermpc-demo = "transformermpc.demo:main" 44 | 45 | [tool.setuptools] 46 | packages = [ 47 | "transformermpc", 48 | "transformermpc.data", 49 | "transformermpc.models", 50 | "transformermpc.utils", 51 | "transformermpc.training", 52 | "transformermpc.demo", 53 | ] 54 | 55 | [tool.setuptools.package-data] 56 | transformermpc = ["data/*.pkl", "data/*.npy", "models/*.pth", "models/*.pt", "demo/*.png", "demo/*.jpg"] 57 | 58 | [tool.black] 59 | line-length = 88 60 | target-version = ["py37", "py38", "py39", "py310"] 61 | 62 | [tool.isort] 63 | profile = "black" 64 | line_length = 88 65 | 66 | [tool.pytest] 67 | testpaths = ["tests"] 68 | python_files = "test_*.py" -------------------------------------------------------------------------------- /tests/test_package_structure.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import importlib 5 | import pkgutil 6 | import transformermpc 7 | 8 | def check_module_exists(module_name): 9 | """Check if a module exists without actually importing it.""" 10 | try: 11 | spec = importlib.util.find_spec(module_name) 12 | return spec is not None 13 | except (ModuleNotFoundError, AttributeError): 14 | return False 15 | 16 | def main(): 17 | """Verify the transformermpc package structure.""" 18 | print(f"Transformermpc version: {transformermpc.__version__}") 19 | print("\nPackage structure:") 20 | 21 | # Get the package path 22 | pkg_path = os.path.dirname(transformermpc.__file__) 23 | print(f"Package path: {pkg_path}") 24 | 25 | # List all subpackages and modules 26 | print("\nSubpackages and modules:") 27 | for _, name, is_pkg in pkgutil.iter_modules([pkg_path]): 28 | if is_pkg: 29 | print(f" - Subpackage: {name}") 30 | # Check subpackage modules 31 | subpkg_path = os.path.join(pkg_path, name) 32 | for _, submodule, _ in pkgutil.iter_modules([subpkg_path]): 33 | print(f" - Module: {name}.{submodule}") 34 | else: 35 | print(f" - Module: {name}") 36 | 37 | # Check expected module paths 38 | expected_modules = [ 39 | "transformermpc.data.dataset", 40 | "transformermpc.data.qp_generator", 41 | "transformermpc.models.constraint_predictor", 42 | "transformermpc.models.warm_start_predictor", 43 | "transformermpc.utils.metrics", 44 | "transformermpc.utils.osqp_wrapper", 45 | "transformermpc.utils.visualization", 46 | "transformermpc.training.trainer", 47 | "transformermpc.demo.demo" 48 | ] 49 | 50 | print("\nChecking key modules:") 51 | for module in expected_modules: 52 | exists = check_module_exists(module) 53 | print(f" - {module}: {'✓' if exists else '✗'}") 54 | 55 | print("\nPackage structure verification complete!") 56 | 57 | if __name__ == "__main__": 58 | main() -------------------------------------------------------------------------------- /tests/test_solve.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | from transformermpc import TransformerMPC 4 | 5 | def main(): 6 | print("Testing TransformerMPC solve method...") 7 | 8 | # Create a simple QP problem for testing 9 | n = 5 # dimension of the problem 10 | m = 3 # number of constraints 11 | 12 | # Create random matrices for the test problem 13 | np.random.seed(42) 14 | Q = np.random.rand(n, n) 15 | Q = Q.T @ Q # Make Q positive definite 16 | c = np.random.rand(n) 17 | A = np.random.rand(m, n) 18 | b = np.random.rand(m) 19 | 20 | # Create the TransformerMPC solver 21 | solver = TransformerMPC( 22 | use_constraint_predictor=False, 23 | use_warm_start_predictor=False 24 | ) 25 | 26 | # Test the solve method 27 | start_time = time.time() 28 | x, solve_time = solver.solve(Q=Q, c=c, A=A, b=b) 29 | total_time = time.time() - start_time 30 | 31 | print(f"Solution vector: {x}") 32 | print(f"Solver time reported: {solve_time:.6f} seconds") 33 | print(f"Total time taken: {total_time:.6f} seconds") 34 | 35 | # Verify the solution with a simple objective function calculation 36 | objective = 0.5 * x.T @ Q @ x + c.T @ x 37 | print(f"Objective function value: {objective:.6f}") 38 | 39 | # Test the baseline solver 40 | baseline_start_time = time.time() 41 | x_baseline, baseline_time = solver.solve_baseline(Q=Q, c=c, A=A, b=b) 42 | baseline_total_time = time.time() - baseline_start_time 43 | 44 | print("\nBaseline solver results:") 45 | print(f"Solution vector: {x_baseline}") 46 | print(f"Solver time reported: {baseline_time:.6f} seconds") 47 | print(f"Total time taken: {baseline_total_time:.6f} seconds") 48 | 49 | # Calculate objective function for baseline solution 50 | baseline_objective = 0.5 * x_baseline.T @ Q @ x_baseline + c.T @ x_baseline 51 | print(f"Objective function value: {baseline_objective:.6f}") 52 | 53 | # Check if solutions are similar 54 | solution_diff = np.linalg.norm(x - x_baseline) 55 | print(f"\nNorm of solution difference: {solution_diff:.6f}") 56 | 57 | if __name__ == "__main__": 58 | main() -------------------------------------------------------------------------------- /tests/test_files.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import importlib.util 5 | import sys 6 | 7 | def print_file_exists(module_path): 8 | """Check if a Python module file exists and print the result.""" 9 | exists = os.path.isfile(module_path) 10 | print(f"{module_path}: {'✓ Exists' if exists else '✗ Missing'}") 11 | return exists 12 | 13 | def main(): 14 | """ 15 | Check if key files in the transformermpc package exist 16 | without importing them. 17 | """ 18 | package_dir = os.path.dirname(os.path.abspath(__file__)) 19 | transformermpc_dir = os.path.join(package_dir, "transformermpc") 20 | 21 | print(f"\nChecking key files in {transformermpc_dir}...\n") 22 | 23 | # Check top-level files 24 | files_to_check = [ 25 | os.path.join(transformermpc_dir, "__init__.py"), 26 | ] 27 | 28 | # Check data module files 29 | data_dir = os.path.join(transformermpc_dir, "data") 30 | for file in ["__init__.py", "dataset.py", "qp_generator.py"]: 31 | files_to_check.append(os.path.join(data_dir, file)) 32 | 33 | # Check models module files 34 | models_dir = os.path.join(transformermpc_dir, "models") 35 | for file in ["__init__.py", "constraint_predictor.py", "warm_start_predictor.py"]: 36 | files_to_check.append(os.path.join(models_dir, file)) 37 | 38 | # Check utils module files 39 | utils_dir = os.path.join(transformermpc_dir, "utils") 40 | for file in ["__init__.py", "metrics.py", "osqp_wrapper.py", "visualization.py"]: 41 | files_to_check.append(os.path.join(utils_dir, file)) 42 | 43 | # Check training module files 44 | training_dir = os.path.join(transformermpc_dir, "training") 45 | for file in ["__init__.py", "trainer.py"]: 46 | files_to_check.append(os.path.join(training_dir, file)) 47 | 48 | # Check demo module files 49 | demo_dir = os.path.join(transformermpc_dir, "demo") 50 | for file in ["__init__.py", "demo.py"]: 51 | files_to_check.append(os.path.join(demo_dir, file)) 52 | 53 | # Check each file 54 | all_exist = True 55 | for file_path in files_to_check: 56 | exists = print_file_exists(file_path) 57 | all_exist = all_exist and exists 58 | 59 | # Print summary 60 | print("\nSummary:") 61 | if all_exist: 62 | print("✓ All key files found in the package!") 63 | else: 64 | print("✗ Some files are missing from the package.") 65 | 66 | return 0 if all_exist else 1 67 | 68 | if __name__ == "__main__": 69 | sys.exit(main()) -------------------------------------------------------------------------------- /tests/test_new_models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | Test script to verify the proper operation of the new vanilla transformer models. 5 | """ 6 | 7 | import torch 8 | import numpy as np 9 | 10 | # Fix the import paths 11 | from transformermpc.models.constraint_predictor import ConstraintPredictor 12 | from transformermpc.models.warm_start_predictor import WarmStartPredictor 13 | 14 | def test_constraint_predictor(): 15 | """Test the ConstraintPredictor model.""" 16 | print("Testing ConstraintPredictor...") 17 | 18 | # Create model instance 19 | model = ConstraintPredictor( 20 | input_dim=20, 21 | hidden_dim=64, 22 | num_constraints=10, 23 | num_layers=2, 24 | num_heads=4, 25 | dropout=0.1 26 | ) 27 | 28 | # Test forward pass with random data 29 | batch_size = 4 30 | x = torch.randn(batch_size, 20) 31 | output = model(x) 32 | 33 | print(f"Output shape: {output.shape}") 34 | assert output.shape == torch.Size([batch_size, 10]), f"Expected shape {torch.Size([batch_size, 10])}, got {output.shape}" 35 | 36 | # Test predict method 37 | predictions = model.predict(x) 38 | print(f"Prediction shape: {predictions.shape}") 39 | print(f"Sample prediction: {predictions[0, :5]}") 40 | 41 | print("ConstraintPredictor test passed!") 42 | 43 | def test_warm_start_predictor(): 44 | """Test the WarmStartPredictor model.""" 45 | print("Testing WarmStartPredictor...") 46 | 47 | # Create model instance 48 | model = WarmStartPredictor( 49 | input_dim=20, 50 | hidden_dim=64, 51 | output_dim=8, 52 | num_layers=2, 53 | num_heads=4, 54 | dropout=0.1 55 | ) 56 | 57 | # Test forward pass with random data 58 | batch_size = 4 59 | x = torch.randn(batch_size, 20) 60 | output = model(x) 61 | 62 | print(f"Output shape: {output.shape}") 63 | assert output.shape == torch.Size([batch_size, 8]), f"Expected shape {torch.Size([batch_size, 8])}, got {output.shape}" 64 | 65 | # Test predict method 66 | predictions = model.predict(x) 67 | print(f"Prediction shape: {predictions.shape}") 68 | print(f"Sample prediction: {predictions[0, :5]}") 69 | 70 | print("WarmStartPredictor test passed!") 71 | 72 | def main(): 73 | """Run the tests.""" 74 | print("Starting tests for the new transformer models...") 75 | 76 | test_constraint_predictor() 77 | print("\n") 78 | test_warm_start_predictor() 79 | 80 | print("\nAll tests completed successfully!") 81 | 82 | if __name__ == "__main__": 83 | main() -------------------------------------------------------------------------------- /scripts/verify_structure.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Verify the reorganized TransformerMPC package structure. 4 | 5 | This script checks that the package structure follows standard Python 6 | package conventions and that all key modules are present. 7 | """ 8 | 9 | import os 10 | import sys 11 | 12 | def check_dir_exists(path, name): 13 | """Check if a directory exists and print the result.""" 14 | exists = os.path.isdir(path) 15 | print(f"{name} directory: {'✓' if exists else '✗'}") 16 | return exists 17 | 18 | def check_file_exists(path, name): 19 | """Check if a file exists and print the result.""" 20 | exists = os.path.isfile(path) 21 | print(f"{name} file: {'✓' if exists else '✗'}") 22 | return exists 23 | 24 | def main(): 25 | """Verify the package structure.""" 26 | # Get project root directory (where this script is located) 27 | script_dir = os.path.dirname(os.path.abspath(__file__)) 28 | project_root = os.path.dirname(script_dir) 29 | 30 | print("\n" + "="*60) 31 | print(" TransformerMPC Package Structure Verification") 32 | print("="*60) 33 | 34 | # Check top-level directories 35 | print("\nChecking top-level directories:") 36 | dirs_ok = True 37 | for dirname in ["transformermpc", "tests", "scripts", "examples"]: 38 | path = os.path.join(project_root, dirname) 39 | dirs_ok = check_dir_exists(path, dirname) and dirs_ok 40 | 41 | # Check top-level files 42 | print("\nChecking top-level files:") 43 | files_ok = True 44 | for filename in ["setup.py", "pyproject.toml", "requirements.txt", "MANIFEST.in", "LICENSE", "README.md"]: 45 | path = os.path.join(project_root, filename) 46 | files_ok = check_file_exists(path, filename) and files_ok 47 | 48 | # Check package structure 49 | print("\nChecking package structure:") 50 | pkg_dir = os.path.join(project_root, "transformermpc") 51 | pkg_ok = True 52 | 53 | # Check subdirectories 54 | for dirname in ["data", "models", "utils", "training", "demo"]: 55 | path = os.path.join(pkg_dir, dirname) 56 | pkg_ok = check_dir_exists(path, f"transformermpc/{dirname}") and pkg_ok 57 | 58 | # Check key files 59 | pkg_ok = check_file_exists(os.path.join(pkg_dir, "__init__.py"), "transformermpc/__init__.py") and pkg_ok 60 | 61 | # Check for key implementation files 62 | key_files = [ 63 | ("transformermpc/models/constraint_predictor.py", "Constraint Predictor"), 64 | ("transformermpc/models/warm_start_predictor.py", "Warm Start Predictor"), 65 | ("transformermpc/data/dataset.py", "Dataset"), 66 | ("transformermpc/data/qp_generator.py", "QP Generator"), 67 | ("transformermpc/utils/metrics.py", "Metrics"), 68 | ("transformermpc/utils/osqp_wrapper.py", "OSQP Wrapper"), 69 | ("transformermpc/demo/demo.py", "Demo") 70 | ] 71 | 72 | print("\nChecking key implementation files:") 73 | for filepath, desc in key_files: 74 | path = os.path.join(project_root, filepath) 75 | pkg_ok = check_file_exists(path, desc) and pkg_ok 76 | 77 | # Print summary 78 | print("\n" + "="*60) 79 | if dirs_ok and files_ok and pkg_ok: 80 | print("✓ Package structure verification PASSED!") 81 | print(" The TransformerMPC package has been properly reorganized.") 82 | else: 83 | print("✗ Package structure verification FAILED!") 84 | print(" Some expected files or directories are missing.") 85 | print("="*60 + "\n") 86 | 87 | return 0 if dirs_ok and files_ok and pkg_ok else 1 88 | 89 | if __name__ == "__main__": 90 | sys.exit(main()) -------------------------------------------------------------------------------- /tests/scripts/verify_structure.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Verify the reorganized TransformerMPC package structure. 4 | 5 | This script checks that the package structure follows standard Python 6 | package conventions and that all key modules are present. 7 | """ 8 | 9 | import os 10 | import sys 11 | 12 | def check_dir_exists(path, name): 13 | """Check if a directory exists and print the result.""" 14 | exists = os.path.isdir(path) 15 | print(f"{name} directory: {'✓' if exists else '✗'}") 16 | return exists 17 | 18 | def check_file_exists(path, name): 19 | """Check if a file exists and print the result.""" 20 | exists = os.path.isfile(path) 21 | print(f"{name} file: {'✓' if exists else '✗'}") 22 | return exists 23 | 24 | def main(): 25 | """Verify the package structure.""" 26 | # Get project root directory (where this script is located) 27 | script_dir = os.path.dirname(os.path.abspath(__file__)) 28 | project_root = os.path.dirname(script_dir) 29 | 30 | print("\n" + "="*60) 31 | print(" TransformerMPC Package Structure Verification") 32 | print("="*60) 33 | 34 | # Check top-level directories 35 | print("\nChecking top-level directories:") 36 | dirs_ok = True 37 | for dirname in ["transformermpc", "tests", "scripts", "examples"]: 38 | path = os.path.join(project_root, dirname) 39 | dirs_ok = check_dir_exists(path, dirname) and dirs_ok 40 | 41 | # Check top-level files 42 | print("\nChecking top-level files:") 43 | files_ok = True 44 | for filename in ["setup.py", "pyproject.toml", "requirements.txt", "MANIFEST.in", "LICENSE", "README.md"]: 45 | path = os.path.join(project_root, filename) 46 | files_ok = check_file_exists(path, filename) and files_ok 47 | 48 | # Check package structure 49 | print("\nChecking package structure:") 50 | pkg_dir = os.path.join(project_root, "transformermpc") 51 | pkg_ok = True 52 | 53 | # Check subdirectories 54 | for dirname in ["data", "models", "utils", "training", "demo"]: 55 | path = os.path.join(pkg_dir, dirname) 56 | pkg_ok = check_dir_exists(path, f"transformermpc/{dirname}") and pkg_ok 57 | 58 | # Check key files 59 | pkg_ok = check_file_exists(os.path.join(pkg_dir, "__init__.py"), "transformermpc/__init__.py") and pkg_ok 60 | 61 | # Check for key implementation files 62 | key_files = [ 63 | ("transformermpc/models/constraint_predictor.py", "Constraint Predictor"), 64 | ("transformermpc/models/warm_start_predictor.py", "Warm Start Predictor"), 65 | ("transformermpc/data/dataset.py", "Dataset"), 66 | ("transformermpc/data/qp_generator.py", "QP Generator"), 67 | ("transformermpc/utils/metrics.py", "Metrics"), 68 | ("transformermpc/utils/osqp_wrapper.py", "OSQP Wrapper"), 69 | ("transformermpc/demo/demo.py", "Demo") 70 | ] 71 | 72 | print("\nChecking key implementation files:") 73 | for filepath, desc in key_files: 74 | path = os.path.join(project_root, filepath) 75 | pkg_ok = check_file_exists(path, desc) and pkg_ok 76 | 77 | # Print summary 78 | print("\n" + "="*60) 79 | if dirs_ok and files_ok and pkg_ok: 80 | print("✓ Package structure verification PASSED!") 81 | print(" The TransformerMPC package has been properly reorganized.") 82 | else: 83 | print("✗ Package structure verification FAILED!") 84 | print(" Some expected files or directories are missing.") 85 | print("="*60 + "\n") 86 | 87 | return 0 if dirs_ok and files_ok and pkg_ok else 1 88 | 89 | if __name__ == "__main__": 90 | sys.exit(main()) -------------------------------------------------------------------------------- /examples/example_usage.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example usage of TransformerMPC library. 3 | 4 | This script demonstrates how to use the TransformerMPC package to accelerate 5 | quadratic programming (QP) solvers for Model Predictive Control (MPC). 6 | """ 7 | 8 | import torch 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | from pathlib import Path 12 | 13 | # Import TransformerMPC modules 14 | from transformermpc.data.qp_generator import QPGenerator 15 | from transformermpc.data.dataset import QPDataset 16 | from transformermpc.models.constraint_predictor import ConstraintPredictor 17 | from transformermpc.models.warm_start_predictor import WarmStartPredictor 18 | from transformermpc.trainer.model_trainer import ModelTrainer 19 | from transformermpc.utils.osqp_wrapper import OSQPSolver 20 | from transformermpc.utils.metrics import compute_solve_time_metrics 21 | from transformermpc.utils.visualization import plot_solve_time_comparison 22 | 23 | def main(): 24 | """Example of training and using TransformerMPC.""" 25 | print("TransformerMPC Example Usage") 26 | 27 | # 1. Generate QP problems 28 | print("\n1. Generating QP problems...") 29 | state_dim = 4 30 | input_dim = 2 31 | horizon = 10 32 | num_samples = 500 # Small number for quick example 33 | 34 | qp_generator = QPGenerator( 35 | state_dim=state_dim, 36 | input_dim=input_dim, 37 | horizon=horizon, 38 | seed=42 39 | ) 40 | 41 | qp_problems = qp_generator.generate_batch(num_samples) 42 | print(f"Generated {len(qp_problems)} QP problems") 43 | 44 | # 2. Create dataset 45 | print("\n2. Creating dataset...") 46 | dataset = QPDataset( 47 | qp_problems=qp_problems, 48 | precompute_solutions=True, 49 | feature_normalization=True 50 | ) 51 | 52 | train_dataset, val_dataset = dataset.split(test_size=0.2) 53 | print(f"Training dataset size: {len(train_dataset)}") 54 | print(f"Validation dataset size: {len(val_dataset)}") 55 | 56 | # 3. Create models 57 | print("\n3. Creating models...") 58 | # Get dimensions from a sample 59 | sample = train_dataset[0] 60 | feature_dim = sample['features'].shape[0] 61 | num_constraints = sample['active_constraints'].shape[0] 62 | solution_dim = sample['solution'].shape[0] 63 | 64 | # Create constraint predictor 65 | cp_model = ConstraintPredictor( 66 | input_dim=feature_dim, 67 | hidden_dim=128, 68 | num_constraints=num_constraints 69 | ) 70 | 71 | # Create warm start predictor 72 | ws_model = WarmStartPredictor( 73 | input_dim=feature_dim, 74 | hidden_dim=256, 75 | output_dim=solution_dim 76 | ) 77 | 78 | # 4. Train models 79 | print("\n4. Training models...") 80 | # Train constraint predictor 81 | cp_trainer = ModelTrainer( 82 | model=cp_model, 83 | train_dataset=train_dataset, 84 | val_dataset=val_dataset, 85 | target_key='active_constraints', 86 | output_dir=Path('example_output/models') 87 | ) 88 | 89 | cp_metrics = cp_trainer.train( 90 | num_epochs=10, # Small number for quick example 91 | batch_size=32, 92 | learning_rate=1e-3 93 | ) 94 | 95 | # Train warm start predictor 96 | ws_trainer = ModelTrainer( 97 | model=ws_model, 98 | train_dataset=train_dataset, 99 | val_dataset=val_dataset, 100 | target_key='solution', 101 | output_dir=Path('example_output/models') 102 | ) 103 | 104 | ws_metrics = ws_trainer.train( 105 | num_epochs=10, # Small number for quick example 106 | batch_size=32, 107 | learning_rate=1e-3 108 | ) 109 | 110 | # 5. Evaluate performance 111 | print("\n5. Evaluating performance...") 112 | solver = OSQPSolver() 113 | 114 | # List to store timing results 115 | baseline_times = [] 116 | transformer_times = [] 117 | 118 | # Sample a few problems for demonstration 119 | test_indices = np.random.choice(len(val_dataset), size=10, replace=False) 120 | 121 | for idx in test_indices: 122 | # Get problem and features 123 | sample = val_dataset[idx] 124 | problem = val_dataset.get_problem(idx) 125 | features = sample['features'] 126 | 127 | # Predict active constraints and warm start 128 | with torch.no_grad(): 129 | pred_active = cp_model(features.unsqueeze(0)).squeeze(0) > 0.5 130 | pred_solution = ws_model(features.unsqueeze(0)).squeeze(0) 131 | 132 | # Baseline solve time (standard OSQP) 133 | _, baseline_time = solver.solve_with_time( 134 | Q=problem.Q, 135 | c=problem.c, 136 | A=problem.A, 137 | b=problem.b 138 | ) 139 | 140 | # TransformerMPC solve time 141 | _, transformer_time, _ = solver.solve_pipeline( 142 | Q=problem.Q, 143 | c=problem.c, 144 | A=problem.A, 145 | b=problem.b, 146 | active_constraints=pred_active.numpy(), 147 | warm_start=pred_solution.numpy(), 148 | fallback_on_violation=True 149 | ) 150 | 151 | baseline_times.append(baseline_time) 152 | transformer_times.append(transformer_time) 153 | 154 | # Calculate metrics 155 | baseline_times = np.array(baseline_times) 156 | transformer_times = np.array(transformer_times) 157 | 158 | metrics = compute_solve_time_metrics(baseline_times, transformer_times) 159 | 160 | print("\nPerformance Summary:") 161 | print(f"Mean baseline time: {metrics['mean_baseline_time']:.6f}s") 162 | print(f"Mean transformer time: {metrics['mean_transformer_time']:.6f}s") 163 | print(f"Mean speedup: {metrics['mean_speedup']:.2f}x") 164 | print(f"Median speedup: {metrics['median_speedup']:.2f}x") 165 | 166 | # Optional: Plot results 167 | output_dir = Path('example_output/results') 168 | output_dir.mkdir(parents=True, exist_ok=True) 169 | 170 | plot_solve_time_comparison( 171 | baseline_times=baseline_times, 172 | transformer_times=transformer_times, 173 | save_path=output_dir / 'solve_time_comparison.png' 174 | ) 175 | 176 | print(f"\nResults saved to {output_dir}") 177 | print("\nExample completed successfully!") 178 | 179 | if __name__ == "__main__": 180 | main() -------------------------------------------------------------------------------- /tests/examples/example_usage.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example usage of TransformerMPC library. 3 | 4 | This script demonstrates how to use the TransformerMPC package to accelerate 5 | quadratic programming (QP) solvers for Model Predictive Control (MPC). 6 | """ 7 | 8 | import torch 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | from pathlib import Path 12 | 13 | # Import TransformerMPC modules 14 | from transformermpc.data.qp_generator import QPGenerator 15 | from transformermpc.data.dataset import QPDataset 16 | from transformermpc.models.constraint_predictor import ConstraintPredictor 17 | from transformermpc.models.warm_start_predictor import WarmStartPredictor 18 | from transformermpc.trainer.model_trainer import ModelTrainer 19 | from transformermpc.utils.osqp_wrapper import OSQPSolver 20 | from transformermpc.utils.metrics import compute_solve_time_metrics 21 | from transformermpc.utils.visualization import plot_solve_time_comparison 22 | 23 | def main(): 24 | """Example of training and using TransformerMPC.""" 25 | print("TransformerMPC Example Usage") 26 | 27 | # 1. Generate QP problems 28 | print("\n1. Generating QP problems...") 29 | state_dim = 4 30 | input_dim = 2 31 | horizon = 10 32 | num_samples = 500 # Small number for quick example 33 | 34 | qp_generator = QPGenerator( 35 | state_dim=state_dim, 36 | input_dim=input_dim, 37 | horizon=horizon, 38 | seed=42 39 | ) 40 | 41 | qp_problems = qp_generator.generate_batch(num_samples) 42 | print(f"Generated {len(qp_problems)} QP problems") 43 | 44 | # 2. Create dataset 45 | print("\n2. Creating dataset...") 46 | dataset = QPDataset( 47 | qp_problems=qp_problems, 48 | precompute_solutions=True, 49 | feature_normalization=True 50 | ) 51 | 52 | train_dataset, val_dataset = dataset.split(test_size=0.2) 53 | print(f"Training dataset size: {len(train_dataset)}") 54 | print(f"Validation dataset size: {len(val_dataset)}") 55 | 56 | # 3. Create models 57 | print("\n3. Creating models...") 58 | # Get dimensions from a sample 59 | sample = train_dataset[0] 60 | feature_dim = sample['features'].shape[0] 61 | num_constraints = sample['active_constraints'].shape[0] 62 | solution_dim = sample['solution'].shape[0] 63 | 64 | # Create constraint predictor 65 | cp_model = ConstraintPredictor( 66 | input_dim=feature_dim, 67 | hidden_dim=128, 68 | num_constraints=num_constraints 69 | ) 70 | 71 | # Create warm start predictor 72 | ws_model = WarmStartPredictor( 73 | input_dim=feature_dim, 74 | hidden_dim=256, 75 | output_dim=solution_dim 76 | ) 77 | 78 | # 4. Train models 79 | print("\n4. Training models...") 80 | # Train constraint predictor 81 | cp_trainer = ModelTrainer( 82 | model=cp_model, 83 | train_dataset=train_dataset, 84 | val_dataset=val_dataset, 85 | target_key='active_constraints', 86 | output_dir=Path('example_output/models') 87 | ) 88 | 89 | cp_metrics = cp_trainer.train( 90 | num_epochs=10, # Small number for quick example 91 | batch_size=32, 92 | learning_rate=1e-3 93 | ) 94 | 95 | # Train warm start predictor 96 | ws_trainer = ModelTrainer( 97 | model=ws_model, 98 | train_dataset=train_dataset, 99 | val_dataset=val_dataset, 100 | target_key='solution', 101 | output_dir=Path('example_output/models') 102 | ) 103 | 104 | ws_metrics = ws_trainer.train( 105 | num_epochs=10, # Small number for quick example 106 | batch_size=32, 107 | learning_rate=1e-3 108 | ) 109 | 110 | # 5. Evaluate performance 111 | print("\n5. Evaluating performance...") 112 | solver = OSQPSolver() 113 | 114 | # List to store timing results 115 | baseline_times = [] 116 | transformer_times = [] 117 | 118 | # Sample a few problems for demonstration 119 | test_indices = np.random.choice(len(val_dataset), size=10, replace=False) 120 | 121 | for idx in test_indices: 122 | # Get problem and features 123 | sample = val_dataset[idx] 124 | problem = val_dataset.get_problem(idx) 125 | features = sample['features'] 126 | 127 | # Predict active constraints and warm start 128 | with torch.no_grad(): 129 | pred_active = cp_model(features.unsqueeze(0)).squeeze(0) > 0.5 130 | pred_solution = ws_model(features.unsqueeze(0)).squeeze(0) 131 | 132 | # Baseline solve time (standard OSQP) 133 | _, baseline_time = solver.solve_with_time( 134 | Q=problem.Q, 135 | c=problem.c, 136 | A=problem.A, 137 | b=problem.b 138 | ) 139 | 140 | # TransformerMPC solve time 141 | _, transformer_time, _ = solver.solve_pipeline( 142 | Q=problem.Q, 143 | c=problem.c, 144 | A=problem.A, 145 | b=problem.b, 146 | active_constraints=pred_active.numpy(), 147 | warm_start=pred_solution.numpy(), 148 | fallback_on_violation=True 149 | ) 150 | 151 | baseline_times.append(baseline_time) 152 | transformer_times.append(transformer_time) 153 | 154 | # Calculate metrics 155 | baseline_times = np.array(baseline_times) 156 | transformer_times = np.array(transformer_times) 157 | 158 | metrics = compute_solve_time_metrics(baseline_times, transformer_times) 159 | 160 | print("\nPerformance Summary:") 161 | print(f"Mean baseline time: {metrics['mean_baseline_time']:.6f}s") 162 | print(f"Mean transformer time: {metrics['mean_transformer_time']:.6f}s") 163 | print(f"Mean speedup: {metrics['mean_speedup']:.2f}x") 164 | print(f"Median speedup: {metrics['median_speedup']:.2f}x") 165 | 166 | # Optional: Plot results 167 | output_dir = Path('example_output/results') 168 | output_dir.mkdir(parents=True, exist_ok=True) 169 | 170 | plot_solve_time_comparison( 171 | baseline_times=baseline_times, 172 | transformer_times=transformer_times, 173 | save_path=output_dir / 'solve_time_comparison.png' 174 | ) 175 | 176 | print(f"\nResults saved to {output_dir}") 177 | print("\nExample completed successfully!") 178 | 179 | if __name__ == "__main__": 180 | main() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

TransformerMPC: Accelerating Model Predictive Control via Transformers [ICRA '25]

4 |

5 | Vrushabh Zinage1 6 | · 7 | Ahmed Khalil1 8 | · 9 | Efstathios Bakolas1 10 | 11 |

12 | 13 | 14 |

15 | 1University of Texas at Austin 16 |

17 |

18 | 19 | [![arXiv](https://img.shields.io/badge/arXiv-2409.09573-blue?logo=arxiv&color=%23B31B1B)](https://arxiv.org/abs/2409.09266) [![ProjectPage](https://img.shields.io/badge/Project_Page-TransformerMPC-blue)](https://transformer-mpc.github.io/) 20 |
21 |

22 | 23 | ## Overview 24 | 25 | TransformerMPC improves the computational efficiency of Model Predictive Control (MPC) problems using transformer based NN models. It employs the following two prediction models: 26 | 27 | 1. **Constraint Predictor**: Identifies inactive constraints in MPC formulations 28 | 2. **Warm Start Predictor**: Generates better initial points for MPC solvers 29 | 30 | By combining these models, TransformerMPC significantly reduces computation time while maintaining solution quality. 31 | 32 | ## Package Structure 33 | 34 | The package is organized with a standard Python package structure: 35 | 36 | ``` 37 | transformermpc/ 38 | ├── transformermpc/ # Core package module 39 | │ ├── data/ # Data generation utilities 40 | │ ├── models/ # Model implementations 41 | │ ├── utils/ # Utility functions and metrics 42 | │ ├── training/ # Training infrastructure 43 | │ └── demo/ # Demo scripts 44 | ├── scripts/ # Demo and utility scripts 45 | ├── tests/ # Testing infrastructure 46 | ├── setup.py # Package installation script 47 | └── requirements.txt # Dependencies 48 | ``` 49 | 50 | ## Installation 51 | 52 | Install directly from PyPI: 53 | 54 | ```bash 55 | pip install transformermpc 56 | ``` 57 | 58 | Or install from source: 59 | 60 | ```bash 61 | git clone https://github.com/Vrushabh27/transformermpc.git 62 | cd transformermpc 63 | pip install -e . 64 | ``` 65 | 66 | ## Dependencies 67 | 68 | - Python >= 3.7 69 | - PyTorch >= 1.9.0 70 | - OSQP >= 0.6.2 71 | - NumPy, SciPy, and other scientific computing libraries 72 | - Additional dependencies specified in requirements.txt 73 | 74 | ## Running the Demos 75 | 76 | The package includes several demo scripts to showcase its capabilities: 77 | 78 | ### Boxplot Demo (Recommended) 79 | 80 | ```bash 81 | python scripts/boxplot_demo.py 82 | ``` 83 | 84 | This script provides a visual comparison of different QP solving strategies using randomly generated problems, without requiring model training. It demonstrates the core concepts behind TransformerMPC by showing the performance impact of: 85 | - Removing inactive constraints 86 | - Using warm starts with different qualities 87 | - Combining these strategies 88 | 89 | The visualizations include boxplots, violin plots, and bar charts comparing solve times across different strategies. 90 | 91 | ### Simple Demo 92 | 93 | ```bash 94 | python scripts/simple_demo.py 95 | ``` 96 | 97 | This script demonstrates the complete pipeline: generating QP problems, training models, and evaluating performance. After completion, it saves performance comparison plots in the `demo_results/results` directory. 98 | 99 | ### Customizing Demo Parameters 100 | 101 | You can customize the boxplot demo by modifying parameters: 102 | 103 | ```bash 104 | # Generate more problems with different dimensions 105 | python scripts/boxplot_demo.py --num_samples 50 --state_dim 6 --input_dim 3 --horizon 10 106 | 107 | # Save results to a custom directory 108 | python scripts/boxplot_demo.py --output_dir my_results 109 | ``` 110 | 111 | Similarly, for the simple demo: 112 | 113 | ```bash 114 | # Generate QP problems with specific parameters 115 | python scripts/simple_demo.py --num_samples 200 --state_dim 6 --input_dim 3 --horizon 10 116 | 117 | # Customize training parameters 118 | python scripts/simple_demo.py --epochs 20 --batch_size 32 119 | 120 | # Use GPU for training if available 121 | python scripts/simple_demo.py --use_gpu 122 | ``` 123 | 124 | ## Usage in Projects 125 | 126 | ### Basic Example 127 | 128 | ```python 129 | from transformermpc import TransformerMPC 130 | import numpy as np 131 | 132 | # Define your QP problem parameters 133 | Q = np.array([[4.0, 1.0], [1.0, 2.0]]) 134 | c = np.array([1.0, 1.0]) 135 | A = np.array([[-1.0, 0.0], [0.0, -1.0], [-1.0, -1.0], [1.0, 1.0]]) 136 | b = np.array([0.0, 0.0, -1.0, 2.0]) 137 | 138 | # Initialize the TransformerMPC solver 139 | solver = TransformerMPC() 140 | 141 | # Solve with model acceleration 142 | solution, solve_time = solver.solve(Q, c, A, b) 143 | 144 | print(f"Solution: {solution}") 145 | print(f"Solve time: {solve_time} seconds") 146 | ``` 147 | 148 | ### General Usage 149 | 150 | ```python 151 | from transformermpc import TransformerMPC, QPProblem 152 | import numpy as np 153 | 154 | # Define your QP problem parameters 155 | Q = np.array([[4.0, 1.0], [1.0, 2.0]]) 156 | c = np.array([1.0, 1.0]) 157 | A = np.array([[-1.0, 0.0], [0.0, -1.0], [-1.0, -1.0], [1.0, 1.0]]) 158 | b = np.array([0.0, 0.0, -1.0, 2.0]) 159 | initial_state = np.array([0.5, 0.5]) # Optional: initial state for MPC problems 160 | 161 | # Create a QP problem instance 162 | qp_problem = QPProblem( 163 | Q=Q, 164 | c=c, 165 | A=A, 166 | b=b, 167 | initial_state=initial_state # Optional 168 | ) 169 | 170 | # Initialize with custom settings 171 | solver = TransformerMPC( 172 | use_constraint_predictor=True, 173 | use_warm_start_predictor=True, 174 | fallback_on_violation=True 175 | ) 176 | 177 | # Solve the problem 178 | solution, solve_time = solver.solve(qp_problem=qp_problem) 179 | print(f"Solution: {solution}") 180 | print(f"Solve time: {solve_time} seconds") 181 | 182 | # Compare with baseline 183 | baseline_solution, baseline_time = solver.solve_baseline(qp_problem=qp_problem) 184 | print(f"Baseline time: {baseline_time} seconds") 185 | ``` 186 | ## If you find our work useful, please cite us 187 | ``` 188 | @article{zinage2024transformermpc, 189 | title={TransformerMPC: Accelerating Model Predictive Control via Transformers}, 190 | author={Zinage, Vrushabh and Khalil, Ahmed and Bakolas, Efstathios}, 191 | journal={arXiv preprint arXiv:2409.09266}, 192 | year={2024} 193 | } 194 | ``` 195 | 196 | ## License 197 | 198 | This project is licensed under the MIT License. 199 | -------------------------------------------------------------------------------- /transformermpc/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | TransformerMPC: Accelerating Model Predictive Control via Neural Networks 3 | ==================================================================== 4 | 5 | TransformerMPC is a Python package that enhances the efficiency of solving 6 | Quadratic Programming (QP) problems in Model Predictive Control (MPC) 7 | using neural network models. 8 | 9 | Authors: Vrushabh Zinage, Ahmed Khalil, Efstathios Bakolas 10 | """ 11 | 12 | __version__ = "0.1.6" 13 | 14 | # Import core components for easy access 15 | from .models.constraint_predictor import ConstraintPredictor 16 | from .models.warm_start_predictor import WarmStartPredictor 17 | from .data.qp_generator import QPGenerator, QPProblem 18 | from .data.dataset import QPDataset 19 | from .utils.osqp_wrapper import OSQPSolver 20 | from .training.trainer import ModelTrainer 21 | 22 | # Define the main solver class 23 | class TransformerMPC: 24 | """ 25 | TransformerMPC solver class that accelerates QP solving using transformer models. 26 | 27 | This class combines two transformer models: 28 | 1. Constraint Predictor: Identifies inactive constraints 29 | 2. Warm Start Predictor: Generates better initial points 30 | 31 | The combined pipeline significantly reduces computation time while maintaining solution quality. 32 | """ 33 | 34 | def __init__(self, 35 | use_constraint_predictor=True, 36 | use_warm_start_predictor=True, 37 | constraint_model_path=None, 38 | warm_start_model_path=None, 39 | fallback_on_violation=True): 40 | """ 41 | Initialize the TransformerMPC solver. 42 | 43 | Parameters: 44 | ----------- 45 | use_constraint_predictor : bool 46 | Whether to use the constraint predictor model 47 | use_warm_start_predictor : bool 48 | Whether to use the warm start predictor model 49 | constraint_model_path : str or None 50 | Path to a pre-trained constraint predictor model 51 | warm_start_model_path : str or None 52 | Path to a pre-trained warm start predictor model 53 | fallback_on_violation : bool 54 | Whether to fallback to full QP if constraints are violated 55 | """ 56 | self.use_constraint_predictor = use_constraint_predictor 57 | self.use_warm_start_predictor = use_warm_start_predictor 58 | self.fallback_on_violation = fallback_on_violation 59 | 60 | # Initialize models if paths are provided 61 | if use_constraint_predictor: 62 | self.constraint_predictor = ConstraintPredictor.load(constraint_model_path) 63 | else: 64 | self.constraint_predictor = None 65 | 66 | if use_warm_start_predictor: 67 | self.warm_start_predictor = WarmStartPredictor.load(warm_start_model_path) 68 | else: 69 | self.warm_start_predictor = None 70 | 71 | # Initialize the OSQP solver 72 | self.solver = OSQPSolver() 73 | 74 | def solve(self, Q=None, c=None, A=None, b=None, qp_problem=None): 75 | """ 76 | Solve a QP problem using the transformer-enhanced pipeline. 77 | 78 | Parameters: 79 | ----------- 80 | Q : numpy.ndarray or None 81 | Quadratic cost matrix 82 | c : numpy.ndarray or None 83 | Linear cost vector 84 | A : numpy.ndarray or None 85 | Constraint matrix 86 | b : numpy.ndarray or None 87 | Constraint vector 88 | qp_problem : QPProblem or None 89 | QPProblem instance (alternative to specifying Q, c, A, b) 90 | 91 | Returns: 92 | -------- 93 | solution : numpy.ndarray 94 | Optimal solution vector 95 | solve_time : float 96 | Computation time in seconds 97 | """ 98 | import numpy as np 99 | import time 100 | 101 | # If qp_problem is provided, extract matrices from it 102 | if qp_problem is not None: 103 | Q = qp_problem.Q 104 | c = qp_problem.c 105 | A = qp_problem.A 106 | b = qp_problem.b 107 | 108 | # Check if all required matrices are provided 109 | if Q is None or c is None: 110 | raise ValueError("Q and c matrices must be provided") 111 | 112 | # Extract features for the QP problem 113 | n_vars = Q.shape[0] 114 | features = np.concatenate([ 115 | Q.flatten(), 116 | c, 117 | A.flatten() if A is not None else np.zeros(0), 118 | b if b is not None else np.zeros(0) 119 | ]) 120 | 121 | # Default values if no transformers are used 122 | active_constraints = np.ones(A.shape[0]) if A is not None else None 123 | warm_start = None 124 | 125 | # Predict active constraints if constraint predictor is enabled 126 | if self.use_constraint_predictor and self.constraint_predictor is not None and A is not None: 127 | try: 128 | active_constraints = self.constraint_predictor.predict(features)[0] 129 | except Exception as e: 130 | print(f"Warning: Constraint prediction failed: {e}") 131 | active_constraints = np.ones(A.shape[0]) 132 | 133 | # Predict warm start if warm start predictor is enabled 134 | if self.use_warm_start_predictor and self.warm_start_predictor is not None: 135 | try: 136 | warm_start = self.warm_start_predictor.predict(features)[0] 137 | except Exception as e: 138 | print(f"Warning: Warm start prediction failed: {e}") 139 | warm_start = None 140 | 141 | # Solve the QP problem with the OSQP solver 142 | start_time = time.time() 143 | 144 | if A is not None and self.use_constraint_predictor: 145 | # Use transformer-enhanced pipeline 146 | solution, _, used_fallback = self.solver.solve_pipeline( 147 | Q=Q, 148 | c=c, 149 | A=A, 150 | b=b, 151 | active_constraints=active_constraints, 152 | warm_start=warm_start, 153 | fallback_on_violation=self.fallback_on_violation 154 | ) 155 | else: 156 | # Use standard OSQP solver 157 | solution = self.solver.solve( 158 | Q=Q, 159 | c=c, 160 | A=A, 161 | b=b, 162 | warm_start=warm_start 163 | ) 164 | 165 | solve_time = time.time() - start_time 166 | 167 | return solution, solve_time 168 | 169 | def solve_baseline(self, Q=None, c=None, A=None, b=None, qp_problem=None): 170 | """ 171 | Solve a QP problem using standard OSQP without transformer enhancements. 172 | 173 | Parameters and returns same as solve(). 174 | """ 175 | import time 176 | 177 | # If qp_problem is provided, extract matrices from it 178 | if qp_problem is not None: 179 | Q = qp_problem.Q 180 | c = qp_problem.c 181 | A = qp_problem.A 182 | b = qp_problem.b 183 | 184 | # Check if all required matrices are provided 185 | if Q is None or c is None: 186 | raise ValueError("Q and c matrices must be provided") 187 | 188 | # Solve using standard OSQP 189 | start_time = time.time() 190 | solution, _ = self.solver.solve_with_time(Q, c, A, b) 191 | solve_time = time.time() - start_time 192 | 193 | return solution, solve_time 194 | -------------------------------------------------------------------------------- /transformermpc.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.4 2 | Name: transformermpc 3 | Version: 0.1.6 4 | Summary: Accelerating Model Predictive Control via Neural Networks 5 | Home-page: https://github.com/Vrushabh27/transformermpc 6 | Author: Ahmed Khalil, Efstathios Bakolas 7 | Author-email: Vrushabh Zinage 8 | Maintainer-email: Vrushabh Zinage 9 | License-Expression: MIT 10 | Project-URL: Homepage, https://github.com/vrushabh/transformermpc 11 | Project-URL: Bug Tracker, https://github.com/vrushabh/transformermpc/issues 12 | Keywords: machine learning,model predictive control,quadratic programming 13 | Classifier: Programming Language :: Python :: 3 14 | Classifier: Operating System :: OS Independent 15 | Requires-Python: >=3.7 16 | Description-Content-Type: text/markdown 17 | License-File: LICENSE 18 | Requires-Dist: numpy>=1.20.0 19 | Requires-Dist: scipy>=1.7.0 20 | Requires-Dist: torch>=1.9.0 21 | Requires-Dist: osqp>=0.6.2 22 | Requires-Dist: matplotlib>=3.4.0 23 | Requires-Dist: pandas>=1.3.0 24 | Requires-Dist: tqdm>=4.62.0 25 | Requires-Dist: scikit-learn>=0.24.0 26 | Requires-Dist: tensorboard>=2.7.0 27 | Requires-Dist: quadprog>=0.1.11 28 | Dynamic: home-page 29 | Dynamic: license-file 30 | Dynamic: requires-python 31 | 32 |

33 | 34 |

TransformerMPC: Accelerating Model Predictive Control via Transformers [ICRA '25]

35 |

36 | Vrushabh Zinage1 37 | · 38 | Ahmed Khalil1 39 | · 40 | Efstathios Bakolas1 41 | 42 |

43 | 44 | 45 |

46 | 1University of Texas at Austin 47 |

48 |

49 | 50 | [![arXiv](https://img.shields.io/badge/arXiv-2409.09573-blue?logo=arxiv&color=%23B31B1B)](https://arxiv.org/abs/2409.09266) [![ProjectPage](https://img.shields.io/badge/Project_Page-TransformerMPC-blue)](https://transformer-mpc.github.io/) 51 |
52 |

53 | 54 | ## Overview 55 | 56 | TransformerMPC improves the computational efficiency of Model Predictive Control (MPC) problems using neural network models. It employs the following two prediction models: 57 | 58 | 1. **Constraint Predictor**: Identifies inactive constraints in MPC formulations 59 | 2. **Warm Start Predictor**: Generates better initial points for MPC solvers 60 | 61 | By combining these models, TransformerMPC significantly reduces computation time while maintaining solution quality. 62 | 63 | ## Package Structure 64 | 65 | The package is organized with a standard Python package structure: 66 | 67 | ``` 68 | transformermpc/ 69 | ├── transformermpc/ # Core package module 70 | │ ├── data/ # Data generation utilities 71 | │ ├── models/ # Model implementations 72 | │ ├── utils/ # Utility functions and metrics 73 | │ ├── training/ # Training infrastructure 74 | │ └── demo/ # Demo scripts 75 | ├── scripts/ # Demo and utility scripts 76 | ├── tests/ # Testing infrastructure 77 | ├── setup.py # Package installation script 78 | └── requirements.txt # Dependencies 79 | ``` 80 | 81 | ## Installation 82 | 83 | Install directly from PyPI: 84 | 85 | ```bash 86 | pip install transformermpc 87 | ``` 88 | 89 | Or install from source: 90 | 91 | ```bash 92 | git clone https://github.com/vrushabh/transformermpc.git 93 | cd transformermpc 94 | pip install -e . 95 | ``` 96 | 97 | ## Dependencies 98 | 99 | - Python >= 3.7 100 | - PyTorch >= 1.9.0 101 | - OSQP >= 0.6.2 102 | - NumPy, SciPy, and other scientific computing libraries 103 | - Additional dependencies specified in requirements.txt 104 | 105 | ## Running the Demos 106 | 107 | The package includes several demo scripts to showcase its capabilities: 108 | 109 | ### Boxplot Demo (Recommended) 110 | 111 | ```bash 112 | python scripts/boxplot_demo.py 113 | ``` 114 | 115 | This script provides a visual comparison of different QP solving strategies using randomly generated problems, without requiring model training. It demonstrates the core concepts behind TransformerMPC by showing the performance impact of: 116 | - Removing inactive constraints 117 | - Using warm starts with different qualities 118 | - Combining these strategies 119 | 120 | The visualizations include boxplots, violin plots, and bar charts comparing solve times across different strategies. 121 | 122 | ### Simple Demo 123 | 124 | ```bash 125 | python scripts/simple_demo.py 126 | ``` 127 | 128 | This script demonstrates the complete pipeline: generating QP problems, training models, and evaluating performance. After completion, it saves performance comparison plots in the `demo_results/results` directory. 129 | 130 | ### Verify Package Structure 131 | 132 | To check that the package is installed correctly: 133 | 134 | ```bash 135 | python scripts/verify_structure.py 136 | ``` 137 | 138 | ### Customizing Demo Parameters 139 | 140 | You can customize the boxplot demo by modifying parameters: 141 | 142 | ```bash 143 | # Generate more problems with different dimensions 144 | python scripts/boxplot_demo.py --num_samples 50 --state_dim 6 --input_dim 3 --horizon 10 145 | 146 | # Save results to a custom directory 147 | python scripts/boxplot_demo.py --output_dir my_results 148 | ``` 149 | 150 | Similarly, for the simple demo: 151 | 152 | ```bash 153 | # Generate QP problems with specific parameters 154 | python scripts/simple_demo.py --num_samples 200 --state_dim 6 --input_dim 3 --horizon 10 155 | 156 | # Customize training parameters 157 | python scripts/simple_demo.py --epochs 20 --batch_size 32 158 | 159 | # Use GPU for training if available 160 | python scripts/simple_demo.py --use_gpu 161 | ``` 162 | 163 | ## Usage in Projects 164 | 165 | ### Basic Example 166 | 167 | ```python 168 | from transformermpc import TransformerMPC 169 | import numpy as np 170 | 171 | # Define your QP problem parameters 172 | Q = np.array([[4.0, 1.0], [1.0, 2.0]]) 173 | c = np.array([1.0, 1.0]) 174 | A = np.array([[-1.0, 0.0], [0.0, -1.0], [-1.0, -1.0], [1.0, 1.0]]) 175 | b = np.array([0.0, 0.0, -1.0, 2.0]) 176 | 177 | # Initialize the TransformerMPC solver 178 | solver = TransformerMPC() 179 | 180 | # Solve with model acceleration 181 | solution, solve_time = solver.solve(Q, c, A, b) 182 | 183 | print(f"Solution: {solution}") 184 | print(f"Solve time: {solve_time} seconds") 185 | ``` 186 | 187 | ### General Usage 188 | 189 | ```python 190 | from transformermpc import TransformerMPC, QPProblem 191 | import numpy as np 192 | 193 | # Define your QP problem parameters 194 | Q = np.array([[4.0, 1.0], [1.0, 2.0]]) 195 | c = np.array([1.0, 1.0]) 196 | A = np.array([[-1.0, 0.0], [0.0, -1.0], [-1.0, -1.0], [1.0, 1.0]]) 197 | b = np.array([0.0, 0.0, -1.0, 2.0]) 198 | initial_state = np.array([0.5, 0.5]) # Optional: initial state for MPC problems 199 | 200 | # Create a QP problem instance 201 | qp_problem = QPProblem( 202 | Q=Q, 203 | c=c, 204 | A=A, 205 | b=b, 206 | initial_state=initial_state # Optional 207 | ) 208 | 209 | # Initialize with custom settings 210 | solver = TransformerMPC( 211 | use_constraint_predictor=True, 212 | use_warm_start_predictor=True, 213 | fallback_on_violation=True 214 | ) 215 | 216 | # Solve the problem 217 | solution, solve_time = solver.solve(qp_problem=qp_problem) 218 | print(f"Solution: {solution}") 219 | print(f"Solve time: {solve_time} seconds") 220 | 221 | # Compare with baseline 222 | baseline_solution, baseline_time = solver.solve_baseline(qp_problem=qp_problem) 223 | print(f"Baseline time: {baseline_time} seconds") 224 | ``` 225 | ## If you find our work useful, please cite us 226 | ``` 227 | @article{zinage2024transformermpc, 228 | title={TransformerMPC: Accelerating Model Predictive Control via Transformers}, 229 | author={Zinage, Vrushabh and Khalil, Ahmed and Bakolas, Efstathios}, 230 | journal={arXiv preprint arXiv:2409.09266}, 231 | year={2024} 232 | } 233 | ``` 234 | 235 | ## License 236 | 237 | This project is licensed under the MIT License. 238 | -------------------------------------------------------------------------------- /tests/test_benchmark.py: -------------------------------------------------------------------------------- 1 | """ 2 | Benchmark script for TransformerMPC. 3 | """ 4 | 5 | import os 6 | import torch 7 | import numpy as np 8 | import argparse 9 | from tqdm import tqdm 10 | import matplotlib.pyplot as plt 11 | from pathlib import Path 12 | 13 | # Fix the serialization issue 14 | import torch.serialization 15 | # Add numpy.core.multiarray.scalar to safe globals 16 | torch.serialization.add_safe_globals(['numpy.core.multiarray.scalar']) 17 | 18 | # Import required modules 19 | from transformermpc.data.qp_generator import QPGenerator 20 | from transformermpc.data.dataset import QPDataset 21 | from transformermpc.models.constraint_predictor import ConstraintPredictor 22 | from transformermpc.models.warm_start_predictor import WarmStartPredictor 23 | from transformermpc.utils.osqp_wrapper import OSQPSolver 24 | from transformermpc.utils.metrics import compute_solve_time_metrics, compute_fallback_rate 25 | from transformermpc.utils.visualization import ( 26 | plot_solve_time_comparison, 27 | plot_solve_time_boxplot 28 | ) 29 | 30 | # Monkey patch torch.load to use weights_only=False by default 31 | original_torch_load = torch.load 32 | def patched_torch_load(f, *args, **kwargs): 33 | if 'weights_only' not in kwargs: 34 | kwargs['weights_only'] = False 35 | return original_torch_load(f, *args, **kwargs) 36 | torch.load = patched_torch_load 37 | 38 | def parse_args(): 39 | """Parse command line arguments.""" 40 | parser = argparse.ArgumentParser(description="TransformerMPC Benchmark") 41 | 42 | # Input/output parameters 43 | parser.add_argument( 44 | "--data_dir", type=str, default="demo_results/data", 45 | help="Directory containing QP problem data (default: demo_results/data)" 46 | ) 47 | parser.add_argument( 48 | "--results_dir", type=str, default="demo_results/results", 49 | help="Directory to save benchmark results (default: demo_results/results)" 50 | ) 51 | 52 | # Benchmark parameters 53 | parser.add_argument( 54 | "--test_size", type=float, default=0.2, 55 | help="Fraction of data to use for testing (default: 0.2)" 56 | ) 57 | parser.add_argument( 58 | "--num_test_problems", type=int, default=20, 59 | help="Number of test problems to benchmark (default: 20)" 60 | ) 61 | parser.add_argument( 62 | "--use_gpu", action="store_true", 63 | help="Use GPU if available" 64 | ) 65 | 66 | return parser.parse_args() 67 | 68 | def main(): 69 | """Run a benchmark test.""" 70 | # Parse command line arguments 71 | args = parse_args() 72 | 73 | # Set up directories 74 | data_dir = Path(args.data_dir) 75 | results_dir = Path(args.results_dir) 76 | 77 | # Set device 78 | device = torch.device('cuda' if torch.cuda.is_available() and args.use_gpu else 'cpu') 79 | print(f"Using device: {device}") 80 | 81 | # 1. Load QP problems 82 | qp_problems_file = data_dir / "qp_problems.npy" 83 | if qp_problems_file.exists(): 84 | print(f"Loading QP problems from {qp_problems_file}") 85 | qp_problems = QPGenerator.load(qp_problems_file) 86 | print(f"Loaded {len(qp_problems)} QP problems") 87 | else: 88 | print("Error: No QP problems found. Please run the demo first.") 89 | return 90 | 91 | # 2. Load dataset 92 | print("Loading dataset") 93 | dataset = QPDataset( 94 | qp_problems=qp_problems, 95 | precompute_solutions=True, 96 | feature_normalization=True, 97 | cache_dir=data_dir 98 | ) 99 | 100 | _, test_dataset = dataset.split(test_size=args.test_size) 101 | print(f"Loaded test dataset with {len(test_dataset)} problems") 102 | 103 | # 3. Create models with default parameters (don't load from files) 104 | print("Creating models") 105 | sample_item = test_dataset[0] 106 | input_dim = sample_item['features'].shape[0] 107 | num_constraints = sample_item['active_constraints'].shape[0] 108 | output_dim = sample_item['solution'].shape[0] 109 | 110 | cp_model = ConstraintPredictor( 111 | input_dim=input_dim, 112 | hidden_dim=128, 113 | num_constraints=num_constraints 114 | ) 115 | 116 | ws_model = WarmStartPredictor( 117 | input_dim=input_dim, 118 | hidden_dim=256, 119 | output_dim=output_dim 120 | ) 121 | 122 | # 4. Test on a small subset 123 | print("Testing on a subset of problems") 124 | solver = OSQPSolver() 125 | 126 | # List to store results 127 | baseline_times = [] 128 | warmstart_only_times = [] 129 | constraint_only_times = [] 130 | transformer_times = [] 131 | fallback_flags = [] 132 | 133 | # Test on a subset for visualization 134 | test_subset = np.random.choice(len(test_dataset), size=args.num_test_problems, replace=False) 135 | 136 | print(f"Benchmarking {args.num_test_problems} problems...") 137 | for idx in tqdm(test_subset): 138 | # Get problem 139 | sample = test_dataset[idx] 140 | problem = test_dataset.get_problem(idx) 141 | 142 | # Get features 143 | features = sample['features'] 144 | 145 | # For demonstration, we'll use the solutions directly instead of predictions 146 | # since we're using untrained models 147 | true_active = sample['active_constraints'].numpy() 148 | true_solution = sample['solution'].numpy() 149 | 150 | # Baseline (OSQP without transformers) 151 | _, baseline_time = solver.solve_with_time( 152 | Q=problem.Q, 153 | c=problem.c, 154 | A=problem.A, 155 | b=problem.b 156 | ) 157 | baseline_times.append(baseline_time) 158 | 159 | # Warm start only 160 | _, warmstart_time = solver.solve_with_time( 161 | Q=problem.Q, 162 | c=problem.c, 163 | A=problem.A, 164 | b=problem.b, 165 | warm_start=true_solution # Using true solution as warm start 166 | ) 167 | warmstart_only_times.append(warmstart_time) 168 | 169 | # Constraint only 170 | _, is_feasible, constraint_time = solver.solve_reduced_with_time( 171 | Q=problem.Q, 172 | c=problem.c, 173 | A=problem.A, 174 | b=problem.b, 175 | active_constraints=true_active # Using true active constraints 176 | ) 177 | constraint_only_times.append(constraint_time) 178 | 179 | # Full transformer pipeline (using true values for demo) 180 | _, transformer_time, used_fallback = solver.solve_pipeline( 181 | Q=problem.Q, 182 | c=problem.c, 183 | A=problem.A, 184 | b=problem.b, 185 | active_constraints=true_active, 186 | warm_start=true_solution, 187 | fallback_on_violation=True 188 | ) 189 | transformer_times.append(transformer_time) 190 | fallback_flags.append(used_fallback) 191 | 192 | # Convert to numpy arrays 193 | baseline_times = np.array(baseline_times) 194 | warmstart_only_times = np.array(warmstart_only_times) 195 | constraint_only_times = np.array(constraint_only_times) 196 | transformer_times = np.array(transformer_times) 197 | 198 | # Compute and print metrics 199 | solve_metrics = compute_solve_time_metrics(baseline_times, transformer_times) 200 | fallback_rate = compute_fallback_rate(fallback_flags) 201 | 202 | print("\nSolve Time Metrics:") 203 | print(f"Mean baseline time: {solve_metrics['mean_baseline_time']:.6f}s") 204 | print(f"Mean transformer time: {solve_metrics['mean_transformer_time']:.6f}s") 205 | print(f"Mean speedup: {solve_metrics['mean_speedup']:.2f}x") 206 | print(f"Median speedup: {solve_metrics['median_speedup']:.2f}x") 207 | print(f"Fallback rate: {fallback_rate:.2f}%") 208 | 209 | # Create results directory if it doesn't exist 210 | results_dir.mkdir(parents=True, exist_ok=True) 211 | 212 | # Plot solve time comparison 213 | plot_solve_time_comparison( 214 | baseline_times=baseline_times, 215 | transformer_times=transformer_times, 216 | save_path=results_dir / "solve_time_comparison.png" 217 | ) 218 | 219 | # Plot boxplot 220 | plot_solve_time_boxplot( 221 | baseline_times=baseline_times, 222 | transformer_times=transformer_times, 223 | constraint_only_times=constraint_only_times, 224 | warmstart_only_times=warmstart_only_times, 225 | save_path=results_dir / "solve_time_boxplot.png" 226 | ) 227 | 228 | print(f"\nResults saved to {results_dir}") 229 | 230 | if __name__ == "__main__": 231 | main() -------------------------------------------------------------------------------- /transformermpc/utils/metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Metrics utility module for TransformerMPC. 3 | 4 | This module provides functions for computing various metrics used to 5 | evaluate the performance of the transformer models and the overall 6 | solving pipeline. 7 | """ 8 | 9 | import numpy as np 10 | import torch 11 | from typing import Dict, List, Tuple, Union, Optional, Any 12 | from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score 13 | 14 | 15 | def compute_constraint_prediction_metrics(y_true: Union[np.ndarray, torch.Tensor], 16 | y_pred: Union[np.ndarray, torch.Tensor], 17 | threshold: float = 0.5) -> Dict[str, float]: 18 | """ 19 | Compute classification metrics for constraint prediction. 20 | 21 | Parameters: 22 | ----------- 23 | y_true : numpy.ndarray or torch.Tensor 24 | True binary labels 25 | y_pred : numpy.ndarray or torch.Tensor 26 | Predicted probability or scores 27 | threshold : float 28 | Threshold for converting probabilities to binary labels 29 | 30 | Returns: 31 | -------- 32 | metrics : Dict[str, float] 33 | Dictionary containing precision, recall, F1 score, and accuracy 34 | """ 35 | # Convert to numpy if tensors 36 | if isinstance(y_true, torch.Tensor): 37 | y_true = y_true.detach().cpu().numpy() 38 | if isinstance(y_pred, torch.Tensor): 39 | y_pred = y_pred.detach().cpu().numpy() 40 | 41 | # Convert predictions to binary labels using threshold 42 | y_pred_binary = (y_pred > threshold).astype(np.int32) 43 | 44 | # Compute metrics 45 | precision = precision_score(y_true, y_pred_binary, average='binary', zero_division=0) 46 | recall = recall_score(y_true, y_pred_binary, average='binary', zero_division=0) 47 | f1 = f1_score(y_true, y_pred_binary, average='binary', zero_division=0) 48 | accuracy = accuracy_score(y_true, y_pred_binary) 49 | 50 | return { 51 | 'precision': precision, 52 | 'recall': recall, 53 | 'f1': f1, 54 | 'accuracy': accuracy 55 | } 56 | 57 | 58 | def compute_warm_start_metrics(y_true: Union[np.ndarray, torch.Tensor], 59 | y_pred: Union[np.ndarray, torch.Tensor]) -> Dict[str, float]: 60 | """ 61 | Compute regression metrics for warm start prediction. 62 | 63 | Parameters: 64 | ----------- 65 | y_true : numpy.ndarray or torch.Tensor 66 | True solution values 67 | y_pred : numpy.ndarray or torch.Tensor 68 | Predicted solution values 69 | 70 | Returns: 71 | -------- 72 | metrics : Dict[str, float] 73 | Dictionary containing MSE, MAE, and relative error 74 | """ 75 | # Convert to numpy if tensors 76 | if isinstance(y_true, torch.Tensor): 77 | y_true = y_true.detach().cpu().numpy() 78 | if isinstance(y_pred, torch.Tensor): 79 | y_pred = y_pred.detach().cpu().numpy() 80 | 81 | # Compute metrics 82 | mse = np.mean((y_true - y_pred) ** 2) 83 | mae = np.mean(np.abs(y_true - y_pred)) 84 | 85 | # Compute relative error (avoid division by zero) 86 | denom = np.linalg.norm(y_true, axis=1) 87 | denom = np.where(denom < 1e-8, 1.0, denom) 88 | rel_error = np.mean(np.linalg.norm(y_true - y_pred, axis=1) / denom) 89 | 90 | return { 91 | 'mse': mse, 92 | 'mae': mae, 93 | 'relative_error': rel_error 94 | } 95 | 96 | 97 | def compute_active_constraint_stats(active_constraints: List[np.ndarray]) -> Dict[str, float]: 98 | """ 99 | Compute statistics about active constraints. 100 | 101 | Parameters: 102 | ----------- 103 | active_constraints : List[numpy.ndarray] 104 | List of binary vectors indicating active constraints for each problem 105 | 106 | Returns: 107 | -------- 108 | stats : Dict[str, float] 109 | Dictionary containing statistics about active constraints 110 | """ 111 | # Calculate percentage of active constraints for each problem 112 | active_percentages = [np.mean(ac) * 100 for ac in active_constraints] 113 | 114 | # Compute statistics 115 | mean_percentage = np.mean(active_percentages) 116 | median_percentage = np.median(active_percentages) 117 | std_percentage = np.std(active_percentages) 118 | min_percentage = np.min(active_percentages) 119 | max_percentage = np.max(active_percentages) 120 | 121 | return { 122 | 'mean_active_percentage': mean_percentage, 123 | 'median_active_percentage': median_percentage, 124 | 'std_active_percentage': std_percentage, 125 | 'min_active_percentage': min_percentage, 126 | 'max_active_percentage': max_percentage, 127 | 'active_percentages': np.array(active_percentages) 128 | } 129 | 130 | 131 | def compute_solve_time_metrics(baseline_times: np.ndarray, 132 | transformer_times: np.ndarray) -> Dict[str, float]: 133 | """ 134 | Compute metrics comparing solve times. 135 | 136 | Parameters: 137 | ----------- 138 | baseline_times : numpy.ndarray 139 | Array of solve times for baseline method 140 | transformer_times : numpy.ndarray 141 | Array of solve times for transformer-enhanced method 142 | 143 | Returns: 144 | -------- 145 | metrics : Dict[str, float] 146 | Dictionary containing solve time comparison metrics 147 | """ 148 | # Compute speedup ratios 149 | speedup_ratios = baseline_times / transformer_times 150 | 151 | # Compute statistics 152 | mean_speedup = np.mean(speedup_ratios) 153 | median_speedup = np.median(speedup_ratios) 154 | min_speedup = np.min(speedup_ratios) 155 | max_speedup = np.max(speedup_ratios) 156 | std_speedup = np.std(speedup_ratios) 157 | 158 | # Compute percentiles 159 | percentiles = np.percentile(speedup_ratios, [10, 25, 75, 90]) 160 | 161 | # Compute time statistics 162 | mean_baseline_time = np.mean(baseline_times) 163 | mean_transformer_time = np.mean(transformer_times) 164 | median_baseline_time = np.median(baseline_times) 165 | median_transformer_time = np.median(transformer_times) 166 | 167 | return { 168 | 'mean_speedup': mean_speedup, 169 | 'median_speedup': median_speedup, 170 | 'min_speedup': min_speedup, 171 | 'max_speedup': max_speedup, 172 | 'std_speedup': std_speedup, 173 | 'p10_speedup': percentiles[0], 174 | 'p25_speedup': percentiles[1], 175 | 'p75_speedup': percentiles[2], 176 | 'p90_speedup': percentiles[3], 177 | 'mean_baseline_time': mean_baseline_time, 178 | 'mean_transformer_time': mean_transformer_time, 179 | 'median_baseline_time': median_baseline_time, 180 | 'median_transformer_time': median_transformer_time, 181 | 'speedup_ratios': speedup_ratios 182 | } 183 | 184 | 185 | def compute_fallback_rate(fallback_flags: List[bool]) -> float: 186 | """ 187 | Compute the rate of fallbacks to full QP solve. 188 | 189 | Parameters: 190 | ----------- 191 | fallback_flags : List[bool] 192 | List of flags indicating whether fallback was used 193 | 194 | Returns: 195 | -------- 196 | fallback_rate : float 197 | Percentage of problems that required fallback 198 | """ 199 | return 100.0 * np.mean([1 if flag else 0 for flag in fallback_flags]) 200 | 201 | 202 | def compute_comprehensive_metrics( 203 | constraint_pred_true: np.ndarray, 204 | constraint_pred: np.ndarray, 205 | warmstart_true: np.ndarray, 206 | warmstart_pred: np.ndarray, 207 | baseline_times: np.ndarray, 208 | transformer_times: np.ndarray, 209 | fallback_flags: List[bool] 210 | ) -> Dict[str, Any]: 211 | """ 212 | Compute comprehensive set of metrics for the entire pipeline. 213 | 214 | Parameters: 215 | ----------- 216 | constraint_pred_true : numpy.ndarray 217 | True binary labels for constraint prediction 218 | constraint_pred : numpy.ndarray 219 | Predicted probability or scores for constraint prediction 220 | warmstart_true : numpy.ndarray 221 | True solution values for warm start prediction 222 | warmstart_pred : numpy.ndarray 223 | Predicted solution values for warm start prediction 224 | baseline_times : numpy.ndarray 225 | Array of solve times for baseline method 226 | transformer_times : numpy.ndarray 227 | Array of solve times for transformer-enhanced method 228 | fallback_flags : List[bool] 229 | List of flags indicating whether fallback was used 230 | 231 | Returns: 232 | -------- 233 | metrics : Dict[str, Any] 234 | Dictionary containing all metrics 235 | """ 236 | # Compute individual metrics 237 | constraint_metrics = compute_constraint_prediction_metrics( 238 | constraint_pred_true, constraint_pred) 239 | 240 | warmstart_metrics = compute_warm_start_metrics( 241 | warmstart_true, warmstart_pred) 242 | 243 | solve_time_metrics = compute_solve_time_metrics( 244 | baseline_times, transformer_times) 245 | 246 | fallback_rate = compute_fallback_rate(fallback_flags) 247 | 248 | # Combine all metrics 249 | metrics = { 250 | 'constraint_prediction': constraint_metrics, 251 | 'warm_start_prediction': warmstart_metrics, 252 | 'solve_time': solve_time_metrics, 253 | 'fallback_rate': fallback_rate 254 | } 255 | 256 | return metrics 257 | -------------------------------------------------------------------------------- /transformermpc/utils/osqp_wrapper.py: -------------------------------------------------------------------------------- 1 | """ 2 | OSQP solver wrapper module. 3 | 4 | This module provides a wrapper for the OSQP solver to solve QP problems. 5 | """ 6 | 7 | import numpy as np 8 | import osqp 9 | import scipy.sparse as sparse 10 | import time 11 | from typing import Dict, Optional, Tuple, Union, List, Any 12 | 13 | class OSQPSolver: 14 | """ 15 | Wrapper class for the OSQP solver. 16 | 17 | This class provides a convenient interface to solve QP problems using OSQP. 18 | """ 19 | 20 | def __init__(self, 21 | verbose: bool = False, 22 | max_iter: int = 4000, 23 | eps_abs: float = 1e-6, 24 | eps_rel: float = 1e-6, 25 | polish: bool = True): 26 | """ 27 | Initialize the OSQP solver wrapper. 28 | 29 | Parameters: 30 | ----------- 31 | verbose : bool 32 | Whether to print solver output 33 | max_iter : int 34 | Maximum number of iterations 35 | eps_abs : float 36 | Absolute tolerance 37 | eps_rel : float 38 | Relative tolerance 39 | polish : bool 40 | Whether to polish the solution 41 | """ 42 | self.verbose = verbose 43 | self.max_iter = max_iter 44 | self.eps_abs = eps_abs 45 | self.eps_rel = eps_rel 46 | self.polish = polish 47 | 48 | def solve(self, 49 | Q: np.ndarray, 50 | c: np.ndarray, 51 | A: Optional[np.ndarray] = None, 52 | b: Optional[np.ndarray] = None, 53 | warm_start: Optional[np.ndarray] = None) -> np.ndarray: 54 | """ 55 | Solve a QP problem using OSQP. 56 | 57 | Parameters: 58 | ----------- 59 | Q : numpy.ndarray 60 | Quadratic cost matrix (n x n) 61 | c : numpy.ndarray 62 | Linear cost vector (n) 63 | A : numpy.ndarray or None 64 | Constraint matrix (m x n) 65 | b : numpy.ndarray or None 66 | Constraint vector (m) 67 | warm_start : numpy.ndarray or None 68 | Warm start vector for the solver 69 | 70 | Returns: 71 | -------- 72 | solution : numpy.ndarray 73 | Optimal solution vector 74 | """ 75 | # Convert to sparse matrices for OSQP 76 | P = sparse.csc_matrix(Q) 77 | q = c 78 | 79 | # Check if we have constraints 80 | if A is not None and b is not None: 81 | A_sparse = sparse.csc_matrix(A) 82 | l = b # Constraints are A*x <= b, so l <= A*x <= u, where l = -inf and u = b 83 | u = np.inf * np.ones(A.shape[0]) 84 | else: 85 | # No constraints 86 | A_sparse = sparse.csc_matrix((0, Q.shape[0])) 87 | l = np.array([]) 88 | u = np.array([]) 89 | 90 | # Create the OSQP solver 91 | solver = osqp.OSQP() 92 | 93 | # Setup the problem 94 | solver.setup(P=P, q=q, A=A_sparse, l=-np.inf * np.ones_like(l), u=b, 95 | verbose=self.verbose, max_iter=self.max_iter, 96 | eps_abs=self.eps_abs, eps_rel=self.eps_rel, 97 | polish=self.polish) 98 | 99 | # Set warm start if provided 100 | if warm_start is not None: 101 | solver.warm_start(x=warm_start) 102 | 103 | # Solve the problem 104 | result = solver.solve() 105 | 106 | # Check if the solver was successful 107 | if result.info.status != 'solved': 108 | print(f"Warning: OSQP solver returned status {result.info.status}") 109 | 110 | # Return the solution 111 | return result.x 112 | 113 | def solve_with_time(self, 114 | Q: np.ndarray, 115 | c: np.ndarray, 116 | A: Optional[np.ndarray] = None, 117 | b: Optional[np.ndarray] = None, 118 | warm_start: Optional[np.ndarray] = None) -> Tuple[np.ndarray, float]: 119 | """ 120 | Solve a QP problem using OSQP and return solution time. 121 | 122 | Parameters: 123 | ----------- 124 | Same as solve method. 125 | 126 | Returns: 127 | -------- 128 | solution : numpy.ndarray 129 | Optimal solution vector 130 | solve_time : float 131 | Solution time in seconds 132 | """ 133 | # Measure the solution time 134 | start_time = time.time() 135 | solution = self.solve(Q, c, A, b, warm_start) 136 | solve_time = time.time() - start_time 137 | 138 | return solution, solve_time 139 | 140 | def solve_reduced(self, 141 | Q: np.ndarray, 142 | c: np.ndarray, 143 | A: np.ndarray, 144 | b: np.ndarray, 145 | active_constraints: np.ndarray, 146 | warm_start: Optional[np.ndarray] = None) -> Tuple[np.ndarray, bool]: 147 | """ 148 | Solve a reduced QP problem with only active constraints. 149 | 150 | Parameters: 151 | ----------- 152 | Q, c, A, b : Same as solve method 153 | active_constraints : numpy.ndarray 154 | Binary vector indicating which constraints are active 155 | warm_start : numpy.ndarray or None 156 | Warm start vector for the solver 157 | 158 | Returns: 159 | -------- 160 | solution : numpy.ndarray 161 | Optimal solution vector 162 | is_feasible : bool 163 | Whether the solution is feasible for the original problem 164 | """ 165 | # Get indices of active constraints 166 | active_indices = np.where(active_constraints > 0.5)[0] 167 | 168 | # Create reduced constraint matrices 169 | if len(active_indices) > 0: 170 | A_reduced = A[active_indices, :] 171 | b_reduced = b[active_indices] 172 | else: 173 | # No active constraints, solve unconstrained problem 174 | A_reduced = None 175 | b_reduced = None 176 | 177 | # Solve the reduced problem 178 | solution = self.solve(Q, c, A_reduced, b_reduced, warm_start) 179 | 180 | # Check if the solution satisfies the original constraints 181 | is_feasible = True 182 | if A is not None and b is not None: 183 | constraint_values = A @ solution - b 184 | is_feasible = np.all(constraint_values <= 1e-6) 185 | 186 | return solution, is_feasible 187 | 188 | def solve_reduced_with_time(self, 189 | Q: np.ndarray, 190 | c: np.ndarray, 191 | A: np.ndarray, 192 | b: np.ndarray, 193 | active_constraints: np.ndarray, 194 | warm_start: Optional[np.ndarray] = None) -> Tuple[np.ndarray, bool, float]: 195 | """ 196 | Solve a reduced QP problem with only active constraints and return solution time. 197 | 198 | Parameters: 199 | ----------- 200 | Same as solve_reduced method. 201 | 202 | Returns: 203 | -------- 204 | solution : numpy.ndarray 205 | Optimal solution vector 206 | is_feasible : bool 207 | Whether the solution is feasible for the original problem 208 | solve_time : float 209 | Solution time in seconds 210 | """ 211 | # Measure the solution time 212 | start_time = time.time() 213 | solution, is_feasible = self.solve_reduced(Q, c, A, b, active_constraints, warm_start) 214 | solve_time = time.time() - start_time 215 | 216 | return solution, is_feasible, solve_time 217 | 218 | def solve_pipeline(self, 219 | Q: np.ndarray, 220 | c: np.ndarray, 221 | A: np.ndarray, 222 | b: np.ndarray, 223 | active_constraints: np.ndarray, 224 | warm_start: Optional[np.ndarray] = None, 225 | fallback_on_violation: bool = True) -> Tuple[np.ndarray, float, bool]: 226 | """ 227 | Solve a QP problem using the transformer-enhanced pipeline. 228 | 229 | This method first tries to solve the reduced problem with active constraints. 230 | If the solution isn't feasible for the original problem, it falls back to the full problem. 231 | 232 | Parameters: 233 | ----------- 234 | Q, c, A, b : Same as solve method 235 | active_constraints : numpy.ndarray 236 | Binary vector indicating which constraints are active 237 | warm_start : numpy.ndarray or None 238 | Warm start vector for the solver 239 | fallback_on_violation : bool 240 | Whether to fall back to the full problem if constraints are violated 241 | 242 | Returns: 243 | -------- 244 | solution : numpy.ndarray 245 | Optimal solution vector 246 | solve_time : float 247 | Solution time in seconds 248 | used_fallback : bool 249 | Whether the fallback solver was used 250 | """ 251 | # Start timing 252 | start_time = time.time() 253 | 254 | # Try to solve the reduced problem 255 | solution, is_feasible = self.solve_reduced(Q, c, A, b, active_constraints, warm_start) 256 | 257 | # If the solution is not feasible and fallback is enabled, solve the full problem 258 | used_fallback = False 259 | if not is_feasible and fallback_on_violation: 260 | solution = self.solve(Q, c, A, b, warm_start) 261 | used_fallback = True 262 | 263 | # Compute total solution time 264 | solve_time = time.time() - start_time 265 | 266 | return solution, solve_time, used_fallback 267 | -------------------------------------------------------------------------------- /scripts/boxplot_demo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | TransformerMPC Boxplot Demo 4 | 5 | This script demonstrates the performance comparison of different QP solving strategies 6 | using randomly generated quadratic programming (QP) problems. 7 | 8 | It focuses on showing the boxplot comparison without actual training of transformer models. 9 | """ 10 | 11 | import os 12 | import time 13 | import argparse 14 | import numpy as np 15 | import matplotlib.pyplot as plt 16 | from pathlib import Path 17 | from tqdm import tqdm 18 | 19 | # Import TransformerMPC modules 20 | from transformermpc.data.qp_generator import QPGenerator 21 | from transformermpc.utils.osqp_wrapper import OSQPSolver 22 | from transformermpc.utils.metrics import compute_solve_time_metrics 23 | 24 | def parse_args(): 25 | """Parse command line arguments.""" 26 | parser = argparse.ArgumentParser(description="TransformerMPC Boxplot Demo") 27 | 28 | # Data generation parameters 29 | parser.add_argument("--num_samples", type=int, default=30, 30 | help="Number of QP problems to generate (default: 30)") 31 | parser.add_argument("--state_dim", type=int, default=4, 32 | help="State dimension for MPC problems (default: 4)") 33 | parser.add_argument("--input_dim", type=int, default=2, 34 | help="Input dimension for MPC problems (default: 2)") 35 | parser.add_argument("--horizon", type=int, default=5, 36 | help="Time horizon for MPC problems (default: 5)") 37 | 38 | # Other parameters 39 | parser.add_argument("--output_dir", type=str, default="demo_results", 40 | help="Directory to save results (default: demo_results)") 41 | 42 | return parser.parse_args() 43 | 44 | def main(): 45 | """Run the boxplot demo workflow.""" 46 | # Parse command line arguments 47 | args = parse_args() 48 | 49 | print("=" * 60) 50 | print("TransformerMPC Boxplot Demo".center(60)) 51 | print("=" * 60) 52 | 53 | # Create output directory 54 | output_dir = Path(args.output_dir) 55 | output_dir.mkdir(parents=True, exist_ok=True) 56 | 57 | # Set up directory for results 58 | results_dir = output_dir / "results" 59 | results_dir.mkdir(exist_ok=True) 60 | 61 | # Step 1: Generate QP problems 62 | print("\nStep 1: Generating QP problems") 63 | print("-" * 60) 64 | 65 | print(f"Generating {args.num_samples} QP problems") 66 | generator = QPGenerator( 67 | state_dim=args.state_dim, 68 | input_dim=args.input_dim, 69 | horizon=args.horizon, 70 | num_samples=args.num_samples 71 | ) 72 | qp_problems = generator.generate() 73 | print(f"Generated {len(qp_problems)} QP problems") 74 | 75 | # Step 2: Performance testing with different strategies 76 | print("\nStep 2: Performance testing") 77 | print("-" * 60) 78 | 79 | solver = OSQPSolver() 80 | 81 | # Lists to store results for different strategies 82 | baseline_times = [] # Standard OSQP 83 | reduced_constraint_times = [] # 50% random constraints removed 84 | warm_start_random_times = [] # Warm start with random initial point 85 | warm_start_perturbed_times = [] # Warm start with slightly perturbed solution 86 | combined_strategy_times = [] # Both constraints reduced and warm start 87 | 88 | print(f"Testing {len(qp_problems)} problems...") 89 | for i, problem in enumerate(tqdm(qp_problems)): 90 | # 1. Baseline (Standard OSQP) 91 | _, baseline_time = solver.solve_with_time( 92 | Q=problem.Q, 93 | c=problem.c, 94 | A=problem.A, 95 | b=problem.b 96 | ) 97 | baseline_times.append(baseline_time) 98 | 99 | # Get the solution for perturbed warm start 100 | solution = solver.solve( 101 | Q=problem.Q, 102 | c=problem.c, 103 | A=problem.A, 104 | b=problem.b 105 | ) 106 | 107 | # 2. Reduced constraints (randomly remove 50% of constraints) 108 | num_constraints = problem.A.shape[0] 109 | mask = np.random.choice([True, False], size=num_constraints, p=[0.5, 0.5]) 110 | 111 | A_reduced = problem.A[mask] 112 | b_reduced = problem.b[mask] 113 | 114 | _, reduced_time = solver.solve_with_time( 115 | Q=problem.Q, 116 | c=problem.c, 117 | A=A_reduced, 118 | b=b_reduced 119 | ) 120 | reduced_constraint_times.append(reduced_time) 121 | 122 | # 3. Warm start with random initial point 123 | warm_start_random = np.random.randn(problem.Q.shape[0]) 124 | _, warm_start_random_time = solver.solve_with_time( 125 | Q=problem.Q, 126 | c=problem.c, 127 | A=problem.A, 128 | b=problem.b, 129 | warm_start=warm_start_random 130 | ) 131 | warm_start_random_times.append(warm_start_random_time) 132 | 133 | # 4. Warm start with slightly perturbed solution (simulate good prediction) 134 | perturbation = np.random.randn(solution.shape[0]) * 0.1 # Small perturbation 135 | warm_start_perturbed = solution + perturbation 136 | 137 | _, warm_start_perturbed_time = solver.solve_with_time( 138 | Q=problem.Q, 139 | c=problem.c, 140 | A=problem.A, 141 | b=problem.b, 142 | warm_start=warm_start_perturbed 143 | ) 144 | warm_start_perturbed_times.append(warm_start_perturbed_time) 145 | 146 | # 5. Combined strategy (reduced constraints + warm start) 147 | _, combined_time = solver.solve_with_time( 148 | Q=problem.Q, 149 | c=problem.c, 150 | A=A_reduced, 151 | b=b_reduced, 152 | warm_start=warm_start_perturbed 153 | ) 154 | combined_strategy_times.append(combined_time) 155 | 156 | # Convert to numpy arrays 157 | baseline_times = np.array(baseline_times) 158 | reduced_constraint_times = np.array(reduced_constraint_times) 159 | warm_start_random_times = np.array(warm_start_random_times) 160 | warm_start_perturbed_times = np.array(warm_start_perturbed_times) 161 | combined_strategy_times = np.array(combined_strategy_times) 162 | 163 | # Compute and print metrics 164 | print("\nPerformance Results:") 165 | print("-" * 60) 166 | print(f"Mean baseline time: {np.mean(baseline_times):.6f}s") 167 | print(f"Mean reduced constraints time: {np.mean(reduced_constraint_times):.6f}s") 168 | print(f"Mean warm start (random) time: {np.mean(warm_start_random_times):.6f}s") 169 | print(f"Mean warm start (perturbed) time: {np.mean(warm_start_perturbed_times):.6f}s") 170 | print(f"Mean combined strategy time: {np.mean(combined_strategy_times):.6f}s") 171 | 172 | # Calculate speedups 173 | print("\nSpeedup Factors:") 174 | print(f"Reduced constraints: {np.mean(baseline_times) / np.mean(reduced_constraint_times):.2f}x") 175 | print(f"Warm start (random): {np.mean(baseline_times) / np.mean(warm_start_random_times):.2f}x") 176 | print(f"Warm start (perturbed): {np.mean(baseline_times) / np.mean(warm_start_perturbed_times):.2f}x") 177 | print(f"Combined strategy: {np.mean(baseline_times) / np.mean(combined_strategy_times):.2f}x") 178 | 179 | # Step 3: Generate visualizations 180 | print("\nStep 3: Generating visualizations") 181 | print("-" * 60) 182 | 183 | # Box plot of solve times 184 | print("Generating solve time boxplot...") 185 | plt.figure(figsize=(12, 8)) 186 | box_data = [ 187 | baseline_times, 188 | reduced_constraint_times, 189 | warm_start_random_times, 190 | warm_start_perturbed_times, 191 | combined_strategy_times 192 | ] 193 | 194 | box_labels = [ 195 | 'Baseline', 196 | 'Reduced\nConstraints', 197 | 'Warm Start\n(Random)', 198 | 'Warm Start\n(Perturbed)', 199 | 'Combined\nStrategy' 200 | ] 201 | 202 | box_plot = plt.boxplot( 203 | box_data, 204 | labels=box_labels, 205 | patch_artist=True, 206 | showmeans=True 207 | ) 208 | 209 | # Add colors to boxes 210 | colors = ['lightblue', 'lightgreen', 'lightpink', 'lightyellow', 'lightcyan'] 211 | for patch, color in zip(box_plot['boxes'], colors): 212 | patch.set_facecolor(color) 213 | 214 | plt.ylabel('Solve Time (s)') 215 | plt.title('QP Solve Time Comparison') 216 | plt.grid(True, axis='y', alpha=0.3) 217 | 218 | # Add mean value annotations 219 | means = [np.mean(data) for data in box_data] 220 | for i, mean in enumerate(means): 221 | plt.text(i+1, mean, f'{mean:.6f}s', 222 | horizontalalignment='center', 223 | verticalalignment='bottom', 224 | fontweight='bold') 225 | 226 | plt.tight_layout() 227 | plt.savefig(results_dir / "solve_time_boxplot.png", dpi=300) 228 | 229 | # Create speedup bar chart 230 | print("Generating speedup bar chart...") 231 | plt.figure(figsize=(10, 6)) 232 | 233 | speedups = [ 234 | np.mean(baseline_times) / np.mean(reduced_constraint_times), 235 | np.mean(baseline_times) / np.mean(warm_start_random_times), 236 | np.mean(baseline_times) / np.mean(warm_start_perturbed_times), 237 | np.mean(baseline_times) / np.mean(combined_strategy_times) 238 | ] 239 | 240 | speedup_labels = [ 241 | 'Reduced\nConstraints', 242 | 'Warm Start\n(Random)', 243 | 'Warm Start\n(Perturbed)', 244 | 'Combined\nStrategy' 245 | ] 246 | 247 | bars = plt.bar(speedup_labels, speedups, color=['lightgreen', 'lightpink', 'lightyellow', 'lightcyan']) 248 | 249 | plt.ylabel('Speedup Factor (×)') 250 | plt.title('Speedup Relative to Baseline') 251 | plt.grid(True, axis='y', alpha=0.3) 252 | 253 | # Add value annotations 254 | for bar, speedup in zip(bars, speedups): 255 | height = bar.get_height() 256 | plt.text(bar.get_x() + bar.get_width()/2., height + 0.1, 257 | f'{speedup:.2f}×', 258 | ha='center', va='bottom', fontweight='bold') 259 | 260 | plt.tight_layout() 261 | plt.savefig(results_dir / "speedup_barchart.png", dpi=300) 262 | 263 | # Violin plot of solve times 264 | print("Generating violin plot...") 265 | plt.figure(figsize=(12, 8)) 266 | 267 | violin_plot = plt.violinplot( 268 | box_data, 269 | showmeans=True, 270 | showextrema=True 271 | ) 272 | 273 | plt.xticks(range(1, 6), box_labels) 274 | plt.ylabel('Solve Time (s)') 275 | plt.title('QP Solve Time Distribution') 276 | plt.grid(True, axis='y', alpha=0.3) 277 | 278 | plt.tight_layout() 279 | plt.savefig(results_dir / "solve_time_violinplot.png", dpi=300) 280 | 281 | print(f"\nResults and visualizations saved to {output_dir}/results") 282 | print("\nDemo completed successfully!") 283 | print("=" * 60) 284 | 285 | if __name__ == "__main__": 286 | main() -------------------------------------------------------------------------------- /tests/scripts/boxplot_demo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | TransformerMPC Boxplot Demo 4 | 5 | This script demonstrates the performance comparison of different QP solving strategies 6 | using randomly generated quadratic programming (QP) problems. 7 | 8 | It focuses on showing the boxplot comparison without actual training of transformer models. 9 | """ 10 | 11 | import os 12 | import time 13 | import argparse 14 | import numpy as np 15 | import matplotlib.pyplot as plt 16 | from pathlib import Path 17 | from tqdm import tqdm 18 | 19 | # Import TransformerMPC modules 20 | from transformermpc.data.qp_generator import QPGenerator 21 | from transformermpc.utils.osqp_wrapper import OSQPSolver 22 | from transformermpc.utils.metrics import compute_solve_time_metrics 23 | 24 | def parse_args(): 25 | """Parse command line arguments.""" 26 | parser = argparse.ArgumentParser(description="TransformerMPC Boxplot Demo") 27 | 28 | # Data generation parameters 29 | parser.add_argument("--num_samples", type=int, default=30, 30 | help="Number of QP problems to generate (default: 30)") 31 | parser.add_argument("--state_dim", type=int, default=4, 32 | help="State dimension for MPC problems (default: 4)") 33 | parser.add_argument("--input_dim", type=int, default=2, 34 | help="Input dimension for MPC problems (default: 2)") 35 | parser.add_argument("--horizon", type=int, default=5, 36 | help="Time horizon for MPC problems (default: 5)") 37 | 38 | # Other parameters 39 | parser.add_argument("--output_dir", type=str, default="demo_results", 40 | help="Directory to save results (default: demo_results)") 41 | 42 | return parser.parse_args() 43 | 44 | def main(): 45 | """Run the boxplot demo workflow.""" 46 | # Parse command line arguments 47 | args = parse_args() 48 | 49 | print("=" * 60) 50 | print("TransformerMPC Boxplot Demo".center(60)) 51 | print("=" * 60) 52 | 53 | # Create output directory 54 | output_dir = Path(args.output_dir) 55 | output_dir.mkdir(parents=True, exist_ok=True) 56 | 57 | # Set up directory for results 58 | results_dir = output_dir / "results" 59 | results_dir.mkdir(exist_ok=True) 60 | 61 | # Step 1: Generate QP problems 62 | print("\nStep 1: Generating QP problems") 63 | print("-" * 60) 64 | 65 | print(f"Generating {args.num_samples} QP problems") 66 | generator = QPGenerator( 67 | state_dim=args.state_dim, 68 | input_dim=args.input_dim, 69 | horizon=args.horizon, 70 | num_samples=args.num_samples 71 | ) 72 | qp_problems = generator.generate() 73 | print(f"Generated {len(qp_problems)} QP problems") 74 | 75 | # Step 2: Performance testing with different strategies 76 | print("\nStep 2: Performance testing") 77 | print("-" * 60) 78 | 79 | solver = OSQPSolver() 80 | 81 | # Lists to store results for different strategies 82 | baseline_times = [] # Standard OSQP 83 | reduced_constraint_times = [] # 50% random constraints removed 84 | warm_start_random_times = [] # Warm start with random initial point 85 | warm_start_perturbed_times = [] # Warm start with slightly perturbed solution 86 | combined_strategy_times = [] # Both constraints reduced and warm start 87 | 88 | print(f"Testing {len(qp_problems)} problems...") 89 | for i, problem in enumerate(tqdm(qp_problems)): 90 | # 1. Baseline (Standard OSQP) 91 | _, baseline_time = solver.solve_with_time( 92 | Q=problem.Q, 93 | c=problem.c, 94 | A=problem.A, 95 | b=problem.b 96 | ) 97 | baseline_times.append(baseline_time) 98 | 99 | # Get the solution for perturbed warm start 100 | solution = solver.solve( 101 | Q=problem.Q, 102 | c=problem.c, 103 | A=problem.A, 104 | b=problem.b 105 | ) 106 | 107 | # 2. Reduced constraints (randomly remove 50% of constraints) 108 | num_constraints = problem.A.shape[0] 109 | mask = np.random.choice([True, False], size=num_constraints, p=[0.5, 0.5]) 110 | 111 | A_reduced = problem.A[mask] 112 | b_reduced = problem.b[mask] 113 | 114 | _, reduced_time = solver.solve_with_time( 115 | Q=problem.Q, 116 | c=problem.c, 117 | A=A_reduced, 118 | b=b_reduced 119 | ) 120 | reduced_constraint_times.append(reduced_time) 121 | 122 | # 3. Warm start with random initial point 123 | warm_start_random = np.random.randn(problem.Q.shape[0]) 124 | _, warm_start_random_time = solver.solve_with_time( 125 | Q=problem.Q, 126 | c=problem.c, 127 | A=problem.A, 128 | b=problem.b, 129 | warm_start=warm_start_random 130 | ) 131 | warm_start_random_times.append(warm_start_random_time) 132 | 133 | # 4. Warm start with slightly perturbed solution (simulate good prediction) 134 | perturbation = np.random.randn(solution.shape[0]) * 0.1 # Small perturbation 135 | warm_start_perturbed = solution + perturbation 136 | 137 | _, warm_start_perturbed_time = solver.solve_with_time( 138 | Q=problem.Q, 139 | c=problem.c, 140 | A=problem.A, 141 | b=problem.b, 142 | warm_start=warm_start_perturbed 143 | ) 144 | warm_start_perturbed_times.append(warm_start_perturbed_time) 145 | 146 | # 5. Combined strategy (reduced constraints + warm start) 147 | _, combined_time = solver.solve_with_time( 148 | Q=problem.Q, 149 | c=problem.c, 150 | A=A_reduced, 151 | b=b_reduced, 152 | warm_start=warm_start_perturbed 153 | ) 154 | combined_strategy_times.append(combined_time) 155 | 156 | # Convert to numpy arrays 157 | baseline_times = np.array(baseline_times) 158 | reduced_constraint_times = np.array(reduced_constraint_times) 159 | warm_start_random_times = np.array(warm_start_random_times) 160 | warm_start_perturbed_times = np.array(warm_start_perturbed_times) 161 | combined_strategy_times = np.array(combined_strategy_times) 162 | 163 | # Compute and print metrics 164 | print("\nPerformance Results:") 165 | print("-" * 60) 166 | print(f"Mean baseline time: {np.mean(baseline_times):.6f}s") 167 | print(f"Mean reduced constraints time: {np.mean(reduced_constraint_times):.6f}s") 168 | print(f"Mean warm start (random) time: {np.mean(warm_start_random_times):.6f}s") 169 | print(f"Mean warm start (perturbed) time: {np.mean(warm_start_perturbed_times):.6f}s") 170 | print(f"Mean combined strategy time: {np.mean(combined_strategy_times):.6f}s") 171 | 172 | # Calculate speedups 173 | print("\nSpeedup Factors:") 174 | print(f"Reduced constraints: {np.mean(baseline_times) / np.mean(reduced_constraint_times):.2f}x") 175 | print(f"Warm start (random): {np.mean(baseline_times) / np.mean(warm_start_random_times):.2f}x") 176 | print(f"Warm start (perturbed): {np.mean(baseline_times) / np.mean(warm_start_perturbed_times):.2f}x") 177 | print(f"Combined strategy: {np.mean(baseline_times) / np.mean(combined_strategy_times):.2f}x") 178 | 179 | # Step 3: Generate visualizations 180 | print("\nStep 3: Generating visualizations") 181 | print("-" * 60) 182 | 183 | # Box plot of solve times 184 | print("Generating solve time boxplot...") 185 | plt.figure(figsize=(12, 8)) 186 | box_data = [ 187 | baseline_times, 188 | reduced_constraint_times, 189 | warm_start_random_times, 190 | warm_start_perturbed_times, 191 | combined_strategy_times 192 | ] 193 | 194 | box_labels = [ 195 | 'Baseline', 196 | 'Reduced\nConstraints', 197 | 'Warm Start\n(Random)', 198 | 'Warm Start\n(Perturbed)', 199 | 'Combined\nStrategy' 200 | ] 201 | 202 | box_plot = plt.boxplot( 203 | box_data, 204 | labels=box_labels, 205 | patch_artist=True, 206 | showmeans=True 207 | ) 208 | 209 | # Add colors to boxes 210 | colors = ['lightblue', 'lightgreen', 'lightpink', 'lightyellow', 'lightcyan'] 211 | for patch, color in zip(box_plot['boxes'], colors): 212 | patch.set_facecolor(color) 213 | 214 | plt.ylabel('Solve Time (s)') 215 | plt.title('QP Solve Time Comparison') 216 | plt.grid(True, axis='y', alpha=0.3) 217 | 218 | # Add mean value annotations 219 | means = [np.mean(data) for data in box_data] 220 | for i, mean in enumerate(means): 221 | plt.text(i+1, mean, f'{mean:.6f}s', 222 | horizontalalignment='center', 223 | verticalalignment='bottom', 224 | fontweight='bold') 225 | 226 | plt.tight_layout() 227 | plt.savefig(results_dir / "solve_time_boxplot.png", dpi=300) 228 | 229 | # Create speedup bar chart 230 | print("Generating speedup bar chart...") 231 | plt.figure(figsize=(10, 6)) 232 | 233 | speedups = [ 234 | np.mean(baseline_times) / np.mean(reduced_constraint_times), 235 | np.mean(baseline_times) / np.mean(warm_start_random_times), 236 | np.mean(baseline_times) / np.mean(warm_start_perturbed_times), 237 | np.mean(baseline_times) / np.mean(combined_strategy_times) 238 | ] 239 | 240 | speedup_labels = [ 241 | 'Reduced\nConstraints', 242 | 'Warm Start\n(Random)', 243 | 'Warm Start\n(Perturbed)', 244 | 'Combined\nStrategy' 245 | ] 246 | 247 | bars = plt.bar(speedup_labels, speedups, color=['lightgreen', 'lightpink', 'lightyellow', 'lightcyan']) 248 | 249 | plt.ylabel('Speedup Factor (×)') 250 | plt.title('Speedup Relative to Baseline') 251 | plt.grid(True, axis='y', alpha=0.3) 252 | 253 | # Add value annotations 254 | for bar, speedup in zip(bars, speedups): 255 | height = bar.get_height() 256 | plt.text(bar.get_x() + bar.get_width()/2., height + 0.1, 257 | f'{speedup:.2f}×', 258 | ha='center', va='bottom', fontweight='bold') 259 | 260 | plt.tight_layout() 261 | plt.savefig(results_dir / "speedup_barchart.png", dpi=300) 262 | 263 | # Violin plot of solve times 264 | print("Generating violin plot...") 265 | plt.figure(figsize=(12, 8)) 266 | 267 | violin_plot = plt.violinplot( 268 | box_data, 269 | showmeans=True, 270 | showextrema=True 271 | ) 272 | 273 | plt.xticks(range(1, 6), box_labels) 274 | plt.ylabel('Solve Time (s)') 275 | plt.title('QP Solve Time Distribution') 276 | plt.grid(True, axis='y', alpha=0.3) 277 | 278 | plt.tight_layout() 279 | plt.savefig(results_dir / "solve_time_violinplot.png", dpi=300) 280 | 281 | print(f"\nResults and visualizations saved to {output_dir}/results") 282 | print("\nDemo completed successfully!") 283 | print("=" * 60) 284 | 285 | if __name__ == "__main__": 286 | main() -------------------------------------------------------------------------------- /transformermpc/demo/demo.py: -------------------------------------------------------------------------------- 1 | """ 2 | Demo script for TransformerMPC. 3 | 4 | This script demonstrates the complete workflow of: 5 | 1. Generating QP problems 6 | 2. Training the constraint predictor and warm start predictor models 7 | 3. Testing the models and comparing performance against baseline 8 | 4. Visualizing the results 9 | """ 10 | 11 | import os 12 | import torch 13 | import numpy as np 14 | import time 15 | from tqdm import tqdm 16 | import matplotlib.pyplot as plt 17 | import argparse 18 | from pathlib import Path 19 | 20 | from ..data.qp_generator import QPGenerator 21 | from ..data.dataset import QPDataset 22 | from ..models.constraint_predictor import ConstraintPredictor 23 | from ..models.warm_start_predictor import WarmStartPredictor 24 | from ..training.trainer import ModelTrainer 25 | from ..utils.osqp_wrapper import OSQPSolver 26 | from ..utils.metrics import compute_solve_time_metrics, compute_fallback_rate 27 | from ..utils.visualization import ( 28 | plot_solve_time_comparison, 29 | plot_solve_time_boxplot, 30 | plot_active_constraints_histogram, 31 | plot_fallback_statistics 32 | ) 33 | 34 | 35 | def parse_args(): 36 | """Parse command line arguments.""" 37 | parser = argparse.ArgumentParser(description="TransformerMPC Demo") 38 | 39 | parser.add_argument( 40 | "--num_samples", type=int, default=20000, 41 | help="Number of QP problems to generate (default: 20000)" 42 | ) 43 | parser.add_argument( 44 | "--state_dim", type=int, default=4, 45 | help="State dimension for MPC problems (default: 4)" 46 | ) 47 | parser.add_argument( 48 | "--input_dim", type=int, default=2, 49 | help="Input dimension for MPC problems (default: 2)" 50 | ) 51 | parser.add_argument( 52 | "--horizon", type=int, default=10, 53 | help="Time horizon for MPC problems (default: 10)" 54 | ) 55 | parser.add_argument( 56 | "--num_epochs", type=int, default=2000, 57 | help="Number of training epochs (default: 2000)" 58 | ) 59 | parser.add_argument( 60 | "--batch_size", type=int, default=64, 61 | help="Batch size for training (default: 64)" 62 | ) 63 | parser.add_argument( 64 | "--test_size", type=float, default=0.2, 65 | help="Fraction of data to use for testing (default: 0.2)" 66 | ) 67 | parser.add_argument( 68 | "--output_dir", type=str, default="transformermpc_results", 69 | help="Directory to save results (default: transformermpc_results)" 70 | ) 71 | parser.add_argument( 72 | "--skip_training", action="store_true", 73 | help="Skip training and use pretrained models if available" 74 | ) 75 | parser.add_argument( 76 | "--cpu", action="store_true", 77 | help="Force using CPU even if GPU is available" 78 | ) 79 | 80 | return parser.parse_args() 81 | 82 | 83 | def main(): 84 | """Run the complete demo workflow.""" 85 | # Parse arguments 86 | args = parse_args() 87 | 88 | # Create output directory 89 | output_dir = Path(args.output_dir) 90 | output_dir.mkdir(parents=True, exist_ok=True) 91 | 92 | # Set up directories 93 | data_dir = output_dir / "data" 94 | models_dir = output_dir / "models" 95 | logs_dir = output_dir / "logs" 96 | results_dir = output_dir / "results" 97 | 98 | data_dir.mkdir(exist_ok=True) 99 | models_dir.mkdir(exist_ok=True) 100 | logs_dir.mkdir(exist_ok=True) 101 | results_dir.mkdir(exist_ok=True) 102 | 103 | # Set device 104 | device = torch.device('cuda' if torch.cuda.is_available() and not args.cpu else 'cpu') 105 | print(f"Using device: {device}") 106 | 107 | # 1. Generate QP problems or load from cache 108 | qp_problems_file = data_dir / "qp_problems.npy" 109 | 110 | if qp_problems_file.exists(): 111 | print(f"Loading QP problems from {qp_problems_file}") 112 | qp_problems = QPGenerator.load(qp_problems_file) 113 | print(f"Loaded {len(qp_problems)} QP problems") 114 | else: 115 | print(f"Generating {args.num_samples} QP problems") 116 | generator = QPGenerator( 117 | state_dim=args.state_dim, 118 | input_dim=args.input_dim, 119 | horizon=args.horizon, 120 | num_samples=args.num_samples 121 | ) 122 | qp_problems = generator.generate() 123 | 124 | # Save problems for future use 125 | generator.save(qp_problems, qp_problems_file) 126 | print(f"Saved QP problems to {qp_problems_file}") 127 | 128 | # 2. Create dataset and split into train/test 129 | print("Creating dataset") 130 | dataset = QPDataset( 131 | qp_problems=qp_problems, 132 | precompute_solutions=True, 133 | feature_normalization=True, 134 | cache_dir=data_dir 135 | ) 136 | 137 | train_dataset, test_dataset = dataset.split(test_size=args.test_size) 138 | print(f"Created datasets - Train: {len(train_dataset)}, Test: {len(test_dataset)}") 139 | 140 | # 3. Train or load constraint predictor 141 | cp_model_file = models_dir / "constraint_predictor.pt" 142 | 143 | if args.skip_training and cp_model_file.exists(): 144 | print(f"Loading constraint predictor from {cp_model_file}") 145 | cp_model = ConstraintPredictor.load(cp_model_file) 146 | else: 147 | print("Training constraint predictor") 148 | # Get input dimension from the dataset 149 | sample_item = train_dataset[0] 150 | input_dim = sample_item['features'].shape[0] 151 | num_constraints = sample_item['active_constraints'].shape[0] 152 | 153 | cp_model = ConstraintPredictor( 154 | input_dim=input_dim, 155 | hidden_dim=128, 156 | num_constraints=num_constraints 157 | ) 158 | 159 | cp_trainer = ModelTrainer( 160 | model=cp_model, 161 | train_dataset=train_dataset, 162 | val_dataset=test_dataset, 163 | batch_size=args.batch_size, 164 | learning_rate=1e-4, 165 | num_epochs=args.num_epochs, 166 | checkpoint_dir=models_dir, 167 | device=device 168 | ) 169 | 170 | cp_history = cp_trainer.train(log_dir=logs_dir / "constraint_predictor") 171 | 172 | # Save model 173 | cp_model.save(cp_model_file) 174 | print(f"Saved constraint predictor to {cp_model_file}") 175 | 176 | # 4. Train or load warm start predictor 177 | ws_model_file = models_dir / "warm_start_predictor.pt" 178 | 179 | if args.skip_training and ws_model_file.exists(): 180 | print(f"Loading warm start predictor from {ws_model_file}") 181 | ws_model = WarmStartPredictor.load(ws_model_file) 182 | else: 183 | print("Training warm start predictor") 184 | # Get input dimension from the dataset 185 | sample_item = train_dataset[0] 186 | input_dim = sample_item['features'].shape[0] 187 | output_dim = sample_item['solution'].shape[0] 188 | 189 | ws_model = WarmStartPredictor( 190 | input_dim=input_dim, 191 | hidden_dim=256, 192 | output_dim=output_dim 193 | ) 194 | 195 | ws_trainer = ModelTrainer( 196 | model=ws_model, 197 | train_dataset=train_dataset, 198 | val_dataset=test_dataset, 199 | batch_size=args.batch_size, 200 | learning_rate=1e-4, 201 | num_epochs=args.num_epochs, 202 | checkpoint_dir=models_dir, 203 | device=device 204 | ) 205 | 206 | ws_history = ws_trainer.train(log_dir=logs_dir / "warm_start_predictor") 207 | 208 | # Save model 209 | ws_model.save(ws_model_file) 210 | print(f"Saved warm start predictor to {ws_model_file}") 211 | 212 | # 5. Benchmark against baseline 213 | print("Benchmarking against baseline") 214 | solver = OSQPSolver() 215 | 216 | # List to store results 217 | baseline_times = [] 218 | warmstart_only_times = [] 219 | constraint_only_times = [] 220 | transformer_times = [] 221 | fallback_flags = [] 222 | 223 | # Test on a subset (100 problems) for visualization 224 | test_subset = np.random.choice(len(test_dataset), size=100, replace=False) 225 | 226 | print("Testing on 100 problems from the test set") 227 | for idx in tqdm(test_subset): 228 | # Get problem 229 | sample = test_dataset[idx] 230 | problem = test_dataset.get_problem(idx) 231 | 232 | # Get features 233 | features = sample['features'] 234 | 235 | # Get ground truth 236 | true_active = sample['active_constraints'].numpy() 237 | true_solution = sample['solution'].numpy() 238 | 239 | # Predict active constraints 240 | pred_active = cp_model.predict(features)[0] 241 | 242 | # Predict warm start 243 | pred_solution = ws_model.predict(features)[0] 244 | 245 | # Baseline (OSQP without transformers) 246 | _, baseline_time = solver.solve_with_time( 247 | Q=problem.Q, 248 | c=problem.c, 249 | A=problem.A, 250 | b=problem.b 251 | ) 252 | baseline_times.append(baseline_time) 253 | 254 | # Warm start only 255 | _, warmstart_time = solver.solve_with_time( 256 | Q=problem.Q, 257 | c=problem.c, 258 | A=problem.A, 259 | b=problem.b, 260 | warm_start=pred_solution 261 | ) 262 | warmstart_only_times.append(warmstart_time) 263 | 264 | # Constraint only 265 | _, is_feasible, constraint_time = solver.solve_reduced_with_time( 266 | Q=problem.Q, 267 | c=problem.c, 268 | A=problem.A, 269 | b=problem.b, 270 | active_constraints=pred_active 271 | ) 272 | constraint_only_times.append(constraint_time) 273 | 274 | # Full transformer pipeline 275 | _, transformer_time, used_fallback = solver.solve_pipeline( 276 | Q=problem.Q, 277 | c=problem.c, 278 | A=problem.A, 279 | b=problem.b, 280 | active_constraints=pred_active, 281 | warm_start=pred_solution, 282 | fallback_on_violation=True 283 | ) 284 | transformer_times.append(transformer_time) 285 | fallback_flags.append(used_fallback) 286 | 287 | # Convert to numpy arrays 288 | baseline_times = np.array(baseline_times) 289 | warmstart_only_times = np.array(warmstart_only_times) 290 | constraint_only_times = np.array(constraint_only_times) 291 | transformer_times = np.array(transformer_times) 292 | 293 | # 6. Compute and print metrics 294 | solve_metrics = compute_solve_time_metrics(baseline_times, transformer_times) 295 | fallback_rate = compute_fallback_rate(fallback_flags) 296 | 297 | print("\nSolve Time Metrics:") 298 | print(f"Mean baseline time: {solve_metrics['mean_baseline_time']:.6f}s") 299 | print(f"Mean transformer time: {solve_metrics['mean_transformer_time']:.6f}s") 300 | print(f"Mean speedup: {solve_metrics['mean_speedup']:.2f}x") 301 | print(f"Median speedup: {solve_metrics['median_speedup']:.2f}x") 302 | print(f"Fallback rate: {fallback_rate:.2f}%") 303 | 304 | # 7. Plot results 305 | print("\nPlotting results") 306 | 307 | # Plot solve time comparison 308 | plot_solve_time_comparison( 309 | baseline_times=baseline_times, 310 | transformer_times=transformer_times, 311 | save_path=results_dir / "solve_time_comparison.png" 312 | ) 313 | 314 | # Plot boxplot 315 | plot_solve_time_boxplot( 316 | baseline_times=baseline_times, 317 | transformer_times=transformer_times, 318 | constraint_only_times=constraint_only_times, 319 | warmstart_only_times=warmstart_only_times, 320 | save_path=results_dir / "solve_time_boxplot.png" 321 | ) 322 | 323 | # Plot fallback statistics 324 | fallback_rates = { 325 | "Transformer-MPC": fallback_rate 326 | } 327 | plot_fallback_statistics( 328 | fallback_rates=fallback_rates, 329 | save_path=results_dir / "fallback_rates.png" 330 | ) 331 | 332 | print(f"\nResults saved to {results_dir}") 333 | 334 | # Save main result plot for README 335 | plot_solve_time_comparison( 336 | baseline_times=baseline_times, 337 | transformer_times=transformer_times, 338 | save_path=output_dir.parent / "benchmarking_results.png" 339 | ) 340 | print(f"Main benchmarking result saved to {output_dir.parent / 'benchmarking_results.png'}") 341 | 342 | 343 | if __name__ == "__main__": 344 | main() 345 | -------------------------------------------------------------------------------- /transformermpc/data/qp_generator.py: -------------------------------------------------------------------------------- 1 | """ 2 | QP problem generator module. 3 | 4 | This module provides classes for generating and representing QP problems 5 | for training and testing the transformer models. 6 | """ 7 | 8 | import numpy as np 9 | import pandas as pd 10 | import os 11 | import time 12 | from dataclasses import dataclass 13 | from typing import List, Tuple, Dict, Optional, Union 14 | 15 | 16 | @dataclass 17 | class QPProblem: 18 | """ 19 | A class representing a Quadratic Programming problem. 20 | 21 | A QP problem is defined as: 22 | minimize 0.5 * x^T Q x + c^T x 23 | subject to A x <= b 24 | 25 | Attributes: 26 | ----------- 27 | Q : numpy.ndarray 28 | Quadratic cost matrix (n x n) 29 | c : numpy.ndarray 30 | Linear cost vector (n) 31 | A : numpy.ndarray 32 | Constraint matrix (m x n) 33 | b : numpy.ndarray 34 | Constraint vector (m) 35 | initial_state : numpy.ndarray, optional 36 | Initial state for MPC problems 37 | reference : numpy.ndarray, optional 38 | Reference trajectory for MPC problems 39 | metadata : dict, optional 40 | Additional problem-specific information 41 | """ 42 | Q: np.ndarray 43 | c: np.ndarray 44 | A: np.ndarray 45 | b: np.ndarray 46 | initial_state: Optional[np.ndarray] = None 47 | reference: Optional[np.ndarray] = None 48 | metadata: Optional[Dict] = None 49 | 50 | @property 51 | def n_variables(self) -> int: 52 | """Number of decision variables.""" 53 | return self.Q.shape[0] 54 | 55 | @property 56 | def n_constraints(self) -> int: 57 | """Number of constraints.""" 58 | return self.A.shape[0] 59 | 60 | def to_dict(self) -> Dict: 61 | """Convert to dictionary representation.""" 62 | return { 63 | "Q": self.Q, 64 | "c": self.c, 65 | "A": self.A, 66 | "b": self.b, 67 | "initial_state": self.initial_state, 68 | "reference": self.reference, 69 | "metadata": self.metadata 70 | } 71 | 72 | @classmethod 73 | def from_dict(cls, data: Dict) -> 'QPProblem': 74 | """Create QPProblem from dictionary.""" 75 | return cls(**data) 76 | 77 | 78 | class QPGenerator: 79 | """ 80 | Generator class for creating QP problem instances. 81 | 82 | This class generates synthetic QP problems for training and testing 83 | the transformer models. 84 | """ 85 | 86 | def __init__(self, 87 | state_dim: int = 4, 88 | input_dim: int = 2, 89 | horizon: int = 10, 90 | num_samples: int = 20000, 91 | seed: Optional[int] = None): 92 | """ 93 | Initialize the QP generator. 94 | 95 | Parameters: 96 | ----------- 97 | state_dim : int 98 | Dimension of state space for MPC problems 99 | input_dim : int 100 | Dimension of input space for MPC problems 101 | horizon : int 102 | Time horizon for MPC problems 103 | num_samples : int 104 | Number of QP problems to generate 105 | seed : int or None 106 | Random seed for reproducibility 107 | """ 108 | self.state_dim = state_dim 109 | self.input_dim = input_dim 110 | self.horizon = horizon 111 | self.num_samples = num_samples 112 | 113 | # Set random seed if provided 114 | if seed is not None: 115 | np.random.seed(seed) 116 | 117 | def generate(self) -> List[QPProblem]: 118 | """ 119 | Generate QP problems. 120 | 121 | Returns: 122 | -------- 123 | problems : List[QPProblem] 124 | List of generated QP problems 125 | """ 126 | problems = [] 127 | 128 | for i in range(self.num_samples): 129 | # Generate random system matrices for an MPC problem 130 | A_dyn, B_dyn = self._generate_dynamics() 131 | 132 | # Generate cost matrices 133 | Q_state = self._generate_state_cost() 134 | R_input = self._generate_input_cost() 135 | 136 | # Generate constraints 137 | state_constraints = self._generate_state_constraints() 138 | input_constraints = self._generate_input_constraints() 139 | 140 | # Generate random initial state and reference 141 | initial_state = np.random.randn(self.state_dim) 142 | reference = np.random.randn(self.horizon * self.state_dim) 143 | 144 | # Create the QP matrices for the MPC problem 145 | Q, c, A, b = self._create_mpc_matrices( 146 | A_dyn, B_dyn, Q_state, R_input, 147 | state_constraints, input_constraints, 148 | initial_state, reference 149 | ) 150 | 151 | # Create the QPProblem instance 152 | problem = QPProblem( 153 | Q=Q, 154 | c=c, 155 | A=A, 156 | b=b, 157 | initial_state=initial_state, 158 | reference=reference, 159 | metadata={ 160 | "type": "mpc", 161 | "state_dim": self.state_dim, 162 | "input_dim": self.input_dim, 163 | "horizon": self.horizon, 164 | "A_dynamics": A_dyn, 165 | "B_dynamics": B_dyn 166 | } 167 | ) 168 | 169 | problems.append(problem) 170 | 171 | return problems 172 | 173 | def _generate_dynamics(self) -> Tuple[np.ndarray, np.ndarray]: 174 | """ 175 | Generate random discrete-time system dynamics matrices. 176 | 177 | Returns: 178 | -------- 179 | A : numpy.ndarray 180 | State transition matrix (state_dim x state_dim) 181 | B : numpy.ndarray 182 | Input matrix (state_dim x input_dim) 183 | """ 184 | # Generate a random discrete-time system 185 | A = np.random.randn(self.state_dim, self.state_dim) 186 | # Scale to make it stable 187 | eigenvalues, _ = np.linalg.eig(A) 188 | max_eig = np.max(np.abs(eigenvalues)) 189 | A = A / (max_eig * 1.1) # Scale to ensure stability 190 | 191 | B = np.random.randn(self.state_dim, self.input_dim) 192 | 193 | return A, B 194 | 195 | def _generate_state_cost(self) -> np.ndarray: 196 | """ 197 | Generate state cost matrix. 198 | 199 | Returns: 200 | -------- 201 | Q : numpy.ndarray 202 | State cost matrix (state_dim x state_dim) 203 | """ 204 | Q_diag = np.abs(np.random.randn(self.state_dim)) 205 | Q = np.diag(Q_diag) 206 | return Q 207 | 208 | def _generate_input_cost(self) -> np.ndarray: 209 | """ 210 | Generate input cost matrix. 211 | 212 | Returns: 213 | -------- 214 | R : numpy.ndarray 215 | Input cost matrix (input_dim x input_dim) 216 | """ 217 | R_diag = np.abs(np.random.randn(self.input_dim)) 218 | R = np.diag(R_diag) 219 | return R 220 | 221 | def _generate_state_constraints(self) -> Tuple[np.ndarray, np.ndarray]: 222 | """ 223 | Generate state constraints. 224 | 225 | Returns: 226 | -------- 227 | A_state : numpy.ndarray 228 | State constraint matrix (2*state_dim x state_dim) 229 | b_state : numpy.ndarray 230 | State constraint vector (2*state_dim) 231 | """ 232 | # Simple box constraints on states 233 | A_state = np.vstack([np.eye(self.state_dim), -np.eye(self.state_dim)]) 234 | 235 | # Random upper and lower bounds 236 | upper_bounds = np.abs(np.random.rand(self.state_dim) * 5 + 5) # Random bounds between 5 and 10 237 | lower_bounds = np.abs(np.random.rand(self.state_dim) * 5 + 5) # Random bounds between 5 and 10 238 | 239 | b_state = np.concatenate([upper_bounds, lower_bounds]) 240 | 241 | return A_state, b_state 242 | 243 | def _generate_input_constraints(self) -> Tuple[np.ndarray, np.ndarray]: 244 | """ 245 | Generate input constraints. 246 | 247 | Returns: 248 | -------- 249 | A_input : numpy.ndarray 250 | Input constraint matrix (2*input_dim x input_dim) 251 | b_input : numpy.ndarray 252 | Input constraint vector (2*input_dim) 253 | """ 254 | # Simple box constraints on inputs 255 | A_input = np.vstack([np.eye(self.input_dim), -np.eye(self.input_dim)]) 256 | 257 | # Random upper and lower bounds 258 | upper_bounds = np.abs(np.random.rand(self.input_dim) * 2 + 1) # Random bounds between 1 and 3 259 | lower_bounds = np.abs(np.random.rand(self.input_dim) * 2 + 1) # Random bounds between 1 and 3 260 | 261 | b_input = np.concatenate([upper_bounds, lower_bounds]) 262 | 263 | return A_input, b_input 264 | 265 | def _create_mpc_matrices(self, 266 | A_dyn: np.ndarray, 267 | B_dyn: np.ndarray, 268 | Q_state: np.ndarray, 269 | R_input: np.ndarray, 270 | state_constraints: Tuple[np.ndarray, np.ndarray], 271 | input_constraints: Tuple[np.ndarray, np.ndarray], 272 | initial_state: np.ndarray, 273 | reference: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: 274 | """ 275 | Create QP matrices for an MPC problem. 276 | 277 | Parameters: 278 | ----------- 279 | A_dyn : numpy.ndarray 280 | State transition matrix 281 | B_dyn : numpy.ndarray 282 | Input matrix 283 | Q_state : numpy.ndarray 284 | State cost matrix 285 | R_input : numpy.ndarray 286 | Input cost matrix 287 | state_constraints : Tuple[numpy.ndarray, numpy.ndarray] 288 | State constraint matrices (A_state, b_state) 289 | input_constraints : Tuple[numpy.ndarray, numpy.ndarray] 290 | Input constraint matrices (A_input, b_input) 291 | initial_state : numpy.ndarray 292 | Initial state 293 | reference : numpy.ndarray 294 | Reference trajectory 295 | 296 | Returns: 297 | -------- 298 | Q : numpy.ndarray 299 | QP cost matrix 300 | c : numpy.ndarray 301 | QP linear cost vector 302 | A : numpy.ndarray 303 | QP constraint matrix 304 | b : numpy.ndarray 305 | QP constraint vector 306 | """ 307 | # Extract constraint matrices 308 | A_state, b_state = state_constraints 309 | A_input, b_input = input_constraints 310 | 311 | # Problem dimensions 312 | nx = self.state_dim 313 | nu = self.input_dim 314 | N = self.horizon 315 | 316 | # Total number of decision variables 317 | n_vars = N * nu 318 | 319 | # Construct the cost matrices 320 | Q_block = np.zeros((n_vars, n_vars)) 321 | for i in range(N): 322 | idx = i * nu 323 | Q_block[idx:idx+nu, idx:idx+nu] = R_input 324 | 325 | # Compute the prediction matrices 326 | x_pred = [initial_state] 327 | for k in range(N): 328 | x_next = A_dyn @ x_pred[-1] 329 | x_pred.append(x_next) 330 | 331 | # Construct the linear cost term 332 | c = np.zeros(n_vars) 333 | 334 | # Construct constraint matrices 335 | # Initial state constraint is already handled in prediction 336 | 337 | # Input constraints for each time step 338 | A_in_list = [] 339 | b_in_list = [] 340 | 341 | for i in range(N): 342 | # Input constraints: A_input * u <= b_input 343 | A_i = np.zeros((A_input.shape[0], n_vars)) 344 | idx = i * nu 345 | A_i[:, idx:idx+nu] = A_input 346 | A_in_list.append(A_i) 347 | b_in_list.append(b_input) 348 | 349 | # Combine all constraints 350 | A = np.vstack(A_in_list) if A_in_list else np.zeros((0, n_vars)) 351 | b = np.concatenate(b_in_list) if b_in_list else np.zeros(0) 352 | 353 | return Q_block, c, A, b 354 | 355 | def save(self, problems: List[QPProblem], filepath: str) -> None: 356 | """ 357 | Save generated problems to a file. 358 | 359 | Parameters: 360 | ----------- 361 | problems : List[QPProblem] 362 | List of QP problems to save 363 | filepath : str 364 | Path to save the problems 365 | """ 366 | # Convert problems to dictionaries 367 | data = [problem.to_dict() for problem in problems] 368 | 369 | # Save using numpy 370 | np.save(filepath, data, allow_pickle=True) 371 | 372 | @staticmethod 373 | def load(filepath: str) -> List[QPProblem]: 374 | """ 375 | Load problems from a file. 376 | 377 | Parameters: 378 | ----------- 379 | filepath : str 380 | Path to load the problems from 381 | 382 | Returns: 383 | -------- 384 | problems : List[QPProblem] 385 | List of loaded QP problems 386 | """ 387 | # Load data from file 388 | data = np.load(filepath, allow_pickle=True) 389 | 390 | # Convert dictionaries to QPProblem instances 391 | problems = [QPProblem.from_dict(item) for item in data] 392 | 393 | return problems 394 | -------------------------------------------------------------------------------- /transformermpc/models/constraint_predictor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Constraint predictor model. 3 | 4 | This module defines the transformer-based model for predicting 5 | which constraints are active in QP problems. 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import numpy as np 12 | import os 13 | from typing import Dict, Optional, Tuple, Union, List, Any 14 | 15 | 16 | class TransformerEncoder(nn.Module): 17 | """Vanilla Transformer Encoder implementation.""" 18 | 19 | def __init__(self, 20 | hidden_dim: int = 128, 21 | num_layers: int = 4, 22 | num_heads: int = 8, 23 | dropout: float = 0.1): 24 | """ 25 | Initialize the transformer encoder. 26 | 27 | Parameters: 28 | ----------- 29 | hidden_dim : int 30 | Dimension of hidden layers 31 | num_layers : int 32 | Number of transformer layers 33 | num_heads : int 34 | Number of attention heads 35 | dropout : float 36 | Dropout probability 37 | """ 38 | super().__init__() 39 | 40 | # Make sure hidden_dim is divisible by num_heads 41 | assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads" 42 | 43 | # Create encoder layer 44 | encoder_layer = nn.TransformerEncoderLayer( 45 | d_model=hidden_dim, 46 | nhead=num_heads, 47 | dim_feedforward=4 * hidden_dim, 48 | dropout=dropout, 49 | activation="relu", 50 | batch_first=True 51 | ) 52 | 53 | # Create encoder 54 | self.encoder = nn.TransformerEncoder( 55 | encoder_layer=encoder_layer, 56 | num_layers=num_layers 57 | ) 58 | 59 | # Positional encoding 60 | self.pos_encoding = PositionalEncoding( 61 | d_model=hidden_dim, 62 | dropout=dropout, 63 | max_len=100 64 | ) 65 | 66 | def forward(self, x: torch.Tensor) -> torch.Tensor: 67 | """ 68 | Forward pass. 69 | 70 | Parameters: 71 | ----------- 72 | x : torch.Tensor 73 | Input tensor of shape (batch_size, seq_len, hidden_dim) 74 | 75 | Returns: 76 | -------- 77 | output : torch.Tensor 78 | Output tensor of shape (batch_size, seq_len, hidden_dim) 79 | """ 80 | # Add positional encoding 81 | x = self.pos_encoding(x) 82 | 83 | # Pass through encoder 84 | output = self.encoder(x) 85 | 86 | return output 87 | 88 | 89 | class PositionalEncoding(nn.Module): 90 | """Positional encoding for Transformer models.""" 91 | 92 | def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 100): 93 | """ 94 | Initialize the positional encoding. 95 | 96 | Parameters: 97 | ----------- 98 | d_model : int 99 | Dimension of the model 100 | dropout : float 101 | Dropout probability 102 | max_len : int 103 | Maximum sequence length 104 | """ 105 | super().__init__() 106 | self.dropout = nn.Dropout(p=dropout) 107 | 108 | # Create positional encoding 109 | position = torch.arange(max_len).unsqueeze(1) 110 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model)) 111 | pe = torch.zeros(1, max_len, d_model) 112 | pe[0, :, 0::2] = torch.sin(position * div_term) 113 | pe[0, :, 1::2] = torch.cos(position * div_term) 114 | self.register_buffer('pe', pe) 115 | 116 | def forward(self, x: torch.Tensor) -> torch.Tensor: 117 | """ 118 | Forward pass. 119 | 120 | Parameters: 121 | ----------- 122 | x : torch.Tensor 123 | Input tensor of shape (batch_size, seq_len, d_model) 124 | 125 | Returns: 126 | -------- 127 | output : torch.Tensor 128 | Output tensor with positional encoding added 129 | """ 130 | x = x + self.pe[:, :x.size(1), :] 131 | return self.dropout(x) 132 | 133 | 134 | class ConstraintPredictor(nn.Module): 135 | """ 136 | Transformer-based model for predicting active constraints. 137 | 138 | This model takes QP problem features as input and outputs a binary vector 139 | indicating which constraints are active. 140 | """ 141 | 142 | def __init__(self, 143 | input_dim: int = 50, 144 | hidden_dim: int = 128, 145 | num_constraints: int = 100, 146 | num_layers: int = 4, 147 | num_heads: int = 8, 148 | dropout: float = 0.1): 149 | """ 150 | Initialize the constraint predictor model. 151 | 152 | Parameters: 153 | ----------- 154 | input_dim : int 155 | Dimension of input features 156 | hidden_dim : int 157 | Dimension of hidden layers 158 | num_constraints : int 159 | Maximum number of constraints to predict 160 | num_layers : int 161 | Number of transformer layers 162 | num_heads : int 163 | Number of attention heads 164 | dropout : float 165 | Dropout probability 166 | """ 167 | super().__init__() 168 | 169 | self.input_dim = input_dim 170 | self.hidden_dim = hidden_dim 171 | self.num_constraints = num_constraints 172 | self.num_layers = num_layers 173 | self.num_heads = num_heads 174 | self.dropout = dropout 175 | 176 | # Input projection 177 | self.input_projection = nn.Linear(input_dim, hidden_dim) 178 | 179 | # Transformer encoder 180 | self.transformer = TransformerEncoder( 181 | hidden_dim=hidden_dim, 182 | num_layers=num_layers, 183 | num_heads=num_heads, 184 | dropout=dropout 185 | ) 186 | 187 | # Output projection 188 | self.output_projection = nn.Linear(hidden_dim, num_constraints) 189 | 190 | def forward(self, x: torch.Tensor) -> torch.Tensor: 191 | """ 192 | Forward pass. 193 | 194 | Parameters: 195 | ----------- 196 | x : torch.Tensor 197 | Input tensor of shape (batch_size, input_dim) 198 | 199 | Returns: 200 | -------- 201 | output : torch.Tensor 202 | Output tensor of shape (batch_size, num_constraints) 203 | containing probabilities of constraints being active 204 | """ 205 | # Input projection 206 | x = self.input_projection(x) 207 | 208 | # Reshape for transformer: [batch_size, seq_len, hidden_dim] 209 | # Here we use a sequence length of 1 210 | x = x.unsqueeze(1) 211 | 212 | # Pass through transformer 213 | transformer_output = self.transformer(x) 214 | 215 | # Extract first token output 216 | first_token = transformer_output[:, 0, :] 217 | 218 | # Output projection 219 | logits = self.output_projection(first_token) 220 | 221 | # Apply sigmoid to get probabilities 222 | probs = torch.sigmoid(logits) 223 | 224 | return probs 225 | 226 | def predict(self, x: Union[torch.Tensor, np.ndarray]) -> np.ndarray: 227 | """ 228 | Make prediction on input data. 229 | 230 | Parameters: 231 | ----------- 232 | x : torch.Tensor or numpy.ndarray 233 | Input features 234 | 235 | Returns: 236 | -------- 237 | prediction : numpy.ndarray 238 | Binary prediction of active constraints 239 | """ 240 | # Convert to tensor if numpy array 241 | if isinstance(x, np.ndarray): 242 | x = torch.tensor(x, dtype=torch.float32) 243 | 244 | # Make sure the input has the right shape 245 | if x.dim() == 1: 246 | x = x.unsqueeze(0) # Add batch dimension 247 | 248 | # Set model to evaluation mode 249 | self.eval() 250 | 251 | # Disable gradient computation 252 | with torch.no_grad(): 253 | # Forward pass 254 | probs = self(x) 255 | 256 | # Convert to binary predictions 257 | binary_pred = (probs > 0.5).float() 258 | 259 | # Convert to numpy 260 | binary_pred = binary_pred.cpu().numpy() 261 | 262 | return binary_pred 263 | 264 | def save(self, filepath: str) -> None: 265 | """ 266 | Save model to file. 267 | 268 | Parameters: 269 | ----------- 270 | filepath : str 271 | Path to save the model 272 | """ 273 | # Create directory if it doesn't exist 274 | os.makedirs(os.path.dirname(filepath), exist_ok=True) 275 | 276 | # Save model state 277 | torch.save({ 278 | 'state_dict': self.state_dict(), 279 | 'input_dim': self.input_dim, 280 | 'hidden_dim': self.hidden_dim, 281 | 'num_constraints': self.num_constraints, 282 | 'num_layers': self.num_layers, 283 | 'num_heads': self.num_heads, 284 | 'dropout': self.dropout 285 | }, filepath) 286 | 287 | @classmethod 288 | def load(cls, filepath: Optional[str] = None) -> 'ConstraintPredictor': 289 | """ 290 | Load model from file. 291 | 292 | Parameters: 293 | ----------- 294 | filepath : str or None 295 | Path to load the model from, or None to create a new model 296 | 297 | Returns: 298 | -------- 299 | model : ConstraintPredictor 300 | Loaded model 301 | """ 302 | if filepath is None or not os.path.exists(filepath): 303 | # Return a new model with default parameters 304 | return cls() 305 | 306 | # Load model state 307 | checkpoint = torch.load(filepath, map_location=torch.device('cpu')) 308 | 309 | # Create model with saved parameters 310 | model = cls( 311 | input_dim=checkpoint['input_dim'], 312 | hidden_dim=checkpoint['hidden_dim'], 313 | num_constraints=checkpoint['num_constraints'], 314 | num_layers=checkpoint.get('num_layers', 4), 315 | num_heads=checkpoint.get('num_heads', 8), 316 | dropout=checkpoint.get('dropout', 0.1) 317 | ) 318 | 319 | # Load state dictionary 320 | model.load_state_dict(checkpoint['state_dict']) 321 | 322 | return model 323 | 324 | def train_step(self, 325 | x: torch.Tensor, 326 | y: torch.Tensor, 327 | optimizer: torch.optim.Optimizer) -> Dict[str, float]: 328 | """ 329 | Perform a single training step. 330 | 331 | Parameters: 332 | ----------- 333 | x : torch.Tensor 334 | Input tensor of shape (batch_size, input_dim) 335 | y : torch.Tensor 336 | Target tensor of shape (batch_size, num_constraints) 337 | optimizer : torch.optim.Optimizer 338 | Optimizer to use for the step 339 | 340 | Returns: 341 | -------- 342 | metrics : Dict[str, float] 343 | Dictionary containing loss and accuracy 344 | """ 345 | # Set model to training mode 346 | self.train() 347 | 348 | # Zero the gradients 349 | optimizer.zero_grad() 350 | 351 | # Forward pass 352 | probs = self(x) 353 | 354 | # Compute loss (binary cross entropy) 355 | loss = F.binary_cross_entropy(probs, y) 356 | 357 | # Backward pass 358 | loss.backward() 359 | 360 | # Update parameters 361 | optimizer.step() 362 | 363 | # Compute accuracy 364 | binary_pred = (probs > 0.5).float() 365 | accuracy = (binary_pred == y).float().mean().item() 366 | 367 | return { 368 | 'loss': loss.item(), 369 | 'accuracy': accuracy 370 | } 371 | 372 | def validate(self, 373 | x: torch.Tensor, 374 | y: torch.Tensor) -> Dict[str, float]: 375 | """ 376 | Validate the model on validation data. 377 | 378 | Parameters: 379 | ----------- 380 | x : torch.Tensor 381 | Input tensor of shape (batch_size, input_dim) 382 | y : torch.Tensor 383 | Target tensor of shape (batch_size, num_constraints) 384 | 385 | Returns: 386 | -------- 387 | metrics : Dict[str, float] 388 | Dictionary containing loss and accuracy 389 | """ 390 | # Set model to evaluation mode 391 | self.eval() 392 | 393 | # Disable gradient computation 394 | with torch.no_grad(): 395 | # Forward pass 396 | probs = self(x) 397 | 398 | # Compute loss (binary cross entropy) 399 | loss = F.binary_cross_entropy(probs, y) 400 | 401 | # Compute accuracy 402 | binary_pred = (probs > 0.5).float() 403 | accuracy = (binary_pred == y).float().mean().item() 404 | 405 | # Compute precision, recall, and F1 score 406 | binary_pred_flat = binary_pred.flatten().cpu().numpy() 407 | y_flat = y.flatten().cpu().numpy() 408 | 409 | # Compute TP, FP, TN, FN 410 | tp = ((binary_pred_flat == 1) & (y_flat == 1)).sum() 411 | fp = ((binary_pred_flat == 1) & (y_flat == 0)).sum() 412 | fn = ((binary_pred_flat == 0) & (y_flat == 1)).sum() 413 | 414 | # Compute precision, recall, and F1 415 | precision = tp / (tp + fp) if tp + fp > 0 else 0.0 416 | recall = tp / (tp + fn) if tp + fn > 0 else 0.0 417 | f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0.0 418 | 419 | return { 420 | 'loss': loss.item(), 421 | 'accuracy': accuracy, 422 | 'precision': precision, 423 | 'recall': recall, 424 | 'f1': f1 425 | } 426 | -------------------------------------------------------------------------------- /transformermpc/utils/visualization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Visualization module for TransformerMPC. 3 | 4 | This module provides functions for visualizing training progress, 5 | evaluation metrics, and benchmarking results. 6 | """ 7 | 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | from typing import List, Dict, Tuple, Optional, Union, Any 11 | import os 12 | 13 | 14 | def plot_training_history(history: Dict[str, List[float]], 15 | save_path: Optional[str] = None, 16 | show: bool = True) -> None: 17 | """ 18 | Plot training history metrics. 19 | 20 | Parameters: 21 | ----------- 22 | history : Dict[str, List[float]] 23 | Dictionary containing training metrics 24 | save_path : str or None 25 | Path to save the plot 26 | show : bool 27 | Whether to display the plot 28 | """ 29 | fig, axes = plt.subplots(nrows=len(history), figsize=(10, 3*len(history))) 30 | 31 | # Handle case with only one metric 32 | if len(history) == 1: 33 | axes = [axes] 34 | 35 | for i, (metric_name, values) in enumerate(history.items()): 36 | axes[i].plot(values) 37 | axes[i].set_title(metric_name) 38 | axes[i].set_xlabel('Epoch') 39 | axes[i].set_ylabel(metric_name) 40 | axes[i].grid(True) 41 | 42 | plt.tight_layout() 43 | 44 | if save_path is not None: 45 | plt.savefig(save_path) 46 | 47 | if show: 48 | plt.show() 49 | else: 50 | plt.close() 51 | 52 | 53 | def plot_constraint_prediction_metrics(precision: List[float], 54 | recall: List[float], 55 | f1: List[float], 56 | accuracy: List[float], 57 | epochs: List[int], 58 | save_path: Optional[str] = None, 59 | show: bool = True) -> None: 60 | """ 61 | Plot constraint prediction model metrics. 62 | 63 | Parameters: 64 | ----------- 65 | precision, recall, f1, accuracy : List[float] 66 | Lists of metric values 67 | epochs : List[int] 68 | List of epoch numbers 69 | save_path : str or None 70 | Path to save the plot 71 | show : bool 72 | Whether to display the plot 73 | """ 74 | fig, ax = plt.subplots(figsize=(10, 6)) 75 | 76 | ax.plot(epochs, precision, 'b-', label='Precision') 77 | ax.plot(epochs, recall, 'g-', label='Recall') 78 | ax.plot(epochs, f1, 'r-', label='F1 Score') 79 | ax.plot(epochs, accuracy, 'k-', label='Accuracy') 80 | 81 | ax.set_xlabel('Epoch') 82 | ax.set_ylabel('Metric Value') 83 | ax.set_title('Constraint Prediction Performance Metrics') 84 | ax.legend() 85 | ax.grid(True) 86 | 87 | plt.tight_layout() 88 | 89 | if save_path is not None: 90 | plt.savefig(save_path) 91 | 92 | if show: 93 | plt.show() 94 | else: 95 | plt.close() 96 | 97 | 98 | def plot_warm_start_metrics(mse: List[float], 99 | mae: List[float], 100 | rel_error: List[float], 101 | epochs: List[int], 102 | save_path: Optional[str] = None, 103 | show: bool = True) -> None: 104 | """ 105 | Plot warm start prediction model metrics. 106 | 107 | Parameters: 108 | ----------- 109 | mse, mae, rel_error : List[float] 110 | Lists of metric values 111 | epochs : List[int] 112 | List of epoch numbers 113 | save_path : str or None 114 | Path to save the plot 115 | show : bool 116 | Whether to display the plot 117 | """ 118 | fig, ax = plt.subplots(figsize=(10, 6)) 119 | 120 | ax.plot(epochs, mse, 'b-', label='MSE') 121 | ax.plot(epochs, mae, 'g-', label='MAE') 122 | ax.plot(epochs, rel_error, 'r-', label='Relative Error') 123 | 124 | ax.set_xlabel('Epoch') 125 | ax.set_ylabel('Error') 126 | ax.set_title('Warm Start Prediction Performance Metrics') 127 | ax.legend() 128 | ax.grid(True) 129 | 130 | plt.tight_layout() 131 | 132 | if save_path is not None: 133 | plt.savefig(save_path) 134 | 135 | if show: 136 | plt.show() 137 | else: 138 | plt.close() 139 | 140 | 141 | def plot_solve_time_comparison(baseline_times: np.ndarray, 142 | transformer_times: np.ndarray, 143 | problem_sizes: Optional[np.ndarray] = None, 144 | save_path: Optional[str] = None, 145 | show: bool = True, 146 | log_scale: bool = False) -> None: 147 | """ 148 | Plot comparison of solve times between baseline and transformer-enhanced approach. 149 | 150 | Parameters: 151 | ----------- 152 | baseline_times : numpy.ndarray 153 | Array of solve times for baseline approach 154 | transformer_times : numpy.ndarray 155 | Array of solve times for transformer-enhanced approach 156 | problem_sizes : numpy.ndarray or None 157 | Array of problem sizes (e.g., number of variables or constraints) 158 | save_path : str or None 159 | Path to save the plot 160 | show : bool 161 | Whether to display the plot 162 | log_scale : bool 163 | Whether to use log scale for the y-axis 164 | """ 165 | fig, ax = plt.subplots(figsize=(10, 6)) 166 | 167 | if problem_sizes is not None: 168 | # Plot against problem sizes 169 | ax.plot(problem_sizes, baseline_times, 'bo-', label='OSQP Baseline') 170 | ax.plot(problem_sizes, transformer_times, 'ro-', label='TransformerMPC') 171 | ax.set_xlabel('Problem Size') 172 | else: 173 | # Plot histograms 174 | ax.hist(baseline_times, bins=30, alpha=0.5, label='OSQP Baseline') 175 | ax.hist(transformer_times, bins=30, alpha=0.5, label='TransformerMPC') 176 | ax.set_xlabel('Solve Time (seconds)') 177 | 178 | if log_scale: 179 | ax.set_yscale('log') 180 | 181 | ax.set_ylabel('Solve Time (seconds)' if problem_sizes is not None else 'Frequency') 182 | ax.set_title('Comparison of Solve Times: OSQP vs TransformerMPC') 183 | ax.legend() 184 | ax.grid(True) 185 | 186 | # Add speedup statistics 187 | mean_speedup = np.mean(baseline_times / transformer_times) 188 | median_speedup = np.median(baseline_times / transformer_times) 189 | max_speedup = np.max(baseline_times / transformer_times) 190 | 191 | stats_text = f"Mean Speedup: {mean_speedup:.2f}x\n" 192 | stats_text += f"Median Speedup: {median_speedup:.2f}x\n" 193 | stats_text += f"Max Speedup: {max_speedup:.2f}x" 194 | 195 | plt.figtext(0.02, 0.02, stats_text, fontsize=10, bbox=dict(facecolor='white', alpha=0.8)) 196 | 197 | plt.tight_layout() 198 | 199 | if save_path is not None: 200 | plt.savefig(save_path) 201 | 202 | if show: 203 | plt.show() 204 | else: 205 | plt.close() 206 | 207 | 208 | def plot_solve_time_boxplot(baseline_times: np.ndarray, 209 | transformer_times: np.ndarray, 210 | constraint_only_times: Optional[np.ndarray] = None, 211 | warmstart_only_times: Optional[np.ndarray] = None, 212 | save_path: Optional[str] = None, 213 | show: bool = True) -> None: 214 | """ 215 | Create boxplot comparison of solve times for different methods. 216 | 217 | Parameters: 218 | ----------- 219 | baseline_times : numpy.ndarray 220 | Array of solve times for baseline approach 221 | transformer_times : numpy.ndarray 222 | Array of solve times for full transformer-enhanced approach 223 | constraint_only_times : numpy.ndarray or None 224 | Array of solve times using only constraint prediction 225 | warmstart_only_times : numpy.ndarray or None 226 | Array of solve times using only warm start prediction 227 | save_path : str or None 228 | Path to save the plot 229 | show : bool 230 | Whether to display the plot 231 | """ 232 | fig, ax = plt.subplots(figsize=(10, 6)) 233 | 234 | data = [] 235 | labels = [] 236 | 237 | # Always include baseline and full transformer 238 | data.append(baseline_times) 239 | labels.append('OSQP Baseline') 240 | 241 | # Include constraint-only if provided 242 | if constraint_only_times is not None: 243 | data.append(constraint_only_times) 244 | labels.append('Constraint Prediction Only') 245 | 246 | # Include warmstart-only if provided 247 | if warmstart_only_times is not None: 248 | data.append(warmstart_only_times) 249 | labels.append('Warm Start Only') 250 | 251 | # Always include full transformer 252 | data.append(transformer_times) 253 | labels.append('Full TransformerMPC') 254 | 255 | # Create the boxplot 256 | ax.boxplot(data, labels=labels, showfliers=True) 257 | 258 | ax.set_ylabel('Solve Time (seconds)') 259 | ax.set_title('Comparison of Solve Times Across Methods') 260 | ax.grid(True, axis='y') 261 | 262 | plt.tight_layout() 263 | 264 | if save_path is not None: 265 | plt.savefig(save_path) 266 | 267 | if show: 268 | plt.show() 269 | else: 270 | plt.close() 271 | 272 | 273 | def plot_active_constraints_histogram(active_percentages: np.ndarray, 274 | save_path: Optional[str] = None, 275 | show: bool = True) -> None: 276 | """ 277 | Plot histogram of percentage of active constraints. 278 | 279 | Parameters: 280 | ----------- 281 | active_percentages : numpy.ndarray 282 | Array of percentages of active constraints for each problem 283 | save_path : str or None 284 | Path to save the plot 285 | show : bool 286 | Whether to display the plot 287 | """ 288 | fig, ax = plt.subplots(figsize=(10, 6)) 289 | 290 | ax.hist(active_percentages, bins=50, alpha=0.75) 291 | ax.axvline(np.mean(active_percentages), color='r', linestyle='--', 292 | label=f'Mean: {np.mean(active_percentages):.2f}%') 293 | ax.axvline(np.median(active_percentages), color='g', linestyle='--', 294 | label=f'Median: {np.median(active_percentages):.2f}%') 295 | 296 | ax.set_xlabel('Percentage of Active Constraints') 297 | ax.set_ylabel('Frequency') 298 | ax.set_title('Distribution of Active Constraints Percentage') 299 | ax.legend() 300 | ax.grid(True) 301 | 302 | plt.tight_layout() 303 | 304 | if save_path is not None: 305 | plt.savefig(save_path) 306 | 307 | if show: 308 | plt.show() 309 | else: 310 | plt.close() 311 | 312 | 313 | def plot_fallback_statistics(fallback_rates: Dict[str, float], 314 | save_path: Optional[str] = None, 315 | show: bool = True) -> None: 316 | """ 317 | Plot statistics about fallback to full QP solve. 318 | 319 | Parameters: 320 | ----------- 321 | fallback_rates : Dict[str, float] 322 | Dictionary of fallback rates for different methods 323 | save_path : str or None 324 | Path to save the plot 325 | show : bool 326 | Whether to display the plot 327 | """ 328 | fig, ax = plt.subplots(figsize=(10, 6)) 329 | 330 | methods = list(fallback_rates.keys()) 331 | rates = list(fallback_rates.values()) 332 | 333 | # Create bar plot 334 | ax.bar(methods, rates, color='skyblue') 335 | 336 | ax.set_xlabel('Method') 337 | ax.set_ylabel('Fallback Rate (%)') 338 | ax.set_title('Fallback Rates for Different Methods') 339 | ax.grid(True, axis='y') 340 | 341 | # Add value labels on bars 342 | for i, v in enumerate(rates): 343 | ax.text(i, v + 1, f"{v:.1f}%", ha='center') 344 | 345 | plt.tight_layout() 346 | 347 | if save_path is not None: 348 | plt.savefig(save_path) 349 | 350 | if show: 351 | plt.show() 352 | else: 353 | plt.close() 354 | 355 | 356 | def plot_benchmarking_results(results: Dict[str, Dict[str, Union[float, np.ndarray]]], 357 | save_path: Optional[str] = None, 358 | show: bool = True) -> None: 359 | """ 360 | Plot comprehensive benchmarking results. 361 | 362 | Parameters: 363 | ----------- 364 | results : Dict 365 | Dictionary containing benchmarking results 366 | save_path : str or None 367 | Path to save the plot 368 | show : bool 369 | Whether to display the plot 370 | """ 371 | fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(15, 12)) 372 | 373 | # Plot 1: Solve time comparison (boxplot) 374 | methods = list(results.keys()) 375 | solve_times = [results[method]['solve_times'] for method in methods] 376 | 377 | axes[0, 0].boxplot(solve_times, labels=methods) 378 | axes[0, 0].set_ylabel('Solve Time (seconds)') 379 | axes[0, 0].set_title('Solve Times by Method') 380 | axes[0, 0].grid(True, axis='y') 381 | 382 | # Plot 2: Speedup factors 383 | baseline_method = methods[0] # Assume first method is baseline 384 | speedups = {} 385 | 386 | for method in methods[1:]: 387 | speedup = results[baseline_method]['mean_solve_time'] / results[method]['mean_solve_time'] 388 | speedups[method] = speedup 389 | 390 | speedup_methods = list(speedups.keys()) 391 | speedup_values = list(speedups.values()) 392 | 393 | axes[0, 1].bar(speedup_methods, speedup_values, color='green') 394 | axes[0, 1].set_ylabel('Speedup Factor (x)') 395 | axes[0, 1].set_title(f'Speedup Relative to {baseline_method}') 396 | axes[0, 1].grid(True, axis='y') 397 | 398 | # Add value labels on bars 399 | for i, v in enumerate(speedup_values): 400 | axes[0, 1].text(i, v + 0.1, f"{v:.2f}x", ha='center') 401 | 402 | # Plot 3: Fallback rates 403 | fallback_rates = {m: results[m].get('fallback_rate', 0) for m in methods[1:]} 404 | fallback_methods = list(fallback_rates.keys()) 405 | fallback_values = list(fallback_rates.values()) 406 | 407 | axes[1, 0].bar(fallback_methods, fallback_values, color='orange') 408 | axes[1, 0].set_ylabel('Fallback Rate (%)') 409 | axes[1, 0].set_title('Fallback Rates by Method') 410 | axes[1, 0].grid(True, axis='y') 411 | 412 | # Add value labels on bars 413 | for i, v in enumerate(fallback_values): 414 | axes[1, 0].text(i, v + 1, f"{v:.1f}%", ha='center') 415 | 416 | # Plot 4: Accuracy metrics for constraint predictor 417 | if 'constraint_predictor' in results and 'metrics' in results['constraint_predictor']: 418 | metrics = results['constraint_predictor']['metrics'] 419 | metric_names = list(metrics.keys()) 420 | metric_values = list(metrics.values()) 421 | 422 | axes[1, 1].bar(metric_names, metric_values, color='purple') 423 | axes[1, 1].set_ylabel('Value') 424 | axes[1, 1].set_title('Constraint Predictor Metrics') 425 | axes[1, 1].grid(True, axis='y') 426 | 427 | # Add value labels on bars 428 | for i, v in enumerate(metric_values): 429 | axes[1, 1].text(i, v + 0.02, f"{v:.3f}", ha='center') 430 | else: 431 | axes[1, 1].set_visible(False) 432 | 433 | plt.tight_layout() 434 | 435 | if save_path is not None: 436 | plt.savefig(save_path) 437 | 438 | if show: 439 | plt.show() 440 | else: 441 | plt.close() 442 | -------------------------------------------------------------------------------- /transformermpc/models/warm_start_predictor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Warm start predictor model. 3 | 4 | This module defines the transformer-based model for predicting 5 | warm start solutions for QP problems. 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import numpy as np 12 | import os 13 | from typing import Dict, Optional, Tuple, Union, List, Any 14 | 15 | 16 | class TransformerEncoderDecoder(nn.Module): 17 | """Vanilla Transformer Encoder-Decoder implementation.""" 18 | 19 | def __init__(self, 20 | hidden_dim: int = 256, 21 | output_dim: int = 20, 22 | num_layers: int = 6, 23 | num_heads: int = 8, 24 | dropout: float = 0.1): 25 | """ 26 | Initialize the transformer encoder-decoder. 27 | 28 | Parameters: 29 | ----------- 30 | hidden_dim : int 31 | Dimension of hidden layers 32 | output_dim : int 33 | Dimension of output 34 | num_layers : int 35 | Number of transformer layers 36 | num_heads : int 37 | Number of attention heads 38 | dropout : float 39 | Dropout probability 40 | """ 41 | super().__init__() 42 | 43 | # Make sure hidden_dim is divisible by num_heads 44 | assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads" 45 | 46 | # Create encoder layer 47 | encoder_layer = nn.TransformerEncoderLayer( 48 | d_model=hidden_dim, 49 | nhead=num_heads, 50 | dim_feedforward=4 * hidden_dim, 51 | dropout=dropout, 52 | activation="relu", 53 | batch_first=True 54 | ) 55 | 56 | # Create encoder 57 | self.encoder = nn.TransformerEncoder( 58 | encoder_layer=encoder_layer, 59 | num_layers=num_layers 60 | ) 61 | 62 | # Create decoder layer 63 | decoder_layer = nn.TransformerDecoderLayer( 64 | d_model=hidden_dim, 65 | nhead=num_heads, 66 | dim_feedforward=4 * hidden_dim, 67 | dropout=dropout, 68 | activation="relu", 69 | batch_first=True 70 | ) 71 | 72 | # Create decoder 73 | self.decoder = nn.TransformerDecoder( 74 | decoder_layer=decoder_layer, 75 | num_layers=num_layers 76 | ) 77 | 78 | # Positional encoding 79 | self.pos_encoding = PositionalEncoding( 80 | d_model=hidden_dim, 81 | dropout=dropout, 82 | max_len=100 83 | ) 84 | 85 | # Output projection 86 | self.output_projection = nn.Linear(hidden_dim, output_dim) 87 | 88 | def forward(self, src: torch.Tensor, tgt: Optional[torch.Tensor] = None) -> torch.Tensor: 89 | """ 90 | Forward pass. 91 | 92 | Parameters: 93 | ----------- 94 | src : torch.Tensor 95 | Source sequence tensor of shape (batch_size, src_seq_len, hidden_dim) 96 | tgt : torch.Tensor, optional 97 | Target sequence tensor of shape (batch_size, tgt_seq_len, hidden_dim) 98 | If None, the source is used as target 99 | 100 | Returns: 101 | -------- 102 | output : torch.Tensor 103 | Output tensor of shape (batch_size, tgt_seq_len, hidden_dim) 104 | """ 105 | # If no target is provided, use source 106 | if tgt is None: 107 | tgt = src 108 | 109 | # Add positional encoding 110 | src = self.pos_encoding(src) 111 | tgt = self.pos_encoding(tgt) 112 | 113 | # Pass through encoder 114 | memory = self.encoder(src) 115 | 116 | # Pass through decoder 117 | output = self.decoder(tgt, memory) 118 | 119 | # Project to output dimension 120 | output = self.output_projection(output) 121 | 122 | return output 123 | 124 | 125 | class PositionalEncoding(nn.Module): 126 | """Positional encoding for Transformer models.""" 127 | 128 | def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 100): 129 | """ 130 | Initialize the positional encoding. 131 | 132 | Parameters: 133 | ----------- 134 | d_model : int 135 | Dimension of the model 136 | dropout : float 137 | Dropout probability 138 | max_len : int 139 | Maximum sequence length 140 | """ 141 | super().__init__() 142 | self.dropout = nn.Dropout(p=dropout) 143 | 144 | # Create positional encoding 145 | position = torch.arange(max_len).unsqueeze(1) 146 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model)) 147 | pe = torch.zeros(1, max_len, d_model) 148 | pe[0, :, 0::2] = torch.sin(position * div_term) 149 | pe[0, :, 1::2] = torch.cos(position * div_term) 150 | self.register_buffer('pe', pe) 151 | 152 | def forward(self, x: torch.Tensor) -> torch.Tensor: 153 | """ 154 | Forward pass. 155 | 156 | Parameters: 157 | ----------- 158 | x : torch.Tensor 159 | Input tensor of shape (batch_size, seq_len, d_model) 160 | 161 | Returns: 162 | -------- 163 | output : torch.Tensor 164 | Output tensor with positional encoding added 165 | """ 166 | x = x + self.pe[:, :x.size(1), :] 167 | return self.dropout(x) 168 | 169 | 170 | class WarmStartPredictor(nn.Module): 171 | """ 172 | Transformer-based model for predicting warm start solutions. 173 | 174 | This model takes QP problem features as input and outputs an approximate 175 | solution that can be used to warm start the QP solver. 176 | """ 177 | 178 | def __init__(self, 179 | input_dim: int = 50, 180 | hidden_dim: int = 256, 181 | output_dim: int = 20, 182 | num_layers: int = 6, 183 | num_heads: int = 8, 184 | dropout: float = 0.1): 185 | """ 186 | Initialize the warm start predictor model. 187 | 188 | Parameters: 189 | ----------- 190 | input_dim : int 191 | Dimension of input features 192 | hidden_dim : int 193 | Dimension of hidden layers 194 | output_dim : int 195 | Dimension of output solution vector 196 | num_layers : int 197 | Number of transformer layers 198 | num_heads : int 199 | Number of attention heads 200 | dropout : float 201 | Dropout probability 202 | """ 203 | super().__init__() 204 | 205 | self.input_dim = input_dim 206 | self.hidden_dim = hidden_dim 207 | self.output_dim = output_dim 208 | self.num_layers = num_layers 209 | self.num_heads = num_heads 210 | self.dropout = dropout 211 | 212 | # Input projection 213 | self.input_projection = nn.Linear(input_dim, hidden_dim) 214 | 215 | # Transformer model 216 | self.transformer = TransformerEncoderDecoder( 217 | hidden_dim=hidden_dim, 218 | output_dim=1, # We'll handle the output projection separately 219 | num_layers=num_layers, 220 | num_heads=num_heads, 221 | dropout=dropout 222 | ) 223 | 224 | # Output layers with residual connections 225 | self.output_layer1 = nn.Linear(hidden_dim, hidden_dim) 226 | self.output_layer2 = nn.Linear(hidden_dim, hidden_dim) 227 | self.output_projection = nn.Linear(hidden_dim, output_dim) 228 | 229 | def forward(self, x: torch.Tensor) -> torch.Tensor: 230 | """ 231 | Forward pass. 232 | 233 | Parameters: 234 | ----------- 235 | x : torch.Tensor 236 | Input tensor of shape (batch_size, input_dim) 237 | 238 | Returns: 239 | -------- 240 | output : torch.Tensor 241 | Output tensor of shape (batch_size, output_dim) 242 | containing approximate solution vector 243 | """ 244 | # Input projection 245 | x = self.input_projection(x) 246 | 247 | # Reshape for transformer: [batch_size, seq_len, hidden_dim] 248 | # Here we use a sequence length of 1 249 | x = x.unsqueeze(1) 250 | 251 | # Pass through transformer 252 | transformer_output = self.transformer(x) 253 | 254 | # Extract first token of the output sequence 255 | first_token = transformer_output[:, 0, :] 256 | 257 | # Output layers with residual connections 258 | output = F.relu(self.output_layer1(first_token)) 259 | output = output + first_token # Residual connection 260 | output = F.relu(self.output_layer2(output)) 261 | output = output + first_token # Residual connection 262 | 263 | # Final output projection 264 | solution = self.output_projection(output) 265 | 266 | return solution 267 | 268 | def predict(self, x: Union[torch.Tensor, np.ndarray]) -> np.ndarray: 269 | """ 270 | Make prediction on input data. 271 | 272 | Parameters: 273 | ----------- 274 | x : torch.Tensor or numpy.ndarray 275 | Input features 276 | 277 | Returns: 278 | -------- 279 | prediction : numpy.ndarray 280 | Predicted solution vector 281 | """ 282 | # Convert to tensor if numpy array 283 | if isinstance(x, np.ndarray): 284 | x = torch.tensor(x, dtype=torch.float32) 285 | 286 | # Make sure the input has the right shape 287 | if x.dim() == 1: 288 | x = x.unsqueeze(0) # Add batch dimension 289 | 290 | # Set model to evaluation mode 291 | self.eval() 292 | 293 | # Disable gradient computation 294 | with torch.no_grad(): 295 | # Forward pass 296 | solution = self(x) 297 | 298 | # Convert to numpy 299 | solution = solution.cpu().numpy() 300 | 301 | return solution 302 | 303 | def save(self, filepath: str) -> None: 304 | """ 305 | Save model to file. 306 | 307 | Parameters: 308 | ----------- 309 | filepath : str 310 | Path to save the model 311 | """ 312 | # Create directory if it doesn't exist 313 | os.makedirs(os.path.dirname(filepath), exist_ok=True) 314 | 315 | # Save model state 316 | torch.save({ 317 | 'state_dict': self.state_dict(), 318 | 'input_dim': self.input_dim, 319 | 'hidden_dim': self.hidden_dim, 320 | 'output_dim': self.output_dim, 321 | 'num_layers': self.num_layers, 322 | 'num_heads': self.num_heads, 323 | 'dropout': self.dropout 324 | }, filepath) 325 | 326 | @classmethod 327 | def load(cls, filepath: Optional[str] = None) -> 'WarmStartPredictor': 328 | """ 329 | Load model from file. 330 | 331 | Parameters: 332 | ----------- 333 | filepath : str or None 334 | Path to load the model from, or None to create a new model 335 | 336 | Returns: 337 | -------- 338 | model : WarmStartPredictor 339 | Loaded model 340 | """ 341 | if filepath is None or not os.path.exists(filepath): 342 | # Return a new model with default parameters 343 | return cls() 344 | 345 | # Load model state 346 | checkpoint = torch.load(filepath, map_location=torch.device('cpu')) 347 | 348 | # Create model with saved parameters 349 | model = cls( 350 | input_dim=checkpoint['input_dim'], 351 | hidden_dim=checkpoint['hidden_dim'], 352 | output_dim=checkpoint['output_dim'], 353 | num_layers=checkpoint.get('num_layers', 6), 354 | num_heads=checkpoint.get('num_heads', 8), 355 | dropout=checkpoint.get('dropout', 0.1) 356 | ) 357 | 358 | # Load state dictionary 359 | model.load_state_dict(checkpoint['state_dict']) 360 | 361 | return model 362 | 363 | def train_step(self, 364 | x: torch.Tensor, 365 | y: torch.Tensor, 366 | optimizer: torch.optim.Optimizer) -> Dict[str, float]: 367 | """ 368 | Perform a single training step. 369 | 370 | Parameters: 371 | ----------- 372 | x : torch.Tensor 373 | Input tensor of shape (batch_size, input_dim) 374 | y : torch.Tensor 375 | Target tensor of shape (batch_size, output_dim) 376 | optimizer : torch.optim.Optimizer 377 | Optimizer to use for the step 378 | 379 | Returns: 380 | -------- 381 | metrics : Dict[str, float] 382 | Dictionary containing loss and error metrics 383 | """ 384 | # Set model to training mode 385 | self.train() 386 | 387 | # Zero the gradients 388 | optimizer.zero_grad() 389 | 390 | # Forward pass 391 | solution = self(x) 392 | 393 | # Compute loss (mean squared error) 394 | mse_loss = F.mse_loss(solution, y) 395 | 396 | # Add L1 regularization for sparsity 397 | l1_loss = torch.mean(torch.abs(solution)) 398 | 399 | # Combined loss 400 | loss = mse_loss + 0.01 * l1_loss 401 | 402 | # Backward pass 403 | loss.backward() 404 | 405 | # Update parameters 406 | optimizer.step() 407 | 408 | # Compute additional metrics 409 | mae = F.l1_loss(solution, y).item() 410 | 411 | # Compute relative error 412 | rel_error = torch.norm(solution - y, dim=1) / (torch.norm(y, dim=1) + 1e-8) 413 | rel_error = torch.mean(rel_error).item() 414 | 415 | return { 416 | 'loss': loss.item(), 417 | 'mse': mse_loss.item(), 418 | 'mae': mae, 419 | 'relative_error': rel_error 420 | } 421 | 422 | def validate(self, 423 | x: torch.Tensor, 424 | y: torch.Tensor) -> Dict[str, float]: 425 | """ 426 | Validate the model on validation data. 427 | 428 | Parameters: 429 | ----------- 430 | x : torch.Tensor 431 | Input tensor of shape (batch_size, input_dim) 432 | y : torch.Tensor 433 | Target tensor of shape (batch_size, output_dim) 434 | 435 | Returns: 436 | -------- 437 | metrics : Dict[str, float] 438 | Dictionary containing loss and error metrics 439 | """ 440 | # Set model to evaluation mode 441 | self.eval() 442 | 443 | # Disable gradient computation 444 | with torch.no_grad(): 445 | # Forward pass 446 | solution = self(x) 447 | 448 | # Compute MSE loss 449 | mse_loss = F.mse_loss(solution, y) 450 | 451 | # Compute additional metrics 452 | mae = F.l1_loss(solution, y).item() 453 | 454 | # Compute relative error 455 | rel_error = torch.norm(solution - y, dim=1) / (torch.norm(y, dim=1) + 1e-8) 456 | rel_error = torch.mean(rel_error).item() 457 | 458 | return { 459 | 'loss': mse_loss.item(), 460 | 'mse': mse_loss.item(), 461 | 'mae': mae, 462 | 'relative_error': rel_error 463 | } 464 | -------------------------------------------------------------------------------- /transformermpc/data/dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataset module for TransformerMPC. 3 | 4 | This module provides classes for creating, processing, and managing datasets 5 | for training and evaluating the transformer models. 6 | """ 7 | 8 | import numpy as np 9 | import torch 10 | from torch.utils.data import Dataset, DataLoader, random_split 11 | from typing import List, Tuple, Dict, Optional, Union, Any 12 | import os 13 | import pickle 14 | import osqp 15 | import scipy.sparse as sparse 16 | 17 | from .qp_generator import QPProblem 18 | from ..utils.osqp_wrapper import OSQPSolver 19 | 20 | 21 | class QPDataset(Dataset): 22 | """ 23 | Dataset class for QP problems. 24 | 25 | This class manages the datasets used for training and evaluating 26 | the transformer models. It handles data preprocessing, feature extraction, 27 | and creating input/target pairs for both transformer models. 28 | """ 29 | 30 | def __init__(self, 31 | qp_problems: List[QPProblem], 32 | precompute_solutions: bool = True, 33 | max_constraints: Optional[int] = None, 34 | feature_normalization: bool = True, 35 | cache_dir: Optional[str] = None): 36 | """ 37 | Initialize the QP dataset. 38 | 39 | Parameters: 40 | ----------- 41 | qp_problems : List[QPProblem] 42 | List of QP problems 43 | precompute_solutions : bool 44 | Whether to precompute solutions and active constraints 45 | max_constraints : int or None 46 | Maximum number of constraints to consider (for padding) 47 | feature_normalization : bool 48 | Whether to normalize features 49 | cache_dir : str or None 50 | Directory to cache precomputed solutions 51 | """ 52 | self.qp_problems = qp_problems 53 | self.precompute_solutions = precompute_solutions 54 | self.max_constraints = max_constraints or self._get_max_constraints() 55 | self.feature_normalization = feature_normalization 56 | self.cache_dir = cache_dir 57 | 58 | # Create cache directory if it doesn't exist 59 | if cache_dir is not None: 60 | os.makedirs(cache_dir, exist_ok=True) 61 | 62 | # Solver for computing solutions 63 | self.solver = OSQPSolver() 64 | 65 | # Precompute solutions if requested 66 | if precompute_solutions: 67 | self._precompute_all_solutions() 68 | 69 | # Compute feature statistics for normalization 70 | if feature_normalization: 71 | self._compute_feature_statistics() 72 | 73 | def _get_max_constraints(self) -> int: 74 | """ 75 | Get maximum number of constraints across all problems. 76 | 77 | Returns: 78 | -------- 79 | max_constraints : int 80 | Maximum number of constraints 81 | """ 82 | return max(problem.n_constraints for problem in self.qp_problems) 83 | 84 | def _compute_feature_statistics(self) -> None: 85 | """ 86 | Compute statistics for feature normalization. 87 | """ 88 | # Extract raw features from each problem 89 | all_features = [] 90 | for i in range(min(1000, len(self.qp_problems))): # Use subset for efficiency 91 | features = self._extract_raw_features(self.qp_problems[i]) 92 | all_features.append(features) 93 | 94 | # Concatenate features 95 | all_features = np.vstack(all_features) 96 | 97 | # Compute mean and standard deviation 98 | self.feature_mean = np.mean(all_features, axis=0) 99 | self.feature_std = np.std(all_features, axis=0) 100 | 101 | # Replace zeros in std to avoid division by zero 102 | self.feature_std = np.where(self.feature_std < 1e-8, 1.0, self.feature_std) 103 | 104 | def _normalize_features(self, features: np.ndarray) -> np.ndarray: 105 | """ 106 | Normalize features using precomputed statistics. 107 | 108 | Parameters: 109 | ----------- 110 | features : numpy.ndarray 111 | Raw features 112 | 113 | Returns: 114 | -------- 115 | normalized_features : numpy.ndarray 116 | Normalized features 117 | """ 118 | if not self.feature_normalization: 119 | return features 120 | 121 | return (features - self.feature_mean) / self.feature_std 122 | 123 | def _extract_raw_features(self, problem: QPProblem) -> np.ndarray: 124 | """ 125 | Extract raw features from a QP problem. 126 | 127 | Parameters: 128 | ----------- 129 | problem : QPProblem 130 | QP problem instance 131 | 132 | Returns: 133 | -------- 134 | features : numpy.ndarray 135 | Raw features 136 | """ 137 | # Basic features: initial state and reference if available 138 | features = [] 139 | 140 | if problem.initial_state is not None: 141 | features.append(problem.initial_state) 142 | 143 | if problem.reference is not None: 144 | features.append(problem.reference) 145 | 146 | # Add problem dimensions as features 147 | features.append(np.array([problem.n_variables, problem.n_constraints])) 148 | 149 | # Flatten and concatenate all features 150 | return np.concatenate([f.flatten() for f in features]) 151 | 152 | def _precompute_all_solutions(self) -> None: 153 | """ 154 | Precompute solutions and active constraints for all problems. 155 | """ 156 | self.solutions = [] 157 | self.active_constraints = [] 158 | 159 | # Define cache file path if using caching 160 | cache_file = None 161 | if self.cache_dir is not None: 162 | cache_file = os.path.join(self.cache_dir, "qp_solutions_cache.pkl") 163 | 164 | # Load from cache if it exists 165 | if os.path.exists(cache_file): 166 | with open(cache_file, 'rb') as f: 167 | cache_data = pickle.load(f) 168 | self.solutions = cache_data['solutions'] 169 | self.active_constraints = cache_data['active_constraints'] 170 | return 171 | 172 | # Solve all problems and identify active constraints 173 | for i, problem in enumerate(self.qp_problems): 174 | # Solve the QP problem 175 | solution = self.solver.solve( 176 | Q=problem.Q, 177 | c=problem.c, 178 | A=problem.A, 179 | b=problem.b 180 | ) 181 | 182 | # Store the solution 183 | self.solutions.append(solution) 184 | 185 | # Identify active constraints 186 | active = self._identify_active_constraints(problem, solution) 187 | self.active_constraints.append(active) 188 | 189 | # Save to cache if using caching 190 | if cache_file is not None: 191 | with open(cache_file, 'wb') as f: 192 | pickle.dump( 193 | { 194 | 'solutions': self.solutions, 195 | 'active_constraints': self.active_constraints 196 | }, 197 | f 198 | ) 199 | 200 | def _identify_active_constraints(self, 201 | problem: QPProblem, 202 | solution: np.ndarray, 203 | tol: float = 1e-6) -> np.ndarray: 204 | """ 205 | Identify active constraints in a QP solution. 206 | 207 | Parameters: 208 | ----------- 209 | problem : QPProblem 210 | QP problem instance 211 | solution : numpy.ndarray 212 | Solution vector 213 | tol : float 214 | Tolerance for identifying active constraints 215 | 216 | Returns: 217 | -------- 218 | active : numpy.ndarray 219 | Binary vector indicating active constraints 220 | """ 221 | # Compute constraint values: A * x - b 222 | constraint_values = problem.A @ solution - problem.b 223 | 224 | # Identify active constraints (those within tolerance of the boundary) 225 | active = np.abs(constraint_values) < tol 226 | 227 | # Pad or truncate to max_constraints 228 | if self.max_constraints is not None: 229 | if len(active) > self.max_constraints: 230 | active = active[:self.max_constraints] 231 | elif len(active) < self.max_constraints: 232 | padding = np.zeros(self.max_constraints - len(active), dtype=bool) 233 | active = np.concatenate([active, padding]) 234 | 235 | return active.astype(np.float32) 236 | 237 | def __len__(self) -> int: 238 | """ 239 | Get the number of problems in the dataset. 240 | 241 | Returns: 242 | -------- 243 | length : int 244 | Number of problems 245 | """ 246 | return len(self.qp_problems) 247 | 248 | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: 249 | """ 250 | Get a single item from the dataset. 251 | 252 | Parameters: 253 | ----------- 254 | idx : int 255 | Index of the item 256 | 257 | Returns: 258 | -------- 259 | item : Dict[str, torch.Tensor] 260 | Dictionary containing: 261 | - 'features': Input features for the models 262 | - 'active_constraints': Target for constraint predictor 263 | - 'solution': Target for warm start predictor 264 | """ 265 | # Get the problem 266 | problem = self.qp_problems[idx] 267 | 268 | # Extract and normalize features 269 | features = self._extract_raw_features(problem) 270 | features = self._normalize_features(features) 271 | 272 | # Get solution and active constraints 273 | if self.precompute_solutions: 274 | solution = self.solutions[idx] 275 | active_constraints = self.active_constraints[idx] 276 | else: 277 | # Solve on-the-fly 278 | solution = self.solver.solve( 279 | Q=problem.Q, 280 | c=problem.c, 281 | A=problem.A, 282 | b=problem.b 283 | ) 284 | active_constraints = self._identify_active_constraints(problem, solution) 285 | 286 | # Convert to torch tensors 287 | features_tensor = torch.tensor(features, dtype=torch.float32) 288 | active_constraints_tensor = torch.tensor(active_constraints, dtype=torch.float32) 289 | solution_tensor = torch.tensor(solution, dtype=torch.float32) 290 | 291 | return { 292 | 'features': features_tensor, 293 | 'active_constraints': active_constraints_tensor, 294 | 'solution': solution_tensor, 295 | 'problem_idx': idx # Include the problem index for reference 296 | } 297 | 298 | def get_problem(self, idx: int) -> QPProblem: 299 | """ 300 | Get the original QP problem at the given index. 301 | 302 | Parameters: 303 | ----------- 304 | idx : int 305 | Index of the problem 306 | 307 | Returns: 308 | -------- 309 | problem : QPProblem 310 | QP problem instance 311 | """ 312 | return self.qp_problems[idx] 313 | 314 | def get_dataloaders(self, 315 | batch_size: int = 32, 316 | val_split: float = 0.2, 317 | test_split: float = 0.1, 318 | shuffle: bool = True, 319 | num_workers: int = 4) -> Tuple[DataLoader, DataLoader, DataLoader]: 320 | """ 321 | Create train, validation, and test dataloaders. 322 | 323 | Parameters: 324 | ----------- 325 | batch_size : int 326 | Batch size 327 | val_split : float 328 | Fraction of data to use for validation 329 | test_split : float 330 | Fraction of data to use for testing 331 | shuffle : bool 332 | Whether to shuffle the data 333 | num_workers : int 334 | Number of workers for data loading 335 | 336 | Returns: 337 | -------- 338 | train_loader : DataLoader 339 | Training data loader 340 | val_loader : DataLoader 341 | Validation data loader 342 | test_loader : DataLoader 343 | Test data loader 344 | """ 345 | # Calculate dataset sizes 346 | dataset_size = len(self) 347 | test_size = int(dataset_size * test_split) 348 | val_size = int(dataset_size * val_split) 349 | train_size = dataset_size - val_size - test_size 350 | 351 | # Split the dataset 352 | train_dataset, val_dataset, test_dataset = random_split( 353 | self, [train_size, val_size, test_size], 354 | generator=torch.Generator().manual_seed(42) 355 | ) 356 | 357 | # Create dataloaders 358 | train_loader = DataLoader( 359 | train_dataset, 360 | batch_size=batch_size, 361 | shuffle=shuffle, 362 | num_workers=num_workers 363 | ) 364 | 365 | val_loader = DataLoader( 366 | val_dataset, 367 | batch_size=batch_size, 368 | shuffle=False, 369 | num_workers=num_workers 370 | ) 371 | 372 | test_loader = DataLoader( 373 | test_dataset, 374 | batch_size=batch_size, 375 | shuffle=False, 376 | num_workers=num_workers 377 | ) 378 | 379 | return train_loader, val_loader, test_loader 380 | 381 | def split(self, test_size: float = 0.2, seed: int = 42) -> Tuple['QPDataset', 'QPDataset']: 382 | """ 383 | Split the dataset into training and test sets. 384 | 385 | Parameters: 386 | ----------- 387 | test_size : float 388 | Fraction of data to use for testing 389 | seed : int 390 | Random seed 391 | 392 | Returns: 393 | -------- 394 | train_dataset : QPDataset 395 | Training dataset 396 | test_dataset : QPDataset 397 | Test dataset 398 | """ 399 | # Set random seed for reproducibility 400 | np.random.seed(seed) 401 | 402 | # Shuffle indices 403 | indices = np.arange(len(self.qp_problems)) 404 | np.random.shuffle(indices) 405 | 406 | # Calculate split sizes 407 | test_idx = int(len(indices) * (1 - test_size)) 408 | 409 | # Split indices 410 | train_indices = indices[:test_idx] 411 | test_indices = indices[test_idx:] 412 | 413 | # Create new datasets 414 | train_problems = [self.qp_problems[i] for i in train_indices] 415 | test_problems = [self.qp_problems[i] for i in test_indices] 416 | 417 | # Create new dataset objects 418 | train_dataset = QPDataset( 419 | train_problems, 420 | precompute_solutions=self.precompute_solutions, 421 | max_constraints=self.max_constraints, 422 | feature_normalization=self.feature_normalization, 423 | cache_dir=self.cache_dir 424 | ) 425 | 426 | test_dataset = QPDataset( 427 | test_problems, 428 | precompute_solutions=self.precompute_solutions, 429 | max_constraints=self.max_constraints, 430 | feature_normalization=self.feature_normalization, 431 | cache_dir=self.cache_dir 432 | ) 433 | 434 | # Copy feature statistics to ensure consistent normalization 435 | if self.feature_normalization: 436 | test_dataset.feature_mean = self.feature_mean 437 | test_dataset.feature_std = self.feature_std 438 | 439 | return train_dataset, test_dataset 440 | -------------------------------------------------------------------------------- /scripts/simple_demo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | TransformerMPC Simple Demo 4 | 5 | This script provides a complete end-to-end demonstration of the TransformerMPC package: 6 | 1. Generates a set of quadratic programming (QP) problems 7 | 2. Creates training and test datasets 8 | 3. Trains both transformer models (constraint predictor and warm start predictor) 9 | 4. Tests the models on the test set 10 | 5. Plots performance comparisons including box plots 11 | 12 | The script is designed to run quickly with a small number of samples and epochs, 13 | but can be modified for more comprehensive training. 14 | """ 15 | 16 | import os 17 | import time 18 | import argparse 19 | import numpy as np 20 | import matplotlib.pyplot as plt 21 | from pathlib import Path 22 | import torch 23 | from tqdm import tqdm 24 | 25 | # Fix potential serialization issues with numpy scalars 26 | import torch.serialization 27 | torch.serialization.add_safe_globals(['numpy.core.multiarray.scalar']) 28 | 29 | # Patch torch.load to handle weights_only parameter 30 | original_torch_load = torch.load 31 | def patched_torch_load(f, *args, **kwargs): 32 | if 'weights_only' not in kwargs: 33 | kwargs['weights_only'] = False 34 | return original_torch_load(f, *args, **kwargs) 35 | torch.load = patched_torch_load 36 | 37 | # Import TransformerMPC modules 38 | from transformermpc.data.qp_generator import QPGenerator 39 | from transformermpc.data.dataset import QPDataset 40 | from transformermpc.models.constraint_predictor import ConstraintPredictor 41 | from transformermpc.models.warm_start_predictor import WarmStartPredictor 42 | from transformermpc.utils.osqp_wrapper import OSQPSolver 43 | from transformermpc.utils.metrics import compute_solve_time_metrics, compute_fallback_rate 44 | from transformermpc.utils.visualization import ( 45 | plot_solve_time_comparison, 46 | plot_solve_time_boxplot 47 | ) 48 | 49 | def parse_args(): 50 | """Parse command line arguments.""" 51 | parser = argparse.ArgumentParser(description="TransformerMPC Simple Demo") 52 | 53 | # Data generation parameters 54 | parser.add_argument("--num_samples", type=int, default=100, 55 | help="Number of QP problems to generate (default: 100)") 56 | parser.add_argument("--state_dim", type=int, default=4, 57 | help="State dimension for MPC problems (default: 4)") 58 | parser.add_argument("--input_dim", type=int, default=2, 59 | help="Input dimension for MPC problems (default: 2)") 60 | parser.add_argument("--horizon", type=int, default=5, 61 | help="Time horizon for MPC problems (default: 5)") 62 | 63 | # Training parameters 64 | parser.add_argument("--epochs", type=int, default=5, 65 | help="Number of epochs for training (default: 5)") 66 | parser.add_argument("--batch_size", type=int, default=16, 67 | help="Batch size for training (default: 16)") 68 | parser.add_argument("--test_size", type=float, default=0.2, 69 | help="Fraction of data to use for testing (default: 0.2)") 70 | 71 | # Other parameters 72 | parser.add_argument("--output_dir", type=str, default="demo_results", 73 | help="Directory to save results (default: demo_results)") 74 | parser.add_argument("--use_gpu", action="store_true", 75 | help="Use GPU if available") 76 | parser.add_argument("--test_problems", type=int, default=10, 77 | help="Number of test problems for evaluation (default: 10)") 78 | 79 | return parser.parse_args() 80 | 81 | def train_model(model, train_data, val_data, num_epochs, batch_size, lr=1e-3): 82 | """Simple training loop for the models""" 83 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 84 | model = model.to(device) 85 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 86 | 87 | train_losses = [] 88 | val_losses = [] 89 | 90 | # Create data loaders 91 | train_loader = torch.utils.data.DataLoader( 92 | train_data, batch_size=batch_size, shuffle=True 93 | ) 94 | val_loader = torch.utils.data.DataLoader( 95 | val_data, batch_size=batch_size, shuffle=False 96 | ) 97 | 98 | print(f"Training for {num_epochs} epochs...") 99 | for epoch in range(num_epochs): 100 | # Training 101 | model.train() 102 | epoch_loss = 0 103 | for batch in train_loader: 104 | features = batch['features'].to(device) 105 | 106 | if isinstance(model, ConstraintPredictor): 107 | targets = batch['active_constraints'].to(device) 108 | optimizer.zero_grad() 109 | outputs = model(features) 110 | loss = torch.nn.functional.binary_cross_entropy_with_logits(outputs, targets) 111 | else: 112 | targets = batch['solution'].to(device) 113 | optimizer.zero_grad() 114 | outputs = model(features) 115 | loss = torch.nn.functional.mse_loss(outputs, targets) 116 | 117 | loss.backward() 118 | optimizer.step() 119 | epoch_loss += loss.item() 120 | 121 | avg_epoch_loss = epoch_loss / len(train_loader) 122 | train_losses.append(avg_epoch_loss) 123 | 124 | # Validation 125 | model.eval() 126 | val_loss = 0 127 | with torch.no_grad(): 128 | for batch in val_loader: 129 | features = batch['features'].to(device) 130 | 131 | if isinstance(model, ConstraintPredictor): 132 | targets = batch['active_constraints'].to(device) 133 | outputs = model(features) 134 | loss = torch.nn.functional.binary_cross_entropy_with_logits(outputs, targets) 135 | else: 136 | targets = batch['solution'].to(device) 137 | outputs = model(features) 138 | loss = torch.nn.functional.mse_loss(outputs, targets) 139 | 140 | val_loss += loss.item() 141 | 142 | avg_val_loss = val_loss / len(val_loader) 143 | val_losses.append(avg_val_loss) 144 | 145 | print(f"Epoch {epoch+1}/{num_epochs}: Train Loss = {avg_epoch_loss:.6f}, Val Loss = {avg_val_loss:.6f}") 146 | 147 | return { 148 | 'train_loss': train_losses, 149 | 'val_loss': val_losses 150 | } 151 | 152 | def main(): 153 | """Run the demo workflow.""" 154 | # Parse command line arguments 155 | args = parse_args() 156 | 157 | print("=" * 60) 158 | print("TransformerMPC Simple Demo".center(60)) 159 | print("=" * 60) 160 | 161 | # Create output directory 162 | output_dir = Path(args.output_dir) 163 | output_dir.mkdir(parents=True, exist_ok=True) 164 | 165 | # Set up directories 166 | models_dir = output_dir / "models" 167 | results_dir = output_dir / "results" 168 | 169 | for directory in [models_dir, results_dir]: 170 | directory.mkdir(exist_ok=True) 171 | 172 | # Set device 173 | device = torch.device('cuda' if torch.cuda.is_available() and args.use_gpu else 'cpu') 174 | print(f"Using device: {device}") 175 | 176 | # Step 1: Generate QP problems 177 | print("\nStep 1: Generating QP problems") 178 | print("-" * 60) 179 | 180 | print(f"Generating {args.num_samples} QP problems") 181 | generator = QPGenerator( 182 | state_dim=args.state_dim, 183 | input_dim=args.input_dim, 184 | horizon=args.horizon, 185 | num_samples=args.num_samples 186 | ) 187 | qp_problems = generator.generate() 188 | print(f"Generated {len(qp_problems)} QP problems") 189 | 190 | # Step 2: Create dataset and split into train/test 191 | print("\nStep 2: Creating datasets") 192 | print("-" * 60) 193 | 194 | dataset = QPDataset( 195 | qp_problems=qp_problems, 196 | precompute_solutions=True, 197 | feature_normalization=True 198 | ) 199 | 200 | train_dataset, test_dataset = dataset.split(test_size=args.test_size) 201 | print(f"Created datasets - Train: {len(train_dataset)}, Test: {len(test_dataset)}") 202 | 203 | # Step 3: Train constraint predictor 204 | print("\nStep 3: Training constraint predictor") 205 | print("-" * 60) 206 | 207 | # Get input dimension from the dataset 208 | sample_item = train_dataset[0] 209 | input_dim = sample_item['features'].shape[0] 210 | num_constraints = sample_item['active_constraints'].shape[0] 211 | 212 | cp_model = ConstraintPredictor( 213 | input_dim=input_dim, 214 | hidden_dim=64, # Smaller model for faster training 215 | num_constraints=num_constraints, 216 | num_layers=2 # Fewer layers for faster training 217 | ) 218 | 219 | # Train constraint predictor with simplified training loop 220 | cp_history = train_model( 221 | model=cp_model, 222 | train_data=train_dataset, 223 | val_data=test_dataset, 224 | num_epochs=args.epochs, 225 | batch_size=args.batch_size 226 | ) 227 | 228 | # Save model 229 | cp_model_file = models_dir / "constraint_predictor.pt" 230 | torch.save(cp_model.state_dict(), cp_model_file) 231 | print(f"Saved constraint predictor to {cp_model_file}") 232 | 233 | # Step 4: Train warm start predictor 234 | print("\nStep 4: Training warm start predictor") 235 | print("-" * 60) 236 | 237 | # Get input dimension from the dataset 238 | output_dim = sample_item['solution'].shape[0] 239 | 240 | ws_model = WarmStartPredictor( 241 | input_dim=input_dim, 242 | hidden_dim=128, # Smaller model for faster training 243 | output_dim=output_dim, 244 | num_layers=2 # Fewer layers for faster training 245 | ) 246 | 247 | # Train warm start predictor with simplified training loop 248 | ws_history = train_model( 249 | model=ws_model, 250 | train_data=train_dataset, 251 | val_data=test_dataset, 252 | num_epochs=args.epochs, 253 | batch_size=args.batch_size 254 | ) 255 | 256 | # Save model 257 | ws_model_file = models_dir / "warm_start_predictor.pt" 258 | torch.save(ws_model.state_dict(), ws_model_file) 259 | print(f"Saved warm start predictor to {ws_model_file}") 260 | 261 | # Step 5: Test on a subset of problems 262 | print("\nStep 5: Performance testing") 263 | print("-" * 60) 264 | 265 | solver = OSQPSolver() 266 | 267 | # Lists to store results 268 | baseline_times = [] 269 | transformer_times = [] 270 | constraint_only_times = [] 271 | warmstart_only_times = [] 272 | fallback_flags = [] 273 | 274 | # Test on a small subset for demonstration 275 | num_test_problems = min(args.test_problems, len(test_dataset)) 276 | test_subset = np.random.choice(len(test_dataset), size=num_test_problems, replace=False) 277 | 278 | print(f"Testing on {num_test_problems} problems...") 279 | for idx in tqdm(test_subset): 280 | # Get problem 281 | sample = test_dataset[idx] 282 | problem = test_dataset.get_problem(idx) 283 | 284 | # Get features 285 | features = sample['features'] 286 | 287 | # Predict active constraints and warm start 288 | with torch.no_grad(): 289 | cp_output = cp_model(torch.tensor(features, dtype=torch.float32).unsqueeze(0)) 290 | pred_active = (torch.sigmoid(cp_output) > 0.5).float().squeeze(0).numpy() 291 | 292 | ws_output = ws_model(torch.tensor(features, dtype=torch.float32).unsqueeze(0)) 293 | pred_solution = ws_output.squeeze(0).numpy() 294 | 295 | # 1. Baseline (OSQP without transformers) 296 | _, baseline_time = solver.solve_with_time( 297 | Q=problem.Q, 298 | c=problem.c, 299 | A=problem.A, 300 | b=problem.b 301 | ) 302 | baseline_times.append(baseline_time) 303 | 304 | # 2. Constraint-only 305 | _, constraint_time, _ = solver.solve_pipeline( 306 | Q=problem.Q, 307 | c=problem.c, 308 | A=problem.A, 309 | b=problem.b, 310 | active_constraints=pred_active, 311 | warm_start=None, 312 | fallback_on_violation=True 313 | ) 314 | constraint_only_times.append(constraint_time) 315 | 316 | # 3. Warm-start-only 317 | _, warmstart_time = solver.solve_with_time( 318 | Q=problem.Q, 319 | c=problem.c, 320 | A=problem.A, 321 | b=problem.b, 322 | warm_start=pred_solution 323 | ) 324 | warmstart_only_times.append(warmstart_time) 325 | 326 | # 4. Full transformer pipeline 327 | _, transformer_time, used_fallback = solver.solve_pipeline( 328 | Q=problem.Q, 329 | c=problem.c, 330 | A=problem.A, 331 | b=problem.b, 332 | active_constraints=pred_active, 333 | warm_start=pred_solution, 334 | fallback_on_violation=True 335 | ) 336 | transformer_times.append(transformer_time) 337 | fallback_flags.append(used_fallback) 338 | 339 | # Convert to numpy arrays 340 | baseline_times = np.array(baseline_times) 341 | transformer_times = np.array(transformer_times) 342 | constraint_only_times = np.array(constraint_only_times) 343 | warmstart_only_times = np.array(warmstart_only_times) 344 | 345 | # Compute and print metrics 346 | solve_metrics = compute_solve_time_metrics(baseline_times, transformer_times) 347 | fallback_rate = compute_fallback_rate(fallback_flags) 348 | 349 | print("\nPerformance Results:") 350 | print("-" * 60) 351 | print(f"Mean baseline time: {solve_metrics['mean_baseline_time']:.6f}s") 352 | print(f"Mean transformer time: {solve_metrics['mean_transformer_time']:.6f}s") 353 | print(f"Mean constraint-only time: {np.mean(constraint_only_times):.6f}s") 354 | print(f"Mean warm-start-only time: {np.mean(warmstart_only_times):.6f}s") 355 | print(f"Mean speedup: {solve_metrics['mean_speedup']:.2f}x") 356 | print(f"Median speedup: {solve_metrics['median_speedup']:.2f}x") 357 | print(f"Fallback rate: {fallback_rate:.2f}%") 358 | 359 | # Step 6: Generate visualizations 360 | print("\nStep 6: Generating visualizations") 361 | print("-" * 60) 362 | 363 | # Plot solve time comparison 364 | print("Generating solve time comparison plot...") 365 | plot_solve_time_comparison( 366 | baseline_times=baseline_times, 367 | transformer_times=transformer_times, 368 | save_path=results_dir / "solve_time_comparison.png" 369 | ) 370 | 371 | # Plot solve time boxplot 372 | print("Generating solve time boxplot...") 373 | plot_solve_time_boxplot( 374 | baseline_times=baseline_times, 375 | transformer_times=transformer_times, 376 | constraint_only_times=constraint_only_times, 377 | warmstart_only_times=warmstart_only_times, 378 | save_path=results_dir / "solve_time_boxplot.png" 379 | ) 380 | 381 | # Plot performance violin plot 382 | print("Generating performance violin plot...") 383 | plt.figure(figsize=(10, 6)) 384 | plt.violinplot( 385 | [baseline_times, constraint_only_times, warmstart_only_times, transformer_times], 386 | showmeans=True 387 | ) 388 | plt.xticks([1, 2, 3, 4], ['Baseline', 'Constraint-only', 'WarmStart-only', 'Full Pipeline']) 389 | plt.ylabel('Solve Time (s)') 390 | plt.title('QP Solve Time Comparison') 391 | plt.grid(True, alpha=0.3) 392 | plt.savefig(results_dir / "performance_violin.png", dpi=300) 393 | 394 | # Plot training history 395 | print("Generating training history plots...") 396 | 397 | # Plot training loss 398 | plt.figure(figsize=(10, 5)) 399 | plt.subplot(1, 2, 1) 400 | plt.plot(cp_history['train_loss'], label='Train Loss') 401 | plt.plot(cp_history['val_loss'], label='Validation Loss') 402 | plt.xlabel('Epoch') 403 | plt.ylabel('Loss') 404 | plt.title('Constraint Predictor Training Loss') 405 | plt.legend() 406 | plt.grid(alpha=0.3) 407 | 408 | plt.subplot(1, 2, 2) 409 | plt.plot(ws_history['train_loss'], label='Train Loss') 410 | plt.plot(ws_history['val_loss'], label='Validation Loss') 411 | plt.xlabel('Epoch') 412 | plt.ylabel('Loss') 413 | plt.title('Warm Start Predictor Training Loss') 414 | plt.legend() 415 | plt.grid(alpha=0.3) 416 | 417 | plt.tight_layout() 418 | plt.savefig(results_dir / "training_history.png", dpi=300) 419 | 420 | # Create a summary plot with box plots 421 | print("Generating summary box plot...") 422 | plt.figure(figsize=(12, 8)) 423 | 424 | # Create box plot 425 | plt.boxplot( 426 | [baseline_times, transformer_times, constraint_only_times, warmstart_only_times], 427 | labels=['Baseline', 'Full Pipeline', 'Constraint-only', 'Warm Start-only'], 428 | showmeans=True 429 | ) 430 | 431 | plt.ylabel('Solve Time (s)') 432 | plt.title('QP Solve Time Comparison (Box Plot)') 433 | plt.grid(True, axis='y', alpha=0.3) 434 | 435 | # Add mean value annotations 436 | means = [ 437 | np.mean(baseline_times), 438 | np.mean(transformer_times), 439 | np.mean(constraint_only_times), 440 | np.mean(warmstart_only_times) 441 | ] 442 | 443 | for i, mean in enumerate(means): 444 | plt.text(i+1, mean, f'{mean:.6f}s', 445 | horizontalalignment='center', 446 | verticalalignment='bottom', 447 | fontweight='bold') 448 | 449 | plt.savefig(results_dir / "summary_boxplot.png", dpi=300) 450 | 451 | print(f"\nResults and visualizations saved to {output_dir}") 452 | print("\nDemo completed successfully!") 453 | print("=" * 60) 454 | 455 | if __name__ == "__main__": 456 | main() --------------------------------------------------------------------------------