├── transformermpc
├── data
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-311.pyc
│ │ ├── dataset.cpython-311.pyc
│ │ └── qp_generator.cpython-311.pyc
│ ├── qp_generator.py
│ └── dataset.py
├── demo
│ ├── __init__.py
│ └── demo.py
├── models
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-311.pyc
│ │ ├── constraint_predictor.cpython-311.pyc
│ │ └── warm_start_predictor.cpython-311.pyc
│ ├── constraint_predictor.py
│ └── warm_start_predictor.py
├── utils
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-311.pyc
│ │ ├── osqp_wrapper.cpython-311.pyc
│ │ └── visualization.cpython-311.pyc
│ ├── metrics.py
│ ├── osqp_wrapper.py
│ └── visualization.py
├── training
│ ├── __init__.py
│ └── __pycache__
│ │ ├── __init__.cpython-311.pyc
│ │ └── trainer.cpython-311.pyc
├── __pycache__
│ └── __init__.cpython-311.pyc
└── __init__.py
├── transformermpc.egg-info
├── not-zip-safe
├── dependency_links.txt
├── top_level.txt
├── entry_points.txt
├── requires.txt
├── SOURCES.txt
└── PKG-INFO
├── summary_plot.png
├── training_curves.png
├── boxplot_comparison.png
├── demo_results
├── models
│ ├── best_model.pt
│ └── constraint_predictor.pt
└── results
│ ├── speedup_barchart.png
│ ├── solve_time_boxplot.png
│ └── solve_time_violinplot.png
├── dist
├── transformermpc-0.1.6.tar.gz
└── transformermpc-0.1.6-py3-none-any.whl
├── tests
├── dist
│ ├── transformermpc-0.1.6.tar.gz
│ └── transformermpc-0.1.6-py3-none-any.whl
├── demo_results
│ ├── models
│ │ ├── best_model.pt
│ │ └── constraint_predictor.pt
│ └── results
│ │ ├── speedup_barchart.png
│ │ ├── solve_time_boxplot.png
│ │ └── solve_time_violinplot.png
├── test_import.py
├── scripts
│ ├── run_demo.py
│ ├── verify_package.py
│ ├── verify_structure.py
│ └── boxplot_demo.py
├── test_package_structure.py
├── test_solve.py
├── test_files.py
├── test_new_models.py
├── examples
│ └── example_usage.py
└── test_benchmark.py
├── setup.cfg
├── requirements.txt
├── MANIFEST.in
├── scripts
├── run_demo.py
├── verify_package.py
├── verify_structure.py
├── boxplot_demo.py
└── simple_demo.py
├── LICENSE
├── setup.py
├── pyproject.toml
├── examples
└── example_usage.py
└── README.md
/transformermpc/data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/transformermpc/demo/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/transformermpc/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/transformermpc/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/transformermpc/training/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/transformermpc.egg-info/not-zip-safe:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/transformermpc.egg-info/dependency_links.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/transformermpc.egg-info/top_level.txt:
--------------------------------------------------------------------------------
1 | transformermpc
2 |
--------------------------------------------------------------------------------
/summary_plot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/summary_plot.png
--------------------------------------------------------------------------------
/training_curves.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/training_curves.png
--------------------------------------------------------------------------------
/boxplot_comparison.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/boxplot_comparison.png
--------------------------------------------------------------------------------
/transformermpc.egg-info/entry_points.txt:
--------------------------------------------------------------------------------
1 | [console_scripts]
2 | transformermpc-demo = transformermpc.demo:main
3 |
--------------------------------------------------------------------------------
/demo_results/models/best_model.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/demo_results/models/best_model.pt
--------------------------------------------------------------------------------
/dist/transformermpc-0.1.6.tar.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/dist/transformermpc-0.1.6.tar.gz
--------------------------------------------------------------------------------
/tests/dist/transformermpc-0.1.6.tar.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/tests/dist/transformermpc-0.1.6.tar.gz
--------------------------------------------------------------------------------
/demo_results/results/speedup_barchart.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/demo_results/results/speedup_barchart.png
--------------------------------------------------------------------------------
/tests/demo_results/models/best_model.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/tests/demo_results/models/best_model.pt
--------------------------------------------------------------------------------
/demo_results/models/constraint_predictor.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/demo_results/models/constraint_predictor.pt
--------------------------------------------------------------------------------
/demo_results/results/solve_time_boxplot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/demo_results/results/solve_time_boxplot.png
--------------------------------------------------------------------------------
/dist/transformermpc-0.1.6-py3-none-any.whl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/dist/transformermpc-0.1.6-py3-none-any.whl
--------------------------------------------------------------------------------
/demo_results/results/solve_time_violinplot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/demo_results/results/solve_time_violinplot.png
--------------------------------------------------------------------------------
/tests/demo_results/results/speedup_barchart.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/tests/demo_results/results/speedup_barchart.png
--------------------------------------------------------------------------------
/tests/dist/transformermpc-0.1.6-py3-none-any.whl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/tests/dist/transformermpc-0.1.6-py3-none-any.whl
--------------------------------------------------------------------------------
/tests/demo_results/models/constraint_predictor.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/tests/demo_results/models/constraint_predictor.pt
--------------------------------------------------------------------------------
/tests/demo_results/results/solve_time_boxplot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/tests/demo_results/results/solve_time_boxplot.png
--------------------------------------------------------------------------------
/transformermpc/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/transformermpc/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | license_file = LICENSE
3 |
4 | [options]
5 | zip_safe = False
6 | include_package_data = True
7 |
8 | [options.package_data]
9 | * = *.py
--------------------------------------------------------------------------------
/tests/demo_results/results/solve_time_violinplot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/tests/demo_results/results/solve_time_violinplot.png
--------------------------------------------------------------------------------
/transformermpc/data/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/transformermpc/data/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/transformermpc/data/__pycache__/dataset.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/transformermpc/data/__pycache__/dataset.cpython-311.pyc
--------------------------------------------------------------------------------
/transformermpc/models/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/transformermpc/models/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/transformermpc/utils/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/transformermpc/utils/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/transformermpc/data/__pycache__/qp_generator.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/transformermpc/data/__pycache__/qp_generator.cpython-311.pyc
--------------------------------------------------------------------------------
/transformermpc/training/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/transformermpc/training/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/transformermpc/training/__pycache__/trainer.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/transformermpc/training/__pycache__/trainer.cpython-311.pyc
--------------------------------------------------------------------------------
/transformermpc/utils/__pycache__/osqp_wrapper.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/transformermpc/utils/__pycache__/osqp_wrapper.cpython-311.pyc
--------------------------------------------------------------------------------
/transformermpc/utils/__pycache__/visualization.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/transformermpc/utils/__pycache__/visualization.cpython-311.pyc
--------------------------------------------------------------------------------
/transformermpc/models/__pycache__/constraint_predictor.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/transformermpc/models/__pycache__/constraint_predictor.cpython-311.pyc
--------------------------------------------------------------------------------
/transformermpc/models/__pycache__/warm_start_predictor.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vrushabh27/transformermpc/HEAD/transformermpc/models/__pycache__/warm_start_predictor.cpython-311.pyc
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy>=1.20.0,<2.0.0
2 | scipy>=1.7.0
3 | torch>=1.9.0
4 | osqp>=0.6.2
5 | matplotlib>=3.4.0
6 | pandas>=1.3.0
7 | tqdm>=4.62.0
8 | scikit-learn>=0.24.0
9 | tensorboard>=2.7.0
10 | quadprog>=0.1.11
11 |
--------------------------------------------------------------------------------
/transformermpc.egg-info/requires.txt:
--------------------------------------------------------------------------------
1 | numpy>=1.20.0
2 | scipy>=1.7.0
3 | torch>=1.9.0
4 | osqp>=0.6.2
5 | matplotlib>=3.4.0
6 | pandas>=1.3.0
7 | tqdm>=4.62.0
8 | scikit-learn>=0.24.0
9 | tensorboard>=2.7.0
10 | quadprog>=0.1.11
11 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include LICENSE
2 | include README.md
3 | include requirements.txt
4 | recursive-include transformermpc *.py
5 | recursive-include transformermpc/data *.pkl *.npy
6 | recursive-include transformermpc/models *.pth *.pt
7 | recursive-include transformermpc/demo *.png *.jpg
--------------------------------------------------------------------------------
/tests/test_import.py:
--------------------------------------------------------------------------------
1 | import transformermpc
2 | print("Package imported successfully!")
3 |
4 | from transformermpc.models.constraint_predictor import ConstraintPredictor
5 | from transformermpc.models.warm_start_predictor import WarmStartPredictor
6 | print("Models imported successfully!")
7 |
8 | from transformermpc.demo.demo import parse_args
9 | print("Demo imported successfully!")
--------------------------------------------------------------------------------
/scripts/run_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import argparse
4 |
5 | # Add the current directory to Python path
6 | sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
7 |
8 | def main():
9 | # Import the demo module
10 | from transformermpc.transformermpc.demo.demo import main as demo_main
11 |
12 | # Run the demo
13 | demo_main()
14 |
15 | if __name__ == "__main__":
16 | main()
--------------------------------------------------------------------------------
/tests/scripts/run_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import argparse
4 |
5 | # Add the current directory to Python path
6 | sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
7 |
8 | def main():
9 | # Import the demo module
10 | from transformermpc.transformermpc.demo.demo import main as demo_main
11 |
12 | # Run the demo
13 | demo_main()
14 |
15 | if __name__ == "__main__":
16 | main()
--------------------------------------------------------------------------------
/scripts/verify_package.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import os
4 |
5 | def list_py_files(start_dir):
6 | print(f"Scanning directory: {start_dir}")
7 |
8 | for root, dirs, files in os.walk(start_dir):
9 | py_files = [f for f in files if f.endswith('.py')]
10 | if py_files:
11 | rel_path = os.path.relpath(root, start_dir)
12 | if rel_path == '.':
13 | rel_path = ''
14 | print(f"\nIn {rel_path or 'root'}:")
15 | for py_file in sorted(py_files):
16 | print(f" - {py_file}")
17 |
18 | if __name__ == "__main__":
19 | list_py_files("transformermpc")
--------------------------------------------------------------------------------
/tests/scripts/verify_package.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import os
4 |
5 | def list_py_files(start_dir):
6 | print(f"Scanning directory: {start_dir}")
7 |
8 | for root, dirs, files in os.walk(start_dir):
9 | py_files = [f for f in files if f.endswith('.py')]
10 | if py_files:
11 | rel_path = os.path.relpath(root, start_dir)
12 | if rel_path == '.':
13 | rel_path = ''
14 | print(f"\nIn {rel_path or 'root'}:")
15 | for py_file in sorted(py_files):
16 | print(f" - {py_file}")
17 |
18 | if __name__ == "__main__":
19 | list_py_files("transformermpc")
--------------------------------------------------------------------------------
/transformermpc.egg-info/SOURCES.txt:
--------------------------------------------------------------------------------
1 | LICENSE
2 | MANIFEST.in
3 | README.md
4 | pyproject.toml
5 | requirements.txt
6 | setup.cfg
7 | setup.py
8 | tests/test_benchmark.py
9 | tests/test_files.py
10 | tests/test_import.py
11 | tests/test_new_models.py
12 | tests/test_package_structure.py
13 | tests/test_solve.py
14 | transformermpc/__init__.py
15 | transformermpc.egg-info/PKG-INFO
16 | transformermpc.egg-info/SOURCES.txt
17 | transformermpc.egg-info/dependency_links.txt
18 | transformermpc.egg-info/entry_points.txt
19 | transformermpc.egg-info/not-zip-safe
20 | transformermpc.egg-info/requires.txt
21 | transformermpc.egg-info/top_level.txt
22 | transformermpc/data/__init__.py
23 | transformermpc/data/dataset.py
24 | transformermpc/data/qp_generator.py
25 | transformermpc/demo/__init__.py
26 | transformermpc/demo/demo.py
27 | transformermpc/models/__init__.py
28 | transformermpc/models/constraint_predictor.py
29 | transformermpc/models/warm_start_predictor.py
30 | transformermpc/training/__init__.py
31 | transformermpc/training/trainer.py
32 | transformermpc/utils/__init__.py
33 | transformermpc/utils/metrics.py
34 | transformermpc/utils/osqp_wrapper.py
35 | transformermpc/utils/visualization.py
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Vrushabh Zinage, Ahmed Khalil, Efstathios Bakolas
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | with open("README.md", "r", encoding="utf-8") as fh:
4 | long_description = fh.read()
5 |
6 | with open("requirements.txt", "r", encoding="utf-8") as fh:
7 | requirements = fh.read().splitlines()
8 |
9 | setup(
10 | name="transformermpc",
11 | version="0.1.6",
12 | author="Vrushabh Zinage, Ahmed Khalil, Efstathios Bakolas",
13 | author_email="vrushabh.zinage@gmail.com",
14 | description="Accelerating Model Predictive Control via Neural Networks",
15 | long_description=long_description,
16 | long_description_content_type="text/markdown",
17 | url="https://github.com/Vrushabh27/transformermpc",
18 | packages=[
19 | 'transformermpc',
20 | 'transformermpc.data',
21 | 'transformermpc.models',
22 | 'transformermpc.utils',
23 | 'transformermpc.training',
24 | 'transformermpc.demo',
25 | ],
26 | package_dir={'transformermpc': 'transformermpc'},
27 | classifiers=[
28 | "Programming Language :: Python :: 3",
29 | "License :: OSI Approved :: MIT License",
30 | "Operating System :: OS Independent",
31 | ],
32 | python_requires=">=3.7",
33 | install_requires=requirements,
34 | include_package_data=True,
35 | )
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=42", "wheel"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "transformermpc"
7 | version = "0.1.6"
8 | description = "Accelerating Model Predictive Control via Neural Networks"
9 | readme = "README.md"
10 | requires-python = ">=3.7"
11 | license = "MIT"
12 | authors = [
13 | {name = "Vrushabh Zinage", email = "vrushabh.zinage@e.com"},
14 | {name = "Ahmed Khalil"},
15 | {name = "Efstathios Bakolas"}
16 | ]
17 | maintainers = [
18 | {name = "Vrushabh Zinage", email = "vrushabh.zinage@e.com"}
19 | ]
20 | keywords = ["machine learning", "model predictive control", "quadratic programming"]
21 | classifiers = [
22 | "Programming Language :: Python :: 3",
23 | "Operating System :: OS Independent",
24 | ]
25 | dependencies = [
26 | "numpy>=1.20.0",
27 | "scipy>=1.7.0",
28 | "torch>=1.9.0",
29 | "osqp>=0.6.2",
30 | "matplotlib>=3.4.0",
31 | "pandas>=1.3.0",
32 | "tqdm>=4.62.0",
33 | "scikit-learn>=0.24.0",
34 | "tensorboard>=2.7.0",
35 | "quadprog>=0.1.11",
36 | ]
37 |
38 | [project.urls]
39 | "Homepage" = "https://github.com/vrushabh/transformermpc"
40 | "Bug Tracker" = "https://github.com/vrushabh/transformermpc/issues"
41 |
42 | [project.scripts]
43 | transformermpc-demo = "transformermpc.demo:main"
44 |
45 | [tool.setuptools]
46 | packages = [
47 | "transformermpc",
48 | "transformermpc.data",
49 | "transformermpc.models",
50 | "transformermpc.utils",
51 | "transformermpc.training",
52 | "transformermpc.demo",
53 | ]
54 |
55 | [tool.setuptools.package-data]
56 | transformermpc = ["data/*.pkl", "data/*.npy", "models/*.pth", "models/*.pt", "demo/*.png", "demo/*.jpg"]
57 |
58 | [tool.black]
59 | line-length = 88
60 | target-version = ["py37", "py38", "py39", "py310"]
61 |
62 | [tool.isort]
63 | profile = "black"
64 | line_length = 88
65 |
66 | [tool.pytest]
67 | testpaths = ["tests"]
68 | python_files = "test_*.py"
--------------------------------------------------------------------------------
/tests/test_package_structure.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import os
4 | import importlib
5 | import pkgutil
6 | import transformermpc
7 |
8 | def check_module_exists(module_name):
9 | """Check if a module exists without actually importing it."""
10 | try:
11 | spec = importlib.util.find_spec(module_name)
12 | return spec is not None
13 | except (ModuleNotFoundError, AttributeError):
14 | return False
15 |
16 | def main():
17 | """Verify the transformermpc package structure."""
18 | print(f"Transformermpc version: {transformermpc.__version__}")
19 | print("\nPackage structure:")
20 |
21 | # Get the package path
22 | pkg_path = os.path.dirname(transformermpc.__file__)
23 | print(f"Package path: {pkg_path}")
24 |
25 | # List all subpackages and modules
26 | print("\nSubpackages and modules:")
27 | for _, name, is_pkg in pkgutil.iter_modules([pkg_path]):
28 | if is_pkg:
29 | print(f" - Subpackage: {name}")
30 | # Check subpackage modules
31 | subpkg_path = os.path.join(pkg_path, name)
32 | for _, submodule, _ in pkgutil.iter_modules([subpkg_path]):
33 | print(f" - Module: {name}.{submodule}")
34 | else:
35 | print(f" - Module: {name}")
36 |
37 | # Check expected module paths
38 | expected_modules = [
39 | "transformermpc.data.dataset",
40 | "transformermpc.data.qp_generator",
41 | "transformermpc.models.constraint_predictor",
42 | "transformermpc.models.warm_start_predictor",
43 | "transformermpc.utils.metrics",
44 | "transformermpc.utils.osqp_wrapper",
45 | "transformermpc.utils.visualization",
46 | "transformermpc.training.trainer",
47 | "transformermpc.demo.demo"
48 | ]
49 |
50 | print("\nChecking key modules:")
51 | for module in expected_modules:
52 | exists = check_module_exists(module)
53 | print(f" - {module}: {'✓' if exists else '✗'}")
54 |
55 | print("\nPackage structure verification complete!")
56 |
57 | if __name__ == "__main__":
58 | main()
--------------------------------------------------------------------------------
/tests/test_solve.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import time
3 | from transformermpc import TransformerMPC
4 |
5 | def main():
6 | print("Testing TransformerMPC solve method...")
7 |
8 | # Create a simple QP problem for testing
9 | n = 5 # dimension of the problem
10 | m = 3 # number of constraints
11 |
12 | # Create random matrices for the test problem
13 | np.random.seed(42)
14 | Q = np.random.rand(n, n)
15 | Q = Q.T @ Q # Make Q positive definite
16 | c = np.random.rand(n)
17 | A = np.random.rand(m, n)
18 | b = np.random.rand(m)
19 |
20 | # Create the TransformerMPC solver
21 | solver = TransformerMPC(
22 | use_constraint_predictor=False,
23 | use_warm_start_predictor=False
24 | )
25 |
26 | # Test the solve method
27 | start_time = time.time()
28 | x, solve_time = solver.solve(Q=Q, c=c, A=A, b=b)
29 | total_time = time.time() - start_time
30 |
31 | print(f"Solution vector: {x}")
32 | print(f"Solver time reported: {solve_time:.6f} seconds")
33 | print(f"Total time taken: {total_time:.6f} seconds")
34 |
35 | # Verify the solution with a simple objective function calculation
36 | objective = 0.5 * x.T @ Q @ x + c.T @ x
37 | print(f"Objective function value: {objective:.6f}")
38 |
39 | # Test the baseline solver
40 | baseline_start_time = time.time()
41 | x_baseline, baseline_time = solver.solve_baseline(Q=Q, c=c, A=A, b=b)
42 | baseline_total_time = time.time() - baseline_start_time
43 |
44 | print("\nBaseline solver results:")
45 | print(f"Solution vector: {x_baseline}")
46 | print(f"Solver time reported: {baseline_time:.6f} seconds")
47 | print(f"Total time taken: {baseline_total_time:.6f} seconds")
48 |
49 | # Calculate objective function for baseline solution
50 | baseline_objective = 0.5 * x_baseline.T @ Q @ x_baseline + c.T @ x_baseline
51 | print(f"Objective function value: {baseline_objective:.6f}")
52 |
53 | # Check if solutions are similar
54 | solution_diff = np.linalg.norm(x - x_baseline)
55 | print(f"\nNorm of solution difference: {solution_diff:.6f}")
56 |
57 | if __name__ == "__main__":
58 | main()
--------------------------------------------------------------------------------
/tests/test_files.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import os
4 | import importlib.util
5 | import sys
6 |
7 | def print_file_exists(module_path):
8 | """Check if a Python module file exists and print the result."""
9 | exists = os.path.isfile(module_path)
10 | print(f"{module_path}: {'✓ Exists' if exists else '✗ Missing'}")
11 | return exists
12 |
13 | def main():
14 | """
15 | Check if key files in the transformermpc package exist
16 | without importing them.
17 | """
18 | package_dir = os.path.dirname(os.path.abspath(__file__))
19 | transformermpc_dir = os.path.join(package_dir, "transformermpc")
20 |
21 | print(f"\nChecking key files in {transformermpc_dir}...\n")
22 |
23 | # Check top-level files
24 | files_to_check = [
25 | os.path.join(transformermpc_dir, "__init__.py"),
26 | ]
27 |
28 | # Check data module files
29 | data_dir = os.path.join(transformermpc_dir, "data")
30 | for file in ["__init__.py", "dataset.py", "qp_generator.py"]:
31 | files_to_check.append(os.path.join(data_dir, file))
32 |
33 | # Check models module files
34 | models_dir = os.path.join(transformermpc_dir, "models")
35 | for file in ["__init__.py", "constraint_predictor.py", "warm_start_predictor.py"]:
36 | files_to_check.append(os.path.join(models_dir, file))
37 |
38 | # Check utils module files
39 | utils_dir = os.path.join(transformermpc_dir, "utils")
40 | for file in ["__init__.py", "metrics.py", "osqp_wrapper.py", "visualization.py"]:
41 | files_to_check.append(os.path.join(utils_dir, file))
42 |
43 | # Check training module files
44 | training_dir = os.path.join(transformermpc_dir, "training")
45 | for file in ["__init__.py", "trainer.py"]:
46 | files_to_check.append(os.path.join(training_dir, file))
47 |
48 | # Check demo module files
49 | demo_dir = os.path.join(transformermpc_dir, "demo")
50 | for file in ["__init__.py", "demo.py"]:
51 | files_to_check.append(os.path.join(demo_dir, file))
52 |
53 | # Check each file
54 | all_exist = True
55 | for file_path in files_to_check:
56 | exists = print_file_exists(file_path)
57 | all_exist = all_exist and exists
58 |
59 | # Print summary
60 | print("\nSummary:")
61 | if all_exist:
62 | print("✓ All key files found in the package!")
63 | else:
64 | print("✗ Some files are missing from the package.")
65 |
66 | return 0 if all_exist else 1
67 |
68 | if __name__ == "__main__":
69 | sys.exit(main())
--------------------------------------------------------------------------------
/tests/test_new_models.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | """
4 | Test script to verify the proper operation of the new vanilla transformer models.
5 | """
6 |
7 | import torch
8 | import numpy as np
9 |
10 | # Fix the import paths
11 | from transformermpc.models.constraint_predictor import ConstraintPredictor
12 | from transformermpc.models.warm_start_predictor import WarmStartPredictor
13 |
14 | def test_constraint_predictor():
15 | """Test the ConstraintPredictor model."""
16 | print("Testing ConstraintPredictor...")
17 |
18 | # Create model instance
19 | model = ConstraintPredictor(
20 | input_dim=20,
21 | hidden_dim=64,
22 | num_constraints=10,
23 | num_layers=2,
24 | num_heads=4,
25 | dropout=0.1
26 | )
27 |
28 | # Test forward pass with random data
29 | batch_size = 4
30 | x = torch.randn(batch_size, 20)
31 | output = model(x)
32 |
33 | print(f"Output shape: {output.shape}")
34 | assert output.shape == torch.Size([batch_size, 10]), f"Expected shape {torch.Size([batch_size, 10])}, got {output.shape}"
35 |
36 | # Test predict method
37 | predictions = model.predict(x)
38 | print(f"Prediction shape: {predictions.shape}")
39 | print(f"Sample prediction: {predictions[0, :5]}")
40 |
41 | print("ConstraintPredictor test passed!")
42 |
43 | def test_warm_start_predictor():
44 | """Test the WarmStartPredictor model."""
45 | print("Testing WarmStartPredictor...")
46 |
47 | # Create model instance
48 | model = WarmStartPredictor(
49 | input_dim=20,
50 | hidden_dim=64,
51 | output_dim=8,
52 | num_layers=2,
53 | num_heads=4,
54 | dropout=0.1
55 | )
56 |
57 | # Test forward pass with random data
58 | batch_size = 4
59 | x = torch.randn(batch_size, 20)
60 | output = model(x)
61 |
62 | print(f"Output shape: {output.shape}")
63 | assert output.shape == torch.Size([batch_size, 8]), f"Expected shape {torch.Size([batch_size, 8])}, got {output.shape}"
64 |
65 | # Test predict method
66 | predictions = model.predict(x)
67 | print(f"Prediction shape: {predictions.shape}")
68 | print(f"Sample prediction: {predictions[0, :5]}")
69 |
70 | print("WarmStartPredictor test passed!")
71 |
72 | def main():
73 | """Run the tests."""
74 | print("Starting tests for the new transformer models...")
75 |
76 | test_constraint_predictor()
77 | print("\n")
78 | test_warm_start_predictor()
79 |
80 | print("\nAll tests completed successfully!")
81 |
82 | if __name__ == "__main__":
83 | main()
--------------------------------------------------------------------------------
/scripts/verify_structure.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | Verify the reorganized TransformerMPC package structure.
4 |
5 | This script checks that the package structure follows standard Python
6 | package conventions and that all key modules are present.
7 | """
8 |
9 | import os
10 | import sys
11 |
12 | def check_dir_exists(path, name):
13 | """Check if a directory exists and print the result."""
14 | exists = os.path.isdir(path)
15 | print(f"{name} directory: {'✓' if exists else '✗'}")
16 | return exists
17 |
18 | def check_file_exists(path, name):
19 | """Check if a file exists and print the result."""
20 | exists = os.path.isfile(path)
21 | print(f"{name} file: {'✓' if exists else '✗'}")
22 | return exists
23 |
24 | def main():
25 | """Verify the package structure."""
26 | # Get project root directory (where this script is located)
27 | script_dir = os.path.dirname(os.path.abspath(__file__))
28 | project_root = os.path.dirname(script_dir)
29 |
30 | print("\n" + "="*60)
31 | print(" TransformerMPC Package Structure Verification")
32 | print("="*60)
33 |
34 | # Check top-level directories
35 | print("\nChecking top-level directories:")
36 | dirs_ok = True
37 | for dirname in ["transformermpc", "tests", "scripts", "examples"]:
38 | path = os.path.join(project_root, dirname)
39 | dirs_ok = check_dir_exists(path, dirname) and dirs_ok
40 |
41 | # Check top-level files
42 | print("\nChecking top-level files:")
43 | files_ok = True
44 | for filename in ["setup.py", "pyproject.toml", "requirements.txt", "MANIFEST.in", "LICENSE", "README.md"]:
45 | path = os.path.join(project_root, filename)
46 | files_ok = check_file_exists(path, filename) and files_ok
47 |
48 | # Check package structure
49 | print("\nChecking package structure:")
50 | pkg_dir = os.path.join(project_root, "transformermpc")
51 | pkg_ok = True
52 |
53 | # Check subdirectories
54 | for dirname in ["data", "models", "utils", "training", "demo"]:
55 | path = os.path.join(pkg_dir, dirname)
56 | pkg_ok = check_dir_exists(path, f"transformermpc/{dirname}") and pkg_ok
57 |
58 | # Check key files
59 | pkg_ok = check_file_exists(os.path.join(pkg_dir, "__init__.py"), "transformermpc/__init__.py") and pkg_ok
60 |
61 | # Check for key implementation files
62 | key_files = [
63 | ("transformermpc/models/constraint_predictor.py", "Constraint Predictor"),
64 | ("transformermpc/models/warm_start_predictor.py", "Warm Start Predictor"),
65 | ("transformermpc/data/dataset.py", "Dataset"),
66 | ("transformermpc/data/qp_generator.py", "QP Generator"),
67 | ("transformermpc/utils/metrics.py", "Metrics"),
68 | ("transformermpc/utils/osqp_wrapper.py", "OSQP Wrapper"),
69 | ("transformermpc/demo/demo.py", "Demo")
70 | ]
71 |
72 | print("\nChecking key implementation files:")
73 | for filepath, desc in key_files:
74 | path = os.path.join(project_root, filepath)
75 | pkg_ok = check_file_exists(path, desc) and pkg_ok
76 |
77 | # Print summary
78 | print("\n" + "="*60)
79 | if dirs_ok and files_ok and pkg_ok:
80 | print("✓ Package structure verification PASSED!")
81 | print(" The TransformerMPC package has been properly reorganized.")
82 | else:
83 | print("✗ Package structure verification FAILED!")
84 | print(" Some expected files or directories are missing.")
85 | print("="*60 + "\n")
86 |
87 | return 0 if dirs_ok and files_ok and pkg_ok else 1
88 |
89 | if __name__ == "__main__":
90 | sys.exit(main())
--------------------------------------------------------------------------------
/tests/scripts/verify_structure.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | Verify the reorganized TransformerMPC package structure.
4 |
5 | This script checks that the package structure follows standard Python
6 | package conventions and that all key modules are present.
7 | """
8 |
9 | import os
10 | import sys
11 |
12 | def check_dir_exists(path, name):
13 | """Check if a directory exists and print the result."""
14 | exists = os.path.isdir(path)
15 | print(f"{name} directory: {'✓' if exists else '✗'}")
16 | return exists
17 |
18 | def check_file_exists(path, name):
19 | """Check if a file exists and print the result."""
20 | exists = os.path.isfile(path)
21 | print(f"{name} file: {'✓' if exists else '✗'}")
22 | return exists
23 |
24 | def main():
25 | """Verify the package structure."""
26 | # Get project root directory (where this script is located)
27 | script_dir = os.path.dirname(os.path.abspath(__file__))
28 | project_root = os.path.dirname(script_dir)
29 |
30 | print("\n" + "="*60)
31 | print(" TransformerMPC Package Structure Verification")
32 | print("="*60)
33 |
34 | # Check top-level directories
35 | print("\nChecking top-level directories:")
36 | dirs_ok = True
37 | for dirname in ["transformermpc", "tests", "scripts", "examples"]:
38 | path = os.path.join(project_root, dirname)
39 | dirs_ok = check_dir_exists(path, dirname) and dirs_ok
40 |
41 | # Check top-level files
42 | print("\nChecking top-level files:")
43 | files_ok = True
44 | for filename in ["setup.py", "pyproject.toml", "requirements.txt", "MANIFEST.in", "LICENSE", "README.md"]:
45 | path = os.path.join(project_root, filename)
46 | files_ok = check_file_exists(path, filename) and files_ok
47 |
48 | # Check package structure
49 | print("\nChecking package structure:")
50 | pkg_dir = os.path.join(project_root, "transformermpc")
51 | pkg_ok = True
52 |
53 | # Check subdirectories
54 | for dirname in ["data", "models", "utils", "training", "demo"]:
55 | path = os.path.join(pkg_dir, dirname)
56 | pkg_ok = check_dir_exists(path, f"transformermpc/{dirname}") and pkg_ok
57 |
58 | # Check key files
59 | pkg_ok = check_file_exists(os.path.join(pkg_dir, "__init__.py"), "transformermpc/__init__.py") and pkg_ok
60 |
61 | # Check for key implementation files
62 | key_files = [
63 | ("transformermpc/models/constraint_predictor.py", "Constraint Predictor"),
64 | ("transformermpc/models/warm_start_predictor.py", "Warm Start Predictor"),
65 | ("transformermpc/data/dataset.py", "Dataset"),
66 | ("transformermpc/data/qp_generator.py", "QP Generator"),
67 | ("transformermpc/utils/metrics.py", "Metrics"),
68 | ("transformermpc/utils/osqp_wrapper.py", "OSQP Wrapper"),
69 | ("transformermpc/demo/demo.py", "Demo")
70 | ]
71 |
72 | print("\nChecking key implementation files:")
73 | for filepath, desc in key_files:
74 | path = os.path.join(project_root, filepath)
75 | pkg_ok = check_file_exists(path, desc) and pkg_ok
76 |
77 | # Print summary
78 | print("\n" + "="*60)
79 | if dirs_ok and files_ok and pkg_ok:
80 | print("✓ Package structure verification PASSED!")
81 | print(" The TransformerMPC package has been properly reorganized.")
82 | else:
83 | print("✗ Package structure verification FAILED!")
84 | print(" Some expected files or directories are missing.")
85 | print("="*60 + "\n")
86 |
87 | return 0 if dirs_ok and files_ok and pkg_ok else 1
88 |
89 | if __name__ == "__main__":
90 | sys.exit(main())
--------------------------------------------------------------------------------
/examples/example_usage.py:
--------------------------------------------------------------------------------
1 | """
2 | Example usage of TransformerMPC library.
3 |
4 | This script demonstrates how to use the TransformerMPC package to accelerate
5 | quadratic programming (QP) solvers for Model Predictive Control (MPC).
6 | """
7 |
8 | import torch
9 | import numpy as np
10 | import matplotlib.pyplot as plt
11 | from pathlib import Path
12 |
13 | # Import TransformerMPC modules
14 | from transformermpc.data.qp_generator import QPGenerator
15 | from transformermpc.data.dataset import QPDataset
16 | from transformermpc.models.constraint_predictor import ConstraintPredictor
17 | from transformermpc.models.warm_start_predictor import WarmStartPredictor
18 | from transformermpc.trainer.model_trainer import ModelTrainer
19 | from transformermpc.utils.osqp_wrapper import OSQPSolver
20 | from transformermpc.utils.metrics import compute_solve_time_metrics
21 | from transformermpc.utils.visualization import plot_solve_time_comparison
22 |
23 | def main():
24 | """Example of training and using TransformerMPC."""
25 | print("TransformerMPC Example Usage")
26 |
27 | # 1. Generate QP problems
28 | print("\n1. Generating QP problems...")
29 | state_dim = 4
30 | input_dim = 2
31 | horizon = 10
32 | num_samples = 500 # Small number for quick example
33 |
34 | qp_generator = QPGenerator(
35 | state_dim=state_dim,
36 | input_dim=input_dim,
37 | horizon=horizon,
38 | seed=42
39 | )
40 |
41 | qp_problems = qp_generator.generate_batch(num_samples)
42 | print(f"Generated {len(qp_problems)} QP problems")
43 |
44 | # 2. Create dataset
45 | print("\n2. Creating dataset...")
46 | dataset = QPDataset(
47 | qp_problems=qp_problems,
48 | precompute_solutions=True,
49 | feature_normalization=True
50 | )
51 |
52 | train_dataset, val_dataset = dataset.split(test_size=0.2)
53 | print(f"Training dataset size: {len(train_dataset)}")
54 | print(f"Validation dataset size: {len(val_dataset)}")
55 |
56 | # 3. Create models
57 | print("\n3. Creating models...")
58 | # Get dimensions from a sample
59 | sample = train_dataset[0]
60 | feature_dim = sample['features'].shape[0]
61 | num_constraints = sample['active_constraints'].shape[0]
62 | solution_dim = sample['solution'].shape[0]
63 |
64 | # Create constraint predictor
65 | cp_model = ConstraintPredictor(
66 | input_dim=feature_dim,
67 | hidden_dim=128,
68 | num_constraints=num_constraints
69 | )
70 |
71 | # Create warm start predictor
72 | ws_model = WarmStartPredictor(
73 | input_dim=feature_dim,
74 | hidden_dim=256,
75 | output_dim=solution_dim
76 | )
77 |
78 | # 4. Train models
79 | print("\n4. Training models...")
80 | # Train constraint predictor
81 | cp_trainer = ModelTrainer(
82 | model=cp_model,
83 | train_dataset=train_dataset,
84 | val_dataset=val_dataset,
85 | target_key='active_constraints',
86 | output_dir=Path('example_output/models')
87 | )
88 |
89 | cp_metrics = cp_trainer.train(
90 | num_epochs=10, # Small number for quick example
91 | batch_size=32,
92 | learning_rate=1e-3
93 | )
94 |
95 | # Train warm start predictor
96 | ws_trainer = ModelTrainer(
97 | model=ws_model,
98 | train_dataset=train_dataset,
99 | val_dataset=val_dataset,
100 | target_key='solution',
101 | output_dir=Path('example_output/models')
102 | )
103 |
104 | ws_metrics = ws_trainer.train(
105 | num_epochs=10, # Small number for quick example
106 | batch_size=32,
107 | learning_rate=1e-3
108 | )
109 |
110 | # 5. Evaluate performance
111 | print("\n5. Evaluating performance...")
112 | solver = OSQPSolver()
113 |
114 | # List to store timing results
115 | baseline_times = []
116 | transformer_times = []
117 |
118 | # Sample a few problems for demonstration
119 | test_indices = np.random.choice(len(val_dataset), size=10, replace=False)
120 |
121 | for idx in test_indices:
122 | # Get problem and features
123 | sample = val_dataset[idx]
124 | problem = val_dataset.get_problem(idx)
125 | features = sample['features']
126 |
127 | # Predict active constraints and warm start
128 | with torch.no_grad():
129 | pred_active = cp_model(features.unsqueeze(0)).squeeze(0) > 0.5
130 | pred_solution = ws_model(features.unsqueeze(0)).squeeze(0)
131 |
132 | # Baseline solve time (standard OSQP)
133 | _, baseline_time = solver.solve_with_time(
134 | Q=problem.Q,
135 | c=problem.c,
136 | A=problem.A,
137 | b=problem.b
138 | )
139 |
140 | # TransformerMPC solve time
141 | _, transformer_time, _ = solver.solve_pipeline(
142 | Q=problem.Q,
143 | c=problem.c,
144 | A=problem.A,
145 | b=problem.b,
146 | active_constraints=pred_active.numpy(),
147 | warm_start=pred_solution.numpy(),
148 | fallback_on_violation=True
149 | )
150 |
151 | baseline_times.append(baseline_time)
152 | transformer_times.append(transformer_time)
153 |
154 | # Calculate metrics
155 | baseline_times = np.array(baseline_times)
156 | transformer_times = np.array(transformer_times)
157 |
158 | metrics = compute_solve_time_metrics(baseline_times, transformer_times)
159 |
160 | print("\nPerformance Summary:")
161 | print(f"Mean baseline time: {metrics['mean_baseline_time']:.6f}s")
162 | print(f"Mean transformer time: {metrics['mean_transformer_time']:.6f}s")
163 | print(f"Mean speedup: {metrics['mean_speedup']:.2f}x")
164 | print(f"Median speedup: {metrics['median_speedup']:.2f}x")
165 |
166 | # Optional: Plot results
167 | output_dir = Path('example_output/results')
168 | output_dir.mkdir(parents=True, exist_ok=True)
169 |
170 | plot_solve_time_comparison(
171 | baseline_times=baseline_times,
172 | transformer_times=transformer_times,
173 | save_path=output_dir / 'solve_time_comparison.png'
174 | )
175 |
176 | print(f"\nResults saved to {output_dir}")
177 | print("\nExample completed successfully!")
178 |
179 | if __name__ == "__main__":
180 | main()
--------------------------------------------------------------------------------
/tests/examples/example_usage.py:
--------------------------------------------------------------------------------
1 | """
2 | Example usage of TransformerMPC library.
3 |
4 | This script demonstrates how to use the TransformerMPC package to accelerate
5 | quadratic programming (QP) solvers for Model Predictive Control (MPC).
6 | """
7 |
8 | import torch
9 | import numpy as np
10 | import matplotlib.pyplot as plt
11 | from pathlib import Path
12 |
13 | # Import TransformerMPC modules
14 | from transformermpc.data.qp_generator import QPGenerator
15 | from transformermpc.data.dataset import QPDataset
16 | from transformermpc.models.constraint_predictor import ConstraintPredictor
17 | from transformermpc.models.warm_start_predictor import WarmStartPredictor
18 | from transformermpc.trainer.model_trainer import ModelTrainer
19 | from transformermpc.utils.osqp_wrapper import OSQPSolver
20 | from transformermpc.utils.metrics import compute_solve_time_metrics
21 | from transformermpc.utils.visualization import plot_solve_time_comparison
22 |
23 | def main():
24 | """Example of training and using TransformerMPC."""
25 | print("TransformerMPC Example Usage")
26 |
27 | # 1. Generate QP problems
28 | print("\n1. Generating QP problems...")
29 | state_dim = 4
30 | input_dim = 2
31 | horizon = 10
32 | num_samples = 500 # Small number for quick example
33 |
34 | qp_generator = QPGenerator(
35 | state_dim=state_dim,
36 | input_dim=input_dim,
37 | horizon=horizon,
38 | seed=42
39 | )
40 |
41 | qp_problems = qp_generator.generate_batch(num_samples)
42 | print(f"Generated {len(qp_problems)} QP problems")
43 |
44 | # 2. Create dataset
45 | print("\n2. Creating dataset...")
46 | dataset = QPDataset(
47 | qp_problems=qp_problems,
48 | precompute_solutions=True,
49 | feature_normalization=True
50 | )
51 |
52 | train_dataset, val_dataset = dataset.split(test_size=0.2)
53 | print(f"Training dataset size: {len(train_dataset)}")
54 | print(f"Validation dataset size: {len(val_dataset)}")
55 |
56 | # 3. Create models
57 | print("\n3. Creating models...")
58 | # Get dimensions from a sample
59 | sample = train_dataset[0]
60 | feature_dim = sample['features'].shape[0]
61 | num_constraints = sample['active_constraints'].shape[0]
62 | solution_dim = sample['solution'].shape[0]
63 |
64 | # Create constraint predictor
65 | cp_model = ConstraintPredictor(
66 | input_dim=feature_dim,
67 | hidden_dim=128,
68 | num_constraints=num_constraints
69 | )
70 |
71 | # Create warm start predictor
72 | ws_model = WarmStartPredictor(
73 | input_dim=feature_dim,
74 | hidden_dim=256,
75 | output_dim=solution_dim
76 | )
77 |
78 | # 4. Train models
79 | print("\n4. Training models...")
80 | # Train constraint predictor
81 | cp_trainer = ModelTrainer(
82 | model=cp_model,
83 | train_dataset=train_dataset,
84 | val_dataset=val_dataset,
85 | target_key='active_constraints',
86 | output_dir=Path('example_output/models')
87 | )
88 |
89 | cp_metrics = cp_trainer.train(
90 | num_epochs=10, # Small number for quick example
91 | batch_size=32,
92 | learning_rate=1e-3
93 | )
94 |
95 | # Train warm start predictor
96 | ws_trainer = ModelTrainer(
97 | model=ws_model,
98 | train_dataset=train_dataset,
99 | val_dataset=val_dataset,
100 | target_key='solution',
101 | output_dir=Path('example_output/models')
102 | )
103 |
104 | ws_metrics = ws_trainer.train(
105 | num_epochs=10, # Small number for quick example
106 | batch_size=32,
107 | learning_rate=1e-3
108 | )
109 |
110 | # 5. Evaluate performance
111 | print("\n5. Evaluating performance...")
112 | solver = OSQPSolver()
113 |
114 | # List to store timing results
115 | baseline_times = []
116 | transformer_times = []
117 |
118 | # Sample a few problems for demonstration
119 | test_indices = np.random.choice(len(val_dataset), size=10, replace=False)
120 |
121 | for idx in test_indices:
122 | # Get problem and features
123 | sample = val_dataset[idx]
124 | problem = val_dataset.get_problem(idx)
125 | features = sample['features']
126 |
127 | # Predict active constraints and warm start
128 | with torch.no_grad():
129 | pred_active = cp_model(features.unsqueeze(0)).squeeze(0) > 0.5
130 | pred_solution = ws_model(features.unsqueeze(0)).squeeze(0)
131 |
132 | # Baseline solve time (standard OSQP)
133 | _, baseline_time = solver.solve_with_time(
134 | Q=problem.Q,
135 | c=problem.c,
136 | A=problem.A,
137 | b=problem.b
138 | )
139 |
140 | # TransformerMPC solve time
141 | _, transformer_time, _ = solver.solve_pipeline(
142 | Q=problem.Q,
143 | c=problem.c,
144 | A=problem.A,
145 | b=problem.b,
146 | active_constraints=pred_active.numpy(),
147 | warm_start=pred_solution.numpy(),
148 | fallback_on_violation=True
149 | )
150 |
151 | baseline_times.append(baseline_time)
152 | transformer_times.append(transformer_time)
153 |
154 | # Calculate metrics
155 | baseline_times = np.array(baseline_times)
156 | transformer_times = np.array(transformer_times)
157 |
158 | metrics = compute_solve_time_metrics(baseline_times, transformer_times)
159 |
160 | print("\nPerformance Summary:")
161 | print(f"Mean baseline time: {metrics['mean_baseline_time']:.6f}s")
162 | print(f"Mean transformer time: {metrics['mean_transformer_time']:.6f}s")
163 | print(f"Mean speedup: {metrics['mean_speedup']:.2f}x")
164 | print(f"Median speedup: {metrics['median_speedup']:.2f}x")
165 |
166 | # Optional: Plot results
167 | output_dir = Path('example_output/results')
168 | output_dir.mkdir(parents=True, exist_ok=True)
169 |
170 | plot_solve_time_comparison(
171 | baseline_times=baseline_times,
172 | transformer_times=transformer_times,
173 | save_path=output_dir / 'solve_time_comparison.png'
174 | )
175 |
176 | print(f"\nResults saved to {output_dir}")
177 | print("\nExample completed successfully!")
178 |
179 | if __name__ == "__main__":
180 | main()
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
TransformerMPC: Accelerating Model Predictive Control via Transformers [ICRA '25]
4 |
5 | Vrushabh Zinage1
6 | ·
7 | Ahmed Khalil1
8 | ·
9 | Efstathios Bakolas1
10 |
11 |
12 |
13 |
14 |
15 | 1University of Texas at Austin
16 |
17 |
18 |
19 | [](https://arxiv.org/abs/2409.09266) [](https://transformer-mpc.github.io/)
20 |
21 |
22 |
23 | ## Overview
24 |
25 | TransformerMPC improves the computational efficiency of Model Predictive Control (MPC) problems using transformer based NN models. It employs the following two prediction models:
26 |
27 | 1. **Constraint Predictor**: Identifies inactive constraints in MPC formulations
28 | 2. **Warm Start Predictor**: Generates better initial points for MPC solvers
29 |
30 | By combining these models, TransformerMPC significantly reduces computation time while maintaining solution quality.
31 |
32 | ## Package Structure
33 |
34 | The package is organized with a standard Python package structure:
35 |
36 | ```
37 | transformermpc/
38 | ├── transformermpc/ # Core package module
39 | │ ├── data/ # Data generation utilities
40 | │ ├── models/ # Model implementations
41 | │ ├── utils/ # Utility functions and metrics
42 | │ ├── training/ # Training infrastructure
43 | │ └── demo/ # Demo scripts
44 | ├── scripts/ # Demo and utility scripts
45 | ├── tests/ # Testing infrastructure
46 | ├── setup.py # Package installation script
47 | └── requirements.txt # Dependencies
48 | ```
49 |
50 | ## Installation
51 |
52 | Install directly from PyPI:
53 |
54 | ```bash
55 | pip install transformermpc
56 | ```
57 |
58 | Or install from source:
59 |
60 | ```bash
61 | git clone https://github.com/Vrushabh27/transformermpc.git
62 | cd transformermpc
63 | pip install -e .
64 | ```
65 |
66 | ## Dependencies
67 |
68 | - Python >= 3.7
69 | - PyTorch >= 1.9.0
70 | - OSQP >= 0.6.2
71 | - NumPy, SciPy, and other scientific computing libraries
72 | - Additional dependencies specified in requirements.txt
73 |
74 | ## Running the Demos
75 |
76 | The package includes several demo scripts to showcase its capabilities:
77 |
78 | ### Boxplot Demo (Recommended)
79 |
80 | ```bash
81 | python scripts/boxplot_demo.py
82 | ```
83 |
84 | This script provides a visual comparison of different QP solving strategies using randomly generated problems, without requiring model training. It demonstrates the core concepts behind TransformerMPC by showing the performance impact of:
85 | - Removing inactive constraints
86 | - Using warm starts with different qualities
87 | - Combining these strategies
88 |
89 | The visualizations include boxplots, violin plots, and bar charts comparing solve times across different strategies.
90 |
91 | ### Simple Demo
92 |
93 | ```bash
94 | python scripts/simple_demo.py
95 | ```
96 |
97 | This script demonstrates the complete pipeline: generating QP problems, training models, and evaluating performance. After completion, it saves performance comparison plots in the `demo_results/results` directory.
98 |
99 | ### Customizing Demo Parameters
100 |
101 | You can customize the boxplot demo by modifying parameters:
102 |
103 | ```bash
104 | # Generate more problems with different dimensions
105 | python scripts/boxplot_demo.py --num_samples 50 --state_dim 6 --input_dim 3 --horizon 10
106 |
107 | # Save results to a custom directory
108 | python scripts/boxplot_demo.py --output_dir my_results
109 | ```
110 |
111 | Similarly, for the simple demo:
112 |
113 | ```bash
114 | # Generate QP problems with specific parameters
115 | python scripts/simple_demo.py --num_samples 200 --state_dim 6 --input_dim 3 --horizon 10
116 |
117 | # Customize training parameters
118 | python scripts/simple_demo.py --epochs 20 --batch_size 32
119 |
120 | # Use GPU for training if available
121 | python scripts/simple_demo.py --use_gpu
122 | ```
123 |
124 | ## Usage in Projects
125 |
126 | ### Basic Example
127 |
128 | ```python
129 | from transformermpc import TransformerMPC
130 | import numpy as np
131 |
132 | # Define your QP problem parameters
133 | Q = np.array([[4.0, 1.0], [1.0, 2.0]])
134 | c = np.array([1.0, 1.0])
135 | A = np.array([[-1.0, 0.0], [0.0, -1.0], [-1.0, -1.0], [1.0, 1.0]])
136 | b = np.array([0.0, 0.0, -1.0, 2.0])
137 |
138 | # Initialize the TransformerMPC solver
139 | solver = TransformerMPC()
140 |
141 | # Solve with model acceleration
142 | solution, solve_time = solver.solve(Q, c, A, b)
143 |
144 | print(f"Solution: {solution}")
145 | print(f"Solve time: {solve_time} seconds")
146 | ```
147 |
148 | ### General Usage
149 |
150 | ```python
151 | from transformermpc import TransformerMPC, QPProblem
152 | import numpy as np
153 |
154 | # Define your QP problem parameters
155 | Q = np.array([[4.0, 1.0], [1.0, 2.0]])
156 | c = np.array([1.0, 1.0])
157 | A = np.array([[-1.0, 0.0], [0.0, -1.0], [-1.0, -1.0], [1.0, 1.0]])
158 | b = np.array([0.0, 0.0, -1.0, 2.0])
159 | initial_state = np.array([0.5, 0.5]) # Optional: initial state for MPC problems
160 |
161 | # Create a QP problem instance
162 | qp_problem = QPProblem(
163 | Q=Q,
164 | c=c,
165 | A=A,
166 | b=b,
167 | initial_state=initial_state # Optional
168 | )
169 |
170 | # Initialize with custom settings
171 | solver = TransformerMPC(
172 | use_constraint_predictor=True,
173 | use_warm_start_predictor=True,
174 | fallback_on_violation=True
175 | )
176 |
177 | # Solve the problem
178 | solution, solve_time = solver.solve(qp_problem=qp_problem)
179 | print(f"Solution: {solution}")
180 | print(f"Solve time: {solve_time} seconds")
181 |
182 | # Compare with baseline
183 | baseline_solution, baseline_time = solver.solve_baseline(qp_problem=qp_problem)
184 | print(f"Baseline time: {baseline_time} seconds")
185 | ```
186 | ## If you find our work useful, please cite us
187 | ```
188 | @article{zinage2024transformermpc,
189 | title={TransformerMPC: Accelerating Model Predictive Control via Transformers},
190 | author={Zinage, Vrushabh and Khalil, Ahmed and Bakolas, Efstathios},
191 | journal={arXiv preprint arXiv:2409.09266},
192 | year={2024}
193 | }
194 | ```
195 |
196 | ## License
197 |
198 | This project is licensed under the MIT License.
199 |
--------------------------------------------------------------------------------
/transformermpc/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | TransformerMPC: Accelerating Model Predictive Control via Neural Networks
3 | ====================================================================
4 |
5 | TransformerMPC is a Python package that enhances the efficiency of solving
6 | Quadratic Programming (QP) problems in Model Predictive Control (MPC)
7 | using neural network models.
8 |
9 | Authors: Vrushabh Zinage, Ahmed Khalil, Efstathios Bakolas
10 | """
11 |
12 | __version__ = "0.1.6"
13 |
14 | # Import core components for easy access
15 | from .models.constraint_predictor import ConstraintPredictor
16 | from .models.warm_start_predictor import WarmStartPredictor
17 | from .data.qp_generator import QPGenerator, QPProblem
18 | from .data.dataset import QPDataset
19 | from .utils.osqp_wrapper import OSQPSolver
20 | from .training.trainer import ModelTrainer
21 |
22 | # Define the main solver class
23 | class TransformerMPC:
24 | """
25 | TransformerMPC solver class that accelerates QP solving using transformer models.
26 |
27 | This class combines two transformer models:
28 | 1. Constraint Predictor: Identifies inactive constraints
29 | 2. Warm Start Predictor: Generates better initial points
30 |
31 | The combined pipeline significantly reduces computation time while maintaining solution quality.
32 | """
33 |
34 | def __init__(self,
35 | use_constraint_predictor=True,
36 | use_warm_start_predictor=True,
37 | constraint_model_path=None,
38 | warm_start_model_path=None,
39 | fallback_on_violation=True):
40 | """
41 | Initialize the TransformerMPC solver.
42 |
43 | Parameters:
44 | -----------
45 | use_constraint_predictor : bool
46 | Whether to use the constraint predictor model
47 | use_warm_start_predictor : bool
48 | Whether to use the warm start predictor model
49 | constraint_model_path : str or None
50 | Path to a pre-trained constraint predictor model
51 | warm_start_model_path : str or None
52 | Path to a pre-trained warm start predictor model
53 | fallback_on_violation : bool
54 | Whether to fallback to full QP if constraints are violated
55 | """
56 | self.use_constraint_predictor = use_constraint_predictor
57 | self.use_warm_start_predictor = use_warm_start_predictor
58 | self.fallback_on_violation = fallback_on_violation
59 |
60 | # Initialize models if paths are provided
61 | if use_constraint_predictor:
62 | self.constraint_predictor = ConstraintPredictor.load(constraint_model_path)
63 | else:
64 | self.constraint_predictor = None
65 |
66 | if use_warm_start_predictor:
67 | self.warm_start_predictor = WarmStartPredictor.load(warm_start_model_path)
68 | else:
69 | self.warm_start_predictor = None
70 |
71 | # Initialize the OSQP solver
72 | self.solver = OSQPSolver()
73 |
74 | def solve(self, Q=None, c=None, A=None, b=None, qp_problem=None):
75 | """
76 | Solve a QP problem using the transformer-enhanced pipeline.
77 |
78 | Parameters:
79 | -----------
80 | Q : numpy.ndarray or None
81 | Quadratic cost matrix
82 | c : numpy.ndarray or None
83 | Linear cost vector
84 | A : numpy.ndarray or None
85 | Constraint matrix
86 | b : numpy.ndarray or None
87 | Constraint vector
88 | qp_problem : QPProblem or None
89 | QPProblem instance (alternative to specifying Q, c, A, b)
90 |
91 | Returns:
92 | --------
93 | solution : numpy.ndarray
94 | Optimal solution vector
95 | solve_time : float
96 | Computation time in seconds
97 | """
98 | import numpy as np
99 | import time
100 |
101 | # If qp_problem is provided, extract matrices from it
102 | if qp_problem is not None:
103 | Q = qp_problem.Q
104 | c = qp_problem.c
105 | A = qp_problem.A
106 | b = qp_problem.b
107 |
108 | # Check if all required matrices are provided
109 | if Q is None or c is None:
110 | raise ValueError("Q and c matrices must be provided")
111 |
112 | # Extract features for the QP problem
113 | n_vars = Q.shape[0]
114 | features = np.concatenate([
115 | Q.flatten(),
116 | c,
117 | A.flatten() if A is not None else np.zeros(0),
118 | b if b is not None else np.zeros(0)
119 | ])
120 |
121 | # Default values if no transformers are used
122 | active_constraints = np.ones(A.shape[0]) if A is not None else None
123 | warm_start = None
124 |
125 | # Predict active constraints if constraint predictor is enabled
126 | if self.use_constraint_predictor and self.constraint_predictor is not None and A is not None:
127 | try:
128 | active_constraints = self.constraint_predictor.predict(features)[0]
129 | except Exception as e:
130 | print(f"Warning: Constraint prediction failed: {e}")
131 | active_constraints = np.ones(A.shape[0])
132 |
133 | # Predict warm start if warm start predictor is enabled
134 | if self.use_warm_start_predictor and self.warm_start_predictor is not None:
135 | try:
136 | warm_start = self.warm_start_predictor.predict(features)[0]
137 | except Exception as e:
138 | print(f"Warning: Warm start prediction failed: {e}")
139 | warm_start = None
140 |
141 | # Solve the QP problem with the OSQP solver
142 | start_time = time.time()
143 |
144 | if A is not None and self.use_constraint_predictor:
145 | # Use transformer-enhanced pipeline
146 | solution, _, used_fallback = self.solver.solve_pipeline(
147 | Q=Q,
148 | c=c,
149 | A=A,
150 | b=b,
151 | active_constraints=active_constraints,
152 | warm_start=warm_start,
153 | fallback_on_violation=self.fallback_on_violation
154 | )
155 | else:
156 | # Use standard OSQP solver
157 | solution = self.solver.solve(
158 | Q=Q,
159 | c=c,
160 | A=A,
161 | b=b,
162 | warm_start=warm_start
163 | )
164 |
165 | solve_time = time.time() - start_time
166 |
167 | return solution, solve_time
168 |
169 | def solve_baseline(self, Q=None, c=None, A=None, b=None, qp_problem=None):
170 | """
171 | Solve a QP problem using standard OSQP without transformer enhancements.
172 |
173 | Parameters and returns same as solve().
174 | """
175 | import time
176 |
177 | # If qp_problem is provided, extract matrices from it
178 | if qp_problem is not None:
179 | Q = qp_problem.Q
180 | c = qp_problem.c
181 | A = qp_problem.A
182 | b = qp_problem.b
183 |
184 | # Check if all required matrices are provided
185 | if Q is None or c is None:
186 | raise ValueError("Q and c matrices must be provided")
187 |
188 | # Solve using standard OSQP
189 | start_time = time.time()
190 | solution, _ = self.solver.solve_with_time(Q, c, A, b)
191 | solve_time = time.time() - start_time
192 |
193 | return solution, solve_time
194 |
--------------------------------------------------------------------------------
/transformermpc.egg-info/PKG-INFO:
--------------------------------------------------------------------------------
1 | Metadata-Version: 2.4
2 | Name: transformermpc
3 | Version: 0.1.6
4 | Summary: Accelerating Model Predictive Control via Neural Networks
5 | Home-page: https://github.com/Vrushabh27/transformermpc
6 | Author: Ahmed Khalil, Efstathios Bakolas
7 | Author-email: Vrushabh Zinage
8 | Maintainer-email: Vrushabh Zinage
9 | License-Expression: MIT
10 | Project-URL: Homepage, https://github.com/vrushabh/transformermpc
11 | Project-URL: Bug Tracker, https://github.com/vrushabh/transformermpc/issues
12 | Keywords: machine learning,model predictive control,quadratic programming
13 | Classifier: Programming Language :: Python :: 3
14 | Classifier: Operating System :: OS Independent
15 | Requires-Python: >=3.7
16 | Description-Content-Type: text/markdown
17 | License-File: LICENSE
18 | Requires-Dist: numpy>=1.20.0
19 | Requires-Dist: scipy>=1.7.0
20 | Requires-Dist: torch>=1.9.0
21 | Requires-Dist: osqp>=0.6.2
22 | Requires-Dist: matplotlib>=3.4.0
23 | Requires-Dist: pandas>=1.3.0
24 | Requires-Dist: tqdm>=4.62.0
25 | Requires-Dist: scikit-learn>=0.24.0
26 | Requires-Dist: tensorboard>=2.7.0
27 | Requires-Dist: quadprog>=0.1.11
28 | Dynamic: home-page
29 | Dynamic: license-file
30 | Dynamic: requires-python
31 |
32 |
33 |
34 |
TransformerMPC: Accelerating Model Predictive Control via Transformers [ICRA '25]
35 |
36 | Vrushabh Zinage1
37 | ·
38 | Ahmed Khalil1
39 | ·
40 | Efstathios Bakolas1
41 |
42 |
43 |
44 |
45 |
46 | 1University of Texas at Austin
47 |
48 |
49 |
50 | [](https://arxiv.org/abs/2409.09266) [](https://transformer-mpc.github.io/)
51 |
52 |
53 |
54 | ## Overview
55 |
56 | TransformerMPC improves the computational efficiency of Model Predictive Control (MPC) problems using neural network models. It employs the following two prediction models:
57 |
58 | 1. **Constraint Predictor**: Identifies inactive constraints in MPC formulations
59 | 2. **Warm Start Predictor**: Generates better initial points for MPC solvers
60 |
61 | By combining these models, TransformerMPC significantly reduces computation time while maintaining solution quality.
62 |
63 | ## Package Structure
64 |
65 | The package is organized with a standard Python package structure:
66 |
67 | ```
68 | transformermpc/
69 | ├── transformermpc/ # Core package module
70 | │ ├── data/ # Data generation utilities
71 | │ ├── models/ # Model implementations
72 | │ ├── utils/ # Utility functions and metrics
73 | │ ├── training/ # Training infrastructure
74 | │ └── demo/ # Demo scripts
75 | ├── scripts/ # Demo and utility scripts
76 | ├── tests/ # Testing infrastructure
77 | ├── setup.py # Package installation script
78 | └── requirements.txt # Dependencies
79 | ```
80 |
81 | ## Installation
82 |
83 | Install directly from PyPI:
84 |
85 | ```bash
86 | pip install transformermpc
87 | ```
88 |
89 | Or install from source:
90 |
91 | ```bash
92 | git clone https://github.com/vrushabh/transformermpc.git
93 | cd transformermpc
94 | pip install -e .
95 | ```
96 |
97 | ## Dependencies
98 |
99 | - Python >= 3.7
100 | - PyTorch >= 1.9.0
101 | - OSQP >= 0.6.2
102 | - NumPy, SciPy, and other scientific computing libraries
103 | - Additional dependencies specified in requirements.txt
104 |
105 | ## Running the Demos
106 |
107 | The package includes several demo scripts to showcase its capabilities:
108 |
109 | ### Boxplot Demo (Recommended)
110 |
111 | ```bash
112 | python scripts/boxplot_demo.py
113 | ```
114 |
115 | This script provides a visual comparison of different QP solving strategies using randomly generated problems, without requiring model training. It demonstrates the core concepts behind TransformerMPC by showing the performance impact of:
116 | - Removing inactive constraints
117 | - Using warm starts with different qualities
118 | - Combining these strategies
119 |
120 | The visualizations include boxplots, violin plots, and bar charts comparing solve times across different strategies.
121 |
122 | ### Simple Demo
123 |
124 | ```bash
125 | python scripts/simple_demo.py
126 | ```
127 |
128 | This script demonstrates the complete pipeline: generating QP problems, training models, and evaluating performance. After completion, it saves performance comparison plots in the `demo_results/results` directory.
129 |
130 | ### Verify Package Structure
131 |
132 | To check that the package is installed correctly:
133 |
134 | ```bash
135 | python scripts/verify_structure.py
136 | ```
137 |
138 | ### Customizing Demo Parameters
139 |
140 | You can customize the boxplot demo by modifying parameters:
141 |
142 | ```bash
143 | # Generate more problems with different dimensions
144 | python scripts/boxplot_demo.py --num_samples 50 --state_dim 6 --input_dim 3 --horizon 10
145 |
146 | # Save results to a custom directory
147 | python scripts/boxplot_demo.py --output_dir my_results
148 | ```
149 |
150 | Similarly, for the simple demo:
151 |
152 | ```bash
153 | # Generate QP problems with specific parameters
154 | python scripts/simple_demo.py --num_samples 200 --state_dim 6 --input_dim 3 --horizon 10
155 |
156 | # Customize training parameters
157 | python scripts/simple_demo.py --epochs 20 --batch_size 32
158 |
159 | # Use GPU for training if available
160 | python scripts/simple_demo.py --use_gpu
161 | ```
162 |
163 | ## Usage in Projects
164 |
165 | ### Basic Example
166 |
167 | ```python
168 | from transformermpc import TransformerMPC
169 | import numpy as np
170 |
171 | # Define your QP problem parameters
172 | Q = np.array([[4.0, 1.0], [1.0, 2.0]])
173 | c = np.array([1.0, 1.0])
174 | A = np.array([[-1.0, 0.0], [0.0, -1.0], [-1.0, -1.0], [1.0, 1.0]])
175 | b = np.array([0.0, 0.0, -1.0, 2.0])
176 |
177 | # Initialize the TransformerMPC solver
178 | solver = TransformerMPC()
179 |
180 | # Solve with model acceleration
181 | solution, solve_time = solver.solve(Q, c, A, b)
182 |
183 | print(f"Solution: {solution}")
184 | print(f"Solve time: {solve_time} seconds")
185 | ```
186 |
187 | ### General Usage
188 |
189 | ```python
190 | from transformermpc import TransformerMPC, QPProblem
191 | import numpy as np
192 |
193 | # Define your QP problem parameters
194 | Q = np.array([[4.0, 1.0], [1.0, 2.0]])
195 | c = np.array([1.0, 1.0])
196 | A = np.array([[-1.0, 0.0], [0.0, -1.0], [-1.0, -1.0], [1.0, 1.0]])
197 | b = np.array([0.0, 0.0, -1.0, 2.0])
198 | initial_state = np.array([0.5, 0.5]) # Optional: initial state for MPC problems
199 |
200 | # Create a QP problem instance
201 | qp_problem = QPProblem(
202 | Q=Q,
203 | c=c,
204 | A=A,
205 | b=b,
206 | initial_state=initial_state # Optional
207 | )
208 |
209 | # Initialize with custom settings
210 | solver = TransformerMPC(
211 | use_constraint_predictor=True,
212 | use_warm_start_predictor=True,
213 | fallback_on_violation=True
214 | )
215 |
216 | # Solve the problem
217 | solution, solve_time = solver.solve(qp_problem=qp_problem)
218 | print(f"Solution: {solution}")
219 | print(f"Solve time: {solve_time} seconds")
220 |
221 | # Compare with baseline
222 | baseline_solution, baseline_time = solver.solve_baseline(qp_problem=qp_problem)
223 | print(f"Baseline time: {baseline_time} seconds")
224 | ```
225 | ## If you find our work useful, please cite us
226 | ```
227 | @article{zinage2024transformermpc,
228 | title={TransformerMPC: Accelerating Model Predictive Control via Transformers},
229 | author={Zinage, Vrushabh and Khalil, Ahmed and Bakolas, Efstathios},
230 | journal={arXiv preprint arXiv:2409.09266},
231 | year={2024}
232 | }
233 | ```
234 |
235 | ## License
236 |
237 | This project is licensed under the MIT License.
238 |
--------------------------------------------------------------------------------
/tests/test_benchmark.py:
--------------------------------------------------------------------------------
1 | """
2 | Benchmark script for TransformerMPC.
3 | """
4 |
5 | import os
6 | import torch
7 | import numpy as np
8 | import argparse
9 | from tqdm import tqdm
10 | import matplotlib.pyplot as plt
11 | from pathlib import Path
12 |
13 | # Fix the serialization issue
14 | import torch.serialization
15 | # Add numpy.core.multiarray.scalar to safe globals
16 | torch.serialization.add_safe_globals(['numpy.core.multiarray.scalar'])
17 |
18 | # Import required modules
19 | from transformermpc.data.qp_generator import QPGenerator
20 | from transformermpc.data.dataset import QPDataset
21 | from transformermpc.models.constraint_predictor import ConstraintPredictor
22 | from transformermpc.models.warm_start_predictor import WarmStartPredictor
23 | from transformermpc.utils.osqp_wrapper import OSQPSolver
24 | from transformermpc.utils.metrics import compute_solve_time_metrics, compute_fallback_rate
25 | from transformermpc.utils.visualization import (
26 | plot_solve_time_comparison,
27 | plot_solve_time_boxplot
28 | )
29 |
30 | # Monkey patch torch.load to use weights_only=False by default
31 | original_torch_load = torch.load
32 | def patched_torch_load(f, *args, **kwargs):
33 | if 'weights_only' not in kwargs:
34 | kwargs['weights_only'] = False
35 | return original_torch_load(f, *args, **kwargs)
36 | torch.load = patched_torch_load
37 |
38 | def parse_args():
39 | """Parse command line arguments."""
40 | parser = argparse.ArgumentParser(description="TransformerMPC Benchmark")
41 |
42 | # Input/output parameters
43 | parser.add_argument(
44 | "--data_dir", type=str, default="demo_results/data",
45 | help="Directory containing QP problem data (default: demo_results/data)"
46 | )
47 | parser.add_argument(
48 | "--results_dir", type=str, default="demo_results/results",
49 | help="Directory to save benchmark results (default: demo_results/results)"
50 | )
51 |
52 | # Benchmark parameters
53 | parser.add_argument(
54 | "--test_size", type=float, default=0.2,
55 | help="Fraction of data to use for testing (default: 0.2)"
56 | )
57 | parser.add_argument(
58 | "--num_test_problems", type=int, default=20,
59 | help="Number of test problems to benchmark (default: 20)"
60 | )
61 | parser.add_argument(
62 | "--use_gpu", action="store_true",
63 | help="Use GPU if available"
64 | )
65 |
66 | return parser.parse_args()
67 |
68 | def main():
69 | """Run a benchmark test."""
70 | # Parse command line arguments
71 | args = parse_args()
72 |
73 | # Set up directories
74 | data_dir = Path(args.data_dir)
75 | results_dir = Path(args.results_dir)
76 |
77 | # Set device
78 | device = torch.device('cuda' if torch.cuda.is_available() and args.use_gpu else 'cpu')
79 | print(f"Using device: {device}")
80 |
81 | # 1. Load QP problems
82 | qp_problems_file = data_dir / "qp_problems.npy"
83 | if qp_problems_file.exists():
84 | print(f"Loading QP problems from {qp_problems_file}")
85 | qp_problems = QPGenerator.load(qp_problems_file)
86 | print(f"Loaded {len(qp_problems)} QP problems")
87 | else:
88 | print("Error: No QP problems found. Please run the demo first.")
89 | return
90 |
91 | # 2. Load dataset
92 | print("Loading dataset")
93 | dataset = QPDataset(
94 | qp_problems=qp_problems,
95 | precompute_solutions=True,
96 | feature_normalization=True,
97 | cache_dir=data_dir
98 | )
99 |
100 | _, test_dataset = dataset.split(test_size=args.test_size)
101 | print(f"Loaded test dataset with {len(test_dataset)} problems")
102 |
103 | # 3. Create models with default parameters (don't load from files)
104 | print("Creating models")
105 | sample_item = test_dataset[0]
106 | input_dim = sample_item['features'].shape[0]
107 | num_constraints = sample_item['active_constraints'].shape[0]
108 | output_dim = sample_item['solution'].shape[0]
109 |
110 | cp_model = ConstraintPredictor(
111 | input_dim=input_dim,
112 | hidden_dim=128,
113 | num_constraints=num_constraints
114 | )
115 |
116 | ws_model = WarmStartPredictor(
117 | input_dim=input_dim,
118 | hidden_dim=256,
119 | output_dim=output_dim
120 | )
121 |
122 | # 4. Test on a small subset
123 | print("Testing on a subset of problems")
124 | solver = OSQPSolver()
125 |
126 | # List to store results
127 | baseline_times = []
128 | warmstart_only_times = []
129 | constraint_only_times = []
130 | transformer_times = []
131 | fallback_flags = []
132 |
133 | # Test on a subset for visualization
134 | test_subset = np.random.choice(len(test_dataset), size=args.num_test_problems, replace=False)
135 |
136 | print(f"Benchmarking {args.num_test_problems} problems...")
137 | for idx in tqdm(test_subset):
138 | # Get problem
139 | sample = test_dataset[idx]
140 | problem = test_dataset.get_problem(idx)
141 |
142 | # Get features
143 | features = sample['features']
144 |
145 | # For demonstration, we'll use the solutions directly instead of predictions
146 | # since we're using untrained models
147 | true_active = sample['active_constraints'].numpy()
148 | true_solution = sample['solution'].numpy()
149 |
150 | # Baseline (OSQP without transformers)
151 | _, baseline_time = solver.solve_with_time(
152 | Q=problem.Q,
153 | c=problem.c,
154 | A=problem.A,
155 | b=problem.b
156 | )
157 | baseline_times.append(baseline_time)
158 |
159 | # Warm start only
160 | _, warmstart_time = solver.solve_with_time(
161 | Q=problem.Q,
162 | c=problem.c,
163 | A=problem.A,
164 | b=problem.b,
165 | warm_start=true_solution # Using true solution as warm start
166 | )
167 | warmstart_only_times.append(warmstart_time)
168 |
169 | # Constraint only
170 | _, is_feasible, constraint_time = solver.solve_reduced_with_time(
171 | Q=problem.Q,
172 | c=problem.c,
173 | A=problem.A,
174 | b=problem.b,
175 | active_constraints=true_active # Using true active constraints
176 | )
177 | constraint_only_times.append(constraint_time)
178 |
179 | # Full transformer pipeline (using true values for demo)
180 | _, transformer_time, used_fallback = solver.solve_pipeline(
181 | Q=problem.Q,
182 | c=problem.c,
183 | A=problem.A,
184 | b=problem.b,
185 | active_constraints=true_active,
186 | warm_start=true_solution,
187 | fallback_on_violation=True
188 | )
189 | transformer_times.append(transformer_time)
190 | fallback_flags.append(used_fallback)
191 |
192 | # Convert to numpy arrays
193 | baseline_times = np.array(baseline_times)
194 | warmstart_only_times = np.array(warmstart_only_times)
195 | constraint_only_times = np.array(constraint_only_times)
196 | transformer_times = np.array(transformer_times)
197 |
198 | # Compute and print metrics
199 | solve_metrics = compute_solve_time_metrics(baseline_times, transformer_times)
200 | fallback_rate = compute_fallback_rate(fallback_flags)
201 |
202 | print("\nSolve Time Metrics:")
203 | print(f"Mean baseline time: {solve_metrics['mean_baseline_time']:.6f}s")
204 | print(f"Mean transformer time: {solve_metrics['mean_transformer_time']:.6f}s")
205 | print(f"Mean speedup: {solve_metrics['mean_speedup']:.2f}x")
206 | print(f"Median speedup: {solve_metrics['median_speedup']:.2f}x")
207 | print(f"Fallback rate: {fallback_rate:.2f}%")
208 |
209 | # Create results directory if it doesn't exist
210 | results_dir.mkdir(parents=True, exist_ok=True)
211 |
212 | # Plot solve time comparison
213 | plot_solve_time_comparison(
214 | baseline_times=baseline_times,
215 | transformer_times=transformer_times,
216 | save_path=results_dir / "solve_time_comparison.png"
217 | )
218 |
219 | # Plot boxplot
220 | plot_solve_time_boxplot(
221 | baseline_times=baseline_times,
222 | transformer_times=transformer_times,
223 | constraint_only_times=constraint_only_times,
224 | warmstart_only_times=warmstart_only_times,
225 | save_path=results_dir / "solve_time_boxplot.png"
226 | )
227 |
228 | print(f"\nResults saved to {results_dir}")
229 |
230 | if __name__ == "__main__":
231 | main()
--------------------------------------------------------------------------------
/transformermpc/utils/metrics.py:
--------------------------------------------------------------------------------
1 | """
2 | Metrics utility module for TransformerMPC.
3 |
4 | This module provides functions for computing various metrics used to
5 | evaluate the performance of the transformer models and the overall
6 | solving pipeline.
7 | """
8 |
9 | import numpy as np
10 | import torch
11 | from typing import Dict, List, Tuple, Union, Optional, Any
12 | from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
13 |
14 |
15 | def compute_constraint_prediction_metrics(y_true: Union[np.ndarray, torch.Tensor],
16 | y_pred: Union[np.ndarray, torch.Tensor],
17 | threshold: float = 0.5) -> Dict[str, float]:
18 | """
19 | Compute classification metrics for constraint prediction.
20 |
21 | Parameters:
22 | -----------
23 | y_true : numpy.ndarray or torch.Tensor
24 | True binary labels
25 | y_pred : numpy.ndarray or torch.Tensor
26 | Predicted probability or scores
27 | threshold : float
28 | Threshold for converting probabilities to binary labels
29 |
30 | Returns:
31 | --------
32 | metrics : Dict[str, float]
33 | Dictionary containing precision, recall, F1 score, and accuracy
34 | """
35 | # Convert to numpy if tensors
36 | if isinstance(y_true, torch.Tensor):
37 | y_true = y_true.detach().cpu().numpy()
38 | if isinstance(y_pred, torch.Tensor):
39 | y_pred = y_pred.detach().cpu().numpy()
40 |
41 | # Convert predictions to binary labels using threshold
42 | y_pred_binary = (y_pred > threshold).astype(np.int32)
43 |
44 | # Compute metrics
45 | precision = precision_score(y_true, y_pred_binary, average='binary', zero_division=0)
46 | recall = recall_score(y_true, y_pred_binary, average='binary', zero_division=0)
47 | f1 = f1_score(y_true, y_pred_binary, average='binary', zero_division=0)
48 | accuracy = accuracy_score(y_true, y_pred_binary)
49 |
50 | return {
51 | 'precision': precision,
52 | 'recall': recall,
53 | 'f1': f1,
54 | 'accuracy': accuracy
55 | }
56 |
57 |
58 | def compute_warm_start_metrics(y_true: Union[np.ndarray, torch.Tensor],
59 | y_pred: Union[np.ndarray, torch.Tensor]) -> Dict[str, float]:
60 | """
61 | Compute regression metrics for warm start prediction.
62 |
63 | Parameters:
64 | -----------
65 | y_true : numpy.ndarray or torch.Tensor
66 | True solution values
67 | y_pred : numpy.ndarray or torch.Tensor
68 | Predicted solution values
69 |
70 | Returns:
71 | --------
72 | metrics : Dict[str, float]
73 | Dictionary containing MSE, MAE, and relative error
74 | """
75 | # Convert to numpy if tensors
76 | if isinstance(y_true, torch.Tensor):
77 | y_true = y_true.detach().cpu().numpy()
78 | if isinstance(y_pred, torch.Tensor):
79 | y_pred = y_pred.detach().cpu().numpy()
80 |
81 | # Compute metrics
82 | mse = np.mean((y_true - y_pred) ** 2)
83 | mae = np.mean(np.abs(y_true - y_pred))
84 |
85 | # Compute relative error (avoid division by zero)
86 | denom = np.linalg.norm(y_true, axis=1)
87 | denom = np.where(denom < 1e-8, 1.0, denom)
88 | rel_error = np.mean(np.linalg.norm(y_true - y_pred, axis=1) / denom)
89 |
90 | return {
91 | 'mse': mse,
92 | 'mae': mae,
93 | 'relative_error': rel_error
94 | }
95 |
96 |
97 | def compute_active_constraint_stats(active_constraints: List[np.ndarray]) -> Dict[str, float]:
98 | """
99 | Compute statistics about active constraints.
100 |
101 | Parameters:
102 | -----------
103 | active_constraints : List[numpy.ndarray]
104 | List of binary vectors indicating active constraints for each problem
105 |
106 | Returns:
107 | --------
108 | stats : Dict[str, float]
109 | Dictionary containing statistics about active constraints
110 | """
111 | # Calculate percentage of active constraints for each problem
112 | active_percentages = [np.mean(ac) * 100 for ac in active_constraints]
113 |
114 | # Compute statistics
115 | mean_percentage = np.mean(active_percentages)
116 | median_percentage = np.median(active_percentages)
117 | std_percentage = np.std(active_percentages)
118 | min_percentage = np.min(active_percentages)
119 | max_percentage = np.max(active_percentages)
120 |
121 | return {
122 | 'mean_active_percentage': mean_percentage,
123 | 'median_active_percentage': median_percentage,
124 | 'std_active_percentage': std_percentage,
125 | 'min_active_percentage': min_percentage,
126 | 'max_active_percentage': max_percentage,
127 | 'active_percentages': np.array(active_percentages)
128 | }
129 |
130 |
131 | def compute_solve_time_metrics(baseline_times: np.ndarray,
132 | transformer_times: np.ndarray) -> Dict[str, float]:
133 | """
134 | Compute metrics comparing solve times.
135 |
136 | Parameters:
137 | -----------
138 | baseline_times : numpy.ndarray
139 | Array of solve times for baseline method
140 | transformer_times : numpy.ndarray
141 | Array of solve times for transformer-enhanced method
142 |
143 | Returns:
144 | --------
145 | metrics : Dict[str, float]
146 | Dictionary containing solve time comparison metrics
147 | """
148 | # Compute speedup ratios
149 | speedup_ratios = baseline_times / transformer_times
150 |
151 | # Compute statistics
152 | mean_speedup = np.mean(speedup_ratios)
153 | median_speedup = np.median(speedup_ratios)
154 | min_speedup = np.min(speedup_ratios)
155 | max_speedup = np.max(speedup_ratios)
156 | std_speedup = np.std(speedup_ratios)
157 |
158 | # Compute percentiles
159 | percentiles = np.percentile(speedup_ratios, [10, 25, 75, 90])
160 |
161 | # Compute time statistics
162 | mean_baseline_time = np.mean(baseline_times)
163 | mean_transformer_time = np.mean(transformer_times)
164 | median_baseline_time = np.median(baseline_times)
165 | median_transformer_time = np.median(transformer_times)
166 |
167 | return {
168 | 'mean_speedup': mean_speedup,
169 | 'median_speedup': median_speedup,
170 | 'min_speedup': min_speedup,
171 | 'max_speedup': max_speedup,
172 | 'std_speedup': std_speedup,
173 | 'p10_speedup': percentiles[0],
174 | 'p25_speedup': percentiles[1],
175 | 'p75_speedup': percentiles[2],
176 | 'p90_speedup': percentiles[3],
177 | 'mean_baseline_time': mean_baseline_time,
178 | 'mean_transformer_time': mean_transformer_time,
179 | 'median_baseline_time': median_baseline_time,
180 | 'median_transformer_time': median_transformer_time,
181 | 'speedup_ratios': speedup_ratios
182 | }
183 |
184 |
185 | def compute_fallback_rate(fallback_flags: List[bool]) -> float:
186 | """
187 | Compute the rate of fallbacks to full QP solve.
188 |
189 | Parameters:
190 | -----------
191 | fallback_flags : List[bool]
192 | List of flags indicating whether fallback was used
193 |
194 | Returns:
195 | --------
196 | fallback_rate : float
197 | Percentage of problems that required fallback
198 | """
199 | return 100.0 * np.mean([1 if flag else 0 for flag in fallback_flags])
200 |
201 |
202 | def compute_comprehensive_metrics(
203 | constraint_pred_true: np.ndarray,
204 | constraint_pred: np.ndarray,
205 | warmstart_true: np.ndarray,
206 | warmstart_pred: np.ndarray,
207 | baseline_times: np.ndarray,
208 | transformer_times: np.ndarray,
209 | fallback_flags: List[bool]
210 | ) -> Dict[str, Any]:
211 | """
212 | Compute comprehensive set of metrics for the entire pipeline.
213 |
214 | Parameters:
215 | -----------
216 | constraint_pred_true : numpy.ndarray
217 | True binary labels for constraint prediction
218 | constraint_pred : numpy.ndarray
219 | Predicted probability or scores for constraint prediction
220 | warmstart_true : numpy.ndarray
221 | True solution values for warm start prediction
222 | warmstart_pred : numpy.ndarray
223 | Predicted solution values for warm start prediction
224 | baseline_times : numpy.ndarray
225 | Array of solve times for baseline method
226 | transformer_times : numpy.ndarray
227 | Array of solve times for transformer-enhanced method
228 | fallback_flags : List[bool]
229 | List of flags indicating whether fallback was used
230 |
231 | Returns:
232 | --------
233 | metrics : Dict[str, Any]
234 | Dictionary containing all metrics
235 | """
236 | # Compute individual metrics
237 | constraint_metrics = compute_constraint_prediction_metrics(
238 | constraint_pred_true, constraint_pred)
239 |
240 | warmstart_metrics = compute_warm_start_metrics(
241 | warmstart_true, warmstart_pred)
242 |
243 | solve_time_metrics = compute_solve_time_metrics(
244 | baseline_times, transformer_times)
245 |
246 | fallback_rate = compute_fallback_rate(fallback_flags)
247 |
248 | # Combine all metrics
249 | metrics = {
250 | 'constraint_prediction': constraint_metrics,
251 | 'warm_start_prediction': warmstart_metrics,
252 | 'solve_time': solve_time_metrics,
253 | 'fallback_rate': fallback_rate
254 | }
255 |
256 | return metrics
257 |
--------------------------------------------------------------------------------
/transformermpc/utils/osqp_wrapper.py:
--------------------------------------------------------------------------------
1 | """
2 | OSQP solver wrapper module.
3 |
4 | This module provides a wrapper for the OSQP solver to solve QP problems.
5 | """
6 |
7 | import numpy as np
8 | import osqp
9 | import scipy.sparse as sparse
10 | import time
11 | from typing import Dict, Optional, Tuple, Union, List, Any
12 |
13 | class OSQPSolver:
14 | """
15 | Wrapper class for the OSQP solver.
16 |
17 | This class provides a convenient interface to solve QP problems using OSQP.
18 | """
19 |
20 | def __init__(self,
21 | verbose: bool = False,
22 | max_iter: int = 4000,
23 | eps_abs: float = 1e-6,
24 | eps_rel: float = 1e-6,
25 | polish: bool = True):
26 | """
27 | Initialize the OSQP solver wrapper.
28 |
29 | Parameters:
30 | -----------
31 | verbose : bool
32 | Whether to print solver output
33 | max_iter : int
34 | Maximum number of iterations
35 | eps_abs : float
36 | Absolute tolerance
37 | eps_rel : float
38 | Relative tolerance
39 | polish : bool
40 | Whether to polish the solution
41 | """
42 | self.verbose = verbose
43 | self.max_iter = max_iter
44 | self.eps_abs = eps_abs
45 | self.eps_rel = eps_rel
46 | self.polish = polish
47 |
48 | def solve(self,
49 | Q: np.ndarray,
50 | c: np.ndarray,
51 | A: Optional[np.ndarray] = None,
52 | b: Optional[np.ndarray] = None,
53 | warm_start: Optional[np.ndarray] = None) -> np.ndarray:
54 | """
55 | Solve a QP problem using OSQP.
56 |
57 | Parameters:
58 | -----------
59 | Q : numpy.ndarray
60 | Quadratic cost matrix (n x n)
61 | c : numpy.ndarray
62 | Linear cost vector (n)
63 | A : numpy.ndarray or None
64 | Constraint matrix (m x n)
65 | b : numpy.ndarray or None
66 | Constraint vector (m)
67 | warm_start : numpy.ndarray or None
68 | Warm start vector for the solver
69 |
70 | Returns:
71 | --------
72 | solution : numpy.ndarray
73 | Optimal solution vector
74 | """
75 | # Convert to sparse matrices for OSQP
76 | P = sparse.csc_matrix(Q)
77 | q = c
78 |
79 | # Check if we have constraints
80 | if A is not None and b is not None:
81 | A_sparse = sparse.csc_matrix(A)
82 | l = b # Constraints are A*x <= b, so l <= A*x <= u, where l = -inf and u = b
83 | u = np.inf * np.ones(A.shape[0])
84 | else:
85 | # No constraints
86 | A_sparse = sparse.csc_matrix((0, Q.shape[0]))
87 | l = np.array([])
88 | u = np.array([])
89 |
90 | # Create the OSQP solver
91 | solver = osqp.OSQP()
92 |
93 | # Setup the problem
94 | solver.setup(P=P, q=q, A=A_sparse, l=-np.inf * np.ones_like(l), u=b,
95 | verbose=self.verbose, max_iter=self.max_iter,
96 | eps_abs=self.eps_abs, eps_rel=self.eps_rel,
97 | polish=self.polish)
98 |
99 | # Set warm start if provided
100 | if warm_start is not None:
101 | solver.warm_start(x=warm_start)
102 |
103 | # Solve the problem
104 | result = solver.solve()
105 |
106 | # Check if the solver was successful
107 | if result.info.status != 'solved':
108 | print(f"Warning: OSQP solver returned status {result.info.status}")
109 |
110 | # Return the solution
111 | return result.x
112 |
113 | def solve_with_time(self,
114 | Q: np.ndarray,
115 | c: np.ndarray,
116 | A: Optional[np.ndarray] = None,
117 | b: Optional[np.ndarray] = None,
118 | warm_start: Optional[np.ndarray] = None) -> Tuple[np.ndarray, float]:
119 | """
120 | Solve a QP problem using OSQP and return solution time.
121 |
122 | Parameters:
123 | -----------
124 | Same as solve method.
125 |
126 | Returns:
127 | --------
128 | solution : numpy.ndarray
129 | Optimal solution vector
130 | solve_time : float
131 | Solution time in seconds
132 | """
133 | # Measure the solution time
134 | start_time = time.time()
135 | solution = self.solve(Q, c, A, b, warm_start)
136 | solve_time = time.time() - start_time
137 |
138 | return solution, solve_time
139 |
140 | def solve_reduced(self,
141 | Q: np.ndarray,
142 | c: np.ndarray,
143 | A: np.ndarray,
144 | b: np.ndarray,
145 | active_constraints: np.ndarray,
146 | warm_start: Optional[np.ndarray] = None) -> Tuple[np.ndarray, bool]:
147 | """
148 | Solve a reduced QP problem with only active constraints.
149 |
150 | Parameters:
151 | -----------
152 | Q, c, A, b : Same as solve method
153 | active_constraints : numpy.ndarray
154 | Binary vector indicating which constraints are active
155 | warm_start : numpy.ndarray or None
156 | Warm start vector for the solver
157 |
158 | Returns:
159 | --------
160 | solution : numpy.ndarray
161 | Optimal solution vector
162 | is_feasible : bool
163 | Whether the solution is feasible for the original problem
164 | """
165 | # Get indices of active constraints
166 | active_indices = np.where(active_constraints > 0.5)[0]
167 |
168 | # Create reduced constraint matrices
169 | if len(active_indices) > 0:
170 | A_reduced = A[active_indices, :]
171 | b_reduced = b[active_indices]
172 | else:
173 | # No active constraints, solve unconstrained problem
174 | A_reduced = None
175 | b_reduced = None
176 |
177 | # Solve the reduced problem
178 | solution = self.solve(Q, c, A_reduced, b_reduced, warm_start)
179 |
180 | # Check if the solution satisfies the original constraints
181 | is_feasible = True
182 | if A is not None and b is not None:
183 | constraint_values = A @ solution - b
184 | is_feasible = np.all(constraint_values <= 1e-6)
185 |
186 | return solution, is_feasible
187 |
188 | def solve_reduced_with_time(self,
189 | Q: np.ndarray,
190 | c: np.ndarray,
191 | A: np.ndarray,
192 | b: np.ndarray,
193 | active_constraints: np.ndarray,
194 | warm_start: Optional[np.ndarray] = None) -> Tuple[np.ndarray, bool, float]:
195 | """
196 | Solve a reduced QP problem with only active constraints and return solution time.
197 |
198 | Parameters:
199 | -----------
200 | Same as solve_reduced method.
201 |
202 | Returns:
203 | --------
204 | solution : numpy.ndarray
205 | Optimal solution vector
206 | is_feasible : bool
207 | Whether the solution is feasible for the original problem
208 | solve_time : float
209 | Solution time in seconds
210 | """
211 | # Measure the solution time
212 | start_time = time.time()
213 | solution, is_feasible = self.solve_reduced(Q, c, A, b, active_constraints, warm_start)
214 | solve_time = time.time() - start_time
215 |
216 | return solution, is_feasible, solve_time
217 |
218 | def solve_pipeline(self,
219 | Q: np.ndarray,
220 | c: np.ndarray,
221 | A: np.ndarray,
222 | b: np.ndarray,
223 | active_constraints: np.ndarray,
224 | warm_start: Optional[np.ndarray] = None,
225 | fallback_on_violation: bool = True) -> Tuple[np.ndarray, float, bool]:
226 | """
227 | Solve a QP problem using the transformer-enhanced pipeline.
228 |
229 | This method first tries to solve the reduced problem with active constraints.
230 | If the solution isn't feasible for the original problem, it falls back to the full problem.
231 |
232 | Parameters:
233 | -----------
234 | Q, c, A, b : Same as solve method
235 | active_constraints : numpy.ndarray
236 | Binary vector indicating which constraints are active
237 | warm_start : numpy.ndarray or None
238 | Warm start vector for the solver
239 | fallback_on_violation : bool
240 | Whether to fall back to the full problem if constraints are violated
241 |
242 | Returns:
243 | --------
244 | solution : numpy.ndarray
245 | Optimal solution vector
246 | solve_time : float
247 | Solution time in seconds
248 | used_fallback : bool
249 | Whether the fallback solver was used
250 | """
251 | # Start timing
252 | start_time = time.time()
253 |
254 | # Try to solve the reduced problem
255 | solution, is_feasible = self.solve_reduced(Q, c, A, b, active_constraints, warm_start)
256 |
257 | # If the solution is not feasible and fallback is enabled, solve the full problem
258 | used_fallback = False
259 | if not is_feasible and fallback_on_violation:
260 | solution = self.solve(Q, c, A, b, warm_start)
261 | used_fallback = True
262 |
263 | # Compute total solution time
264 | solve_time = time.time() - start_time
265 |
266 | return solution, solve_time, used_fallback
267 |
--------------------------------------------------------------------------------
/scripts/boxplot_demo.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | TransformerMPC Boxplot Demo
4 |
5 | This script demonstrates the performance comparison of different QP solving strategies
6 | using randomly generated quadratic programming (QP) problems.
7 |
8 | It focuses on showing the boxplot comparison without actual training of transformer models.
9 | """
10 |
11 | import os
12 | import time
13 | import argparse
14 | import numpy as np
15 | import matplotlib.pyplot as plt
16 | from pathlib import Path
17 | from tqdm import tqdm
18 |
19 | # Import TransformerMPC modules
20 | from transformermpc.data.qp_generator import QPGenerator
21 | from transformermpc.utils.osqp_wrapper import OSQPSolver
22 | from transformermpc.utils.metrics import compute_solve_time_metrics
23 |
24 | def parse_args():
25 | """Parse command line arguments."""
26 | parser = argparse.ArgumentParser(description="TransformerMPC Boxplot Demo")
27 |
28 | # Data generation parameters
29 | parser.add_argument("--num_samples", type=int, default=30,
30 | help="Number of QP problems to generate (default: 30)")
31 | parser.add_argument("--state_dim", type=int, default=4,
32 | help="State dimension for MPC problems (default: 4)")
33 | parser.add_argument("--input_dim", type=int, default=2,
34 | help="Input dimension for MPC problems (default: 2)")
35 | parser.add_argument("--horizon", type=int, default=5,
36 | help="Time horizon for MPC problems (default: 5)")
37 |
38 | # Other parameters
39 | parser.add_argument("--output_dir", type=str, default="demo_results",
40 | help="Directory to save results (default: demo_results)")
41 |
42 | return parser.parse_args()
43 |
44 | def main():
45 | """Run the boxplot demo workflow."""
46 | # Parse command line arguments
47 | args = parse_args()
48 |
49 | print("=" * 60)
50 | print("TransformerMPC Boxplot Demo".center(60))
51 | print("=" * 60)
52 |
53 | # Create output directory
54 | output_dir = Path(args.output_dir)
55 | output_dir.mkdir(parents=True, exist_ok=True)
56 |
57 | # Set up directory for results
58 | results_dir = output_dir / "results"
59 | results_dir.mkdir(exist_ok=True)
60 |
61 | # Step 1: Generate QP problems
62 | print("\nStep 1: Generating QP problems")
63 | print("-" * 60)
64 |
65 | print(f"Generating {args.num_samples} QP problems")
66 | generator = QPGenerator(
67 | state_dim=args.state_dim,
68 | input_dim=args.input_dim,
69 | horizon=args.horizon,
70 | num_samples=args.num_samples
71 | )
72 | qp_problems = generator.generate()
73 | print(f"Generated {len(qp_problems)} QP problems")
74 |
75 | # Step 2: Performance testing with different strategies
76 | print("\nStep 2: Performance testing")
77 | print("-" * 60)
78 |
79 | solver = OSQPSolver()
80 |
81 | # Lists to store results for different strategies
82 | baseline_times = [] # Standard OSQP
83 | reduced_constraint_times = [] # 50% random constraints removed
84 | warm_start_random_times = [] # Warm start with random initial point
85 | warm_start_perturbed_times = [] # Warm start with slightly perturbed solution
86 | combined_strategy_times = [] # Both constraints reduced and warm start
87 |
88 | print(f"Testing {len(qp_problems)} problems...")
89 | for i, problem in enumerate(tqdm(qp_problems)):
90 | # 1. Baseline (Standard OSQP)
91 | _, baseline_time = solver.solve_with_time(
92 | Q=problem.Q,
93 | c=problem.c,
94 | A=problem.A,
95 | b=problem.b
96 | )
97 | baseline_times.append(baseline_time)
98 |
99 | # Get the solution for perturbed warm start
100 | solution = solver.solve(
101 | Q=problem.Q,
102 | c=problem.c,
103 | A=problem.A,
104 | b=problem.b
105 | )
106 |
107 | # 2. Reduced constraints (randomly remove 50% of constraints)
108 | num_constraints = problem.A.shape[0]
109 | mask = np.random.choice([True, False], size=num_constraints, p=[0.5, 0.5])
110 |
111 | A_reduced = problem.A[mask]
112 | b_reduced = problem.b[mask]
113 |
114 | _, reduced_time = solver.solve_with_time(
115 | Q=problem.Q,
116 | c=problem.c,
117 | A=A_reduced,
118 | b=b_reduced
119 | )
120 | reduced_constraint_times.append(reduced_time)
121 |
122 | # 3. Warm start with random initial point
123 | warm_start_random = np.random.randn(problem.Q.shape[0])
124 | _, warm_start_random_time = solver.solve_with_time(
125 | Q=problem.Q,
126 | c=problem.c,
127 | A=problem.A,
128 | b=problem.b,
129 | warm_start=warm_start_random
130 | )
131 | warm_start_random_times.append(warm_start_random_time)
132 |
133 | # 4. Warm start with slightly perturbed solution (simulate good prediction)
134 | perturbation = np.random.randn(solution.shape[0]) * 0.1 # Small perturbation
135 | warm_start_perturbed = solution + perturbation
136 |
137 | _, warm_start_perturbed_time = solver.solve_with_time(
138 | Q=problem.Q,
139 | c=problem.c,
140 | A=problem.A,
141 | b=problem.b,
142 | warm_start=warm_start_perturbed
143 | )
144 | warm_start_perturbed_times.append(warm_start_perturbed_time)
145 |
146 | # 5. Combined strategy (reduced constraints + warm start)
147 | _, combined_time = solver.solve_with_time(
148 | Q=problem.Q,
149 | c=problem.c,
150 | A=A_reduced,
151 | b=b_reduced,
152 | warm_start=warm_start_perturbed
153 | )
154 | combined_strategy_times.append(combined_time)
155 |
156 | # Convert to numpy arrays
157 | baseline_times = np.array(baseline_times)
158 | reduced_constraint_times = np.array(reduced_constraint_times)
159 | warm_start_random_times = np.array(warm_start_random_times)
160 | warm_start_perturbed_times = np.array(warm_start_perturbed_times)
161 | combined_strategy_times = np.array(combined_strategy_times)
162 |
163 | # Compute and print metrics
164 | print("\nPerformance Results:")
165 | print("-" * 60)
166 | print(f"Mean baseline time: {np.mean(baseline_times):.6f}s")
167 | print(f"Mean reduced constraints time: {np.mean(reduced_constraint_times):.6f}s")
168 | print(f"Mean warm start (random) time: {np.mean(warm_start_random_times):.6f}s")
169 | print(f"Mean warm start (perturbed) time: {np.mean(warm_start_perturbed_times):.6f}s")
170 | print(f"Mean combined strategy time: {np.mean(combined_strategy_times):.6f}s")
171 |
172 | # Calculate speedups
173 | print("\nSpeedup Factors:")
174 | print(f"Reduced constraints: {np.mean(baseline_times) / np.mean(reduced_constraint_times):.2f}x")
175 | print(f"Warm start (random): {np.mean(baseline_times) / np.mean(warm_start_random_times):.2f}x")
176 | print(f"Warm start (perturbed): {np.mean(baseline_times) / np.mean(warm_start_perturbed_times):.2f}x")
177 | print(f"Combined strategy: {np.mean(baseline_times) / np.mean(combined_strategy_times):.2f}x")
178 |
179 | # Step 3: Generate visualizations
180 | print("\nStep 3: Generating visualizations")
181 | print("-" * 60)
182 |
183 | # Box plot of solve times
184 | print("Generating solve time boxplot...")
185 | plt.figure(figsize=(12, 8))
186 | box_data = [
187 | baseline_times,
188 | reduced_constraint_times,
189 | warm_start_random_times,
190 | warm_start_perturbed_times,
191 | combined_strategy_times
192 | ]
193 |
194 | box_labels = [
195 | 'Baseline',
196 | 'Reduced\nConstraints',
197 | 'Warm Start\n(Random)',
198 | 'Warm Start\n(Perturbed)',
199 | 'Combined\nStrategy'
200 | ]
201 |
202 | box_plot = plt.boxplot(
203 | box_data,
204 | labels=box_labels,
205 | patch_artist=True,
206 | showmeans=True
207 | )
208 |
209 | # Add colors to boxes
210 | colors = ['lightblue', 'lightgreen', 'lightpink', 'lightyellow', 'lightcyan']
211 | for patch, color in zip(box_plot['boxes'], colors):
212 | patch.set_facecolor(color)
213 |
214 | plt.ylabel('Solve Time (s)')
215 | plt.title('QP Solve Time Comparison')
216 | plt.grid(True, axis='y', alpha=0.3)
217 |
218 | # Add mean value annotations
219 | means = [np.mean(data) for data in box_data]
220 | for i, mean in enumerate(means):
221 | plt.text(i+1, mean, f'{mean:.6f}s',
222 | horizontalalignment='center',
223 | verticalalignment='bottom',
224 | fontweight='bold')
225 |
226 | plt.tight_layout()
227 | plt.savefig(results_dir / "solve_time_boxplot.png", dpi=300)
228 |
229 | # Create speedup bar chart
230 | print("Generating speedup bar chart...")
231 | plt.figure(figsize=(10, 6))
232 |
233 | speedups = [
234 | np.mean(baseline_times) / np.mean(reduced_constraint_times),
235 | np.mean(baseline_times) / np.mean(warm_start_random_times),
236 | np.mean(baseline_times) / np.mean(warm_start_perturbed_times),
237 | np.mean(baseline_times) / np.mean(combined_strategy_times)
238 | ]
239 |
240 | speedup_labels = [
241 | 'Reduced\nConstraints',
242 | 'Warm Start\n(Random)',
243 | 'Warm Start\n(Perturbed)',
244 | 'Combined\nStrategy'
245 | ]
246 |
247 | bars = plt.bar(speedup_labels, speedups, color=['lightgreen', 'lightpink', 'lightyellow', 'lightcyan'])
248 |
249 | plt.ylabel('Speedup Factor (×)')
250 | plt.title('Speedup Relative to Baseline')
251 | plt.grid(True, axis='y', alpha=0.3)
252 |
253 | # Add value annotations
254 | for bar, speedup in zip(bars, speedups):
255 | height = bar.get_height()
256 | plt.text(bar.get_x() + bar.get_width()/2., height + 0.1,
257 | f'{speedup:.2f}×',
258 | ha='center', va='bottom', fontweight='bold')
259 |
260 | plt.tight_layout()
261 | plt.savefig(results_dir / "speedup_barchart.png", dpi=300)
262 |
263 | # Violin plot of solve times
264 | print("Generating violin plot...")
265 | plt.figure(figsize=(12, 8))
266 |
267 | violin_plot = plt.violinplot(
268 | box_data,
269 | showmeans=True,
270 | showextrema=True
271 | )
272 |
273 | plt.xticks(range(1, 6), box_labels)
274 | plt.ylabel('Solve Time (s)')
275 | plt.title('QP Solve Time Distribution')
276 | plt.grid(True, axis='y', alpha=0.3)
277 |
278 | plt.tight_layout()
279 | plt.savefig(results_dir / "solve_time_violinplot.png", dpi=300)
280 |
281 | print(f"\nResults and visualizations saved to {output_dir}/results")
282 | print("\nDemo completed successfully!")
283 | print("=" * 60)
284 |
285 | if __name__ == "__main__":
286 | main()
--------------------------------------------------------------------------------
/tests/scripts/boxplot_demo.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | TransformerMPC Boxplot Demo
4 |
5 | This script demonstrates the performance comparison of different QP solving strategies
6 | using randomly generated quadratic programming (QP) problems.
7 |
8 | It focuses on showing the boxplot comparison without actual training of transformer models.
9 | """
10 |
11 | import os
12 | import time
13 | import argparse
14 | import numpy as np
15 | import matplotlib.pyplot as plt
16 | from pathlib import Path
17 | from tqdm import tqdm
18 |
19 | # Import TransformerMPC modules
20 | from transformermpc.data.qp_generator import QPGenerator
21 | from transformermpc.utils.osqp_wrapper import OSQPSolver
22 | from transformermpc.utils.metrics import compute_solve_time_metrics
23 |
24 | def parse_args():
25 | """Parse command line arguments."""
26 | parser = argparse.ArgumentParser(description="TransformerMPC Boxplot Demo")
27 |
28 | # Data generation parameters
29 | parser.add_argument("--num_samples", type=int, default=30,
30 | help="Number of QP problems to generate (default: 30)")
31 | parser.add_argument("--state_dim", type=int, default=4,
32 | help="State dimension for MPC problems (default: 4)")
33 | parser.add_argument("--input_dim", type=int, default=2,
34 | help="Input dimension for MPC problems (default: 2)")
35 | parser.add_argument("--horizon", type=int, default=5,
36 | help="Time horizon for MPC problems (default: 5)")
37 |
38 | # Other parameters
39 | parser.add_argument("--output_dir", type=str, default="demo_results",
40 | help="Directory to save results (default: demo_results)")
41 |
42 | return parser.parse_args()
43 |
44 | def main():
45 | """Run the boxplot demo workflow."""
46 | # Parse command line arguments
47 | args = parse_args()
48 |
49 | print("=" * 60)
50 | print("TransformerMPC Boxplot Demo".center(60))
51 | print("=" * 60)
52 |
53 | # Create output directory
54 | output_dir = Path(args.output_dir)
55 | output_dir.mkdir(parents=True, exist_ok=True)
56 |
57 | # Set up directory for results
58 | results_dir = output_dir / "results"
59 | results_dir.mkdir(exist_ok=True)
60 |
61 | # Step 1: Generate QP problems
62 | print("\nStep 1: Generating QP problems")
63 | print("-" * 60)
64 |
65 | print(f"Generating {args.num_samples} QP problems")
66 | generator = QPGenerator(
67 | state_dim=args.state_dim,
68 | input_dim=args.input_dim,
69 | horizon=args.horizon,
70 | num_samples=args.num_samples
71 | )
72 | qp_problems = generator.generate()
73 | print(f"Generated {len(qp_problems)} QP problems")
74 |
75 | # Step 2: Performance testing with different strategies
76 | print("\nStep 2: Performance testing")
77 | print("-" * 60)
78 |
79 | solver = OSQPSolver()
80 |
81 | # Lists to store results for different strategies
82 | baseline_times = [] # Standard OSQP
83 | reduced_constraint_times = [] # 50% random constraints removed
84 | warm_start_random_times = [] # Warm start with random initial point
85 | warm_start_perturbed_times = [] # Warm start with slightly perturbed solution
86 | combined_strategy_times = [] # Both constraints reduced and warm start
87 |
88 | print(f"Testing {len(qp_problems)} problems...")
89 | for i, problem in enumerate(tqdm(qp_problems)):
90 | # 1. Baseline (Standard OSQP)
91 | _, baseline_time = solver.solve_with_time(
92 | Q=problem.Q,
93 | c=problem.c,
94 | A=problem.A,
95 | b=problem.b
96 | )
97 | baseline_times.append(baseline_time)
98 |
99 | # Get the solution for perturbed warm start
100 | solution = solver.solve(
101 | Q=problem.Q,
102 | c=problem.c,
103 | A=problem.A,
104 | b=problem.b
105 | )
106 |
107 | # 2. Reduced constraints (randomly remove 50% of constraints)
108 | num_constraints = problem.A.shape[0]
109 | mask = np.random.choice([True, False], size=num_constraints, p=[0.5, 0.5])
110 |
111 | A_reduced = problem.A[mask]
112 | b_reduced = problem.b[mask]
113 |
114 | _, reduced_time = solver.solve_with_time(
115 | Q=problem.Q,
116 | c=problem.c,
117 | A=A_reduced,
118 | b=b_reduced
119 | )
120 | reduced_constraint_times.append(reduced_time)
121 |
122 | # 3. Warm start with random initial point
123 | warm_start_random = np.random.randn(problem.Q.shape[0])
124 | _, warm_start_random_time = solver.solve_with_time(
125 | Q=problem.Q,
126 | c=problem.c,
127 | A=problem.A,
128 | b=problem.b,
129 | warm_start=warm_start_random
130 | )
131 | warm_start_random_times.append(warm_start_random_time)
132 |
133 | # 4. Warm start with slightly perturbed solution (simulate good prediction)
134 | perturbation = np.random.randn(solution.shape[0]) * 0.1 # Small perturbation
135 | warm_start_perturbed = solution + perturbation
136 |
137 | _, warm_start_perturbed_time = solver.solve_with_time(
138 | Q=problem.Q,
139 | c=problem.c,
140 | A=problem.A,
141 | b=problem.b,
142 | warm_start=warm_start_perturbed
143 | )
144 | warm_start_perturbed_times.append(warm_start_perturbed_time)
145 |
146 | # 5. Combined strategy (reduced constraints + warm start)
147 | _, combined_time = solver.solve_with_time(
148 | Q=problem.Q,
149 | c=problem.c,
150 | A=A_reduced,
151 | b=b_reduced,
152 | warm_start=warm_start_perturbed
153 | )
154 | combined_strategy_times.append(combined_time)
155 |
156 | # Convert to numpy arrays
157 | baseline_times = np.array(baseline_times)
158 | reduced_constraint_times = np.array(reduced_constraint_times)
159 | warm_start_random_times = np.array(warm_start_random_times)
160 | warm_start_perturbed_times = np.array(warm_start_perturbed_times)
161 | combined_strategy_times = np.array(combined_strategy_times)
162 |
163 | # Compute and print metrics
164 | print("\nPerformance Results:")
165 | print("-" * 60)
166 | print(f"Mean baseline time: {np.mean(baseline_times):.6f}s")
167 | print(f"Mean reduced constraints time: {np.mean(reduced_constraint_times):.6f}s")
168 | print(f"Mean warm start (random) time: {np.mean(warm_start_random_times):.6f}s")
169 | print(f"Mean warm start (perturbed) time: {np.mean(warm_start_perturbed_times):.6f}s")
170 | print(f"Mean combined strategy time: {np.mean(combined_strategy_times):.6f}s")
171 |
172 | # Calculate speedups
173 | print("\nSpeedup Factors:")
174 | print(f"Reduced constraints: {np.mean(baseline_times) / np.mean(reduced_constraint_times):.2f}x")
175 | print(f"Warm start (random): {np.mean(baseline_times) / np.mean(warm_start_random_times):.2f}x")
176 | print(f"Warm start (perturbed): {np.mean(baseline_times) / np.mean(warm_start_perturbed_times):.2f}x")
177 | print(f"Combined strategy: {np.mean(baseline_times) / np.mean(combined_strategy_times):.2f}x")
178 |
179 | # Step 3: Generate visualizations
180 | print("\nStep 3: Generating visualizations")
181 | print("-" * 60)
182 |
183 | # Box plot of solve times
184 | print("Generating solve time boxplot...")
185 | plt.figure(figsize=(12, 8))
186 | box_data = [
187 | baseline_times,
188 | reduced_constraint_times,
189 | warm_start_random_times,
190 | warm_start_perturbed_times,
191 | combined_strategy_times
192 | ]
193 |
194 | box_labels = [
195 | 'Baseline',
196 | 'Reduced\nConstraints',
197 | 'Warm Start\n(Random)',
198 | 'Warm Start\n(Perturbed)',
199 | 'Combined\nStrategy'
200 | ]
201 |
202 | box_plot = plt.boxplot(
203 | box_data,
204 | labels=box_labels,
205 | patch_artist=True,
206 | showmeans=True
207 | )
208 |
209 | # Add colors to boxes
210 | colors = ['lightblue', 'lightgreen', 'lightpink', 'lightyellow', 'lightcyan']
211 | for patch, color in zip(box_plot['boxes'], colors):
212 | patch.set_facecolor(color)
213 |
214 | plt.ylabel('Solve Time (s)')
215 | plt.title('QP Solve Time Comparison')
216 | plt.grid(True, axis='y', alpha=0.3)
217 |
218 | # Add mean value annotations
219 | means = [np.mean(data) for data in box_data]
220 | for i, mean in enumerate(means):
221 | plt.text(i+1, mean, f'{mean:.6f}s',
222 | horizontalalignment='center',
223 | verticalalignment='bottom',
224 | fontweight='bold')
225 |
226 | plt.tight_layout()
227 | plt.savefig(results_dir / "solve_time_boxplot.png", dpi=300)
228 |
229 | # Create speedup bar chart
230 | print("Generating speedup bar chart...")
231 | plt.figure(figsize=(10, 6))
232 |
233 | speedups = [
234 | np.mean(baseline_times) / np.mean(reduced_constraint_times),
235 | np.mean(baseline_times) / np.mean(warm_start_random_times),
236 | np.mean(baseline_times) / np.mean(warm_start_perturbed_times),
237 | np.mean(baseline_times) / np.mean(combined_strategy_times)
238 | ]
239 |
240 | speedup_labels = [
241 | 'Reduced\nConstraints',
242 | 'Warm Start\n(Random)',
243 | 'Warm Start\n(Perturbed)',
244 | 'Combined\nStrategy'
245 | ]
246 |
247 | bars = plt.bar(speedup_labels, speedups, color=['lightgreen', 'lightpink', 'lightyellow', 'lightcyan'])
248 |
249 | plt.ylabel('Speedup Factor (×)')
250 | plt.title('Speedup Relative to Baseline')
251 | plt.grid(True, axis='y', alpha=0.3)
252 |
253 | # Add value annotations
254 | for bar, speedup in zip(bars, speedups):
255 | height = bar.get_height()
256 | plt.text(bar.get_x() + bar.get_width()/2., height + 0.1,
257 | f'{speedup:.2f}×',
258 | ha='center', va='bottom', fontweight='bold')
259 |
260 | plt.tight_layout()
261 | plt.savefig(results_dir / "speedup_barchart.png", dpi=300)
262 |
263 | # Violin plot of solve times
264 | print("Generating violin plot...")
265 | plt.figure(figsize=(12, 8))
266 |
267 | violin_plot = plt.violinplot(
268 | box_data,
269 | showmeans=True,
270 | showextrema=True
271 | )
272 |
273 | plt.xticks(range(1, 6), box_labels)
274 | plt.ylabel('Solve Time (s)')
275 | plt.title('QP Solve Time Distribution')
276 | plt.grid(True, axis='y', alpha=0.3)
277 |
278 | plt.tight_layout()
279 | plt.savefig(results_dir / "solve_time_violinplot.png", dpi=300)
280 |
281 | print(f"\nResults and visualizations saved to {output_dir}/results")
282 | print("\nDemo completed successfully!")
283 | print("=" * 60)
284 |
285 | if __name__ == "__main__":
286 | main()
--------------------------------------------------------------------------------
/transformermpc/demo/demo.py:
--------------------------------------------------------------------------------
1 | """
2 | Demo script for TransformerMPC.
3 |
4 | This script demonstrates the complete workflow of:
5 | 1. Generating QP problems
6 | 2. Training the constraint predictor and warm start predictor models
7 | 3. Testing the models and comparing performance against baseline
8 | 4. Visualizing the results
9 | """
10 |
11 | import os
12 | import torch
13 | import numpy as np
14 | import time
15 | from tqdm import tqdm
16 | import matplotlib.pyplot as plt
17 | import argparse
18 | from pathlib import Path
19 |
20 | from ..data.qp_generator import QPGenerator
21 | from ..data.dataset import QPDataset
22 | from ..models.constraint_predictor import ConstraintPredictor
23 | from ..models.warm_start_predictor import WarmStartPredictor
24 | from ..training.trainer import ModelTrainer
25 | from ..utils.osqp_wrapper import OSQPSolver
26 | from ..utils.metrics import compute_solve_time_metrics, compute_fallback_rate
27 | from ..utils.visualization import (
28 | plot_solve_time_comparison,
29 | plot_solve_time_boxplot,
30 | plot_active_constraints_histogram,
31 | plot_fallback_statistics
32 | )
33 |
34 |
35 | def parse_args():
36 | """Parse command line arguments."""
37 | parser = argparse.ArgumentParser(description="TransformerMPC Demo")
38 |
39 | parser.add_argument(
40 | "--num_samples", type=int, default=20000,
41 | help="Number of QP problems to generate (default: 20000)"
42 | )
43 | parser.add_argument(
44 | "--state_dim", type=int, default=4,
45 | help="State dimension for MPC problems (default: 4)"
46 | )
47 | parser.add_argument(
48 | "--input_dim", type=int, default=2,
49 | help="Input dimension for MPC problems (default: 2)"
50 | )
51 | parser.add_argument(
52 | "--horizon", type=int, default=10,
53 | help="Time horizon for MPC problems (default: 10)"
54 | )
55 | parser.add_argument(
56 | "--num_epochs", type=int, default=2000,
57 | help="Number of training epochs (default: 2000)"
58 | )
59 | parser.add_argument(
60 | "--batch_size", type=int, default=64,
61 | help="Batch size for training (default: 64)"
62 | )
63 | parser.add_argument(
64 | "--test_size", type=float, default=0.2,
65 | help="Fraction of data to use for testing (default: 0.2)"
66 | )
67 | parser.add_argument(
68 | "--output_dir", type=str, default="transformermpc_results",
69 | help="Directory to save results (default: transformermpc_results)"
70 | )
71 | parser.add_argument(
72 | "--skip_training", action="store_true",
73 | help="Skip training and use pretrained models if available"
74 | )
75 | parser.add_argument(
76 | "--cpu", action="store_true",
77 | help="Force using CPU even if GPU is available"
78 | )
79 |
80 | return parser.parse_args()
81 |
82 |
83 | def main():
84 | """Run the complete demo workflow."""
85 | # Parse arguments
86 | args = parse_args()
87 |
88 | # Create output directory
89 | output_dir = Path(args.output_dir)
90 | output_dir.mkdir(parents=True, exist_ok=True)
91 |
92 | # Set up directories
93 | data_dir = output_dir / "data"
94 | models_dir = output_dir / "models"
95 | logs_dir = output_dir / "logs"
96 | results_dir = output_dir / "results"
97 |
98 | data_dir.mkdir(exist_ok=True)
99 | models_dir.mkdir(exist_ok=True)
100 | logs_dir.mkdir(exist_ok=True)
101 | results_dir.mkdir(exist_ok=True)
102 |
103 | # Set device
104 | device = torch.device('cuda' if torch.cuda.is_available() and not args.cpu else 'cpu')
105 | print(f"Using device: {device}")
106 |
107 | # 1. Generate QP problems or load from cache
108 | qp_problems_file = data_dir / "qp_problems.npy"
109 |
110 | if qp_problems_file.exists():
111 | print(f"Loading QP problems from {qp_problems_file}")
112 | qp_problems = QPGenerator.load(qp_problems_file)
113 | print(f"Loaded {len(qp_problems)} QP problems")
114 | else:
115 | print(f"Generating {args.num_samples} QP problems")
116 | generator = QPGenerator(
117 | state_dim=args.state_dim,
118 | input_dim=args.input_dim,
119 | horizon=args.horizon,
120 | num_samples=args.num_samples
121 | )
122 | qp_problems = generator.generate()
123 |
124 | # Save problems for future use
125 | generator.save(qp_problems, qp_problems_file)
126 | print(f"Saved QP problems to {qp_problems_file}")
127 |
128 | # 2. Create dataset and split into train/test
129 | print("Creating dataset")
130 | dataset = QPDataset(
131 | qp_problems=qp_problems,
132 | precompute_solutions=True,
133 | feature_normalization=True,
134 | cache_dir=data_dir
135 | )
136 |
137 | train_dataset, test_dataset = dataset.split(test_size=args.test_size)
138 | print(f"Created datasets - Train: {len(train_dataset)}, Test: {len(test_dataset)}")
139 |
140 | # 3. Train or load constraint predictor
141 | cp_model_file = models_dir / "constraint_predictor.pt"
142 |
143 | if args.skip_training and cp_model_file.exists():
144 | print(f"Loading constraint predictor from {cp_model_file}")
145 | cp_model = ConstraintPredictor.load(cp_model_file)
146 | else:
147 | print("Training constraint predictor")
148 | # Get input dimension from the dataset
149 | sample_item = train_dataset[0]
150 | input_dim = sample_item['features'].shape[0]
151 | num_constraints = sample_item['active_constraints'].shape[0]
152 |
153 | cp_model = ConstraintPredictor(
154 | input_dim=input_dim,
155 | hidden_dim=128,
156 | num_constraints=num_constraints
157 | )
158 |
159 | cp_trainer = ModelTrainer(
160 | model=cp_model,
161 | train_dataset=train_dataset,
162 | val_dataset=test_dataset,
163 | batch_size=args.batch_size,
164 | learning_rate=1e-4,
165 | num_epochs=args.num_epochs,
166 | checkpoint_dir=models_dir,
167 | device=device
168 | )
169 |
170 | cp_history = cp_trainer.train(log_dir=logs_dir / "constraint_predictor")
171 |
172 | # Save model
173 | cp_model.save(cp_model_file)
174 | print(f"Saved constraint predictor to {cp_model_file}")
175 |
176 | # 4. Train or load warm start predictor
177 | ws_model_file = models_dir / "warm_start_predictor.pt"
178 |
179 | if args.skip_training and ws_model_file.exists():
180 | print(f"Loading warm start predictor from {ws_model_file}")
181 | ws_model = WarmStartPredictor.load(ws_model_file)
182 | else:
183 | print("Training warm start predictor")
184 | # Get input dimension from the dataset
185 | sample_item = train_dataset[0]
186 | input_dim = sample_item['features'].shape[0]
187 | output_dim = sample_item['solution'].shape[0]
188 |
189 | ws_model = WarmStartPredictor(
190 | input_dim=input_dim,
191 | hidden_dim=256,
192 | output_dim=output_dim
193 | )
194 |
195 | ws_trainer = ModelTrainer(
196 | model=ws_model,
197 | train_dataset=train_dataset,
198 | val_dataset=test_dataset,
199 | batch_size=args.batch_size,
200 | learning_rate=1e-4,
201 | num_epochs=args.num_epochs,
202 | checkpoint_dir=models_dir,
203 | device=device
204 | )
205 |
206 | ws_history = ws_trainer.train(log_dir=logs_dir / "warm_start_predictor")
207 |
208 | # Save model
209 | ws_model.save(ws_model_file)
210 | print(f"Saved warm start predictor to {ws_model_file}")
211 |
212 | # 5. Benchmark against baseline
213 | print("Benchmarking against baseline")
214 | solver = OSQPSolver()
215 |
216 | # List to store results
217 | baseline_times = []
218 | warmstart_only_times = []
219 | constraint_only_times = []
220 | transformer_times = []
221 | fallback_flags = []
222 |
223 | # Test on a subset (100 problems) for visualization
224 | test_subset = np.random.choice(len(test_dataset), size=100, replace=False)
225 |
226 | print("Testing on 100 problems from the test set")
227 | for idx in tqdm(test_subset):
228 | # Get problem
229 | sample = test_dataset[idx]
230 | problem = test_dataset.get_problem(idx)
231 |
232 | # Get features
233 | features = sample['features']
234 |
235 | # Get ground truth
236 | true_active = sample['active_constraints'].numpy()
237 | true_solution = sample['solution'].numpy()
238 |
239 | # Predict active constraints
240 | pred_active = cp_model.predict(features)[0]
241 |
242 | # Predict warm start
243 | pred_solution = ws_model.predict(features)[0]
244 |
245 | # Baseline (OSQP without transformers)
246 | _, baseline_time = solver.solve_with_time(
247 | Q=problem.Q,
248 | c=problem.c,
249 | A=problem.A,
250 | b=problem.b
251 | )
252 | baseline_times.append(baseline_time)
253 |
254 | # Warm start only
255 | _, warmstart_time = solver.solve_with_time(
256 | Q=problem.Q,
257 | c=problem.c,
258 | A=problem.A,
259 | b=problem.b,
260 | warm_start=pred_solution
261 | )
262 | warmstart_only_times.append(warmstart_time)
263 |
264 | # Constraint only
265 | _, is_feasible, constraint_time = solver.solve_reduced_with_time(
266 | Q=problem.Q,
267 | c=problem.c,
268 | A=problem.A,
269 | b=problem.b,
270 | active_constraints=pred_active
271 | )
272 | constraint_only_times.append(constraint_time)
273 |
274 | # Full transformer pipeline
275 | _, transformer_time, used_fallback = solver.solve_pipeline(
276 | Q=problem.Q,
277 | c=problem.c,
278 | A=problem.A,
279 | b=problem.b,
280 | active_constraints=pred_active,
281 | warm_start=pred_solution,
282 | fallback_on_violation=True
283 | )
284 | transformer_times.append(transformer_time)
285 | fallback_flags.append(used_fallback)
286 |
287 | # Convert to numpy arrays
288 | baseline_times = np.array(baseline_times)
289 | warmstart_only_times = np.array(warmstart_only_times)
290 | constraint_only_times = np.array(constraint_only_times)
291 | transformer_times = np.array(transformer_times)
292 |
293 | # 6. Compute and print metrics
294 | solve_metrics = compute_solve_time_metrics(baseline_times, transformer_times)
295 | fallback_rate = compute_fallback_rate(fallback_flags)
296 |
297 | print("\nSolve Time Metrics:")
298 | print(f"Mean baseline time: {solve_metrics['mean_baseline_time']:.6f}s")
299 | print(f"Mean transformer time: {solve_metrics['mean_transformer_time']:.6f}s")
300 | print(f"Mean speedup: {solve_metrics['mean_speedup']:.2f}x")
301 | print(f"Median speedup: {solve_metrics['median_speedup']:.2f}x")
302 | print(f"Fallback rate: {fallback_rate:.2f}%")
303 |
304 | # 7. Plot results
305 | print("\nPlotting results")
306 |
307 | # Plot solve time comparison
308 | plot_solve_time_comparison(
309 | baseline_times=baseline_times,
310 | transformer_times=transformer_times,
311 | save_path=results_dir / "solve_time_comparison.png"
312 | )
313 |
314 | # Plot boxplot
315 | plot_solve_time_boxplot(
316 | baseline_times=baseline_times,
317 | transformer_times=transformer_times,
318 | constraint_only_times=constraint_only_times,
319 | warmstart_only_times=warmstart_only_times,
320 | save_path=results_dir / "solve_time_boxplot.png"
321 | )
322 |
323 | # Plot fallback statistics
324 | fallback_rates = {
325 | "Transformer-MPC": fallback_rate
326 | }
327 | plot_fallback_statistics(
328 | fallback_rates=fallback_rates,
329 | save_path=results_dir / "fallback_rates.png"
330 | )
331 |
332 | print(f"\nResults saved to {results_dir}")
333 |
334 | # Save main result plot for README
335 | plot_solve_time_comparison(
336 | baseline_times=baseline_times,
337 | transformer_times=transformer_times,
338 | save_path=output_dir.parent / "benchmarking_results.png"
339 | )
340 | print(f"Main benchmarking result saved to {output_dir.parent / 'benchmarking_results.png'}")
341 |
342 |
343 | if __name__ == "__main__":
344 | main()
345 |
--------------------------------------------------------------------------------
/transformermpc/data/qp_generator.py:
--------------------------------------------------------------------------------
1 | """
2 | QP problem generator module.
3 |
4 | This module provides classes for generating and representing QP problems
5 | for training and testing the transformer models.
6 | """
7 |
8 | import numpy as np
9 | import pandas as pd
10 | import os
11 | import time
12 | from dataclasses import dataclass
13 | from typing import List, Tuple, Dict, Optional, Union
14 |
15 |
16 | @dataclass
17 | class QPProblem:
18 | """
19 | A class representing a Quadratic Programming problem.
20 |
21 | A QP problem is defined as:
22 | minimize 0.5 * x^T Q x + c^T x
23 | subject to A x <= b
24 |
25 | Attributes:
26 | -----------
27 | Q : numpy.ndarray
28 | Quadratic cost matrix (n x n)
29 | c : numpy.ndarray
30 | Linear cost vector (n)
31 | A : numpy.ndarray
32 | Constraint matrix (m x n)
33 | b : numpy.ndarray
34 | Constraint vector (m)
35 | initial_state : numpy.ndarray, optional
36 | Initial state for MPC problems
37 | reference : numpy.ndarray, optional
38 | Reference trajectory for MPC problems
39 | metadata : dict, optional
40 | Additional problem-specific information
41 | """
42 | Q: np.ndarray
43 | c: np.ndarray
44 | A: np.ndarray
45 | b: np.ndarray
46 | initial_state: Optional[np.ndarray] = None
47 | reference: Optional[np.ndarray] = None
48 | metadata: Optional[Dict] = None
49 |
50 | @property
51 | def n_variables(self) -> int:
52 | """Number of decision variables."""
53 | return self.Q.shape[0]
54 |
55 | @property
56 | def n_constraints(self) -> int:
57 | """Number of constraints."""
58 | return self.A.shape[0]
59 |
60 | def to_dict(self) -> Dict:
61 | """Convert to dictionary representation."""
62 | return {
63 | "Q": self.Q,
64 | "c": self.c,
65 | "A": self.A,
66 | "b": self.b,
67 | "initial_state": self.initial_state,
68 | "reference": self.reference,
69 | "metadata": self.metadata
70 | }
71 |
72 | @classmethod
73 | def from_dict(cls, data: Dict) -> 'QPProblem':
74 | """Create QPProblem from dictionary."""
75 | return cls(**data)
76 |
77 |
78 | class QPGenerator:
79 | """
80 | Generator class for creating QP problem instances.
81 |
82 | This class generates synthetic QP problems for training and testing
83 | the transformer models.
84 | """
85 |
86 | def __init__(self,
87 | state_dim: int = 4,
88 | input_dim: int = 2,
89 | horizon: int = 10,
90 | num_samples: int = 20000,
91 | seed: Optional[int] = None):
92 | """
93 | Initialize the QP generator.
94 |
95 | Parameters:
96 | -----------
97 | state_dim : int
98 | Dimension of state space for MPC problems
99 | input_dim : int
100 | Dimension of input space for MPC problems
101 | horizon : int
102 | Time horizon for MPC problems
103 | num_samples : int
104 | Number of QP problems to generate
105 | seed : int or None
106 | Random seed for reproducibility
107 | """
108 | self.state_dim = state_dim
109 | self.input_dim = input_dim
110 | self.horizon = horizon
111 | self.num_samples = num_samples
112 |
113 | # Set random seed if provided
114 | if seed is not None:
115 | np.random.seed(seed)
116 |
117 | def generate(self) -> List[QPProblem]:
118 | """
119 | Generate QP problems.
120 |
121 | Returns:
122 | --------
123 | problems : List[QPProblem]
124 | List of generated QP problems
125 | """
126 | problems = []
127 |
128 | for i in range(self.num_samples):
129 | # Generate random system matrices for an MPC problem
130 | A_dyn, B_dyn = self._generate_dynamics()
131 |
132 | # Generate cost matrices
133 | Q_state = self._generate_state_cost()
134 | R_input = self._generate_input_cost()
135 |
136 | # Generate constraints
137 | state_constraints = self._generate_state_constraints()
138 | input_constraints = self._generate_input_constraints()
139 |
140 | # Generate random initial state and reference
141 | initial_state = np.random.randn(self.state_dim)
142 | reference = np.random.randn(self.horizon * self.state_dim)
143 |
144 | # Create the QP matrices for the MPC problem
145 | Q, c, A, b = self._create_mpc_matrices(
146 | A_dyn, B_dyn, Q_state, R_input,
147 | state_constraints, input_constraints,
148 | initial_state, reference
149 | )
150 |
151 | # Create the QPProblem instance
152 | problem = QPProblem(
153 | Q=Q,
154 | c=c,
155 | A=A,
156 | b=b,
157 | initial_state=initial_state,
158 | reference=reference,
159 | metadata={
160 | "type": "mpc",
161 | "state_dim": self.state_dim,
162 | "input_dim": self.input_dim,
163 | "horizon": self.horizon,
164 | "A_dynamics": A_dyn,
165 | "B_dynamics": B_dyn
166 | }
167 | )
168 |
169 | problems.append(problem)
170 |
171 | return problems
172 |
173 | def _generate_dynamics(self) -> Tuple[np.ndarray, np.ndarray]:
174 | """
175 | Generate random discrete-time system dynamics matrices.
176 |
177 | Returns:
178 | --------
179 | A : numpy.ndarray
180 | State transition matrix (state_dim x state_dim)
181 | B : numpy.ndarray
182 | Input matrix (state_dim x input_dim)
183 | """
184 | # Generate a random discrete-time system
185 | A = np.random.randn(self.state_dim, self.state_dim)
186 | # Scale to make it stable
187 | eigenvalues, _ = np.linalg.eig(A)
188 | max_eig = np.max(np.abs(eigenvalues))
189 | A = A / (max_eig * 1.1) # Scale to ensure stability
190 |
191 | B = np.random.randn(self.state_dim, self.input_dim)
192 |
193 | return A, B
194 |
195 | def _generate_state_cost(self) -> np.ndarray:
196 | """
197 | Generate state cost matrix.
198 |
199 | Returns:
200 | --------
201 | Q : numpy.ndarray
202 | State cost matrix (state_dim x state_dim)
203 | """
204 | Q_diag = np.abs(np.random.randn(self.state_dim))
205 | Q = np.diag(Q_diag)
206 | return Q
207 |
208 | def _generate_input_cost(self) -> np.ndarray:
209 | """
210 | Generate input cost matrix.
211 |
212 | Returns:
213 | --------
214 | R : numpy.ndarray
215 | Input cost matrix (input_dim x input_dim)
216 | """
217 | R_diag = np.abs(np.random.randn(self.input_dim))
218 | R = np.diag(R_diag)
219 | return R
220 |
221 | def _generate_state_constraints(self) -> Tuple[np.ndarray, np.ndarray]:
222 | """
223 | Generate state constraints.
224 |
225 | Returns:
226 | --------
227 | A_state : numpy.ndarray
228 | State constraint matrix (2*state_dim x state_dim)
229 | b_state : numpy.ndarray
230 | State constraint vector (2*state_dim)
231 | """
232 | # Simple box constraints on states
233 | A_state = np.vstack([np.eye(self.state_dim), -np.eye(self.state_dim)])
234 |
235 | # Random upper and lower bounds
236 | upper_bounds = np.abs(np.random.rand(self.state_dim) * 5 + 5) # Random bounds between 5 and 10
237 | lower_bounds = np.abs(np.random.rand(self.state_dim) * 5 + 5) # Random bounds between 5 and 10
238 |
239 | b_state = np.concatenate([upper_bounds, lower_bounds])
240 |
241 | return A_state, b_state
242 |
243 | def _generate_input_constraints(self) -> Tuple[np.ndarray, np.ndarray]:
244 | """
245 | Generate input constraints.
246 |
247 | Returns:
248 | --------
249 | A_input : numpy.ndarray
250 | Input constraint matrix (2*input_dim x input_dim)
251 | b_input : numpy.ndarray
252 | Input constraint vector (2*input_dim)
253 | """
254 | # Simple box constraints on inputs
255 | A_input = np.vstack([np.eye(self.input_dim), -np.eye(self.input_dim)])
256 |
257 | # Random upper and lower bounds
258 | upper_bounds = np.abs(np.random.rand(self.input_dim) * 2 + 1) # Random bounds between 1 and 3
259 | lower_bounds = np.abs(np.random.rand(self.input_dim) * 2 + 1) # Random bounds between 1 and 3
260 |
261 | b_input = np.concatenate([upper_bounds, lower_bounds])
262 |
263 | return A_input, b_input
264 |
265 | def _create_mpc_matrices(self,
266 | A_dyn: np.ndarray,
267 | B_dyn: np.ndarray,
268 | Q_state: np.ndarray,
269 | R_input: np.ndarray,
270 | state_constraints: Tuple[np.ndarray, np.ndarray],
271 | input_constraints: Tuple[np.ndarray, np.ndarray],
272 | initial_state: np.ndarray,
273 | reference: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
274 | """
275 | Create QP matrices for an MPC problem.
276 |
277 | Parameters:
278 | -----------
279 | A_dyn : numpy.ndarray
280 | State transition matrix
281 | B_dyn : numpy.ndarray
282 | Input matrix
283 | Q_state : numpy.ndarray
284 | State cost matrix
285 | R_input : numpy.ndarray
286 | Input cost matrix
287 | state_constraints : Tuple[numpy.ndarray, numpy.ndarray]
288 | State constraint matrices (A_state, b_state)
289 | input_constraints : Tuple[numpy.ndarray, numpy.ndarray]
290 | Input constraint matrices (A_input, b_input)
291 | initial_state : numpy.ndarray
292 | Initial state
293 | reference : numpy.ndarray
294 | Reference trajectory
295 |
296 | Returns:
297 | --------
298 | Q : numpy.ndarray
299 | QP cost matrix
300 | c : numpy.ndarray
301 | QP linear cost vector
302 | A : numpy.ndarray
303 | QP constraint matrix
304 | b : numpy.ndarray
305 | QP constraint vector
306 | """
307 | # Extract constraint matrices
308 | A_state, b_state = state_constraints
309 | A_input, b_input = input_constraints
310 |
311 | # Problem dimensions
312 | nx = self.state_dim
313 | nu = self.input_dim
314 | N = self.horizon
315 |
316 | # Total number of decision variables
317 | n_vars = N * nu
318 |
319 | # Construct the cost matrices
320 | Q_block = np.zeros((n_vars, n_vars))
321 | for i in range(N):
322 | idx = i * nu
323 | Q_block[idx:idx+nu, idx:idx+nu] = R_input
324 |
325 | # Compute the prediction matrices
326 | x_pred = [initial_state]
327 | for k in range(N):
328 | x_next = A_dyn @ x_pred[-1]
329 | x_pred.append(x_next)
330 |
331 | # Construct the linear cost term
332 | c = np.zeros(n_vars)
333 |
334 | # Construct constraint matrices
335 | # Initial state constraint is already handled in prediction
336 |
337 | # Input constraints for each time step
338 | A_in_list = []
339 | b_in_list = []
340 |
341 | for i in range(N):
342 | # Input constraints: A_input * u <= b_input
343 | A_i = np.zeros((A_input.shape[0], n_vars))
344 | idx = i * nu
345 | A_i[:, idx:idx+nu] = A_input
346 | A_in_list.append(A_i)
347 | b_in_list.append(b_input)
348 |
349 | # Combine all constraints
350 | A = np.vstack(A_in_list) if A_in_list else np.zeros((0, n_vars))
351 | b = np.concatenate(b_in_list) if b_in_list else np.zeros(0)
352 |
353 | return Q_block, c, A, b
354 |
355 | def save(self, problems: List[QPProblem], filepath: str) -> None:
356 | """
357 | Save generated problems to a file.
358 |
359 | Parameters:
360 | -----------
361 | problems : List[QPProblem]
362 | List of QP problems to save
363 | filepath : str
364 | Path to save the problems
365 | """
366 | # Convert problems to dictionaries
367 | data = [problem.to_dict() for problem in problems]
368 |
369 | # Save using numpy
370 | np.save(filepath, data, allow_pickle=True)
371 |
372 | @staticmethod
373 | def load(filepath: str) -> List[QPProblem]:
374 | """
375 | Load problems from a file.
376 |
377 | Parameters:
378 | -----------
379 | filepath : str
380 | Path to load the problems from
381 |
382 | Returns:
383 | --------
384 | problems : List[QPProblem]
385 | List of loaded QP problems
386 | """
387 | # Load data from file
388 | data = np.load(filepath, allow_pickle=True)
389 |
390 | # Convert dictionaries to QPProblem instances
391 | problems = [QPProblem.from_dict(item) for item in data]
392 |
393 | return problems
394 |
--------------------------------------------------------------------------------
/transformermpc/models/constraint_predictor.py:
--------------------------------------------------------------------------------
1 | """
2 | Constraint predictor model.
3 |
4 | This module defines the transformer-based model for predicting
5 | which constraints are active in QP problems.
6 | """
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | import numpy as np
12 | import os
13 | from typing import Dict, Optional, Tuple, Union, List, Any
14 |
15 |
16 | class TransformerEncoder(nn.Module):
17 | """Vanilla Transformer Encoder implementation."""
18 |
19 | def __init__(self,
20 | hidden_dim: int = 128,
21 | num_layers: int = 4,
22 | num_heads: int = 8,
23 | dropout: float = 0.1):
24 | """
25 | Initialize the transformer encoder.
26 |
27 | Parameters:
28 | -----------
29 | hidden_dim : int
30 | Dimension of hidden layers
31 | num_layers : int
32 | Number of transformer layers
33 | num_heads : int
34 | Number of attention heads
35 | dropout : float
36 | Dropout probability
37 | """
38 | super().__init__()
39 |
40 | # Make sure hidden_dim is divisible by num_heads
41 | assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads"
42 |
43 | # Create encoder layer
44 | encoder_layer = nn.TransformerEncoderLayer(
45 | d_model=hidden_dim,
46 | nhead=num_heads,
47 | dim_feedforward=4 * hidden_dim,
48 | dropout=dropout,
49 | activation="relu",
50 | batch_first=True
51 | )
52 |
53 | # Create encoder
54 | self.encoder = nn.TransformerEncoder(
55 | encoder_layer=encoder_layer,
56 | num_layers=num_layers
57 | )
58 |
59 | # Positional encoding
60 | self.pos_encoding = PositionalEncoding(
61 | d_model=hidden_dim,
62 | dropout=dropout,
63 | max_len=100
64 | )
65 |
66 | def forward(self, x: torch.Tensor) -> torch.Tensor:
67 | """
68 | Forward pass.
69 |
70 | Parameters:
71 | -----------
72 | x : torch.Tensor
73 | Input tensor of shape (batch_size, seq_len, hidden_dim)
74 |
75 | Returns:
76 | --------
77 | output : torch.Tensor
78 | Output tensor of shape (batch_size, seq_len, hidden_dim)
79 | """
80 | # Add positional encoding
81 | x = self.pos_encoding(x)
82 |
83 | # Pass through encoder
84 | output = self.encoder(x)
85 |
86 | return output
87 |
88 |
89 | class PositionalEncoding(nn.Module):
90 | """Positional encoding for Transformer models."""
91 |
92 | def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 100):
93 | """
94 | Initialize the positional encoding.
95 |
96 | Parameters:
97 | -----------
98 | d_model : int
99 | Dimension of the model
100 | dropout : float
101 | Dropout probability
102 | max_len : int
103 | Maximum sequence length
104 | """
105 | super().__init__()
106 | self.dropout = nn.Dropout(p=dropout)
107 |
108 | # Create positional encoding
109 | position = torch.arange(max_len).unsqueeze(1)
110 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
111 | pe = torch.zeros(1, max_len, d_model)
112 | pe[0, :, 0::2] = torch.sin(position * div_term)
113 | pe[0, :, 1::2] = torch.cos(position * div_term)
114 | self.register_buffer('pe', pe)
115 |
116 | def forward(self, x: torch.Tensor) -> torch.Tensor:
117 | """
118 | Forward pass.
119 |
120 | Parameters:
121 | -----------
122 | x : torch.Tensor
123 | Input tensor of shape (batch_size, seq_len, d_model)
124 |
125 | Returns:
126 | --------
127 | output : torch.Tensor
128 | Output tensor with positional encoding added
129 | """
130 | x = x + self.pe[:, :x.size(1), :]
131 | return self.dropout(x)
132 |
133 |
134 | class ConstraintPredictor(nn.Module):
135 | """
136 | Transformer-based model for predicting active constraints.
137 |
138 | This model takes QP problem features as input and outputs a binary vector
139 | indicating which constraints are active.
140 | """
141 |
142 | def __init__(self,
143 | input_dim: int = 50,
144 | hidden_dim: int = 128,
145 | num_constraints: int = 100,
146 | num_layers: int = 4,
147 | num_heads: int = 8,
148 | dropout: float = 0.1):
149 | """
150 | Initialize the constraint predictor model.
151 |
152 | Parameters:
153 | -----------
154 | input_dim : int
155 | Dimension of input features
156 | hidden_dim : int
157 | Dimension of hidden layers
158 | num_constraints : int
159 | Maximum number of constraints to predict
160 | num_layers : int
161 | Number of transformer layers
162 | num_heads : int
163 | Number of attention heads
164 | dropout : float
165 | Dropout probability
166 | """
167 | super().__init__()
168 |
169 | self.input_dim = input_dim
170 | self.hidden_dim = hidden_dim
171 | self.num_constraints = num_constraints
172 | self.num_layers = num_layers
173 | self.num_heads = num_heads
174 | self.dropout = dropout
175 |
176 | # Input projection
177 | self.input_projection = nn.Linear(input_dim, hidden_dim)
178 |
179 | # Transformer encoder
180 | self.transformer = TransformerEncoder(
181 | hidden_dim=hidden_dim,
182 | num_layers=num_layers,
183 | num_heads=num_heads,
184 | dropout=dropout
185 | )
186 |
187 | # Output projection
188 | self.output_projection = nn.Linear(hidden_dim, num_constraints)
189 |
190 | def forward(self, x: torch.Tensor) -> torch.Tensor:
191 | """
192 | Forward pass.
193 |
194 | Parameters:
195 | -----------
196 | x : torch.Tensor
197 | Input tensor of shape (batch_size, input_dim)
198 |
199 | Returns:
200 | --------
201 | output : torch.Tensor
202 | Output tensor of shape (batch_size, num_constraints)
203 | containing probabilities of constraints being active
204 | """
205 | # Input projection
206 | x = self.input_projection(x)
207 |
208 | # Reshape for transformer: [batch_size, seq_len, hidden_dim]
209 | # Here we use a sequence length of 1
210 | x = x.unsqueeze(1)
211 |
212 | # Pass through transformer
213 | transformer_output = self.transformer(x)
214 |
215 | # Extract first token output
216 | first_token = transformer_output[:, 0, :]
217 |
218 | # Output projection
219 | logits = self.output_projection(first_token)
220 |
221 | # Apply sigmoid to get probabilities
222 | probs = torch.sigmoid(logits)
223 |
224 | return probs
225 |
226 | def predict(self, x: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
227 | """
228 | Make prediction on input data.
229 |
230 | Parameters:
231 | -----------
232 | x : torch.Tensor or numpy.ndarray
233 | Input features
234 |
235 | Returns:
236 | --------
237 | prediction : numpy.ndarray
238 | Binary prediction of active constraints
239 | """
240 | # Convert to tensor if numpy array
241 | if isinstance(x, np.ndarray):
242 | x = torch.tensor(x, dtype=torch.float32)
243 |
244 | # Make sure the input has the right shape
245 | if x.dim() == 1:
246 | x = x.unsqueeze(0) # Add batch dimension
247 |
248 | # Set model to evaluation mode
249 | self.eval()
250 |
251 | # Disable gradient computation
252 | with torch.no_grad():
253 | # Forward pass
254 | probs = self(x)
255 |
256 | # Convert to binary predictions
257 | binary_pred = (probs > 0.5).float()
258 |
259 | # Convert to numpy
260 | binary_pred = binary_pred.cpu().numpy()
261 |
262 | return binary_pred
263 |
264 | def save(self, filepath: str) -> None:
265 | """
266 | Save model to file.
267 |
268 | Parameters:
269 | -----------
270 | filepath : str
271 | Path to save the model
272 | """
273 | # Create directory if it doesn't exist
274 | os.makedirs(os.path.dirname(filepath), exist_ok=True)
275 |
276 | # Save model state
277 | torch.save({
278 | 'state_dict': self.state_dict(),
279 | 'input_dim': self.input_dim,
280 | 'hidden_dim': self.hidden_dim,
281 | 'num_constraints': self.num_constraints,
282 | 'num_layers': self.num_layers,
283 | 'num_heads': self.num_heads,
284 | 'dropout': self.dropout
285 | }, filepath)
286 |
287 | @classmethod
288 | def load(cls, filepath: Optional[str] = None) -> 'ConstraintPredictor':
289 | """
290 | Load model from file.
291 |
292 | Parameters:
293 | -----------
294 | filepath : str or None
295 | Path to load the model from, or None to create a new model
296 |
297 | Returns:
298 | --------
299 | model : ConstraintPredictor
300 | Loaded model
301 | """
302 | if filepath is None or not os.path.exists(filepath):
303 | # Return a new model with default parameters
304 | return cls()
305 |
306 | # Load model state
307 | checkpoint = torch.load(filepath, map_location=torch.device('cpu'))
308 |
309 | # Create model with saved parameters
310 | model = cls(
311 | input_dim=checkpoint['input_dim'],
312 | hidden_dim=checkpoint['hidden_dim'],
313 | num_constraints=checkpoint['num_constraints'],
314 | num_layers=checkpoint.get('num_layers', 4),
315 | num_heads=checkpoint.get('num_heads', 8),
316 | dropout=checkpoint.get('dropout', 0.1)
317 | )
318 |
319 | # Load state dictionary
320 | model.load_state_dict(checkpoint['state_dict'])
321 |
322 | return model
323 |
324 | def train_step(self,
325 | x: torch.Tensor,
326 | y: torch.Tensor,
327 | optimizer: torch.optim.Optimizer) -> Dict[str, float]:
328 | """
329 | Perform a single training step.
330 |
331 | Parameters:
332 | -----------
333 | x : torch.Tensor
334 | Input tensor of shape (batch_size, input_dim)
335 | y : torch.Tensor
336 | Target tensor of shape (batch_size, num_constraints)
337 | optimizer : torch.optim.Optimizer
338 | Optimizer to use for the step
339 |
340 | Returns:
341 | --------
342 | metrics : Dict[str, float]
343 | Dictionary containing loss and accuracy
344 | """
345 | # Set model to training mode
346 | self.train()
347 |
348 | # Zero the gradients
349 | optimizer.zero_grad()
350 |
351 | # Forward pass
352 | probs = self(x)
353 |
354 | # Compute loss (binary cross entropy)
355 | loss = F.binary_cross_entropy(probs, y)
356 |
357 | # Backward pass
358 | loss.backward()
359 |
360 | # Update parameters
361 | optimizer.step()
362 |
363 | # Compute accuracy
364 | binary_pred = (probs > 0.5).float()
365 | accuracy = (binary_pred == y).float().mean().item()
366 |
367 | return {
368 | 'loss': loss.item(),
369 | 'accuracy': accuracy
370 | }
371 |
372 | def validate(self,
373 | x: torch.Tensor,
374 | y: torch.Tensor) -> Dict[str, float]:
375 | """
376 | Validate the model on validation data.
377 |
378 | Parameters:
379 | -----------
380 | x : torch.Tensor
381 | Input tensor of shape (batch_size, input_dim)
382 | y : torch.Tensor
383 | Target tensor of shape (batch_size, num_constraints)
384 |
385 | Returns:
386 | --------
387 | metrics : Dict[str, float]
388 | Dictionary containing loss and accuracy
389 | """
390 | # Set model to evaluation mode
391 | self.eval()
392 |
393 | # Disable gradient computation
394 | with torch.no_grad():
395 | # Forward pass
396 | probs = self(x)
397 |
398 | # Compute loss (binary cross entropy)
399 | loss = F.binary_cross_entropy(probs, y)
400 |
401 | # Compute accuracy
402 | binary_pred = (probs > 0.5).float()
403 | accuracy = (binary_pred == y).float().mean().item()
404 |
405 | # Compute precision, recall, and F1 score
406 | binary_pred_flat = binary_pred.flatten().cpu().numpy()
407 | y_flat = y.flatten().cpu().numpy()
408 |
409 | # Compute TP, FP, TN, FN
410 | tp = ((binary_pred_flat == 1) & (y_flat == 1)).sum()
411 | fp = ((binary_pred_flat == 1) & (y_flat == 0)).sum()
412 | fn = ((binary_pred_flat == 0) & (y_flat == 1)).sum()
413 |
414 | # Compute precision, recall, and F1
415 | precision = tp / (tp + fp) if tp + fp > 0 else 0.0
416 | recall = tp / (tp + fn) if tp + fn > 0 else 0.0
417 | f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0.0
418 |
419 | return {
420 | 'loss': loss.item(),
421 | 'accuracy': accuracy,
422 | 'precision': precision,
423 | 'recall': recall,
424 | 'f1': f1
425 | }
426 |
--------------------------------------------------------------------------------
/transformermpc/utils/visualization.py:
--------------------------------------------------------------------------------
1 | """
2 | Visualization module for TransformerMPC.
3 |
4 | This module provides functions for visualizing training progress,
5 | evaluation metrics, and benchmarking results.
6 | """
7 |
8 | import matplotlib.pyplot as plt
9 | import numpy as np
10 | from typing import List, Dict, Tuple, Optional, Union, Any
11 | import os
12 |
13 |
14 | def plot_training_history(history: Dict[str, List[float]],
15 | save_path: Optional[str] = None,
16 | show: bool = True) -> None:
17 | """
18 | Plot training history metrics.
19 |
20 | Parameters:
21 | -----------
22 | history : Dict[str, List[float]]
23 | Dictionary containing training metrics
24 | save_path : str or None
25 | Path to save the plot
26 | show : bool
27 | Whether to display the plot
28 | """
29 | fig, axes = plt.subplots(nrows=len(history), figsize=(10, 3*len(history)))
30 |
31 | # Handle case with only one metric
32 | if len(history) == 1:
33 | axes = [axes]
34 |
35 | for i, (metric_name, values) in enumerate(history.items()):
36 | axes[i].plot(values)
37 | axes[i].set_title(metric_name)
38 | axes[i].set_xlabel('Epoch')
39 | axes[i].set_ylabel(metric_name)
40 | axes[i].grid(True)
41 |
42 | plt.tight_layout()
43 |
44 | if save_path is not None:
45 | plt.savefig(save_path)
46 |
47 | if show:
48 | plt.show()
49 | else:
50 | plt.close()
51 |
52 |
53 | def plot_constraint_prediction_metrics(precision: List[float],
54 | recall: List[float],
55 | f1: List[float],
56 | accuracy: List[float],
57 | epochs: List[int],
58 | save_path: Optional[str] = None,
59 | show: bool = True) -> None:
60 | """
61 | Plot constraint prediction model metrics.
62 |
63 | Parameters:
64 | -----------
65 | precision, recall, f1, accuracy : List[float]
66 | Lists of metric values
67 | epochs : List[int]
68 | List of epoch numbers
69 | save_path : str or None
70 | Path to save the plot
71 | show : bool
72 | Whether to display the plot
73 | """
74 | fig, ax = plt.subplots(figsize=(10, 6))
75 |
76 | ax.plot(epochs, precision, 'b-', label='Precision')
77 | ax.plot(epochs, recall, 'g-', label='Recall')
78 | ax.plot(epochs, f1, 'r-', label='F1 Score')
79 | ax.plot(epochs, accuracy, 'k-', label='Accuracy')
80 |
81 | ax.set_xlabel('Epoch')
82 | ax.set_ylabel('Metric Value')
83 | ax.set_title('Constraint Prediction Performance Metrics')
84 | ax.legend()
85 | ax.grid(True)
86 |
87 | plt.tight_layout()
88 |
89 | if save_path is not None:
90 | plt.savefig(save_path)
91 |
92 | if show:
93 | plt.show()
94 | else:
95 | plt.close()
96 |
97 |
98 | def plot_warm_start_metrics(mse: List[float],
99 | mae: List[float],
100 | rel_error: List[float],
101 | epochs: List[int],
102 | save_path: Optional[str] = None,
103 | show: bool = True) -> None:
104 | """
105 | Plot warm start prediction model metrics.
106 |
107 | Parameters:
108 | -----------
109 | mse, mae, rel_error : List[float]
110 | Lists of metric values
111 | epochs : List[int]
112 | List of epoch numbers
113 | save_path : str or None
114 | Path to save the plot
115 | show : bool
116 | Whether to display the plot
117 | """
118 | fig, ax = plt.subplots(figsize=(10, 6))
119 |
120 | ax.plot(epochs, mse, 'b-', label='MSE')
121 | ax.plot(epochs, mae, 'g-', label='MAE')
122 | ax.plot(epochs, rel_error, 'r-', label='Relative Error')
123 |
124 | ax.set_xlabel('Epoch')
125 | ax.set_ylabel('Error')
126 | ax.set_title('Warm Start Prediction Performance Metrics')
127 | ax.legend()
128 | ax.grid(True)
129 |
130 | plt.tight_layout()
131 |
132 | if save_path is not None:
133 | plt.savefig(save_path)
134 |
135 | if show:
136 | plt.show()
137 | else:
138 | plt.close()
139 |
140 |
141 | def plot_solve_time_comparison(baseline_times: np.ndarray,
142 | transformer_times: np.ndarray,
143 | problem_sizes: Optional[np.ndarray] = None,
144 | save_path: Optional[str] = None,
145 | show: bool = True,
146 | log_scale: bool = False) -> None:
147 | """
148 | Plot comparison of solve times between baseline and transformer-enhanced approach.
149 |
150 | Parameters:
151 | -----------
152 | baseline_times : numpy.ndarray
153 | Array of solve times for baseline approach
154 | transformer_times : numpy.ndarray
155 | Array of solve times for transformer-enhanced approach
156 | problem_sizes : numpy.ndarray or None
157 | Array of problem sizes (e.g., number of variables or constraints)
158 | save_path : str or None
159 | Path to save the plot
160 | show : bool
161 | Whether to display the plot
162 | log_scale : bool
163 | Whether to use log scale for the y-axis
164 | """
165 | fig, ax = plt.subplots(figsize=(10, 6))
166 |
167 | if problem_sizes is not None:
168 | # Plot against problem sizes
169 | ax.plot(problem_sizes, baseline_times, 'bo-', label='OSQP Baseline')
170 | ax.plot(problem_sizes, transformer_times, 'ro-', label='TransformerMPC')
171 | ax.set_xlabel('Problem Size')
172 | else:
173 | # Plot histograms
174 | ax.hist(baseline_times, bins=30, alpha=0.5, label='OSQP Baseline')
175 | ax.hist(transformer_times, bins=30, alpha=0.5, label='TransformerMPC')
176 | ax.set_xlabel('Solve Time (seconds)')
177 |
178 | if log_scale:
179 | ax.set_yscale('log')
180 |
181 | ax.set_ylabel('Solve Time (seconds)' if problem_sizes is not None else 'Frequency')
182 | ax.set_title('Comparison of Solve Times: OSQP vs TransformerMPC')
183 | ax.legend()
184 | ax.grid(True)
185 |
186 | # Add speedup statistics
187 | mean_speedup = np.mean(baseline_times / transformer_times)
188 | median_speedup = np.median(baseline_times / transformer_times)
189 | max_speedup = np.max(baseline_times / transformer_times)
190 |
191 | stats_text = f"Mean Speedup: {mean_speedup:.2f}x\n"
192 | stats_text += f"Median Speedup: {median_speedup:.2f}x\n"
193 | stats_text += f"Max Speedup: {max_speedup:.2f}x"
194 |
195 | plt.figtext(0.02, 0.02, stats_text, fontsize=10, bbox=dict(facecolor='white', alpha=0.8))
196 |
197 | plt.tight_layout()
198 |
199 | if save_path is not None:
200 | plt.savefig(save_path)
201 |
202 | if show:
203 | plt.show()
204 | else:
205 | plt.close()
206 |
207 |
208 | def plot_solve_time_boxplot(baseline_times: np.ndarray,
209 | transformer_times: np.ndarray,
210 | constraint_only_times: Optional[np.ndarray] = None,
211 | warmstart_only_times: Optional[np.ndarray] = None,
212 | save_path: Optional[str] = None,
213 | show: bool = True) -> None:
214 | """
215 | Create boxplot comparison of solve times for different methods.
216 |
217 | Parameters:
218 | -----------
219 | baseline_times : numpy.ndarray
220 | Array of solve times for baseline approach
221 | transformer_times : numpy.ndarray
222 | Array of solve times for full transformer-enhanced approach
223 | constraint_only_times : numpy.ndarray or None
224 | Array of solve times using only constraint prediction
225 | warmstart_only_times : numpy.ndarray or None
226 | Array of solve times using only warm start prediction
227 | save_path : str or None
228 | Path to save the plot
229 | show : bool
230 | Whether to display the plot
231 | """
232 | fig, ax = plt.subplots(figsize=(10, 6))
233 |
234 | data = []
235 | labels = []
236 |
237 | # Always include baseline and full transformer
238 | data.append(baseline_times)
239 | labels.append('OSQP Baseline')
240 |
241 | # Include constraint-only if provided
242 | if constraint_only_times is not None:
243 | data.append(constraint_only_times)
244 | labels.append('Constraint Prediction Only')
245 |
246 | # Include warmstart-only if provided
247 | if warmstart_only_times is not None:
248 | data.append(warmstart_only_times)
249 | labels.append('Warm Start Only')
250 |
251 | # Always include full transformer
252 | data.append(transformer_times)
253 | labels.append('Full TransformerMPC')
254 |
255 | # Create the boxplot
256 | ax.boxplot(data, labels=labels, showfliers=True)
257 |
258 | ax.set_ylabel('Solve Time (seconds)')
259 | ax.set_title('Comparison of Solve Times Across Methods')
260 | ax.grid(True, axis='y')
261 |
262 | plt.tight_layout()
263 |
264 | if save_path is not None:
265 | plt.savefig(save_path)
266 |
267 | if show:
268 | plt.show()
269 | else:
270 | plt.close()
271 |
272 |
273 | def plot_active_constraints_histogram(active_percentages: np.ndarray,
274 | save_path: Optional[str] = None,
275 | show: bool = True) -> None:
276 | """
277 | Plot histogram of percentage of active constraints.
278 |
279 | Parameters:
280 | -----------
281 | active_percentages : numpy.ndarray
282 | Array of percentages of active constraints for each problem
283 | save_path : str or None
284 | Path to save the plot
285 | show : bool
286 | Whether to display the plot
287 | """
288 | fig, ax = plt.subplots(figsize=(10, 6))
289 |
290 | ax.hist(active_percentages, bins=50, alpha=0.75)
291 | ax.axvline(np.mean(active_percentages), color='r', linestyle='--',
292 | label=f'Mean: {np.mean(active_percentages):.2f}%')
293 | ax.axvline(np.median(active_percentages), color='g', linestyle='--',
294 | label=f'Median: {np.median(active_percentages):.2f}%')
295 |
296 | ax.set_xlabel('Percentage of Active Constraints')
297 | ax.set_ylabel('Frequency')
298 | ax.set_title('Distribution of Active Constraints Percentage')
299 | ax.legend()
300 | ax.grid(True)
301 |
302 | plt.tight_layout()
303 |
304 | if save_path is not None:
305 | plt.savefig(save_path)
306 |
307 | if show:
308 | plt.show()
309 | else:
310 | plt.close()
311 |
312 |
313 | def plot_fallback_statistics(fallback_rates: Dict[str, float],
314 | save_path: Optional[str] = None,
315 | show: bool = True) -> None:
316 | """
317 | Plot statistics about fallback to full QP solve.
318 |
319 | Parameters:
320 | -----------
321 | fallback_rates : Dict[str, float]
322 | Dictionary of fallback rates for different methods
323 | save_path : str or None
324 | Path to save the plot
325 | show : bool
326 | Whether to display the plot
327 | """
328 | fig, ax = plt.subplots(figsize=(10, 6))
329 |
330 | methods = list(fallback_rates.keys())
331 | rates = list(fallback_rates.values())
332 |
333 | # Create bar plot
334 | ax.bar(methods, rates, color='skyblue')
335 |
336 | ax.set_xlabel('Method')
337 | ax.set_ylabel('Fallback Rate (%)')
338 | ax.set_title('Fallback Rates for Different Methods')
339 | ax.grid(True, axis='y')
340 |
341 | # Add value labels on bars
342 | for i, v in enumerate(rates):
343 | ax.text(i, v + 1, f"{v:.1f}%", ha='center')
344 |
345 | plt.tight_layout()
346 |
347 | if save_path is not None:
348 | plt.savefig(save_path)
349 |
350 | if show:
351 | plt.show()
352 | else:
353 | plt.close()
354 |
355 |
356 | def plot_benchmarking_results(results: Dict[str, Dict[str, Union[float, np.ndarray]]],
357 | save_path: Optional[str] = None,
358 | show: bool = True) -> None:
359 | """
360 | Plot comprehensive benchmarking results.
361 |
362 | Parameters:
363 | -----------
364 | results : Dict
365 | Dictionary containing benchmarking results
366 | save_path : str or None
367 | Path to save the plot
368 | show : bool
369 | Whether to display the plot
370 | """
371 | fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(15, 12))
372 |
373 | # Plot 1: Solve time comparison (boxplot)
374 | methods = list(results.keys())
375 | solve_times = [results[method]['solve_times'] for method in methods]
376 |
377 | axes[0, 0].boxplot(solve_times, labels=methods)
378 | axes[0, 0].set_ylabel('Solve Time (seconds)')
379 | axes[0, 0].set_title('Solve Times by Method')
380 | axes[0, 0].grid(True, axis='y')
381 |
382 | # Plot 2: Speedup factors
383 | baseline_method = methods[0] # Assume first method is baseline
384 | speedups = {}
385 |
386 | for method in methods[1:]:
387 | speedup = results[baseline_method]['mean_solve_time'] / results[method]['mean_solve_time']
388 | speedups[method] = speedup
389 |
390 | speedup_methods = list(speedups.keys())
391 | speedup_values = list(speedups.values())
392 |
393 | axes[0, 1].bar(speedup_methods, speedup_values, color='green')
394 | axes[0, 1].set_ylabel('Speedup Factor (x)')
395 | axes[0, 1].set_title(f'Speedup Relative to {baseline_method}')
396 | axes[0, 1].grid(True, axis='y')
397 |
398 | # Add value labels on bars
399 | for i, v in enumerate(speedup_values):
400 | axes[0, 1].text(i, v + 0.1, f"{v:.2f}x", ha='center')
401 |
402 | # Plot 3: Fallback rates
403 | fallback_rates = {m: results[m].get('fallback_rate', 0) for m in methods[1:]}
404 | fallback_methods = list(fallback_rates.keys())
405 | fallback_values = list(fallback_rates.values())
406 |
407 | axes[1, 0].bar(fallback_methods, fallback_values, color='orange')
408 | axes[1, 0].set_ylabel('Fallback Rate (%)')
409 | axes[1, 0].set_title('Fallback Rates by Method')
410 | axes[1, 0].grid(True, axis='y')
411 |
412 | # Add value labels on bars
413 | for i, v in enumerate(fallback_values):
414 | axes[1, 0].text(i, v + 1, f"{v:.1f}%", ha='center')
415 |
416 | # Plot 4: Accuracy metrics for constraint predictor
417 | if 'constraint_predictor' in results and 'metrics' in results['constraint_predictor']:
418 | metrics = results['constraint_predictor']['metrics']
419 | metric_names = list(metrics.keys())
420 | metric_values = list(metrics.values())
421 |
422 | axes[1, 1].bar(metric_names, metric_values, color='purple')
423 | axes[1, 1].set_ylabel('Value')
424 | axes[1, 1].set_title('Constraint Predictor Metrics')
425 | axes[1, 1].grid(True, axis='y')
426 |
427 | # Add value labels on bars
428 | for i, v in enumerate(metric_values):
429 | axes[1, 1].text(i, v + 0.02, f"{v:.3f}", ha='center')
430 | else:
431 | axes[1, 1].set_visible(False)
432 |
433 | plt.tight_layout()
434 |
435 | if save_path is not None:
436 | plt.savefig(save_path)
437 |
438 | if show:
439 | plt.show()
440 | else:
441 | plt.close()
442 |
--------------------------------------------------------------------------------
/transformermpc/models/warm_start_predictor.py:
--------------------------------------------------------------------------------
1 | """
2 | Warm start predictor model.
3 |
4 | This module defines the transformer-based model for predicting
5 | warm start solutions for QP problems.
6 | """
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | import numpy as np
12 | import os
13 | from typing import Dict, Optional, Tuple, Union, List, Any
14 |
15 |
16 | class TransformerEncoderDecoder(nn.Module):
17 | """Vanilla Transformer Encoder-Decoder implementation."""
18 |
19 | def __init__(self,
20 | hidden_dim: int = 256,
21 | output_dim: int = 20,
22 | num_layers: int = 6,
23 | num_heads: int = 8,
24 | dropout: float = 0.1):
25 | """
26 | Initialize the transformer encoder-decoder.
27 |
28 | Parameters:
29 | -----------
30 | hidden_dim : int
31 | Dimension of hidden layers
32 | output_dim : int
33 | Dimension of output
34 | num_layers : int
35 | Number of transformer layers
36 | num_heads : int
37 | Number of attention heads
38 | dropout : float
39 | Dropout probability
40 | """
41 | super().__init__()
42 |
43 | # Make sure hidden_dim is divisible by num_heads
44 | assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads"
45 |
46 | # Create encoder layer
47 | encoder_layer = nn.TransformerEncoderLayer(
48 | d_model=hidden_dim,
49 | nhead=num_heads,
50 | dim_feedforward=4 * hidden_dim,
51 | dropout=dropout,
52 | activation="relu",
53 | batch_first=True
54 | )
55 |
56 | # Create encoder
57 | self.encoder = nn.TransformerEncoder(
58 | encoder_layer=encoder_layer,
59 | num_layers=num_layers
60 | )
61 |
62 | # Create decoder layer
63 | decoder_layer = nn.TransformerDecoderLayer(
64 | d_model=hidden_dim,
65 | nhead=num_heads,
66 | dim_feedforward=4 * hidden_dim,
67 | dropout=dropout,
68 | activation="relu",
69 | batch_first=True
70 | )
71 |
72 | # Create decoder
73 | self.decoder = nn.TransformerDecoder(
74 | decoder_layer=decoder_layer,
75 | num_layers=num_layers
76 | )
77 |
78 | # Positional encoding
79 | self.pos_encoding = PositionalEncoding(
80 | d_model=hidden_dim,
81 | dropout=dropout,
82 | max_len=100
83 | )
84 |
85 | # Output projection
86 | self.output_projection = nn.Linear(hidden_dim, output_dim)
87 |
88 | def forward(self, src: torch.Tensor, tgt: Optional[torch.Tensor] = None) -> torch.Tensor:
89 | """
90 | Forward pass.
91 |
92 | Parameters:
93 | -----------
94 | src : torch.Tensor
95 | Source sequence tensor of shape (batch_size, src_seq_len, hidden_dim)
96 | tgt : torch.Tensor, optional
97 | Target sequence tensor of shape (batch_size, tgt_seq_len, hidden_dim)
98 | If None, the source is used as target
99 |
100 | Returns:
101 | --------
102 | output : torch.Tensor
103 | Output tensor of shape (batch_size, tgt_seq_len, hidden_dim)
104 | """
105 | # If no target is provided, use source
106 | if tgt is None:
107 | tgt = src
108 |
109 | # Add positional encoding
110 | src = self.pos_encoding(src)
111 | tgt = self.pos_encoding(tgt)
112 |
113 | # Pass through encoder
114 | memory = self.encoder(src)
115 |
116 | # Pass through decoder
117 | output = self.decoder(tgt, memory)
118 |
119 | # Project to output dimension
120 | output = self.output_projection(output)
121 |
122 | return output
123 |
124 |
125 | class PositionalEncoding(nn.Module):
126 | """Positional encoding for Transformer models."""
127 |
128 | def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 100):
129 | """
130 | Initialize the positional encoding.
131 |
132 | Parameters:
133 | -----------
134 | d_model : int
135 | Dimension of the model
136 | dropout : float
137 | Dropout probability
138 | max_len : int
139 | Maximum sequence length
140 | """
141 | super().__init__()
142 | self.dropout = nn.Dropout(p=dropout)
143 |
144 | # Create positional encoding
145 | position = torch.arange(max_len).unsqueeze(1)
146 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
147 | pe = torch.zeros(1, max_len, d_model)
148 | pe[0, :, 0::2] = torch.sin(position * div_term)
149 | pe[0, :, 1::2] = torch.cos(position * div_term)
150 | self.register_buffer('pe', pe)
151 |
152 | def forward(self, x: torch.Tensor) -> torch.Tensor:
153 | """
154 | Forward pass.
155 |
156 | Parameters:
157 | -----------
158 | x : torch.Tensor
159 | Input tensor of shape (batch_size, seq_len, d_model)
160 |
161 | Returns:
162 | --------
163 | output : torch.Tensor
164 | Output tensor with positional encoding added
165 | """
166 | x = x + self.pe[:, :x.size(1), :]
167 | return self.dropout(x)
168 |
169 |
170 | class WarmStartPredictor(nn.Module):
171 | """
172 | Transformer-based model for predicting warm start solutions.
173 |
174 | This model takes QP problem features as input and outputs an approximate
175 | solution that can be used to warm start the QP solver.
176 | """
177 |
178 | def __init__(self,
179 | input_dim: int = 50,
180 | hidden_dim: int = 256,
181 | output_dim: int = 20,
182 | num_layers: int = 6,
183 | num_heads: int = 8,
184 | dropout: float = 0.1):
185 | """
186 | Initialize the warm start predictor model.
187 |
188 | Parameters:
189 | -----------
190 | input_dim : int
191 | Dimension of input features
192 | hidden_dim : int
193 | Dimension of hidden layers
194 | output_dim : int
195 | Dimension of output solution vector
196 | num_layers : int
197 | Number of transformer layers
198 | num_heads : int
199 | Number of attention heads
200 | dropout : float
201 | Dropout probability
202 | """
203 | super().__init__()
204 |
205 | self.input_dim = input_dim
206 | self.hidden_dim = hidden_dim
207 | self.output_dim = output_dim
208 | self.num_layers = num_layers
209 | self.num_heads = num_heads
210 | self.dropout = dropout
211 |
212 | # Input projection
213 | self.input_projection = nn.Linear(input_dim, hidden_dim)
214 |
215 | # Transformer model
216 | self.transformer = TransformerEncoderDecoder(
217 | hidden_dim=hidden_dim,
218 | output_dim=1, # We'll handle the output projection separately
219 | num_layers=num_layers,
220 | num_heads=num_heads,
221 | dropout=dropout
222 | )
223 |
224 | # Output layers with residual connections
225 | self.output_layer1 = nn.Linear(hidden_dim, hidden_dim)
226 | self.output_layer2 = nn.Linear(hidden_dim, hidden_dim)
227 | self.output_projection = nn.Linear(hidden_dim, output_dim)
228 |
229 | def forward(self, x: torch.Tensor) -> torch.Tensor:
230 | """
231 | Forward pass.
232 |
233 | Parameters:
234 | -----------
235 | x : torch.Tensor
236 | Input tensor of shape (batch_size, input_dim)
237 |
238 | Returns:
239 | --------
240 | output : torch.Tensor
241 | Output tensor of shape (batch_size, output_dim)
242 | containing approximate solution vector
243 | """
244 | # Input projection
245 | x = self.input_projection(x)
246 |
247 | # Reshape for transformer: [batch_size, seq_len, hidden_dim]
248 | # Here we use a sequence length of 1
249 | x = x.unsqueeze(1)
250 |
251 | # Pass through transformer
252 | transformer_output = self.transformer(x)
253 |
254 | # Extract first token of the output sequence
255 | first_token = transformer_output[:, 0, :]
256 |
257 | # Output layers with residual connections
258 | output = F.relu(self.output_layer1(first_token))
259 | output = output + first_token # Residual connection
260 | output = F.relu(self.output_layer2(output))
261 | output = output + first_token # Residual connection
262 |
263 | # Final output projection
264 | solution = self.output_projection(output)
265 |
266 | return solution
267 |
268 | def predict(self, x: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
269 | """
270 | Make prediction on input data.
271 |
272 | Parameters:
273 | -----------
274 | x : torch.Tensor or numpy.ndarray
275 | Input features
276 |
277 | Returns:
278 | --------
279 | prediction : numpy.ndarray
280 | Predicted solution vector
281 | """
282 | # Convert to tensor if numpy array
283 | if isinstance(x, np.ndarray):
284 | x = torch.tensor(x, dtype=torch.float32)
285 |
286 | # Make sure the input has the right shape
287 | if x.dim() == 1:
288 | x = x.unsqueeze(0) # Add batch dimension
289 |
290 | # Set model to evaluation mode
291 | self.eval()
292 |
293 | # Disable gradient computation
294 | with torch.no_grad():
295 | # Forward pass
296 | solution = self(x)
297 |
298 | # Convert to numpy
299 | solution = solution.cpu().numpy()
300 |
301 | return solution
302 |
303 | def save(self, filepath: str) -> None:
304 | """
305 | Save model to file.
306 |
307 | Parameters:
308 | -----------
309 | filepath : str
310 | Path to save the model
311 | """
312 | # Create directory if it doesn't exist
313 | os.makedirs(os.path.dirname(filepath), exist_ok=True)
314 |
315 | # Save model state
316 | torch.save({
317 | 'state_dict': self.state_dict(),
318 | 'input_dim': self.input_dim,
319 | 'hidden_dim': self.hidden_dim,
320 | 'output_dim': self.output_dim,
321 | 'num_layers': self.num_layers,
322 | 'num_heads': self.num_heads,
323 | 'dropout': self.dropout
324 | }, filepath)
325 |
326 | @classmethod
327 | def load(cls, filepath: Optional[str] = None) -> 'WarmStartPredictor':
328 | """
329 | Load model from file.
330 |
331 | Parameters:
332 | -----------
333 | filepath : str or None
334 | Path to load the model from, or None to create a new model
335 |
336 | Returns:
337 | --------
338 | model : WarmStartPredictor
339 | Loaded model
340 | """
341 | if filepath is None or not os.path.exists(filepath):
342 | # Return a new model with default parameters
343 | return cls()
344 |
345 | # Load model state
346 | checkpoint = torch.load(filepath, map_location=torch.device('cpu'))
347 |
348 | # Create model with saved parameters
349 | model = cls(
350 | input_dim=checkpoint['input_dim'],
351 | hidden_dim=checkpoint['hidden_dim'],
352 | output_dim=checkpoint['output_dim'],
353 | num_layers=checkpoint.get('num_layers', 6),
354 | num_heads=checkpoint.get('num_heads', 8),
355 | dropout=checkpoint.get('dropout', 0.1)
356 | )
357 |
358 | # Load state dictionary
359 | model.load_state_dict(checkpoint['state_dict'])
360 |
361 | return model
362 |
363 | def train_step(self,
364 | x: torch.Tensor,
365 | y: torch.Tensor,
366 | optimizer: torch.optim.Optimizer) -> Dict[str, float]:
367 | """
368 | Perform a single training step.
369 |
370 | Parameters:
371 | -----------
372 | x : torch.Tensor
373 | Input tensor of shape (batch_size, input_dim)
374 | y : torch.Tensor
375 | Target tensor of shape (batch_size, output_dim)
376 | optimizer : torch.optim.Optimizer
377 | Optimizer to use for the step
378 |
379 | Returns:
380 | --------
381 | metrics : Dict[str, float]
382 | Dictionary containing loss and error metrics
383 | """
384 | # Set model to training mode
385 | self.train()
386 |
387 | # Zero the gradients
388 | optimizer.zero_grad()
389 |
390 | # Forward pass
391 | solution = self(x)
392 |
393 | # Compute loss (mean squared error)
394 | mse_loss = F.mse_loss(solution, y)
395 |
396 | # Add L1 regularization for sparsity
397 | l1_loss = torch.mean(torch.abs(solution))
398 |
399 | # Combined loss
400 | loss = mse_loss + 0.01 * l1_loss
401 |
402 | # Backward pass
403 | loss.backward()
404 |
405 | # Update parameters
406 | optimizer.step()
407 |
408 | # Compute additional metrics
409 | mae = F.l1_loss(solution, y).item()
410 |
411 | # Compute relative error
412 | rel_error = torch.norm(solution - y, dim=1) / (torch.norm(y, dim=1) + 1e-8)
413 | rel_error = torch.mean(rel_error).item()
414 |
415 | return {
416 | 'loss': loss.item(),
417 | 'mse': mse_loss.item(),
418 | 'mae': mae,
419 | 'relative_error': rel_error
420 | }
421 |
422 | def validate(self,
423 | x: torch.Tensor,
424 | y: torch.Tensor) -> Dict[str, float]:
425 | """
426 | Validate the model on validation data.
427 |
428 | Parameters:
429 | -----------
430 | x : torch.Tensor
431 | Input tensor of shape (batch_size, input_dim)
432 | y : torch.Tensor
433 | Target tensor of shape (batch_size, output_dim)
434 |
435 | Returns:
436 | --------
437 | metrics : Dict[str, float]
438 | Dictionary containing loss and error metrics
439 | """
440 | # Set model to evaluation mode
441 | self.eval()
442 |
443 | # Disable gradient computation
444 | with torch.no_grad():
445 | # Forward pass
446 | solution = self(x)
447 |
448 | # Compute MSE loss
449 | mse_loss = F.mse_loss(solution, y)
450 |
451 | # Compute additional metrics
452 | mae = F.l1_loss(solution, y).item()
453 |
454 | # Compute relative error
455 | rel_error = torch.norm(solution - y, dim=1) / (torch.norm(y, dim=1) + 1e-8)
456 | rel_error = torch.mean(rel_error).item()
457 |
458 | return {
459 | 'loss': mse_loss.item(),
460 | 'mse': mse_loss.item(),
461 | 'mae': mae,
462 | 'relative_error': rel_error
463 | }
464 |
--------------------------------------------------------------------------------
/transformermpc/data/dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Dataset module for TransformerMPC.
3 |
4 | This module provides classes for creating, processing, and managing datasets
5 | for training and evaluating the transformer models.
6 | """
7 |
8 | import numpy as np
9 | import torch
10 | from torch.utils.data import Dataset, DataLoader, random_split
11 | from typing import List, Tuple, Dict, Optional, Union, Any
12 | import os
13 | import pickle
14 | import osqp
15 | import scipy.sparse as sparse
16 |
17 | from .qp_generator import QPProblem
18 | from ..utils.osqp_wrapper import OSQPSolver
19 |
20 |
21 | class QPDataset(Dataset):
22 | """
23 | Dataset class for QP problems.
24 |
25 | This class manages the datasets used for training and evaluating
26 | the transformer models. It handles data preprocessing, feature extraction,
27 | and creating input/target pairs for both transformer models.
28 | """
29 |
30 | def __init__(self,
31 | qp_problems: List[QPProblem],
32 | precompute_solutions: bool = True,
33 | max_constraints: Optional[int] = None,
34 | feature_normalization: bool = True,
35 | cache_dir: Optional[str] = None):
36 | """
37 | Initialize the QP dataset.
38 |
39 | Parameters:
40 | -----------
41 | qp_problems : List[QPProblem]
42 | List of QP problems
43 | precompute_solutions : bool
44 | Whether to precompute solutions and active constraints
45 | max_constraints : int or None
46 | Maximum number of constraints to consider (for padding)
47 | feature_normalization : bool
48 | Whether to normalize features
49 | cache_dir : str or None
50 | Directory to cache precomputed solutions
51 | """
52 | self.qp_problems = qp_problems
53 | self.precompute_solutions = precompute_solutions
54 | self.max_constraints = max_constraints or self._get_max_constraints()
55 | self.feature_normalization = feature_normalization
56 | self.cache_dir = cache_dir
57 |
58 | # Create cache directory if it doesn't exist
59 | if cache_dir is not None:
60 | os.makedirs(cache_dir, exist_ok=True)
61 |
62 | # Solver for computing solutions
63 | self.solver = OSQPSolver()
64 |
65 | # Precompute solutions if requested
66 | if precompute_solutions:
67 | self._precompute_all_solutions()
68 |
69 | # Compute feature statistics for normalization
70 | if feature_normalization:
71 | self._compute_feature_statistics()
72 |
73 | def _get_max_constraints(self) -> int:
74 | """
75 | Get maximum number of constraints across all problems.
76 |
77 | Returns:
78 | --------
79 | max_constraints : int
80 | Maximum number of constraints
81 | """
82 | return max(problem.n_constraints for problem in self.qp_problems)
83 |
84 | def _compute_feature_statistics(self) -> None:
85 | """
86 | Compute statistics for feature normalization.
87 | """
88 | # Extract raw features from each problem
89 | all_features = []
90 | for i in range(min(1000, len(self.qp_problems))): # Use subset for efficiency
91 | features = self._extract_raw_features(self.qp_problems[i])
92 | all_features.append(features)
93 |
94 | # Concatenate features
95 | all_features = np.vstack(all_features)
96 |
97 | # Compute mean and standard deviation
98 | self.feature_mean = np.mean(all_features, axis=0)
99 | self.feature_std = np.std(all_features, axis=0)
100 |
101 | # Replace zeros in std to avoid division by zero
102 | self.feature_std = np.where(self.feature_std < 1e-8, 1.0, self.feature_std)
103 |
104 | def _normalize_features(self, features: np.ndarray) -> np.ndarray:
105 | """
106 | Normalize features using precomputed statistics.
107 |
108 | Parameters:
109 | -----------
110 | features : numpy.ndarray
111 | Raw features
112 |
113 | Returns:
114 | --------
115 | normalized_features : numpy.ndarray
116 | Normalized features
117 | """
118 | if not self.feature_normalization:
119 | return features
120 |
121 | return (features - self.feature_mean) / self.feature_std
122 |
123 | def _extract_raw_features(self, problem: QPProblem) -> np.ndarray:
124 | """
125 | Extract raw features from a QP problem.
126 |
127 | Parameters:
128 | -----------
129 | problem : QPProblem
130 | QP problem instance
131 |
132 | Returns:
133 | --------
134 | features : numpy.ndarray
135 | Raw features
136 | """
137 | # Basic features: initial state and reference if available
138 | features = []
139 |
140 | if problem.initial_state is not None:
141 | features.append(problem.initial_state)
142 |
143 | if problem.reference is not None:
144 | features.append(problem.reference)
145 |
146 | # Add problem dimensions as features
147 | features.append(np.array([problem.n_variables, problem.n_constraints]))
148 |
149 | # Flatten and concatenate all features
150 | return np.concatenate([f.flatten() for f in features])
151 |
152 | def _precompute_all_solutions(self) -> None:
153 | """
154 | Precompute solutions and active constraints for all problems.
155 | """
156 | self.solutions = []
157 | self.active_constraints = []
158 |
159 | # Define cache file path if using caching
160 | cache_file = None
161 | if self.cache_dir is not None:
162 | cache_file = os.path.join(self.cache_dir, "qp_solutions_cache.pkl")
163 |
164 | # Load from cache if it exists
165 | if os.path.exists(cache_file):
166 | with open(cache_file, 'rb') as f:
167 | cache_data = pickle.load(f)
168 | self.solutions = cache_data['solutions']
169 | self.active_constraints = cache_data['active_constraints']
170 | return
171 |
172 | # Solve all problems and identify active constraints
173 | for i, problem in enumerate(self.qp_problems):
174 | # Solve the QP problem
175 | solution = self.solver.solve(
176 | Q=problem.Q,
177 | c=problem.c,
178 | A=problem.A,
179 | b=problem.b
180 | )
181 |
182 | # Store the solution
183 | self.solutions.append(solution)
184 |
185 | # Identify active constraints
186 | active = self._identify_active_constraints(problem, solution)
187 | self.active_constraints.append(active)
188 |
189 | # Save to cache if using caching
190 | if cache_file is not None:
191 | with open(cache_file, 'wb') as f:
192 | pickle.dump(
193 | {
194 | 'solutions': self.solutions,
195 | 'active_constraints': self.active_constraints
196 | },
197 | f
198 | )
199 |
200 | def _identify_active_constraints(self,
201 | problem: QPProblem,
202 | solution: np.ndarray,
203 | tol: float = 1e-6) -> np.ndarray:
204 | """
205 | Identify active constraints in a QP solution.
206 |
207 | Parameters:
208 | -----------
209 | problem : QPProblem
210 | QP problem instance
211 | solution : numpy.ndarray
212 | Solution vector
213 | tol : float
214 | Tolerance for identifying active constraints
215 |
216 | Returns:
217 | --------
218 | active : numpy.ndarray
219 | Binary vector indicating active constraints
220 | """
221 | # Compute constraint values: A * x - b
222 | constraint_values = problem.A @ solution - problem.b
223 |
224 | # Identify active constraints (those within tolerance of the boundary)
225 | active = np.abs(constraint_values) < tol
226 |
227 | # Pad or truncate to max_constraints
228 | if self.max_constraints is not None:
229 | if len(active) > self.max_constraints:
230 | active = active[:self.max_constraints]
231 | elif len(active) < self.max_constraints:
232 | padding = np.zeros(self.max_constraints - len(active), dtype=bool)
233 | active = np.concatenate([active, padding])
234 |
235 | return active.astype(np.float32)
236 |
237 | def __len__(self) -> int:
238 | """
239 | Get the number of problems in the dataset.
240 |
241 | Returns:
242 | --------
243 | length : int
244 | Number of problems
245 | """
246 | return len(self.qp_problems)
247 |
248 | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
249 | """
250 | Get a single item from the dataset.
251 |
252 | Parameters:
253 | -----------
254 | idx : int
255 | Index of the item
256 |
257 | Returns:
258 | --------
259 | item : Dict[str, torch.Tensor]
260 | Dictionary containing:
261 | - 'features': Input features for the models
262 | - 'active_constraints': Target for constraint predictor
263 | - 'solution': Target for warm start predictor
264 | """
265 | # Get the problem
266 | problem = self.qp_problems[idx]
267 |
268 | # Extract and normalize features
269 | features = self._extract_raw_features(problem)
270 | features = self._normalize_features(features)
271 |
272 | # Get solution and active constraints
273 | if self.precompute_solutions:
274 | solution = self.solutions[idx]
275 | active_constraints = self.active_constraints[idx]
276 | else:
277 | # Solve on-the-fly
278 | solution = self.solver.solve(
279 | Q=problem.Q,
280 | c=problem.c,
281 | A=problem.A,
282 | b=problem.b
283 | )
284 | active_constraints = self._identify_active_constraints(problem, solution)
285 |
286 | # Convert to torch tensors
287 | features_tensor = torch.tensor(features, dtype=torch.float32)
288 | active_constraints_tensor = torch.tensor(active_constraints, dtype=torch.float32)
289 | solution_tensor = torch.tensor(solution, dtype=torch.float32)
290 |
291 | return {
292 | 'features': features_tensor,
293 | 'active_constraints': active_constraints_tensor,
294 | 'solution': solution_tensor,
295 | 'problem_idx': idx # Include the problem index for reference
296 | }
297 |
298 | def get_problem(self, idx: int) -> QPProblem:
299 | """
300 | Get the original QP problem at the given index.
301 |
302 | Parameters:
303 | -----------
304 | idx : int
305 | Index of the problem
306 |
307 | Returns:
308 | --------
309 | problem : QPProblem
310 | QP problem instance
311 | """
312 | return self.qp_problems[idx]
313 |
314 | def get_dataloaders(self,
315 | batch_size: int = 32,
316 | val_split: float = 0.2,
317 | test_split: float = 0.1,
318 | shuffle: bool = True,
319 | num_workers: int = 4) -> Tuple[DataLoader, DataLoader, DataLoader]:
320 | """
321 | Create train, validation, and test dataloaders.
322 |
323 | Parameters:
324 | -----------
325 | batch_size : int
326 | Batch size
327 | val_split : float
328 | Fraction of data to use for validation
329 | test_split : float
330 | Fraction of data to use for testing
331 | shuffle : bool
332 | Whether to shuffle the data
333 | num_workers : int
334 | Number of workers for data loading
335 |
336 | Returns:
337 | --------
338 | train_loader : DataLoader
339 | Training data loader
340 | val_loader : DataLoader
341 | Validation data loader
342 | test_loader : DataLoader
343 | Test data loader
344 | """
345 | # Calculate dataset sizes
346 | dataset_size = len(self)
347 | test_size = int(dataset_size * test_split)
348 | val_size = int(dataset_size * val_split)
349 | train_size = dataset_size - val_size - test_size
350 |
351 | # Split the dataset
352 | train_dataset, val_dataset, test_dataset = random_split(
353 | self, [train_size, val_size, test_size],
354 | generator=torch.Generator().manual_seed(42)
355 | )
356 |
357 | # Create dataloaders
358 | train_loader = DataLoader(
359 | train_dataset,
360 | batch_size=batch_size,
361 | shuffle=shuffle,
362 | num_workers=num_workers
363 | )
364 |
365 | val_loader = DataLoader(
366 | val_dataset,
367 | batch_size=batch_size,
368 | shuffle=False,
369 | num_workers=num_workers
370 | )
371 |
372 | test_loader = DataLoader(
373 | test_dataset,
374 | batch_size=batch_size,
375 | shuffle=False,
376 | num_workers=num_workers
377 | )
378 |
379 | return train_loader, val_loader, test_loader
380 |
381 | def split(self, test_size: float = 0.2, seed: int = 42) -> Tuple['QPDataset', 'QPDataset']:
382 | """
383 | Split the dataset into training and test sets.
384 |
385 | Parameters:
386 | -----------
387 | test_size : float
388 | Fraction of data to use for testing
389 | seed : int
390 | Random seed
391 |
392 | Returns:
393 | --------
394 | train_dataset : QPDataset
395 | Training dataset
396 | test_dataset : QPDataset
397 | Test dataset
398 | """
399 | # Set random seed for reproducibility
400 | np.random.seed(seed)
401 |
402 | # Shuffle indices
403 | indices = np.arange(len(self.qp_problems))
404 | np.random.shuffle(indices)
405 |
406 | # Calculate split sizes
407 | test_idx = int(len(indices) * (1 - test_size))
408 |
409 | # Split indices
410 | train_indices = indices[:test_idx]
411 | test_indices = indices[test_idx:]
412 |
413 | # Create new datasets
414 | train_problems = [self.qp_problems[i] for i in train_indices]
415 | test_problems = [self.qp_problems[i] for i in test_indices]
416 |
417 | # Create new dataset objects
418 | train_dataset = QPDataset(
419 | train_problems,
420 | precompute_solutions=self.precompute_solutions,
421 | max_constraints=self.max_constraints,
422 | feature_normalization=self.feature_normalization,
423 | cache_dir=self.cache_dir
424 | )
425 |
426 | test_dataset = QPDataset(
427 | test_problems,
428 | precompute_solutions=self.precompute_solutions,
429 | max_constraints=self.max_constraints,
430 | feature_normalization=self.feature_normalization,
431 | cache_dir=self.cache_dir
432 | )
433 |
434 | # Copy feature statistics to ensure consistent normalization
435 | if self.feature_normalization:
436 | test_dataset.feature_mean = self.feature_mean
437 | test_dataset.feature_std = self.feature_std
438 |
439 | return train_dataset, test_dataset
440 |
--------------------------------------------------------------------------------
/scripts/simple_demo.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | TransformerMPC Simple Demo
4 |
5 | This script provides a complete end-to-end demonstration of the TransformerMPC package:
6 | 1. Generates a set of quadratic programming (QP) problems
7 | 2. Creates training and test datasets
8 | 3. Trains both transformer models (constraint predictor and warm start predictor)
9 | 4. Tests the models on the test set
10 | 5. Plots performance comparisons including box plots
11 |
12 | The script is designed to run quickly with a small number of samples and epochs,
13 | but can be modified for more comprehensive training.
14 | """
15 |
16 | import os
17 | import time
18 | import argparse
19 | import numpy as np
20 | import matplotlib.pyplot as plt
21 | from pathlib import Path
22 | import torch
23 | from tqdm import tqdm
24 |
25 | # Fix potential serialization issues with numpy scalars
26 | import torch.serialization
27 | torch.serialization.add_safe_globals(['numpy.core.multiarray.scalar'])
28 |
29 | # Patch torch.load to handle weights_only parameter
30 | original_torch_load = torch.load
31 | def patched_torch_load(f, *args, **kwargs):
32 | if 'weights_only' not in kwargs:
33 | kwargs['weights_only'] = False
34 | return original_torch_load(f, *args, **kwargs)
35 | torch.load = patched_torch_load
36 |
37 | # Import TransformerMPC modules
38 | from transformermpc.data.qp_generator import QPGenerator
39 | from transformermpc.data.dataset import QPDataset
40 | from transformermpc.models.constraint_predictor import ConstraintPredictor
41 | from transformermpc.models.warm_start_predictor import WarmStartPredictor
42 | from transformermpc.utils.osqp_wrapper import OSQPSolver
43 | from transformermpc.utils.metrics import compute_solve_time_metrics, compute_fallback_rate
44 | from transformermpc.utils.visualization import (
45 | plot_solve_time_comparison,
46 | plot_solve_time_boxplot
47 | )
48 |
49 | def parse_args():
50 | """Parse command line arguments."""
51 | parser = argparse.ArgumentParser(description="TransformerMPC Simple Demo")
52 |
53 | # Data generation parameters
54 | parser.add_argument("--num_samples", type=int, default=100,
55 | help="Number of QP problems to generate (default: 100)")
56 | parser.add_argument("--state_dim", type=int, default=4,
57 | help="State dimension for MPC problems (default: 4)")
58 | parser.add_argument("--input_dim", type=int, default=2,
59 | help="Input dimension for MPC problems (default: 2)")
60 | parser.add_argument("--horizon", type=int, default=5,
61 | help="Time horizon for MPC problems (default: 5)")
62 |
63 | # Training parameters
64 | parser.add_argument("--epochs", type=int, default=5,
65 | help="Number of epochs for training (default: 5)")
66 | parser.add_argument("--batch_size", type=int, default=16,
67 | help="Batch size for training (default: 16)")
68 | parser.add_argument("--test_size", type=float, default=0.2,
69 | help="Fraction of data to use for testing (default: 0.2)")
70 |
71 | # Other parameters
72 | parser.add_argument("--output_dir", type=str, default="demo_results",
73 | help="Directory to save results (default: demo_results)")
74 | parser.add_argument("--use_gpu", action="store_true",
75 | help="Use GPU if available")
76 | parser.add_argument("--test_problems", type=int, default=10,
77 | help="Number of test problems for evaluation (default: 10)")
78 |
79 | return parser.parse_args()
80 |
81 | def train_model(model, train_data, val_data, num_epochs, batch_size, lr=1e-3):
82 | """Simple training loop for the models"""
83 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
84 | model = model.to(device)
85 | optimizer = torch.optim.Adam(model.parameters(), lr=lr)
86 |
87 | train_losses = []
88 | val_losses = []
89 |
90 | # Create data loaders
91 | train_loader = torch.utils.data.DataLoader(
92 | train_data, batch_size=batch_size, shuffle=True
93 | )
94 | val_loader = torch.utils.data.DataLoader(
95 | val_data, batch_size=batch_size, shuffle=False
96 | )
97 |
98 | print(f"Training for {num_epochs} epochs...")
99 | for epoch in range(num_epochs):
100 | # Training
101 | model.train()
102 | epoch_loss = 0
103 | for batch in train_loader:
104 | features = batch['features'].to(device)
105 |
106 | if isinstance(model, ConstraintPredictor):
107 | targets = batch['active_constraints'].to(device)
108 | optimizer.zero_grad()
109 | outputs = model(features)
110 | loss = torch.nn.functional.binary_cross_entropy_with_logits(outputs, targets)
111 | else:
112 | targets = batch['solution'].to(device)
113 | optimizer.zero_grad()
114 | outputs = model(features)
115 | loss = torch.nn.functional.mse_loss(outputs, targets)
116 |
117 | loss.backward()
118 | optimizer.step()
119 | epoch_loss += loss.item()
120 |
121 | avg_epoch_loss = epoch_loss / len(train_loader)
122 | train_losses.append(avg_epoch_loss)
123 |
124 | # Validation
125 | model.eval()
126 | val_loss = 0
127 | with torch.no_grad():
128 | for batch in val_loader:
129 | features = batch['features'].to(device)
130 |
131 | if isinstance(model, ConstraintPredictor):
132 | targets = batch['active_constraints'].to(device)
133 | outputs = model(features)
134 | loss = torch.nn.functional.binary_cross_entropy_with_logits(outputs, targets)
135 | else:
136 | targets = batch['solution'].to(device)
137 | outputs = model(features)
138 | loss = torch.nn.functional.mse_loss(outputs, targets)
139 |
140 | val_loss += loss.item()
141 |
142 | avg_val_loss = val_loss / len(val_loader)
143 | val_losses.append(avg_val_loss)
144 |
145 | print(f"Epoch {epoch+1}/{num_epochs}: Train Loss = {avg_epoch_loss:.6f}, Val Loss = {avg_val_loss:.6f}")
146 |
147 | return {
148 | 'train_loss': train_losses,
149 | 'val_loss': val_losses
150 | }
151 |
152 | def main():
153 | """Run the demo workflow."""
154 | # Parse command line arguments
155 | args = parse_args()
156 |
157 | print("=" * 60)
158 | print("TransformerMPC Simple Demo".center(60))
159 | print("=" * 60)
160 |
161 | # Create output directory
162 | output_dir = Path(args.output_dir)
163 | output_dir.mkdir(parents=True, exist_ok=True)
164 |
165 | # Set up directories
166 | models_dir = output_dir / "models"
167 | results_dir = output_dir / "results"
168 |
169 | for directory in [models_dir, results_dir]:
170 | directory.mkdir(exist_ok=True)
171 |
172 | # Set device
173 | device = torch.device('cuda' if torch.cuda.is_available() and args.use_gpu else 'cpu')
174 | print(f"Using device: {device}")
175 |
176 | # Step 1: Generate QP problems
177 | print("\nStep 1: Generating QP problems")
178 | print("-" * 60)
179 |
180 | print(f"Generating {args.num_samples} QP problems")
181 | generator = QPGenerator(
182 | state_dim=args.state_dim,
183 | input_dim=args.input_dim,
184 | horizon=args.horizon,
185 | num_samples=args.num_samples
186 | )
187 | qp_problems = generator.generate()
188 | print(f"Generated {len(qp_problems)} QP problems")
189 |
190 | # Step 2: Create dataset and split into train/test
191 | print("\nStep 2: Creating datasets")
192 | print("-" * 60)
193 |
194 | dataset = QPDataset(
195 | qp_problems=qp_problems,
196 | precompute_solutions=True,
197 | feature_normalization=True
198 | )
199 |
200 | train_dataset, test_dataset = dataset.split(test_size=args.test_size)
201 | print(f"Created datasets - Train: {len(train_dataset)}, Test: {len(test_dataset)}")
202 |
203 | # Step 3: Train constraint predictor
204 | print("\nStep 3: Training constraint predictor")
205 | print("-" * 60)
206 |
207 | # Get input dimension from the dataset
208 | sample_item = train_dataset[0]
209 | input_dim = sample_item['features'].shape[0]
210 | num_constraints = sample_item['active_constraints'].shape[0]
211 |
212 | cp_model = ConstraintPredictor(
213 | input_dim=input_dim,
214 | hidden_dim=64, # Smaller model for faster training
215 | num_constraints=num_constraints,
216 | num_layers=2 # Fewer layers for faster training
217 | )
218 |
219 | # Train constraint predictor with simplified training loop
220 | cp_history = train_model(
221 | model=cp_model,
222 | train_data=train_dataset,
223 | val_data=test_dataset,
224 | num_epochs=args.epochs,
225 | batch_size=args.batch_size
226 | )
227 |
228 | # Save model
229 | cp_model_file = models_dir / "constraint_predictor.pt"
230 | torch.save(cp_model.state_dict(), cp_model_file)
231 | print(f"Saved constraint predictor to {cp_model_file}")
232 |
233 | # Step 4: Train warm start predictor
234 | print("\nStep 4: Training warm start predictor")
235 | print("-" * 60)
236 |
237 | # Get input dimension from the dataset
238 | output_dim = sample_item['solution'].shape[0]
239 |
240 | ws_model = WarmStartPredictor(
241 | input_dim=input_dim,
242 | hidden_dim=128, # Smaller model for faster training
243 | output_dim=output_dim,
244 | num_layers=2 # Fewer layers for faster training
245 | )
246 |
247 | # Train warm start predictor with simplified training loop
248 | ws_history = train_model(
249 | model=ws_model,
250 | train_data=train_dataset,
251 | val_data=test_dataset,
252 | num_epochs=args.epochs,
253 | batch_size=args.batch_size
254 | )
255 |
256 | # Save model
257 | ws_model_file = models_dir / "warm_start_predictor.pt"
258 | torch.save(ws_model.state_dict(), ws_model_file)
259 | print(f"Saved warm start predictor to {ws_model_file}")
260 |
261 | # Step 5: Test on a subset of problems
262 | print("\nStep 5: Performance testing")
263 | print("-" * 60)
264 |
265 | solver = OSQPSolver()
266 |
267 | # Lists to store results
268 | baseline_times = []
269 | transformer_times = []
270 | constraint_only_times = []
271 | warmstart_only_times = []
272 | fallback_flags = []
273 |
274 | # Test on a small subset for demonstration
275 | num_test_problems = min(args.test_problems, len(test_dataset))
276 | test_subset = np.random.choice(len(test_dataset), size=num_test_problems, replace=False)
277 |
278 | print(f"Testing on {num_test_problems} problems...")
279 | for idx in tqdm(test_subset):
280 | # Get problem
281 | sample = test_dataset[idx]
282 | problem = test_dataset.get_problem(idx)
283 |
284 | # Get features
285 | features = sample['features']
286 |
287 | # Predict active constraints and warm start
288 | with torch.no_grad():
289 | cp_output = cp_model(torch.tensor(features, dtype=torch.float32).unsqueeze(0))
290 | pred_active = (torch.sigmoid(cp_output) > 0.5).float().squeeze(0).numpy()
291 |
292 | ws_output = ws_model(torch.tensor(features, dtype=torch.float32).unsqueeze(0))
293 | pred_solution = ws_output.squeeze(0).numpy()
294 |
295 | # 1. Baseline (OSQP without transformers)
296 | _, baseline_time = solver.solve_with_time(
297 | Q=problem.Q,
298 | c=problem.c,
299 | A=problem.A,
300 | b=problem.b
301 | )
302 | baseline_times.append(baseline_time)
303 |
304 | # 2. Constraint-only
305 | _, constraint_time, _ = solver.solve_pipeline(
306 | Q=problem.Q,
307 | c=problem.c,
308 | A=problem.A,
309 | b=problem.b,
310 | active_constraints=pred_active,
311 | warm_start=None,
312 | fallback_on_violation=True
313 | )
314 | constraint_only_times.append(constraint_time)
315 |
316 | # 3. Warm-start-only
317 | _, warmstart_time = solver.solve_with_time(
318 | Q=problem.Q,
319 | c=problem.c,
320 | A=problem.A,
321 | b=problem.b,
322 | warm_start=pred_solution
323 | )
324 | warmstart_only_times.append(warmstart_time)
325 |
326 | # 4. Full transformer pipeline
327 | _, transformer_time, used_fallback = solver.solve_pipeline(
328 | Q=problem.Q,
329 | c=problem.c,
330 | A=problem.A,
331 | b=problem.b,
332 | active_constraints=pred_active,
333 | warm_start=pred_solution,
334 | fallback_on_violation=True
335 | )
336 | transformer_times.append(transformer_time)
337 | fallback_flags.append(used_fallback)
338 |
339 | # Convert to numpy arrays
340 | baseline_times = np.array(baseline_times)
341 | transformer_times = np.array(transformer_times)
342 | constraint_only_times = np.array(constraint_only_times)
343 | warmstart_only_times = np.array(warmstart_only_times)
344 |
345 | # Compute and print metrics
346 | solve_metrics = compute_solve_time_metrics(baseline_times, transformer_times)
347 | fallback_rate = compute_fallback_rate(fallback_flags)
348 |
349 | print("\nPerformance Results:")
350 | print("-" * 60)
351 | print(f"Mean baseline time: {solve_metrics['mean_baseline_time']:.6f}s")
352 | print(f"Mean transformer time: {solve_metrics['mean_transformer_time']:.6f}s")
353 | print(f"Mean constraint-only time: {np.mean(constraint_only_times):.6f}s")
354 | print(f"Mean warm-start-only time: {np.mean(warmstart_only_times):.6f}s")
355 | print(f"Mean speedup: {solve_metrics['mean_speedup']:.2f}x")
356 | print(f"Median speedup: {solve_metrics['median_speedup']:.2f}x")
357 | print(f"Fallback rate: {fallback_rate:.2f}%")
358 |
359 | # Step 6: Generate visualizations
360 | print("\nStep 6: Generating visualizations")
361 | print("-" * 60)
362 |
363 | # Plot solve time comparison
364 | print("Generating solve time comparison plot...")
365 | plot_solve_time_comparison(
366 | baseline_times=baseline_times,
367 | transformer_times=transformer_times,
368 | save_path=results_dir / "solve_time_comparison.png"
369 | )
370 |
371 | # Plot solve time boxplot
372 | print("Generating solve time boxplot...")
373 | plot_solve_time_boxplot(
374 | baseline_times=baseline_times,
375 | transformer_times=transformer_times,
376 | constraint_only_times=constraint_only_times,
377 | warmstart_only_times=warmstart_only_times,
378 | save_path=results_dir / "solve_time_boxplot.png"
379 | )
380 |
381 | # Plot performance violin plot
382 | print("Generating performance violin plot...")
383 | plt.figure(figsize=(10, 6))
384 | plt.violinplot(
385 | [baseline_times, constraint_only_times, warmstart_only_times, transformer_times],
386 | showmeans=True
387 | )
388 | plt.xticks([1, 2, 3, 4], ['Baseline', 'Constraint-only', 'WarmStart-only', 'Full Pipeline'])
389 | plt.ylabel('Solve Time (s)')
390 | plt.title('QP Solve Time Comparison')
391 | plt.grid(True, alpha=0.3)
392 | plt.savefig(results_dir / "performance_violin.png", dpi=300)
393 |
394 | # Plot training history
395 | print("Generating training history plots...")
396 |
397 | # Plot training loss
398 | plt.figure(figsize=(10, 5))
399 | plt.subplot(1, 2, 1)
400 | plt.plot(cp_history['train_loss'], label='Train Loss')
401 | plt.plot(cp_history['val_loss'], label='Validation Loss')
402 | plt.xlabel('Epoch')
403 | plt.ylabel('Loss')
404 | plt.title('Constraint Predictor Training Loss')
405 | plt.legend()
406 | plt.grid(alpha=0.3)
407 |
408 | plt.subplot(1, 2, 2)
409 | plt.plot(ws_history['train_loss'], label='Train Loss')
410 | plt.plot(ws_history['val_loss'], label='Validation Loss')
411 | plt.xlabel('Epoch')
412 | plt.ylabel('Loss')
413 | plt.title('Warm Start Predictor Training Loss')
414 | plt.legend()
415 | plt.grid(alpha=0.3)
416 |
417 | plt.tight_layout()
418 | plt.savefig(results_dir / "training_history.png", dpi=300)
419 |
420 | # Create a summary plot with box plots
421 | print("Generating summary box plot...")
422 | plt.figure(figsize=(12, 8))
423 |
424 | # Create box plot
425 | plt.boxplot(
426 | [baseline_times, transformer_times, constraint_only_times, warmstart_only_times],
427 | labels=['Baseline', 'Full Pipeline', 'Constraint-only', 'Warm Start-only'],
428 | showmeans=True
429 | )
430 |
431 | plt.ylabel('Solve Time (s)')
432 | plt.title('QP Solve Time Comparison (Box Plot)')
433 | plt.grid(True, axis='y', alpha=0.3)
434 |
435 | # Add mean value annotations
436 | means = [
437 | np.mean(baseline_times),
438 | np.mean(transformer_times),
439 | np.mean(constraint_only_times),
440 | np.mean(warmstart_only_times)
441 | ]
442 |
443 | for i, mean in enumerate(means):
444 | plt.text(i+1, mean, f'{mean:.6f}s',
445 | horizontalalignment='center',
446 | verticalalignment='bottom',
447 | fontweight='bold')
448 |
449 | plt.savefig(results_dir / "summary_boxplot.png", dpi=300)
450 |
451 | print(f"\nResults and visualizations saved to {output_dir}")
452 | print("\nDemo completed successfully!")
453 | print("=" * 60)
454 |
455 | if __name__ == "__main__":
456 | main()
--------------------------------------------------------------------------------