├── framework ├── __init__.py ├── models │ ├── __init__.py │ └── score_detector │ │ └── crnn_ctc.py ├── ScoreDetectorConfig.py ├── ControlDevice.py ├── Actions.py ├── ControlDeviceCfg.py ├── ScreenDetectorFixed.py ├── Joystick.py ├── HIDDevice.py ├── Logger.py ├── CameraUtils.py ├── ScreenDetector.py ├── Keyboard.py ├── RoboTroller.py ├── ScoreDetector.py ├── MCCDAQDevice.py └── v4l2_defs.py ├── requirements-dev.txt ├── assets ├── models │ ├── atlantis_score.pt │ ├── krull_score_lives.pt │ ├── qbert_score_lives.pt │ ├── defender_score_lives.pt │ ├── battle_zone_score_lives.pt │ ├── centipede_score_lives.pt │ ├── ms_pacman_score_lives.pt │ └── up_n_down_score_lives.pt └── images │ └── screen │ ├── sil-krull.png │ ├── sil-qbert.png │ ├── sil-atlantis.png │ ├── sil-defender.png │ ├── sil-battle_zone.png │ ├── sil-centipede.png │ ├── sil-ms_pacman.png │ └── sil-up_n_down.png ├── docs ├── images │ └── score_detection │ │ ├── krull_lives_box.png │ │ ├── krull_score_box.png │ │ ├── qbert_lives_box.png │ │ ├── qbert_score_box.png │ │ ├── atlantis_score_box.png │ │ ├── defender_lives_box.png │ │ ├── defender_score_box.png │ │ ├── centipede_lives_box.png │ │ ├── centipede_score_box.png │ │ ├── ms_pacman_lives_box.png │ │ ├── ms_pacman_score_box.png │ │ ├── up_n_down_lives_box.png │ │ ├── up_n_down_score_box.png │ │ ├── battle_zone_lives_box.png │ │ └── battle_zone_score_box.png ├── profiling.md └── io_controller.md ├── train └── score_detector │ ├── requirements.txt │ ├── checkpoint_viewer.py │ ├── generate_dataset.py │ ├── dataset.py │ └── ale_ram_injection.py ├── configs ├── controllers │ ├── robotroller.json │ └── io_controller.json ├── screen_detection │ ├── april_tags.json │ └── fixed.json ├── games │ ├── atlantis.json │ ├── qbert.json │ ├── up_n_down.json │ ├── defender.json │ ├── battle_zone.json │ ├── krull.json │ ├── centipede.json │ └── ms_pacman.json └── cameras │ └── camera_kiyo_pro.json ├── .gitignore ├── requirements.txt ├── .editorconfig ├── .gitattributes ├── pyproject.toml ├── docker_build.sh ├── .pre-commit-config.yaml ├── env_base.py ├── scripts ├── format_code.sh ├── lambda-hold.sh ├── performance │ ├── nvidia-persistence.sh │ ├── cpu-governor.sh │ ├── nvidia-powerd.sh │ └── power-profile.sh └── plot_data.py ├── CONTRIBUTING.md ├── docker_run.sh ├── tests ├── test_controller.py └── test_camera.py ├── agent_random.py ├── Dockerfile ├── QUICKSTART.md ├── LICENSE └── README.md /framework/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /framework/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | pre-commit==4.2.0 2 | ruff==0.1.15 3 | black==24.4.2 4 | isort==5.13.2 5 | pyupgrade==3.20.0 6 | -------------------------------------------------------------------------------- /assets/models/atlantis_score.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Keen-Technologies/physical_atari/HEAD/assets/models/atlantis_score.pt -------------------------------------------------------------------------------- /assets/images/screen/sil-krull.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Keen-Technologies/physical_atari/HEAD/assets/images/screen/sil-krull.png -------------------------------------------------------------------------------- /assets/images/screen/sil-qbert.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Keen-Technologies/physical_atari/HEAD/assets/images/screen/sil-qbert.png -------------------------------------------------------------------------------- /assets/models/krull_score_lives.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Keen-Technologies/physical_atari/HEAD/assets/models/krull_score_lives.pt -------------------------------------------------------------------------------- /assets/models/qbert_score_lives.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Keen-Technologies/physical_atari/HEAD/assets/models/qbert_score_lives.pt -------------------------------------------------------------------------------- /assets/images/screen/sil-atlantis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Keen-Technologies/physical_atari/HEAD/assets/images/screen/sil-atlantis.png -------------------------------------------------------------------------------- /assets/images/screen/sil-defender.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Keen-Technologies/physical_atari/HEAD/assets/images/screen/sil-defender.png -------------------------------------------------------------------------------- /assets/models/defender_score_lives.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Keen-Technologies/physical_atari/HEAD/assets/models/defender_score_lives.pt -------------------------------------------------------------------------------- /assets/images/screen/sil-battle_zone.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Keen-Technologies/physical_atari/HEAD/assets/images/screen/sil-battle_zone.png -------------------------------------------------------------------------------- /assets/images/screen/sil-centipede.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Keen-Technologies/physical_atari/HEAD/assets/images/screen/sil-centipede.png -------------------------------------------------------------------------------- /assets/images/screen/sil-ms_pacman.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Keen-Technologies/physical_atari/HEAD/assets/images/screen/sil-ms_pacman.png -------------------------------------------------------------------------------- /assets/images/screen/sil-up_n_down.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Keen-Technologies/physical_atari/HEAD/assets/images/screen/sil-up_n_down.png -------------------------------------------------------------------------------- /assets/models/battle_zone_score_lives.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Keen-Technologies/physical_atari/HEAD/assets/models/battle_zone_score_lives.pt -------------------------------------------------------------------------------- /assets/models/centipede_score_lives.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Keen-Technologies/physical_atari/HEAD/assets/models/centipede_score_lives.pt -------------------------------------------------------------------------------- /assets/models/ms_pacman_score_lives.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Keen-Technologies/physical_atari/HEAD/assets/models/ms_pacman_score_lives.pt -------------------------------------------------------------------------------- /assets/models/up_n_down_score_lives.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Keen-Technologies/physical_atari/HEAD/assets/models/up_n_down_score_lives.pt -------------------------------------------------------------------------------- /docs/images/score_detection/krull_lives_box.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Keen-Technologies/physical_atari/HEAD/docs/images/score_detection/krull_lives_box.png -------------------------------------------------------------------------------- /docs/images/score_detection/krull_score_box.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Keen-Technologies/physical_atari/HEAD/docs/images/score_detection/krull_score_box.png -------------------------------------------------------------------------------- /docs/images/score_detection/qbert_lives_box.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Keen-Technologies/physical_atari/HEAD/docs/images/score_detection/qbert_lives_box.png -------------------------------------------------------------------------------- /docs/images/score_detection/qbert_score_box.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Keen-Technologies/physical_atari/HEAD/docs/images/score_detection/qbert_score_box.png -------------------------------------------------------------------------------- /docs/images/score_detection/atlantis_score_box.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Keen-Technologies/physical_atari/HEAD/docs/images/score_detection/atlantis_score_box.png -------------------------------------------------------------------------------- /docs/images/score_detection/defender_lives_box.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Keen-Technologies/physical_atari/HEAD/docs/images/score_detection/defender_lives_box.png -------------------------------------------------------------------------------- /docs/images/score_detection/defender_score_box.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Keen-Technologies/physical_atari/HEAD/docs/images/score_detection/defender_score_box.png -------------------------------------------------------------------------------- /docs/images/score_detection/centipede_lives_box.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Keen-Technologies/physical_atari/HEAD/docs/images/score_detection/centipede_lives_box.png -------------------------------------------------------------------------------- /docs/images/score_detection/centipede_score_box.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Keen-Technologies/physical_atari/HEAD/docs/images/score_detection/centipede_score_box.png -------------------------------------------------------------------------------- /docs/images/score_detection/ms_pacman_lives_box.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Keen-Technologies/physical_atari/HEAD/docs/images/score_detection/ms_pacman_lives_box.png -------------------------------------------------------------------------------- /docs/images/score_detection/ms_pacman_score_box.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Keen-Technologies/physical_atari/HEAD/docs/images/score_detection/ms_pacman_score_box.png -------------------------------------------------------------------------------- /docs/images/score_detection/up_n_down_lives_box.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Keen-Technologies/physical_atari/HEAD/docs/images/score_detection/up_n_down_lives_box.png -------------------------------------------------------------------------------- /docs/images/score_detection/up_n_down_score_box.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Keen-Technologies/physical_atari/HEAD/docs/images/score_detection/up_n_down_score_box.png -------------------------------------------------------------------------------- /docs/images/score_detection/battle_zone_lives_box.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Keen-Technologies/physical_atari/HEAD/docs/images/score_detection/battle_zone_lives_box.png -------------------------------------------------------------------------------- /docs/images/score_detection/battle_zone_score_box.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Keen-Technologies/physical_atari/HEAD/docs/images/score_detection/battle_zone_score_box.png -------------------------------------------------------------------------------- /train/score_detector/requirements.txt: -------------------------------------------------------------------------------- 1 | # pip install -r requirements.txt 2 | # assumes top level requirements has been installed 3 | numpy>=1.24.4,<2 4 | albumentations==1.3.1 5 | editdistance==0.8.1 6 | pygame==2.6.1 7 | -------------------------------------------------------------------------------- /configs/controllers/robotroller.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "QinHeng Electronics USB Single Serial", 3 | "vendor_id": "0x1a86", 4 | "product_id": "0x55d3", 5 | "port_name": "/dev/ttyACM0", 6 | "baud_rate": 2000000 7 | } 8 | -------------------------------------------------------------------------------- /configs/controllers/io_controller.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "MCC USB-1024LS", 3 | "vendor_id": "0x09db", 4 | "product_id": "0x0076", 5 | "pin_to_action_str_map": { 6 | "24": "UP", 7 | "25": "DOWN", 8 | "26": "RIGHT", 9 | "27": "LEFT", 10 | "28": "FIRE" 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # byte-compiled / cache files 2 | __pycache__/ 3 | *.py[cod] 4 | *.pyo 5 | *.pyd 6 | *.so 7 | 8 | # Jupyter/IPython 9 | .ipynb_checkpoints/ 10 | 11 | # Environment files 12 | .env 13 | .venv/ 14 | venv/ 15 | env/ 16 | pip-log.txt 17 | pip-delete-this-directory.txt 18 | 19 | # VS Code / PyCharm / JetBrains 20 | .vscode/ 21 | .idea/ 22 | 23 | # training artifacts 24 | .setup.cfg.json 25 | results/ 26 | local/ 27 | train_data/ 28 | -------------------------------------------------------------------------------- /configs/screen_detection/april_tags.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "dt_apriltags", 3 | "detection_config": { 4 | "family": "tag36h11", 5 | "quad_decimate": 2.0, 6 | "quad_sigma": 0.8, 7 | "refine_edges": 1, 8 | "decode_sharpening": 0.5 9 | }, 10 | "corners": { 11 | "TAG_ID_TOP_LEFT": 2, 12 | "TAG_ID_TOP_RIGHT": 3, 13 | "TAG_ID_BOTTOM_RIGHT": 0, 14 | "TAG_ID_BOTTOM_LEFT": 1 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ale-py==0.10.1 2 | dearpygui==2.0.0 3 | dt_apriltags==3.1.7 4 | dynamixel_sdk==3.7.31 5 | libusb1==3.3.1 6 | msgpack==1.1.0 7 | numpy<2.0 8 | 9 | # use apt version of cv 10 | #opencv-python==4.11.0.86 11 | 12 | psutil==7.0.0 13 | # v1.8.0 has an api breaking change; refactor before updating 14 | pynput==1.7.8 15 | 16 | torch==2.7.0 17 | torchvision==0.22.0 18 | torchaudio==2.7.0 19 | 20 | # https://pypi.org/project/uldaq/ 21 | uldaq==1.2.3 22 | Xlib==0.21 23 | -------------------------------------------------------------------------------- /configs/screen_detection/fixed.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "fixed", 3 | "detection_config": { 4 | "screen_rect": [ 5 | [ 6 | 116.0, 7 | 25.0 8 | ], 9 | [ 10 | 523.0, 11 | 26.0 12 | ], 13 | [ 14 | 510.0, 15 | 311.0 16 | ], 17 | [ 18 | 111.0, 19 | 298.0 20 | ] 21 | ] 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | # .editorconfig 2 | root = true 3 | 4 | [*] 5 | indent_style = space 6 | indent_size = 4 7 | end_of_line = lf 8 | charset = utf-8 9 | trim_trailing_whitespace = true 10 | insert_final_newline = true 11 | 12 | [*.py] 13 | max_line_length = 120 14 | 15 | [*.json] 16 | indent_size = 4 17 | 18 | [*.yml] 19 | indent_size = 4 20 | 21 | [*.yaml] 22 | indent_size = 4 23 | 24 | [*.toml] 25 | indent_size = 4 26 | 27 | [*.sh] 28 | indent_size = 4 29 | 30 | [*.md] 31 | trim_trailing_whitespace = false # allows trailing whitespace in markdown formatting 32 | -------------------------------------------------------------------------------- /configs/games/atlantis.json: -------------------------------------------------------------------------------- 1 | { 2 | "game_config": { 3 | "name": "atlantis", 4 | "lives": 6, 5 | "minimal_actions": [ 6 | "NOOP", 7 | "FIRE", 8 | "RIGHTFIRE", 9 | "LEFTFIRE" 10 | ] 11 | }, 12 | "score_config": { 13 | "checkpoint": "assets/models/atlantis_score.pt", 14 | "score_crop_info": { 15 | "x": 56, 16 | "y": 187, 17 | "w": 48, 18 | "h": 12, 19 | "num_digits": 6 20 | }, 21 | "lives_crop_info": {}, 22 | "valid_jumps": [ 23 | 0, 24 | 100, 25 | 200, 26 | 400, 27 | 500, 28 | 1000, 29 | 1500, 30 | 2000, 31 | 2500, 32 | 3000, 33 | 3500 34 | ] 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # https://help.github.com/en/github/using-git/configuring-git-to-handle-line-endings 2 | 3 | # Force Unix-style (LF) line endings for text files 4 | * text=auto 5 | 6 | *.py text eol=lf 7 | *.ipynb text eol=lf 8 | 9 | # C/C++ source and headers 10 | *.c text eol=lf 11 | *.cpp text eol=lf 12 | *.h text eol=lf 13 | *.hpp text eol=lf 14 | 15 | # Shell and Make 16 | *.sh text eol=lf 17 | Makefile text eol=lf 18 | 19 | # Config/data files 20 | *.json text eol=lf 21 | *.yml text eol=lf 22 | *.yaml text eol=lf 23 | *.md text eol=lf 24 | *.toml text eol=lf 25 | *.log text eol=lf 26 | 27 | # Binary files 28 | *.pt binary 29 | *.model binary 30 | *.pkl binary 31 | *.png binary 32 | *.jpg binary 33 | *.gif binary 34 | *.pdf binary 35 | *.zip binary 36 | *.so binary 37 | *.dll binary 38 | *.exe binary 39 | 40 | # Special diff drivers 41 | *.json diff=json 42 | *.ipynb diff=jupyternotebook 43 | -------------------------------------------------------------------------------- /configs/games/qbert.json: -------------------------------------------------------------------------------- 1 | { 2 | "game_config": { 3 | "name": "qbert", 4 | "lives": 4, 5 | "minimal_actions": [ 6 | "NOOP", 7 | "FIRE", 8 | "UP", 9 | "RIGHT", 10 | "LEFT", 11 | "DOWN" 12 | ] 13 | }, 14 | "score_config": { 15 | "checkpoint": "assets/models/qbert_score_lives.pt", 16 | "score_crop_info": { 17 | "x": 26, 18 | "y": 5, 19 | "w": 56, 20 | "h": 10, 21 | "num_digits": 6 22 | }, 23 | "lives_crop_info": { 24 | "x": 32, 25 | "y": 14, 26 | "w": 27, 27 | "h": 16, 28 | "num_digits": 4 29 | }, 30 | "valid_jumps": [ 31 | 0, 32 | 25, 33 | 100, 34 | 300, 35 | 500, 36 | 3100 37 | ] 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 120 3 | target-version = ["py39"] 4 | skip-string-normalization = true 5 | 6 | [tool.isort] 7 | profile = "black" 8 | line_length = 120 9 | known_third_party = ["torch", "numpy", "ale_py"] 10 | multi_line_output = 3 11 | include_trailing_comma = true 12 | force_grid_wrap = 0 13 | use_parentheses = true 14 | 15 | [tool.ruff] 16 | line-length = 120 17 | target-version = "py39" 18 | fix = true 19 | 20 | [tool.ruff.lint] 21 | select = ["E", "F", "I", "UP", "C90"] 22 | ignore = [ 23 | "E501", # line too long (handled by black) 24 | "F403", # 'from module import *' used 25 | "F401", # unused imports (during dev) 26 | ] 27 | exclude = ["tests/", "build/", ".venv/"] 28 | 29 | [tool.ruff.lint.mccabe] 30 | max-complexity = 60 31 | 32 | [tool.ruff.format] 33 | indent-style = "space" 34 | quote-style = "preserve" 35 | line-ending = "lf" 36 | 37 | [tool.pyright] 38 | pythonVersion = "3.9" 39 | typeCheckingMode = "basic" 40 | -------------------------------------------------------------------------------- /configs/cameras/camera_kiyo_pro.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Razer Kiyo Pro", 3 | "camera_config": { 4 | "width": 640, 5 | "height": 480, 6 | "fps": 60, 7 | "buffer_size": 4, 8 | "codec": "YUYV", 9 | "controls": { 10 | "exposure_dynamic_framerate": 0, 11 | "focus_automatic_continuous": 0, 12 | "focus_absolute": 300, 13 | "auto_exposure": 1, 14 | "exposure_time_absolute": 20, 15 | "backlight_compensation": 0, 16 | "white_balance_automatic": 0, 17 | "white_balance_temperature": 7500, 18 | "brightness": 128, 19 | "contrast": 84, 20 | "gain": 10, 21 | "saturation": 84, 22 | "power_line_frequency": 2, 23 | "sharpness": 128, 24 | "pan_absolute": 0, 25 | "tilt_absolute": 0, 26 | "zoom_absolute": 100 27 | } 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /configs/games/up_n_down.json: -------------------------------------------------------------------------------- 1 | { 2 | "game_config": { 3 | "name": "up_n_down", 4 | "lives": 5, 5 | "minimal_actions": [ 6 | "NOOP", 7 | "FIRE", 8 | "UP", 9 | "DOWN", 10 | "UPFIRE", 11 | "DOWNFIRE" 12 | ] 13 | }, 14 | "score_config": { 15 | "checkpoint": "assets/models/up_n_down_score_lives.pt", 16 | "score_crop_info": { 17 | "x": 16, 18 | "y": 5, 19 | "w": 48, 20 | "h": 10, 21 | "num_digits": 6 22 | }, 23 | "lives_crop_info": { 24 | "x": 14, 25 | "y": 194, 26 | "w": 32, 27 | "h": 12, 28 | "num_digits": 4 29 | }, 30 | "valid_jumps": [ 31 | 0, 32 | 10, 33 | 100, 34 | 110, 35 | 400, 36 | 410, 37 | 600, 38 | 610, 39 | 650 40 | ] 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /docker_build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2025 Keen Technologies, Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | # To run: 18 | # chmod +x docker_build.sh 19 | # 20 | # The image_name is optional, if not provided will default. 21 | # ./docker_build.sh 22 | # ./docker_build.sh custom_image_name 23 | 24 | DEFAULT_IMAGE_NAME="keen_physical_gpu" 25 | 26 | DOCKER_IMAGE_NAME=${1:-$DEFAULT_IMAGE_NAME} 27 | 28 | # Run the docker build command 29 | docker build -f Dockerfile -t "$DOCKER_IMAGE_NAME" ./ 30 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v5.0.0 4 | hooks: 5 | - id: trailing-whitespace 6 | - id: end-of-file-fixer 7 | - id: check-added-large-files 8 | - id: check-merge-conflict 9 | - id: debug-statements 10 | 11 | - repo: https://github.com/psf/black 12 | rev: 24.4.2 13 | hooks: 14 | - id: black 15 | 16 | - repo: local 17 | hooks: 18 | - id: isort 19 | name: isort 20 | entry: isort 21 | language: system 22 | types: [python] 23 | args: ["--profile", "black"] 24 | pass_filenames: true 25 | 26 | - repo: https://github.com/asottile/pyupgrade 27 | rev: v3.20.0 28 | hooks: 29 | - id: pyupgrade 30 | args: ["--py39-plus"] 31 | 32 | - repo: local 33 | hooks: 34 | - id: ruff 35 | name: ruff 36 | entry: ruff 37 | language: system 38 | types: [python] 39 | args: ["check", "--fix", "--preview"] 40 | pass_filenames: true 41 | -------------------------------------------------------------------------------- /framework/ScoreDetectorConfig.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Keen Technologies, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # specifies the available options for score detection 16 | 17 | NETWORK_MODELS = {"crnn_ctc"} 18 | DIRECTED_MODELS = {} 19 | 20 | ALL_MODELS = sorted(NETWORK_MODELS.union(DIRECTED_MODELS)) 21 | DEFAULT_MODEL = "crnn_ctc" 22 | 23 | 24 | def get_model_type(name: str) -> str: 25 | name = name.lower() 26 | if name in NETWORK_MODELS: 27 | return "network" 28 | elif name in DIRECTED_MODELS: 29 | return "directed" 30 | else: 31 | raise ValueError(f"Invalid model name: {name}") 32 | -------------------------------------------------------------------------------- /env_base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Keen Technologies, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | class BaseEnv: 17 | def close(self): 18 | raise NotImplementedError 19 | 20 | def get_name(self): 21 | raise NotImplementedError 22 | 23 | def get_action_set(self): 24 | raise NotImplementedError 25 | 26 | def reset(self): 27 | raise NotImplementedError 28 | 29 | def act(self, action): 30 | raise NotImplementedError 31 | 32 | def game_over(self): 33 | raise NotImplementedError 34 | 35 | def lives(self): 36 | raise NotImplementedError 37 | 38 | def get_observation(self): 39 | raise NotImplementedError 40 | -------------------------------------------------------------------------------- /framework/ControlDevice.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Keen Technologies, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from abc import ABC, abstractmethod 16 | 17 | from framework.Actions import Action 18 | 19 | 20 | class ControlDevice(ABC): 21 | @abstractmethod 22 | def apply_action(self, action: Action, state: int): 23 | # apply an action such as UP, FIRE, etc., optionally with a press/release state. 24 | pass 25 | 26 | @abstractmethod 27 | def shutdown(self): 28 | # shutdown or clean up the control device. 29 | pass 30 | 31 | def get_pins(self) -> list[int]: 32 | # optional: return list of active GPIO-like pins used by the device. 33 | return [] 34 | -------------------------------------------------------------------------------- /configs/games/defender.json: -------------------------------------------------------------------------------- 1 | { 2 | "game_config": { 3 | "name": "defender", 4 | "lives": 3, 5 | "minimal_actions": [ 6 | "NOOP", 7 | "FIRE", 8 | "UP", 9 | "RIGHT", 10 | "LEFT", 11 | "DOWN", 12 | "UPRIGHT", 13 | "UPLEFT", 14 | "DOWNRIGHT", 15 | "DOWNLEFT", 16 | "UPFIRE", 17 | "RIGHTFIRE", 18 | "LEFTFIRE", 19 | "DOWNFIRE", 20 | "UPRIGHTFIRE", 21 | "UPLEFTFIRE", 22 | "DOWNRIGHTFIRE", 23 | "DOWNLEFTFIRE" 24 | ] 25 | }, 26 | "score_config": { 27 | "checkpoint": "assets/models/defender_score_lives.pt", 28 | "score_crop_info": { 29 | "x": 55, 30 | "y": 176, 31 | "w": 48, 32 | "h": 9, 33 | "num_digits": 6 34 | }, 35 | "lives_crop_info": { 36 | "x": 0, 37 | "y": 182, 38 | "w": 48, 39 | "h": 12, 40 | "num_digits": 3 41 | }, 42 | "valid_jumps": [ 43 | 0, 44 | 50, 45 | 100, 46 | 150, 47 | 1000 48 | ] 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /configs/games/battle_zone.json: -------------------------------------------------------------------------------- 1 | { 2 | "game_config": { 3 | "name": "battle_zone", 4 | "lives": 5, 5 | "minimal_actions": [ 6 | "NOOP", 7 | "FIRE", 8 | "UP", 9 | "RIGHT", 10 | "LEFT", 11 | "DOWN", 12 | "UPRIGHT", 13 | "UPLEFT", 14 | "DOWNRIGHT", 15 | "DOWNLEFT", 16 | "UPFIRE", 17 | "RIGHTFIRE", 18 | "LEFTFIRE", 19 | "DOWNFIRE", 20 | "UPRIGHTFIRE", 21 | "UPLEFTFIRE", 22 | "DOWNRIGHTFIRE", 23 | "DOWNLEFTFIRE" 24 | ] 25 | }, 26 | "score_config": { 27 | "checkpoint": "assets/models/battle_zone_score_lives.pt", 28 | "score_crop_info": { 29 | "x": 64, 30 | "y": 179, 31 | "w": 48, 32 | "h": 10, 33 | "num_digits": 6 34 | }, 35 | "lives_crop_info": { 36 | "x": 64, 37 | "y": 188, 38 | "w": 40, 39 | "h": 10, 40 | "num_digits": 5 41 | }, 42 | "valid_jumps": [ 43 | 0, 44 | 1000, 45 | 2000, 46 | 5000, 47 | 6000 48 | ] 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /framework/Actions.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Keen Technologies, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from enum import Enum 16 | 17 | 18 | class Action(Enum): 19 | NOOP = 0 20 | FIRE = 1 21 | UP = 2 22 | RIGHT = 3 23 | LEFT = 4 24 | DOWN = 5 25 | UPRIGHT = 6 26 | UPLEFT = 7 27 | DOWNRIGHT = 8 28 | DOWNLEFT = 9 29 | UPFIRE = 10 30 | RIGHTFIRE = 11 31 | LEFTFIRE = 12 32 | DOWNFIRE = 13 33 | UPRIGHTFIRE = 14 34 | UPLEFTFIRE = 15 35 | DOWNRIGHTFIRE = 16 36 | DOWNLEFTFIRE = 17 37 | 38 | @classmethod 39 | def has_key(cls, key): 40 | return key in cls.__members__ 41 | 42 | @classmethod 43 | def has_value(cls, value): 44 | return value in cls._value2member_map_ 45 | -------------------------------------------------------------------------------- /configs/games/krull.json: -------------------------------------------------------------------------------- 1 | { 2 | "game_config": { 3 | "name": "krull", 4 | "lives": 3, 5 | "minimal_actions": [ 6 | "NOOP", 7 | "FIRE", 8 | "UP", 9 | "RIGHT", 10 | "LEFT", 11 | "DOWN", 12 | "UPRIGHT", 13 | "UPLEFT", 14 | "DOWNRIGHT", 15 | "DOWNLEFT", 16 | "UPFIRE", 17 | "RIGHTFIRE", 18 | "LEFTFIRE", 19 | "DOWNFIRE", 20 | "UPRIGHTFIRE", 21 | "UPLEFTFIRE", 22 | "DOWNRIGHTFIRE", 23 | "DOWNLEFTFIRE" 24 | ] 25 | }, 26 | "score_config": { 27 | "checkpoint": "assets/models/krull_score_lives.pt", 28 | "score_crop_info": { 29 | "x": 56, 30 | "y": 175, 31 | "w": 46, 32 | "h": 10, 33 | "num_digits": 6 34 | }, 35 | "lives_crop_info": { 36 | "x": 47, 37 | "y": 186, 38 | "w": 24, 39 | "h": 12, 40 | "num_digits": 2 41 | }, 42 | "valid_jumps": [ 43 | 0, 44 | 10, 45 | 20, 46 | 30, 47 | 50, 48 | 500, 49 | 1000, 50 | 3000 51 | ] 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /framework/ControlDeviceCfg.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Keen Technologies, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from framework.Logger import logger 16 | 17 | 18 | def create_control_device_from_cfg(**kwargs): 19 | model_name = kwargs.pop("model_name", None) 20 | if model_name == "MCC USB-1024LS": 21 | import framework.MCCDAQDevice as DAQDevice 22 | 23 | logger.info(f"Initializing {model_name}") 24 | return DAQDevice.DAQDevice(model_name, **kwargs) 25 | elif model_name == "QinHeng Electronics USB Single Serial": 26 | logger.info(f"Initializing {model_name} in position mode") 27 | import framework.RoboTroller as RoboTroller 28 | 29 | return RoboTroller.RoboTroller(model_name, **kwargs) 30 | else: 31 | raise ValueError(f"Joystick model not supported: {model_name}") 32 | -------------------------------------------------------------------------------- /configs/games/centipede.json: -------------------------------------------------------------------------------- 1 | { 2 | "game_config": { 3 | "name": "centipede", 4 | "lives": 3, 5 | "minimal_actions": [ 6 | "NOOP", 7 | "FIRE", 8 | "UP", 9 | "RIGHT", 10 | "LEFT", 11 | "DOWN", 12 | "UPRIGHT", 13 | "UPLEFT", 14 | "DOWNRIGHT", 15 | "DOWNLEFT", 16 | "UPFIRE", 17 | "RIGHTFIRE", 18 | "LEFTFIRE", 19 | "DOWNFIRE", 20 | "UPRIGHTFIRE", 21 | "UPLEFTFIRE", 22 | "DOWNRIGHTFIRE", 23 | "DOWNLEFTFIRE" 24 | ] 25 | }, 26 | "score_config": { 27 | "checkpoint": "assets/models/centipede_score_lives.pt", 28 | "score_crop_info": { 29 | "x": 100, 30 | "y": 185, 31 | "w": 48, 32 | "h": 12, 33 | "num_digits": 6 34 | }, 35 | "lives_crop_info": { 36 | "x": 14, 37 | "y": 185, 38 | "w": 16, 39 | "h": 12, 40 | "num_digits": 2 41 | }, 42 | "valid_jumps": [ 43 | 0, 44 | 1, 45 | 5, 46 | 10, 47 | 20, 48 | 100, 49 | 110, 50 | 200, 51 | 300, 52 | 600, 53 | 900, 54 | 1000 55 | ] 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /scripts/format_code.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2025 Keen Technologies, Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # chmod +x scripts/format_code.sh 17 | set -e 18 | 19 | # run from repo root dir 20 | cd "$(dirname "$0")/.." 21 | 22 | # ensure ~/.local/bin is in PATH 23 | export PATH="$HOME/.local/bin:$PATH" 24 | 25 | REQUIRED_TOOLS=("pre-commit" "black" "isort" "ruff") 26 | 27 | # check for missing tools 28 | missing=() 29 | for tool in "${REQUIRED_TOOLS[@]}"; do 30 | if ! command -v "$tool" &>/dev/null; then 31 | missing+=("$tool") 32 | fi 33 | done 34 | 35 | if [ ${#missing[@]} -gt 0 ]; then 36 | echo "Missing tools: ${missing[*]}" 37 | echo "Installing dev requirements from requirements-dev.txt..." 38 | python3 -m pip install --user -r requirements-dev.txt 39 | fi 40 | 41 | echo "Running pre-commit on all files..." 42 | pre-commit run --all-files 43 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | We welcome small, well-scoped contributions to this project. We're keeping the scope tight and focused. 4 | 5 | Please read the following notes before submitting a pull request or issue. 6 | 7 | --- 8 | 9 | ## Issues 10 | 11 | - Tag your issue with one of: `bug`, `feature request`, or `question` 12 | - Keep issues concise and reproducible when possible 13 | - Maintainer response times may be slow depending on bandwidth 14 | 15 | --- 16 | 17 | ## Pull Requests 18 | 19 | We accept: 20 | 21 | - Bug fixes 22 | - Small, testable improvements 23 | - Minor documentation updates 24 | 25 | We do not accept: 26 | 27 | - API redesigns 28 | - Major refactors 29 | - Subsystem rewrites 30 | 31 | All contributions are reviewed via GitHub pull requests. 32 | See [GitHub Help](https://help.github.com/articles/about-pull-requests/) if you're unfamiliar with the process. 33 | 34 | --- 35 | 36 | ## Code Review 37 | 38 | - All submissions require review 39 | - Pre-commit checks (formatting, linting, etc.) must pass before merge 40 | - Run `scripts/format_code.sh` before committing. It checks formatting and installs required tools if missing. 41 | - You can install dev tools manually with: `pip install -r requirements-dev.txt` 42 | - Include or update tests if applicable 43 | 44 | --- 45 | 46 | ## Contributor License Agreement (CLA) 47 | 48 | We do **not** require a CLA. Contributions must be easy to verify and are accepted at our discretion. 49 | -------------------------------------------------------------------------------- /configs/games/ms_pacman.json: -------------------------------------------------------------------------------- 1 | { 2 | "game_config": { 3 | "name": "ms_pacman", 4 | "lives": 3, 5 | "minimal_actions": [ 6 | "NOOP", 7 | "UP", 8 | "RIGHT", 9 | "LEFT", 10 | "DOWN", 11 | "UPRIGHT", 12 | "UPLEFT", 13 | "DOWNRIGHT", 14 | "DOWNLEFT" 15 | ] 16 | }, 17 | "score_config": { 18 | "checkpoint": "assets/models/ms_pacman_score_lives.pt", 19 | "score_crop_info": { 20 | "x": 55, 21 | "y": 184, 22 | "w": 48, 23 | "h": 12, 24 | "num_digits": 6 25 | }, 26 | "lives_crop_info": { 27 | "x": 11, 28 | "y": 171, 29 | "w": 25, 30 | "h": 14, 31 | "num_digits": 2 32 | }, 33 | "valid_jumps": [ 34 | 0, 35 | 10, 36 | 50, 37 | 60, 38 | 100, 39 | 110, 40 | 200, 41 | 210, 42 | 300, 43 | 310, 44 | 400, 45 | 410, 46 | 500, 47 | 510, 48 | 700, 49 | 710, 50 | 800, 51 | 810, 52 | 1000, 53 | 1010, 54 | 1600, 55 | 1610, 56 | 2000, 57 | 2010, 58 | 5000, 59 | 5010 60 | ] 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /scripts/lambda-hold.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2025 Keen Technologies, Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # To run: 17 | # lambda-hold.sh — prevent updates to Lambda Stack packages 18 | # To unhold by package: sudo apt-mark unhold 19 | # or: 20 | # to unhold all: apt-mark showhold | xargs sudo apt-mark unhold 21 | 22 | set -e 23 | 24 | echo "Locking Lambda Stack packages..." 25 | 26 | # core packages to hold 27 | LAMBDA_PACKAGES=( 28 | lambda-stack-cuda 29 | lambda-stack 30 | libcudnn8 31 | libcudnn8-dev 32 | libnccl2 33 | libnccl-dev 34 | nvidia-cuda-toolkit 35 | nvidia-driver 36 | nvidia-headless 37 | nvidia-fabricmanager 38 | cuda 39 | ) 40 | 41 | # go through packages, and if installed, hold 42 | for pkg in "${LAMBDA_PACKAGES[@]}"; do 43 | MATCHED_PKGS=$(dpkg -l | awk '{print $2}' | grep -E "^${pkg}.*$" || true) 44 | for matched in $MATCHED_PKGS; do 45 | echo "-- Holding $matched" 46 | sudo apt-mark hold "$matched" 47 | done 48 | done 49 | 50 | echo "" 51 | echo "Hold complete. The following packages are currently on hold:" 52 | apt-mark showhold 53 | -------------------------------------------------------------------------------- /docker_run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2025 Keen Technologies, Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | # To run: 18 | # chmod +x docker_run.sh 19 | # 20 | # Both the code folder and image_name are optional, if not provided the code folder will default to the current directory. 21 | #./docker_run.sh /path/to/code 22 | #./docker_run.sh /path/to/code custom_image_name 23 | # 24 | # Inside the container, /path/to/code will be mounted as /workspaces/code 25 | 26 | DEFAULT_IMAGE_NAME="keen_physical_gpu" 27 | 28 | CODE_FOLDER_PATH=${1:-$(pwd)} 29 | DOCKER_IMAGE_NAME=${2:-$DEFAULT_IMAGE_NAME} 30 | 31 | # Get the last directory in the path for the mount point (basename of the code folder) 32 | MOUNT_POINT="/workspaces/$(basename "$CODE_FOLDER_PATH")" 33 | 34 | # Run the docker container with the specified image name and mount the code folder to the container 35 | docker run --rm -it \ 36 | -v "$CODE_FOLDER_PATH:$MOUNT_POINT" \ 37 | -w "$MOUNT_POINT" \ 38 | --mount source=/dev,target=/dev,type=bind \ 39 | -e DISPLAY=$DISPLAY \ 40 | -e XAUTHORITY=/tmp/.Xauthority \ 41 | --gpus=all \ 42 | --privileged \ 43 | --network=host \ 44 | --ipc=host \ 45 | --ulimit=memlock=-1 \ 46 | --ulimit=stack=67108864 \ 47 | --cap-add=SYS_PTRACE \ 48 | --volume=/tmp/.X11-unix:/tmp/.X11-unix \ 49 | --volume=$XAUTHORITY:/tmp/.Xauthority \ 50 | "$DOCKER_IMAGE_NAME" 51 | -------------------------------------------------------------------------------- /train/score_detector/checkpoint_viewer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Keen Technologies, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import pprint 17 | 18 | import torch 19 | 20 | 21 | def load_checkpoint(path): 22 | checkpoint = torch.load(path, map_location="cpu", weights_only=False) 23 | print(f"\n[Checkpoint Loaded] {path}") 24 | 25 | print("\n[Model Config]") 26 | pprint.pprint(checkpoint.get("model_config", {})) 27 | 28 | print("\n[Train Config]") 29 | pprint.pprint(checkpoint.get("train_config", {})) 30 | 31 | print("\n[Game Config]") 32 | pprint.pprint(checkpoint.get("game_config", {})) 33 | 34 | if "train_summary" in checkpoint: 35 | print("\n[Train Summary]") 36 | history = checkpoint["train_summary"] 37 | print(f" Epochs Trained: {history['epochs']}") 38 | print(f" Final Train Loss: {history['train_losses'][-1]:.4f}") 39 | print(f" Final Test Loss: {history['test_losses'][-1]:.4f}") 40 | print(f" Final CER: {history['cer'][-1]:.2f}%") 41 | print(f" Final Accuracy: {history['accuracy'][-1] * 100:.2f}%") 42 | 43 | print("\n[Available Keys in Checkpoint]") 44 | print(list(checkpoint.keys())) 45 | 46 | 47 | if __name__ == "__main__": 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument("checkpoint_path", type=str) 50 | args = parser.parse_args() 51 | load_checkpoint(args.checkpoint_path) 52 | -------------------------------------------------------------------------------- /scripts/performance/nvidia-persistence.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2025 Keen Technologies, Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # chmod +x nvidia-persistence.sh 17 | # sudo ./nvidia-persistence.sh 18 | 19 | set -e 20 | 21 | echo "Starting NVIDIA persistence mode setup..." 22 | 23 | # check if nvidia-smi exists 24 | if ! command -v nvidia-smi &> /dev/null; then 25 | echo "[ERROR] nvidia-smi not found. Please install NVIDIA drivers first." 26 | exit 1 27 | fi 28 | 29 | # verify script is run as root 30 | if [[ $EUID -ne 0 ]]; then 31 | echo "[ERROR] Script must be run as root (e.g., sudo $0)" 32 | exit 1 33 | fi 34 | 35 | echo "Enabling persistence mode now..." 36 | nvidia-smi -pm 1 37 | 38 | # create systemd service to enable persistence mode on boot 39 | SERVICE_PATH="/etc/systemd/system/nvidia-persistence.service" 40 | 41 | echo "Creating systemd service at $SERVICE_PATH..." 42 | 43 | cat < "$SERVICE_PATH" 44 | [Unit] 45 | Description=NVIDIA Persistence Mode 46 | After=multi-user.target 47 | 48 | [Service] 49 | Type=oneshot 50 | ExecStart=/usr/bin/nvidia-smi -pm 1 51 | RemainAfterExit=true 52 | 53 | [Install] 54 | WantedBy=multi-user.target 55 | EOF 56 | 57 | # reload systemd and enable the service 58 | echo "Reloading systemd daemon and enabling nvidia-persistence.service..." 59 | systemctl daemon-reload 60 | systemctl enable nvidia-persistence.service 61 | systemctl start nvidia-persistence.service 62 | 63 | # verify status 64 | echo "Verifying persistence mode status..." 65 | nvidia-smi -q | grep "Persistence Mode" 66 | 67 | echo "NVIDIA persistence mode enabled and service installed." 68 | -------------------------------------------------------------------------------- /tests/test_controller.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Keen Technologies, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import json 16 | import os 17 | import sys 18 | import time 19 | 20 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))) 21 | from framework.ControlDeviceCfg import create_control_device_from_cfg 22 | from framework.Keyboard import Keyboard 23 | 24 | # Use the keyboard to send command to the device communicating actions to the Atari. 25 | 26 | 27 | def main(args): 28 | with open(args.device_config) as kf: 29 | device_data = kf.read() 30 | 31 | device_data = json.loads(device_data) 32 | device = create_control_device_from_cfg(**device_data) 33 | keyboard = Keyboard(device) 34 | 35 | try: 36 | while True: 37 | should_quit, _ = keyboard.update() 38 | if should_quit: 39 | print("Exiting...") 40 | break 41 | time.sleep(0.001) 42 | 43 | except KeyboardInterrupt: 44 | pass 45 | 46 | finally: 47 | keyboard.shutdown() 48 | 49 | 50 | def get_argument_parser(): 51 | from argparse import ArgumentParser 52 | 53 | parser = ArgumentParser(description="keyboard_test.py arguments") 54 | # parser.add_argument('--device_config', type=str, default="configs/controllers/io_controller.json") 55 | parser.add_argument('--device_config', type=str, default="configs/controllers/robotroller.json") 56 | return parser 57 | 58 | 59 | if __name__ == '__main__': 60 | arg_parser = get_argument_parser() 61 | args = arg_parser.parse_args() 62 | 63 | main(args) 64 | -------------------------------------------------------------------------------- /framework/ScreenDetectorFixed.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Keen Technologies, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import cv2 16 | import numpy as np 17 | 18 | from framework.Logger import logger 19 | 20 | 21 | class ScreenDetectorFixed: 22 | def __init__(self, method_name, screen_rect=None, mspacman_rect=None): 23 | assert method_name == "fixed" 24 | self.last_detected_tags = {} 25 | # A mspacman_rect replaces any screen_rect 26 | if mspacman_rect: # mspacman_rect in camera pixels, clockwise from upperleft 27 | # The ALE coordinates of the visible box in mspacman 28 | mx, my = 160, 171 29 | mspacman_rect_ale = np.array([(0, 0), (mx, 0), (mx, my), (0, my)], dtype=np.float32) 30 | mspacman_rect = np.array(mspacman_rect, dtype=np.float32) 31 | transform = cv2.getPerspectiveTransform(mspacman_rect_ale, mspacman_rect) 32 | 33 | sx, sy = 160, 210 34 | source_rect = np.array([(0, 0, 1), (sx, 0, 1), (sx, sy, 1), (0, sy, 1)], dtype=np.float32) 35 | screen_rect = [] 36 | for point in source_rect: 37 | q = np.matmul(transform, point) 38 | screen_rect.append((q[0] / q[2], q[1] / q[2])) 39 | 40 | self.screen_rect = np.array(screen_rect, dtype=np.float32) 41 | 42 | elif screen_rect is not None: 43 | self.screen_rect = np.array(screen_rect, dtype=np.float32) 44 | else: 45 | raise ValueError("ScreenDetectorFixed requires either mspacman_rect or screen_rect") 46 | 47 | def shutdown(self): 48 | pass 49 | 50 | def get_screen_rect_info(self, _frame_g): 51 | return self.screen_rect, self.last_detected_tags 52 | -------------------------------------------------------------------------------- /agent_random.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Keen Technologies, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # agent_random.py 16 | # 17 | # Use the last evaluations for target calculation instead of a target model evaluation 18 | import time 19 | 20 | import numpy as np 21 | 22 | from framework.Logger import logger 23 | 24 | 25 | class Agent: 26 | def __init__(self, data_dir, seed, num_actions, total_frames, **kwargs): 27 | # defaults that might be overridden by explicit experiment runs 28 | 29 | self.num_actions = num_actions # many games can use a reduced action set for faster learning 30 | self.gpu = -1 31 | self.total_frames = total_frames 32 | self.frame_skip = 4 33 | self.seed = seed 34 | self.training_model = None 35 | self.ring_buffer_size = 0 36 | self.train_losses = 0 37 | self.use_model = 0 38 | 39 | # dynamically override configuration 40 | for key, value in kwargs.items(): 41 | try: 42 | assert hasattr(self, key) 43 | except AssertionError: 44 | logger.error(f"agent_random: Request to set unknown property: {key}") 45 | continue 46 | setattr(self, key, value) 47 | 48 | # variables used by policy 49 | self.step = 0 50 | self.rng = np.random.default_rng(self.seed) 51 | self.selected_action_index = 0 52 | 53 | # -------------------------------- 54 | # Returns the selected action index 55 | # -------------------------------- 56 | def frame(self, observation_rgb8, reward, end_of_episode): 57 | if 0 == self.step % self.frame_skip: 58 | self.selected_action_index = self.rng.integers(self.num_actions) 59 | self.step += 1 60 | return self.selected_action_index 61 | 62 | def save_model(self, filename): 63 | pass 64 | -------------------------------------------------------------------------------- /framework/Joystick.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Keen Technologies, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import queue 16 | import threading 17 | import time 18 | 19 | from framework.Actions import Action 20 | from framework.ControlDevice import ControlDevice 21 | from framework.Logger import logger 22 | 23 | 24 | class Joystick: 25 | def __init__(self, device: ControlDevice, threaded: bool = False): 26 | self.device = device 27 | assert self.device is not None 28 | self.threaded = threaded 29 | self.running = True 30 | self.action_queue = queue.Queue(maxsize=2) 31 | self.thread = None 32 | 33 | if self.threaded: 34 | self.thread = threading.Thread(target=self._process_action_queue, daemon=True) 35 | self.thread.start() 36 | 37 | def shutdown(self): 38 | self.running = False 39 | if self.threaded and self.thread is not None: 40 | self.thread.join() 41 | if self.device: 42 | self.device.shutdown() 43 | self.device = None 44 | with self.action_queue.mutex: 45 | self.action_queue.queue.clear() 46 | 47 | def _process_action_queue(self): 48 | while self.running: 49 | try: 50 | action = self.action_queue.get(timeout=0.1) 51 | self.device.apply_action(action, 1) 52 | except queue.Empty: 53 | pass 54 | time.sleep(0.001) 55 | 56 | def apply_action(self, action: Action) -> None: 57 | if self.threaded: 58 | if self.action_queue.full(): 59 | try: 60 | _ = self.action_queue.get_nowait() 61 | except queue.Empty: 62 | pass 63 | self.action_queue.put(action) 64 | else: 65 | self.device.apply_action(action, 1) 66 | 67 | def __repr__(self): 68 | return f"" 69 | -------------------------------------------------------------------------------- /scripts/performance/cpu-governor.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright 2025 Keen Technologies, Inc. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # chmod +x cpu-governor.sh 18 | # sudo ./cpu-governor.sh 19 | 20 | # Set CPU governor to performance. 21 | 22 | set -e 23 | 24 | GOV="performance" 25 | 26 | echo "Starting CPU governor setup..." 27 | 28 | # verify script is run as root 29 | if [[ $EUID -ne 0 ]]; then 30 | echo "[ERROR] Script must be run as root (e.g., sudo $0)" 31 | exit 1 32 | fi 33 | 34 | # install required tools 35 | echo "Installing cpupower tools..." 36 | apt update 37 | apt install -y linux-tools-common linux-tools-$(uname -r) 38 | 39 | # disable interfering services 40 | if systemctl is-active --quiet power-profiles-daemon.service; then 41 | echo "Disabling power-profiles-daemon to avoid conflicts..." 42 | systemctl mask power-profiles-daemon.service 43 | systemctl stop power-profiles-daemon.service 44 | fi 45 | 46 | # Set governor temporarily on all CPU cores 47 | echo "Setting CPU governor to $GOV temporarily..." 48 | for gov_file in /sys/devices/system/cpu/cpu[0-9]*/cpufreq/scaling_governor; do 49 | if [ -w "$gov_file" ]; then 50 | echo "$GOV" > "$gov_file" || echo "[WARN] Failed to set $gov_file" 51 | else 52 | echo "[WARN] No write permission for $gov_file" 53 | fi 54 | done 55 | 56 | # create cpupower systemd service for persistence 57 | echo "Creating cpupower systemd service..." 58 | cat < /etc/systemd/system/cpupower.service 59 | [Unit] 60 | Description=Set CPU frequency scaling governor 61 | After=multi-user.target 62 | 63 | [Service] 64 | Type=oneshot 65 | RemainAfterExit=true 66 | ExecStart=/usr/bin/cpupower frequency-set -g $GOV 67 | 68 | [Install] 69 | WantedBy=multi-user.target 70 | EOF 71 | 72 | # reload systemd, enable and start the service 73 | echo "Enabling and starting cpupower.service..." 74 | systemctl daemon-reexec 75 | systemctl daemon-reload 76 | systemctl enable cpupower.service 77 | systemctl start cpupower.service 78 | 79 | # confirm status 80 | echo "Verifying current governor:" 81 | cpupower frequency-info | grep "governor" 82 | 83 | echo "Success. CPU governor set to '$GOV' and service installed for persistence." 84 | -------------------------------------------------------------------------------- /scripts/performance/nvidia-powerd.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2025 Keen Technologies, Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # chmod +x nvidia-powerd.sh 17 | # sudo ./nvidia-powerd.sh 18 | 19 | # NVIDIA provides a service called 'nvidia-powerd' that enables Dynamic Boost - a feature 20 | # that reallocates power between the CPU and GPU based on workload for better performance. 21 | # See https://download.nvidia.com/XFree86/Linux-x86_64/510.47.03/README/dynamicboost.html. 22 | 23 | # This must be run on host after every driver update. 24 | # Verify power limits are no longer limited with 'nvidia-smi --query-gpu=power.draw,power.limit --format=csv' 25 | 26 | set -e 27 | 28 | echo "Starting NVIDIA Dynamic Boost setup..." 29 | 30 | # get the major driver version number (e.g., "535" from "535.54.03") 31 | DRIVER_VER=$(nvidia-smi --query-gpu=driver_version --format=csv,noheader | head -1 | cut -d '.' -f1) 32 | 33 | if [[ -z "$DRIVER_VER" ]]; then 34 | echo "[ERROR] Could not determine NVIDIA driver version. Verify driver installed with nvidia-smi." 35 | exit 1 36 | fi 37 | 38 | echo "Detected NVIDIA driver version: $DRIVER_VER" 39 | 40 | # paths to the source files in the driver docs 41 | DBUS_CONF_SRC="/usr/share/doc/nvidia-driver-$DRIVER_VER/nvidia-dbus.conf" 42 | POWERD_SERVICE_SRC="/usr/share/doc/nvidia-kernel-common-$DRIVER_VER/nvidia-powerd.service" 43 | 44 | # Verify source files exist 45 | if [[ ! -f "$DBUS_CONF_SRC" ]]; then 46 | echo "[ERROR] $DBUS_CONF_SRC not found." 47 | exit 1 48 | fi 49 | 50 | if [[ ! -f "$POWERD_SERVICE_SRC" ]]; then 51 | echo "[ERROR] $POWERD_SERVICE_SRC not found." 52 | exit 1 53 | fi 54 | 55 | DBUS_CONF_DST="/etc/dbus-1/system.d/nvidia-dbus.conf" 56 | POWERD_SERVICE_DST="/etc/systemd/system/nvidia-powerd.service" 57 | 58 | echo "Copying $DBUS_CONF_SRC to $DBUS_CONF_DST" 59 | cp "$DBUS_CONF_SRC" "$DBUS_CONF_DST" 60 | 61 | echo "Copying $POWERD_SERVICE_SRC to $POWERD_SERVICE_DST" 62 | cp "$POWERD_SERVICE_SRC" "$POWERD_SERVICE_DST" 63 | 64 | echo "Reloading systemd daemon..." 65 | systemctl daemon-reload 66 | 67 | echo "Enabling and starting nvidia-powerd.service..." 68 | systemctl enable --now nvidia-powerd.service 69 | 70 | echo "NVIDIA powerd service status:" 71 | systemctl status nvidia-powerd.service --no-pager 72 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:24.11-py3 2 | 3 | # https://github.com/openucx/ucc/issues/476 - 'ImportError: /opt/hpcx/ucx/lib/libucs.so.0: undefined symbol: ucm_set_global_opts' 4 | # Workaround: Error happens because compiler picks up libucm required by libucs from a different directory, 5 | # i.e. libucs is taken from $UCX_HOME while libucm comes from HPCX installed in /opt. Proper container environment 6 | # config should resolve the issue. 7 | ENV LD_LIBRARY_PATH="/opt/hpcx/ucx/lib:$LD_LIBRARY_PATH" 8 | 9 | # allow more efficient management of memory segments 10 | ENV PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True 11 | 12 | WORKDIR /workspaces 13 | 14 | COPY requirements.txt /workspaces/requirements.txt 15 | 16 | # NOTE: Added libgl1 for cv2 'ImportError: libGL.so.1: cannot open shared object file: No such file or directory' 17 | # libxkbfile1 is neeed for nsys-ui 18 | RUN apt-get update \ 19 | && DEBIAN_FRONTEND=noninteractive \ 20 | apt-get install --no-install-recommends --assume-yes \ 21 | build-essential make gcc g++ gdb strace valgrind git clang-format \ 22 | xauth libgl1 ffmpeg v4l-utils udev usbutils libusb-1.0-0-dev wget \ 23 | x11-utils x11-xserver-utils \ 24 | python3-opencv libxkbfile1 nvtop tlp 25 | 26 | # mcc daq https://github.com/mccdaq/uldaq?_ga=2.85905500.479671302.1736441555-1860231292.1736441555 27 | # Required dependency for python library 'uldaq' 28 | RUN mkdir mccdaq 29 | RUN wget -N https://github.com/mccdaq/uldaq/releases/download/v1.2.1/libuldaq-1.2.1.tar.bz2 --directory-prefix=mccdaq/ 30 | RUN tar -xvjf mccdaq/libuldaq-1.2.1.tar.bz2 -C mccdaq/ 31 | RUN cd mccdaq/libuldaq-1.2.1 && ./configure && make && make install && cd ../.. 32 | 33 | RUN python -m pip install --upgrade pip 34 | 35 | # ngc.nvidia pulls in opencv=4.7.0 which is not compatible with running graphical 36 | # user interfaces within a docker container. An appropriate version of opencv is 37 | # installed in 'requirements.txt'. However, if the container library is not explicitly removed 38 | # imshow will fail with 'error: (-2:Unspecified error) The function is not implemented. Rebuild the library with ...' 39 | RUN pip uninstall --yes opencv 40 | 41 | RUN pip install -r requirements.txt 42 | 43 | #---------------------- 44 | # profiling tools 45 | # https://docs.nvidia.com/nsight-systems/InstallationGuide/index.html 46 | #---------------------- 47 | 48 | ARG NSYS_URL=https://developer.nvidia.com/downloads/assets/tools/secure/nsight-systems/2025_2/ 49 | # cli-only 50 | #ARG NSYS_PKG=NsightSystems-linux-cli-public-2025.2.1.130-3569061.deb 51 | # cli + ui 52 | ARG NSYS_PKG=nsight-systems-2025.2.1_2025.2.1.130-1_amd64.deb 53 | #RUN apt-get update && apt install -y libglib2.0-0 libxkbfile1 54 | #RUN wget ${NSYS_URL}${NSYS_PKG} && dpkg -i $NSYS_PKG && rm $NSYS_PKG 55 | RUN wget ${NSYS_URL}${NSYS_PKG} && apt install -y ./$NSYS_PKG && rm $NSYS_PKG 56 | -------------------------------------------------------------------------------- /tests/test_camera.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Keen Technologies, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import json 16 | import os 17 | import sys 18 | import time 19 | 20 | import cv2 21 | 22 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))) 23 | from framework.CameraDevice_v4l2 import CameraDevice_v4l2 as CameraDevice 24 | 25 | 26 | def main(args): 27 | with open(args.camera_config) as cf: 28 | camera_data = cf.read() 29 | 30 | camera_data = json.loads(camera_data) 31 | camera_name = camera_data["model_name"] 32 | camera_config = camera_data["camera_config"] 33 | camera = CameraDevice(camera_name, **camera_config) 34 | 35 | try: 36 | print("Starting performance test...") 37 | start_time = time.time() 38 | for _ in range(600): 39 | _ = camera.get_frame() 40 | print(f"Total_time to read 600 frames={(time.time() - start_time)}s") 41 | 42 | target_fps = camera.get_fps() 43 | frames = 0 44 | total_time = 0.0 45 | while True: 46 | start_time = time.time() 47 | _frame_data = camera.get_frame() 48 | total_time += time.time() - start_time 49 | frames += 1 50 | 51 | if frames == target_fps: 52 | print(f"Camera FPS={(target_fps / total_time):.2f}") 53 | frames = 0 54 | total_time = 0.0 55 | 56 | """ 57 | start_time = time.time() 58 | frame = camera.convert_to_rgb(_frame_data["frame"]) 59 | #print(f"convert: {(time.time()-start_time)*1000.0:.2f}") 60 | cv2.imshow("Camera", frame) 61 | """ 62 | keycode = cv2.pollKey() 63 | if keycode == 27: # Escape Key 64 | break 65 | 66 | except KeyboardInterrupt: 67 | pass 68 | 69 | finally: 70 | camera.shutdown() 71 | 72 | 73 | def get_argument_parser(): 74 | from argparse import ArgumentParser 75 | 76 | parser = ArgumentParser(description="camera_test.py arguments") 77 | parser.add_argument('--camera_config', type=str, default="configs/cameras/camera_kiyo_pro.json") 78 | return parser 79 | 80 | 81 | if __name__ == '__main__': 82 | arg_parser = get_argument_parser() 83 | args = arg_parser.parse_args() 84 | 85 | main(args) 86 | -------------------------------------------------------------------------------- /scripts/performance/power-profile.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2025 Keen Technologies, Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # chmod +x power-profile.sh 17 | # sudo ./power-profile.sh 18 | 19 | # Set platform profile to performance using ACPI interface. 20 | set -e 21 | 22 | PROFILE_PATH="/sys/firmware/acpi/platform_profile" 23 | PROFILE="performance" 24 | 25 | echo "Starting power profile setup..." 26 | 27 | # verify script is run as root 28 | if [[ $EUID -ne 0 ]]; then 29 | echo "[ERROR] Script must be run as root (e.g., sudo $0)" 30 | exit 1 31 | fi 32 | 33 | # check if the platform profile interface exists 34 | if [[ ! -w "$PROFILE_PATH" ]]; then 35 | echo "[ERROR] ACPI platform profile interface not available or not writable: $PROFILE_PATH" 36 | echo "This system may not support ACPI performance profiles." 37 | exit 1 38 | fi 39 | 40 | # disable interfering services 41 | if systemctl is-active --quiet power-profiles-daemon.service; then 42 | echo "Disabling power-profiles-daemon to avoid conflicts..." 43 | systemctl mask power-profiles-daemon.service 44 | systemctl stop power-profiles-daemon.service 45 | fi 46 | 47 | # apply the setting immediately 48 | echo "Setting platform profile to '$PROFILE'..." 49 | echo "$PROFILE" > "$PROFILE_PATH" 50 | 51 | # create a systemd service to enforce setting at boot 52 | echo "Creating systemd service for platform profile persistence..." 53 | 54 | cat < /etc/systemd/system/acpi-performance-profile.service 55 | [Unit] 56 | Description=Set ACPI platform profile to performance 57 | After=multi-user.target 58 | 59 | [Service] 60 | Type=oneshot 61 | ExecStart=/bin/bash -c 'echo performance > /sys/firmware/acpi/platform_profile' 62 | RemainAfterExit=true 63 | 64 | [Install] 65 | WantedBy=multi-user.target 66 | EOF 67 | 68 | # enable and start the service 69 | echo "Enabling and starting acpi-performance-profile.service..." 70 | systemctl daemon-reexec 71 | systemctl daemon-reload 72 | systemctl enable acpi-performance-profile.service 73 | systemctl start acpi-performance-profile.service 74 | 75 | # confirm result 76 | current=$(cat "$PROFILE_PATH") 77 | echo "Current platform profile: $current" 78 | 79 | if [[ "$current" == "$PROFILE" ]]; then 80 | echo "Success. Platform profile set to '$PROFILE' and will persist on reboot." 81 | else 82 | echo "[WARN] Failed to apply platform profile. Current value: $current" 83 | fi 84 | -------------------------------------------------------------------------------- /framework/HIDDevice.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Keen Technologies, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # apt-get install libusb-1.0-0-dev 16 | # pip install libusb1 17 | 18 | import threading 19 | 20 | import usb1 21 | 22 | from framework.Logger import logger 23 | 24 | 25 | class HIDDevice: 26 | def __init__(self, vendor_id_str: str, product_id_str: str, endpoint_out=0x01): 27 | self.vendor_id = int(vendor_id_str, 16) 28 | self.product_id = int(product_id_str, 16) 29 | self.endpoint_out = endpoint_out 30 | 31 | self.context = usb1.USBContext() 32 | self.handle = self.context.openByVendorIDAndProductID(self.vendor_id, self.product_id) 33 | if self.handle is None: 34 | raise ValueError(f"Device {vendor_id_str}:{product_id_str} not found.") 35 | self.handle.claimInterface(0) 36 | 37 | self._running = True 38 | self._event_thread = threading.Thread(target=self._handle_events, daemon=True) 39 | self._event_thread.start() 40 | 41 | def shutdown(self): 42 | self._running = False 43 | self._event_thread.join(timeout=1.0) 44 | try: 45 | self.handle.releaseInterface(0) 46 | except usb1.USBError as e: 47 | logger.warning(f"shutdown: failed to release interface: {e}") 48 | self.context.close() 49 | 50 | def write_sync(self, data: bytes, timeout=1000): 51 | # blocking write to the HID device 52 | try: 53 | self.handle.interruptWrite(self.endpoint_out, data, timeout) 54 | except usb1.USBError as e: 55 | logger.warning(f"write_sync: USB write failed: {e}") 56 | 57 | def write_async(self, data: bytes): 58 | # submit non-blocking write to HID device using interrupt endpoint 59 | def _async_callback(transfer): 60 | status = transfer.getStatus() 61 | if status != usb1.TRANSFER_COMPLETED: 62 | logger.warning(f"write_async: usb transfer failed with status {status}") 63 | transfer.close() 64 | 65 | transfer = self.handle.getTransfer() 66 | transfer.setInterrupt(endpoint=self.endpoint_out, data=data, callback=_async_callback, timeout=1000) 67 | try: 68 | transfer.submit() 69 | except usb1.USBError as e: 70 | logger.warning(f"write_async: async transfer failed to submit: {e}") 71 | transfer.close() 72 | 73 | def _handle_events(self): 74 | # Handle async USB events in a background thread 75 | while self._running: 76 | try: 77 | self.context.handleEventsTimeout(tv=0.01) 78 | except usb1.USBErrorInterrupted: 79 | continue # expected on shutdown 80 | except Exception as e: 81 | logger.warning(f"_handle_events: error: {e}") 82 | -------------------------------------------------------------------------------- /framework/Logger.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Keen Technologies, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | import os 17 | import re 18 | from enum import Enum 19 | from threading import local 20 | 21 | # thread-local storage for frame count 22 | _frame_count_storage = local() 23 | 24 | 25 | def get_frame_count(): 26 | # -1 as a sentinel value for unset frame_count 27 | return getattr(_frame_count_storage, 'frame_count', -1) 28 | 29 | 30 | def set_frame_count(frame_count): 31 | _frame_count_storage.frame_count = frame_count 32 | 33 | 34 | class AnsiColor(str, Enum): 35 | RED = "\033[91m" 36 | GREEN = "\033[92m" 37 | YELLOW = "\033[93m" 38 | BLUE = "\033[94m" 39 | CYAN = "\033[96m" 40 | RESET = "\033[0m" 41 | 42 | 43 | class ColorFormatter(logging.Formatter): 44 | LEVEL_COLORS = { 45 | 'DEBUG': AnsiColor.CYAN, 46 | 'INFO': AnsiColor.GREEN, 47 | 'WARNING': AnsiColor.YELLOW, 48 | 'ERROR': AnsiColor.RED, 49 | 'CRITICAL': AnsiColor.RED, 50 | } 51 | 52 | def format(self, record): 53 | color = self.LEVEL_COLORS.get(record.levelname, AnsiColor.RESET) 54 | record.colored_levelname = f"{color.value}{record.levelname}{AnsiColor.RESET.value}" 55 | if not hasattr(record, 'frame_count'): 56 | record.frame_count = get_frame_count() 57 | return super().format(record) 58 | 59 | 60 | class NoColorFormatter(logging.Formatter): 61 | ANSI_ESCAPE = re.compile(r'\x1B\[[0-?]*[ -/]*[@-~]') 62 | 63 | def format(self, record): 64 | if not hasattr(record, 'frame_count'): 65 | record.frame_count = get_frame_count() 66 | original = super().format(record) 67 | return self.ANSI_ESCAPE.sub('', original) 68 | 69 | 70 | class FrameCountAdapter(logging.LoggerAdapter): 71 | def __init__(self, logger): 72 | super().__init__(logger, {}) 73 | 74 | def process(self, msg, kwargs): 75 | frame_count = get_frame_count() 76 | kwargs.setdefault('extra', {})['frame_count'] = frame_count 77 | return msg, kwargs 78 | 79 | 80 | def create_logger(): 81 | logger = logging.getLogger("frame_logger") 82 | if logger.hasHandlers(): 83 | logger.handlers.clear() 84 | 85 | console_handler = logging.StreamHandler() 86 | formatter = ColorFormatter('[%(asctime)s]: frame:%(frame_count)s: %(colored_levelname)s %(message)s') 87 | console_handler.setFormatter(formatter) 88 | logger.addHandler(console_handler) 89 | logger.setLevel(logging.DEBUG) 90 | return FrameCountAdapter(logger) 91 | 92 | 93 | # global logger 94 | logger = create_logger() 95 | 96 | 97 | # dynamically add the file handler when we know the experiment directory 98 | # alternatively, we can create the file handler at init and move the log 99 | # file to the experiment dir on exit. 100 | def add_file_handler_to_logger(log_file_path): 101 | if log_file_path: 102 | os.makedirs(os.path.dirname(log_file_path), exist_ok=True) 103 | 104 | file_handler = logging.FileHandler(log_file_path) 105 | formatter = NoColorFormatter( 106 | '[%(asctime)s]: %(process)d:%(thread)d: frame:%(frame_count)s: %(levelname)s %(message)s' 107 | ) 108 | file_handler.setFormatter(formatter) 109 | logging.getLogger("frame_logger").addHandler(file_handler) 110 | -------------------------------------------------------------------------------- /docs/profiling.md: -------------------------------------------------------------------------------- 1 | # Profiling the Physical System 2 | 3 | This document outlines how to **profile performance of the physical setup** using NVIDIA Nsight tools inside a **Docker container** running on an Ubuntu 24.04 system with NVIDIA GPU(s). 4 | 5 | The main performance focus is on: 6 | 7 | - CPU–GPU memory transfer latency 8 | - USB/IO contention, especially during input polling 9 | - Device-level scheduling issues 10 | - X11 GUI rendering overhead via containerized forwarding 11 | - Multi-threaded behavior within the control loop 12 | 13 | --- 14 | 15 | ## Tool Overview 16 | 17 | ### [Nsight Systems](https://developer.nvidia.com/nsight-systems) 18 | 19 | A timeline-based system-wide profiler for analyzing: 20 | 21 | - CPU/GPU concurrency 22 | - OS thread scheduling 23 | - CUDA kernel launches 24 | - Blocking I/O behavior (USB, GUI) 25 | 26 | ### [Nsight Compute](https://developer.nvidia.com/nsight-compute) 27 | 28 | A low-level kernel analysis tool for examining: 29 | 30 | - GPU kernel throughput 31 | - Memory access patterns 32 | - Warp occupancy and stall reasons 33 | 34 | > Nsight Compute is **not yet included** in the Docker image. 35 | 36 | --- 37 | 38 | ## Docker Environment 39 | 40 | The profiling takes place inside a Docker container that includes: 41 | 42 | - CUDA runtime 43 | - Nsight Systems CLI and UI tools 44 | - PyTorch with CUDA + NVTX support 45 | - X11 forwarding for GUI rendering 46 | 47 | --- 48 | 49 | ## NVTX Annotations (via PyTorch) 50 | 51 | torch.cuda.nvtx annotations are used to instrument critical sections of code, enabling detailed performance analysis in Nsight Systems by: 52 | 53 | - Visualizing the timing and concurrency of individual components within the Nsight timeline. 54 | - Identifying CPU or GPU stalls and resource contention, such as between USB polling and GPU tasks. 55 | - Measuring latency and synchronization between key stages like data capture, inference, and control signal output. 56 | 57 | ```python 58 | import torch.cuda.nvtx as nvtx 59 | 60 | for i in range(num_frames): 61 | nvtx.range_push("env.act") 62 | env.act() 63 | nvtx.range_pop() 64 | 65 | nvtx.range_push("env.get_observation") 66 | observation_rgb8[i] = env.get_observation() 67 | nvtx.range_pop() 68 | 69 | nvtx.range_push("agent.accept_observation") 70 | taken_action = agent.accept_observations(observation_rgb8, rewards, end_of_episodes) 71 | nvtx.range_pop() 72 | ``` 73 | 74 | --- 75 | 76 | ## Running Nsight Systems 77 | 78 | Use the following command to launch a 60-second profiling session: 79 | 80 | ```bash 81 | nsys profile \ 82 | --stats=true \ 83 | --sample=cpu \ 84 | --trace=cuda,cudnn,cublas,nvtx,osrt,oshmem \ 85 | --cudabacktrace=kernel:1000000,sync:1000000,memory:1000000 \ 86 | --delay=1 \ 87 | --duration=60 \ 88 | --wait=all \ 89 | --force-overwrite=true \ 90 | --output="/tmp/nsys_profile" \ 91 | python harness_physical.py --use_gui=0 92 | ``` 93 | 94 | - '--delay=1': Skips early initialization overhead 95 | - '--trace': Includes CUDA, NVTX, and OS runtime events 96 | - '--output': Stores the profile as '/tmp/nsys_profile.nsys-rep' 97 | - '--use_gui': Set to 1 to include the GUI. 98 | 99 | --- 100 | 101 | ### Open Results in nsys GUI 102 | 103 | Outside the container, copy the result and run: 104 | 105 | ```bash 106 | nsys-ui /tmp/nsys_profile.nsys-rep 107 | ``` 108 | 109 | Or run the GUI directly from the container as X11 forwarding is supported. 110 | 111 | --- 112 | 113 | ## What to Look for in Profiling 114 | 115 | | Aspect | Indicators to Watch For | 116 | | --------------------- | ----------------------------------------------------- | 117 | | USB Camera Capture | Long CPU threads polling USB, USB bandwidth stalls | 118 | | USB Control Signaling | Delays or blocking on USB writes | 119 | | GPU Inference | GPU kernel stalls, inefficient memory transfers | 120 | | PyTorch Training | CPU/GPU imbalance, memory bottlenecks | 121 | | Inter-component Sync | CPU waits between USB, GPU, and training steps | 122 | | (If GUI enabled) | Thread stalls during rendering, X11 forwarding delays | 123 | -------------------------------------------------------------------------------- /docs/io_controller.md: -------------------------------------------------------------------------------- 1 | # Building a Digital I/O Controller for Atari 2600+ 2 | 3 | This document walks you through building a **digital I/O controller module** that connects to the **Atari 2600 joystick port** and simulates joystick inputs (up/down/left/right/fire). It uses a digital I/O device such as the **MCC USB-1024LS**, wired to a **DB9 Atari joystick cable** to send digital signals directly to the Atari console. 4 | 5 | The controller works by pulling specific pins on the DB9 joystick port **low (to ground)**, mimicking how the real joystick functions. This setup is useful for experimenting with Atari hardware from modern control systems. 6 | 7 | --- 8 | 9 | ## Hardware Overview 10 | 11 | ### Digital I/O Devices 12 | 13 | You will need a USB digital I/O module with TTL-level (5V) outputs. Recommended options: 14 | 15 | | Device | Description | Notes | 16 | |--------|-------------|-------| 17 | | **[MCC USB-1024LS](https://microdaq.com/usb-1024ls-24-bit-digital-input-output-i-o-module.php)** | 24 digital I/O lines, 5V TTL | Well-documented and supported | 18 | | **Arduino Nano + Screw Terminal Shield** | 14 digital I/O lines | Requires custom sketch | 19 | 20 | --- 21 | 22 | ## Atari Joystick Port Pinout (DB9) 23 | 24 | Refer to the [Atari joystick port pinout](https://en.wikipedia.org/wiki/Atari_joystick_port), which shows the connector **as seen from the front (on the Atari console)**: 25 | 26 | | Pin | Function | Notes | 27 | |-----|--------------|-------| 28 | | 1 | Up | Active LOW | 29 | | 2 | Down | Active LOW | 30 | | 3 | Left | Active LOW | 31 | | 4 | Right | Active LOW | 32 | | 5 | Paddle B | **Not used** | 33 | | 6 | Fire Button | Active LOW | 34 | | 7 | +5V Power | **Unused** | 35 | | 8 | Ground (GND) | Shared return path | 36 | | 9 | Paddle A | **Not used** | 37 | 38 | > If you're using a **male DB9-to-bare wire cable**, the pinout may appear mirrored. Use continuity testing to confirm pin mapping. 39 | 40 | --- 41 | 42 | ## Identifying Wires via Continuity Testing 43 | 44 | To build the controller, you'll need to match each wire in the DB9 cable to its pin. Wire colors are often non-standard. 45 | 46 | ### Steps: 47 | 48 | 1. Insert **paper clips into the front of the DB9 connector** to make contact with the pins. 49 | 2. Set your **multimeter to continuity mode** (beep or ohm check). 50 | 3. Touch one multimeter probe to a paper clip in **Pin 1 (Up)**. 51 | 4. Touch the other probe to each wire end until you hear a beep. 52 | 5. Repeat for Pins 2–4, 6, and 8 (GND). 53 | 54 | Write down the wire color associated with each pin for reference during wiring. 55 | 56 | > Avoid touching adjacent paper clips at the same time — you may short pins during testing. 57 | 58 | --- 59 | 60 | ## Wiring the Controller Module 61 | 62 | Once you have mapped the wires, connect them to the I/O device. 63 | 64 | ### Signal Logic 65 | 66 | - The Atari expects **active LOW signals**: pulling the signal to **GND = pressed**. 67 | - Set digital output pins **LOW (0V)** to press, **HIGH (5V)** to release. 68 | 69 | ### MCC USB-1024LS: Port A Wiring 70 | 71 | If using the MCC USB-1024LS, connect the joystick wires to **Port A**, using pins **24–28 for control**, and **29 for GND**: 72 | 73 | | Function | DB9 Pin | DAQ Port | Terminal Pin # | Bit | Notes | 74 | |----------|---------|----------|----------------|------|-------| 75 | | Up | 1 | Port A | 24 | P0.0 | Press = LOW | 76 | | Down | 2 | Port A | 25 | P0.1 | Press = LOW | 77 | | Left | 3 | Port A | 26 | P0.2 | Press = LOW | 78 | | Right | 4 | Port A | 27 | P0.3 | Press = LOW | 79 | | Fire | 6 | Port A | 28 | P0.4 | Press = LOW | 80 | | GND | 8 | GND | 29 | – | Required for signal return | 81 | 82 | **Important**: Connect the GND wire (pin 8) from the joystick cable to **terminal pin 29 (GND)** on the DAQ to complete the circuit. 83 | 84 | --- 85 | 86 | ## Cable Prep & Strain Relief 87 | 88 | When handling joystick cables: 89 | 90 | 1. **Strip each wire** ~5mm from the end. Be careful with cables that have **nylon/tinsel** insulation inside - trim or **singe** with a lighter. 91 | 2. Insert stripped wires into **screw terminals** on the I/O module. 92 | 3. Double-check wire mapping before applying power. 93 | 94 | ### Strain Relief 95 | 96 | Use a **cable tie and cable tie gun** to secure the joystick cable to the screw terminal block. This prevents stress on the connections and improves durability. 97 | 98 | --- 99 | -------------------------------------------------------------------------------- /scripts/plot_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Keen Technologies, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import os 17 | 18 | import matplotlib.pyplot as plt 19 | import numpy as np 20 | 21 | 22 | def remove_outliers_zscore(x, threshold=3.0): 23 | mean = np.mean(x) 24 | std = np.std(x) 25 | if std == 0: 26 | return np.ones_like(x, dtype=bool) 27 | return np.abs(x - mean) < threshold * std 28 | 29 | 30 | def plot_data(files, title, remove_outliers, output_path, ylabel, xlabel, is_scatter=False, color=None): 31 | plt.figure(figsize=(10, 6)) 32 | 33 | plt.ticklabel_format(style='plain', axis='x') 34 | 35 | for file in files: 36 | label = os.path.basename(os.path.dirname(file)) 37 | data = np.fromfile(file, dtype=np.float32) 38 | 39 | if is_scatter: 40 | data = data.reshape(-1, 2) 41 | x_data, y_data = data[:, 0], data[:, 1] 42 | else: 43 | x_data, y_data = np.arange(len(data)), data 44 | 45 | if remove_outliers and len(y_data) > 0: 46 | mask = remove_outliers_zscore(y_data) 47 | num_removed = len(y_data) - np.count_nonzero(mask) 48 | if num_removed > 0: 49 | print(f"{file}: removed {num_removed} outliers.") 50 | x_data, y_data = x_data[mask], y_data[mask] 51 | 52 | if is_scatter: 53 | plt.scatter(x_data, y_data, label=label, alpha=0.6) 54 | else: 55 | plt.plot(x_data, y_data, label=label, color=color) 56 | 57 | plt.xlabel(xlabel) 58 | plt.ylabel(ylabel) 59 | plt.title(title) 60 | plt.legend() 61 | plt.savefig(output_path, format='jpg') 62 | print(f"Saved plot: {output_path}") 63 | plt.close() 64 | 65 | 66 | def parse_arguments(): 67 | parser = argparse.ArgumentParser(description="plot_data.py arguments.") 68 | parser.add_argument('root_dir', type=str, help="path to experiment result directory") 69 | parser.add_argument('--title', type=str, default=None, help="plot title") 70 | parser.add_argument('--remove-outliers', action='store_true', help="remove statistical outliers using z-score.") 71 | return parser.parse_args() 72 | 73 | 74 | if __name__ == '__main__': 75 | args = parse_arguments() 76 | 77 | extensions = { 78 | '.loss': { 79 | 'ylabel': 'Loss', 80 | 'xlabel': 'Training Step', 81 | 'outfile': 'loss_plot.jpg', 82 | 'is_scatter': False, 83 | 'color': 'red', 84 | }, 85 | '.score': { 86 | 'ylabel': 'Average Score', 87 | 'xlabel': 'Episode', 88 | 'outfile': 'score_plot.jpg', 89 | 'is_scatter': False, 90 | 'color': None, 91 | }, 92 | '.scatter': { 93 | 'ylabel': 'Episode Score', 94 | 'xlabel': 'Episode End', 95 | 'outfile': 'scatter_plot.jpg', 96 | 'is_scatter': True, 97 | 'color': None, 98 | }, 99 | } 100 | 101 | file_map = {ext: [] for ext in extensions} 102 | for dirpath, _, filenames in os.walk(args.root_dir): 103 | for fname in filenames: 104 | full_path = os.path.join(dirpath, fname) 105 | 106 | _, ext = os.path.splitext(fname) 107 | if ext in file_map: 108 | file_map[ext].append(full_path) 109 | 110 | for ext, cfg in extensions.items(): 111 | files = file_map[ext] 112 | if not files: 113 | continue 114 | plot_data( 115 | files=files, 116 | title=args.title or f"{cfg['ylabel']} Plot", 117 | remove_outliers=args.remove_outliers, 118 | output_path=os.path.join(args.root_dir, cfg['outfile']), 119 | ylabel=cfg['ylabel'], 120 | xlabel=cfg['xlabel'], 121 | is_scatter=cfg['is_scatter'], 122 | color=cfg['color'], 123 | ) 124 | 125 | print("Complete.") 126 | -------------------------------------------------------------------------------- /framework/CameraUtils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Keen Technologies, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import re 16 | import shlex 17 | import subprocess 18 | 19 | from framework.Logger import logger 20 | 21 | 22 | def set_control(device_idx, ctrl_name, ctrl_value): 23 | cmdline = f'v4l2-ctl --device /dev/video{device_idx} --set-ctrl={ctrl_name}={ctrl_value}' 24 | cmd_list = shlex.split(cmdline, posix=False) 25 | process = subprocess.Popen(cmd_list, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) 26 | output, _ = process.communicate() 27 | output = output.decode('utf-8') 28 | status = process.returncode 29 | if status: 30 | logger.warning( 31 | f"set_control: Failed to set {ctrl_name}={ctrl_value} for device {device_idx}. Return code={status}." 32 | ) 33 | 34 | 35 | def parse_control(line): 36 | parts = line.split(':') 37 | if len(parts) < 2: 38 | return None 39 | 40 | name_part = parts[0].strip() 41 | control_name = name_part.split()[0] 42 | control_data = parts[1].strip() 43 | data = control_data.split() 44 | 45 | control = {'name': control_name} 46 | 47 | for d in data: 48 | if '=' in d: 49 | key, value = d.split('=') 50 | try: 51 | value = int(value) 52 | except ValueError: 53 | pass 54 | control[key] = value 55 | elif '(' in d: # description is optional 56 | control['desc'] = d.strip('()') 57 | 58 | return control 59 | 60 | 61 | def get_controls(device_idx): 62 | ctrls_dict = {} 63 | cmd_list = shlex.split(f'v4l2-ctl -d /dev/video{device_idx} --list-ctrls', posix=False) 64 | process = subprocess.Popen(cmd_list, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) 65 | output, _ = process.communicate() 66 | output = output.decode('utf-8') 67 | status = process.returncode 68 | if status: 69 | logger.warning(f"get_controls: list-ctrls for device {device_idx} failed. Return code={status}.") 70 | else: 71 | for line in output.splitlines(): 72 | ctrl = parse_control(line.strip()) 73 | if ctrl: 74 | ctrls_dict[ctrl['name']] = ctrl 75 | return ctrls_dict 76 | 77 | 78 | def get_index_from_model_name(model_name): 79 | cmd_list = shlex.split('v4l2-ctl --list-devices', posix=False) 80 | process = subprocess.Popen(cmd_list, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) 81 | output, _ = process.communicate() 82 | output = output.decode('utf-8') 83 | status = process.returncode # ignore unless parsing fails 84 | 85 | lines = output.splitlines() 86 | if not lines: 87 | if status: 88 | logger.warning(f"get_index_from_model_name: empty output ({status})") 89 | else: 90 | logger.warning("get_index_from_model_name: empty output") 91 | return -1 92 | 93 | i = 0 94 | while i < len(lines): 95 | line = lines[i] 96 | if not line.strip(): 97 | i += 1 98 | continue 99 | 100 | if not line.startswith((" ", "\t")): 101 | header = line.strip() 102 | if model_name.lower() in header.lower(): 103 | i += 1 104 | # find /dev/video* lines 105 | while i < len(lines) and lines[i].startswith((" ", "\t")): 106 | device_line = lines[i].strip() 107 | match = re.search(r'/dev/video(\d+)', device_line) 108 | if match: 109 | device_idx = int(match.group(1)) 110 | logger.debug(f"Found {model_name} at {device_line} idx {device_idx}") 111 | return device_idx 112 | i += 1 113 | logger.warning(f"get_index_from_model_name: matched header '{header}' but found no /dev/video* entries") 114 | return -1 115 | i += 1 116 | 117 | logger.warning(f"get_index_from_model_name: No device header matched '{model_name}'") 118 | return -1 119 | -------------------------------------------------------------------------------- /framework/ScreenDetector.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Keen Technologies, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import queue 16 | import threading 17 | import time 18 | from enum import Enum 19 | 20 | import dt_apriltags 21 | import numpy as np 22 | 23 | from framework.Logger import logger 24 | 25 | """ 26 | Tags are expected to be oriented in a CW fashion starting with tag<0> at the top-left corner 27 | of the monitor: 28 | 29 | 0 -----> 1 30 | | | 31 | | | 32 | 3<------ 2 33 | 34 | The orientation of the tag matters as specific corners are used to determine the screen rect. 35 | Tags should be applied in their default orientation. 36 | 37 | dt_apriltags returns tag corner winding is CCW with 0 at the bottom-left. 38 | 3<------- 2 39 | | | 40 | | | 41 | 0 ------->1 42 | """ 43 | 44 | 45 | class TagID(Enum): 46 | TAG_ID_TOP_LEFT = 0 47 | TAG_ID_TOP_RIGHT = 1 48 | TAG_ID_BOTTOM_RIGHT = 2 49 | TAG_ID_BOTTOM_LEFT = 3 50 | 51 | 52 | class ScreenDetector: 53 | def __init__( 54 | self, method_name, corners, family, quad_decimate, quad_sigma, refine_edges, decode_sharpening, threaded=True 55 | ): 56 | assert method_name == "dt_apriltags" 57 | self.detector = dt_apriltags.Detector( 58 | families=family, 59 | quad_decimate=quad_decimate, 60 | quad_sigma=quad_sigma, 61 | refine_edges=refine_edges, 62 | decode_sharpening=decode_sharpening, 63 | ) 64 | 65 | # Mapping of which corner of the tag should be used 66 | # to define the screen_rect. 67 | self.tag_id_corner_idx = {TagID[key].value: idx for key, idx in corners.items()} 68 | self.threaded = threaded 69 | 70 | self.screen_rect = None 71 | self.last_detected_tags = {} 72 | 73 | if self.threaded: 74 | self.running = True 75 | self.frame_queue = queue.Queue(maxsize=4) 76 | self.shutdown_cond = threading.Condition() 77 | self.lock = threading.Lock() 78 | self.detection_thread = threading.Thread(target=self._process_frames, daemon=True) 79 | self.detection_thread.start() 80 | 81 | def shutdown(self): 82 | if self.threaded: 83 | self.running = False 84 | with self.shutdown_cond: 85 | self.shutdown_cond.notify() 86 | self.detection_thread.join() 87 | self.detector = None 88 | 89 | # Expects grayscale np.ndarray 90 | def get_screen_rect_info(self, frame_g): 91 | if self.threaded: 92 | # add a frame to the queue for processing and return the previous valid 93 | # screen info 94 | try: 95 | self.frame_queue.put_nowait(frame_g) 96 | except queue.Full: 97 | try: 98 | self.frame_queue.get_nowait() 99 | self.frame_queue.put_nowait(frame_g) 100 | except (queue.Full, queue.Empty): 101 | pass # still full or already drained 102 | except Exception: 103 | # logger.warning(f"ScreenDetector: Unexpected queue error: {e}") 104 | pass 105 | 106 | with self.lock: 107 | return self.screen_rect, self.last_detected_tags 108 | else: 109 | screen_rect, tag_data = self._detect_screen(frame_g) 110 | self.screen_rect = screen_rect 111 | self.last_detected_tags = tag_data 112 | return screen_rect, tag_data 113 | 114 | def _detect_screen(self, frame_g): 115 | if frame_g is None: 116 | logger.warning("ScreenDetector::_detect_screen: invalid frame") 117 | return None, None 118 | 119 | tags = self._detect_tags(frame_g) 120 | tag_data = {tag.tag_id: tag.corners for tag in tags} 121 | 122 | if len(tags) == 4: 123 | sr_pt_dict = { 124 | tag.tag_id: ( 125 | tag.corners[self.tag_id_corner_idx[tag.tag_id]][0], 126 | tag.corners[self.tag_id_corner_idx[tag.tag_id]][1], 127 | ) 128 | for tag in tags 129 | } 130 | screen_rect = np.float32([sr_pt_dict[i] for i in range(4)]) 131 | else: 132 | screen_rect = None 133 | # logger.debug(f"ScreenDetector:_detect_screen: Only detected {len(tags)} tags. Verify camera position and lighting") 134 | 135 | return screen_rect, tag_data 136 | 137 | def _detect_tags(self, frame_g): 138 | tags = self.detector.detect(frame_g) 139 | 140 | # verify the correct tags have been identified 141 | # in low-lighting conditions, tags can be mis-identified. 142 | i = 0 143 | while i < len(tags): 144 | if tags[i].tag_id not in self.tag_id_corner_idx: 145 | logger.warning(f"ScreenDetector: invalid tag_id found = {tags[i].tag_id}") 146 | tags.pop(i) 147 | else: 148 | i += 1 149 | 150 | return tags 151 | 152 | def _process_frames(self): 153 | while self.running: 154 | if not self.frame_queue.empty(): 155 | frame = self.frame_queue.get() 156 | screen_rect, tag_data = self._detect_screen(frame) 157 | 158 | with self.lock: 159 | # only update screen_rect with valid data 160 | if screen_rect is not None: 161 | self.screen_rect = screen_rect 162 | 163 | # always update tag info as it can be used to determine why screen detection failed. 164 | self.last_detected_tags = tag_data 165 | 166 | with self.shutdown_cond: 167 | self.shutdown_cond.wait(1) 168 | -------------------------------------------------------------------------------- /QUICKSTART.md: -------------------------------------------------------------------------------- 1 | # Quickstart Guide: Physical Atari Setup 2 | 3 | This guide walks you through launching the Physical Atari system. 4 | 5 | --- 6 | 7 | ## What You Need 8 | 9 | You must have the following hardware: 10 | 11 | - **Atari 2600+ Console** 12 | - Set **TV Type** to `Color` 13 | - Set **Aspect Ratio** to `4:3` 14 | - This ensures output matches emulator-rendered frames. 15 | - **Monitor** 16 | - Connected to the Atari 2600+ console 17 | - Refer to [setup.md](docs/setup.md) for recommended brightness, color mode, and refresh rate 18 | - **Camera** 19 | - Recommended: Razer Kiyo Pro (1080p60) 20 | - If using another camera, create a config in `configs/cameras/` (see `camera_kiyo_pro.json`) 21 | - **Controller** 22 | - Connected to the **Left Controller Port** 23 | - Either: 24 | - [RoboTroller](https://robotroller.keenagi.com) (mechanically actuates a CX40+ joystick) 25 | - MCC USB-1024LS (sends directional + fire actions via USB I/O) 26 | - A custom solution (requires custom device class and a matching config under `configs/controllers/`) 27 | - **Linux System** 28 | - Ubuntu 24.04 LTS 29 | - NVIDIA GPU with **≥16GB VRAM** if running the provided agent 30 | - Docker + NVIDIA Container Toolkit 31 | 32 | --- 33 | 34 | ## 1. Physical Console and Camera Setup 35 | 36 | - Mount the camera directly in front of the monitor 37 | - Frame the Atari screen so it fills the view horizontally 38 | - Ideal pixel mapping: **~2 camera pixels per Atari pixel** 39 | 40 | If you're using a non-Razer camera, define your camera settings in a new config JSON under `configs/cameras/`. 41 | 42 | --- 43 | 44 | ## 2. Controller 45 | 46 | - **RoboTroller**: 47 | - Follow build instructions at [robotroller.keenagi.com](https://robotroller.keenagi.com) 48 | - Use `configs/controllers/robotroller.json` 49 | - **Digital I/O (MCC USB-1024LS)**: 50 | - Follow build instructions at [io_controller.md](docs/io_controller.md) 51 | - Use `configs/controllers/io_controller.json` 52 | - If you're using another I/O board, you must write a custom device class and define your pin map config under `configs/controllers/`. 53 | 54 | --- 55 | 56 | ## 3. Install Software Stack 57 | 58 | Follow the full instructions in [setup.md](docs/setup.md). 59 | 60 | You’ll need: 61 | 62 | - NVIDIA drivers (Lambda Stack recommended) 63 | - Docker + NVIDIA Container Toolkit 64 | - System performance validation (`check_performance.py`) 65 | 66 | --- 67 | 68 | ## 4. Start the System 69 | 70 | ### Build the Docker Environment 71 | 72 | ```bash 73 | ./docker_build.sh 74 | ``` 75 | 76 | This sets up the runtime environment with all dependencies. 77 | 78 | ### Run the Container 79 | 80 | ```bash 81 | ./docker_run.sh 82 | ``` 83 | 84 | This gives you an interactive shell inside the container with GPU, USB, X11 forwarding, and code access. 85 | 86 | ### Launch the Physical Harness 87 | 88 | You can now launch the main physical harness. Example (for Ms. Pac-Man with RoboTroller): 89 | 90 | ```bash 91 | python3 harness_physical.py \ 92 | --detection_config=configs/screen_detection/fixed.json \ 93 | --game_config=configs/games/ms_pacman.json \ 94 | --agent_type=agent_delay_target \ 95 | --reduce_action_set=2 \ 96 | --gpu=0 \ 97 | --joystick_config=configs/controllers/robotroller.json \ 98 | --total_frames=1_000_000 99 | ``` 100 | 101 | This will launch the GUI by default where you can configure the setup before beginning training. 102 | 103 | --- 104 | 105 | ## 5. Screen Detection Setup 106 | 107 | You must define the 4 corners of the active screen region before any score detection can occur. 108 | 109 | There are two options: 110 | 111 | ### Fixed Corner Selection (Default) 112 | 113 | - In the GUI, click the **Configuration▶Screen Detection** config view to give it focus 114 | - Use: 115 | - `Shift+Tab` to cycle between points 116 | - `WASD` to move the selected point 117 | - Position corners precisely around the 4:3 screen content (exclude pillarbox bars) 118 | - Click **Save** to persist the region config 119 | 120 | ### April Tag Detection (Advanced) 121 | 122 | - Requires physical AprilTags printed and placed at the four corners of the screen 123 | - Provides automatic detection but is sensitive to lighting, reflections, and tag size 124 | - Best used with a ring light or consistent ambient illumination 125 | - Configuration and tag details are in [setup.md](docs/setup.md) 126 | - To use, specify: `--detection_config=configs/screen_detection/april_tags.json` 127 | 128 | --- 129 | 130 | ## 6. Score and Lives Region Configuration 131 | 132 | - In the GUI, click the **Configuration▶Score Detection** config view to give it focus 133 | - Use: 134 | - `Shift+Tab` to toggle between score and lives box 135 | - `WASD` to adjust box position 136 | - Click **Save** to commit changes 137 | 138 | > Reference images for correct placement are available in [docs/setup.md](docs/setup.md#score-and-lives-box-placement) 139 | Note: When using the default models shipped with this framework, incorrect crop placement may degrade accuracy. 140 | The default models were trained with tight bounds for certain games to exclude nearby HUD graphics. Some games may 141 | benefit from looser bounds, and in those cases, retraining with larger crops can improve robustness. 142 | 143 | Click the **Game Frame** view to give it focus, where you can use your keyboard to control the game, and verify that score and lives are correctly parsed in real time. 144 | 145 | --- 146 | 147 | ## 7. Start Training 148 | 149 | Once the screen, score, and lives regions are configured, you can: 150 | 151 | - Start training runs from within the GUI 152 | - Monitor training progress in real time via the displayed graphs 153 | - View per-frame details including score, lives, episode count, frame number, action, and other data 154 | 155 | Results will be written to the configured `results/` path. 156 | 157 | --- 158 | 159 | ## 8. Troubleshooting Performance 160 | 161 | Even with setup complete, your system may throttle GPU or CPU performance. If framerate or latency is unstable: 162 | 163 | - Run: 164 | ```bash 165 | sudo python3 scripts/check_performance.py 166 | ``` 167 | - Apply fixes or reboot if prompted 168 | 169 | If performance still degrades after reboot, power down and **fully disconnect** all devices for 30 seconds (flea drain) to reset transient hardware state. 170 | 171 | --- 172 | 173 | Refer to [setup.md](docs/setup.md) for detailed installation instructions, hardware tuning, and system diagnostics. 174 | -------------------------------------------------------------------------------- /framework/Keyboard.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Keen Technologies, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import queue 16 | import threading 17 | import time 18 | from enum import Enum 19 | 20 | from pynput.keyboard import Key, Listener 21 | 22 | from framework.Actions import Action 23 | from framework.ControlDevice import ControlDevice 24 | from framework.Logger import logger 25 | 26 | """ 27 | Use a keyboard to send commands to control device. 28 | """ 29 | 30 | 31 | class Keys(Enum): 32 | UP = "w" 33 | LEFT = "a" 34 | DOWN = "s" 35 | RIGHT = "d" 36 | FIRE = " " 37 | NOOP = "e" 38 | 39 | 40 | def get_keys_to_action() -> dict[tuple[int, ...], Action]: 41 | mapping = { 42 | Action.NOOP: (Keys.NOOP.value,), 43 | Action.UP: (Keys.UP.value,), 44 | Action.FIRE: (Keys.FIRE.value,), 45 | Action.DOWN: (Keys.DOWN.value,), 46 | Action.LEFT: (Keys.LEFT.value,), 47 | Action.RIGHT: (Keys.RIGHT.value,), 48 | Action.UPFIRE: (Keys.UP.value, Keys.FIRE.value), 49 | Action.DOWNFIRE: (Keys.DOWN.value, Keys.FIRE.value), 50 | Action.LEFTFIRE: (Keys.LEFT.value, Keys.FIRE.value), 51 | Action.RIGHTFIRE: (Keys.RIGHT.value, Keys.FIRE.value), 52 | Action.UPLEFT: (Keys.UP.value, Keys.LEFT.value), 53 | Action.UPRIGHT: (Keys.UP.value, Keys.RIGHT.value), 54 | Action.DOWNLEFT: (Keys.DOWN.value, Keys.LEFT.value), 55 | Action.DOWNRIGHT: (Keys.DOWN.value, Keys.RIGHT.value), 56 | Action.UPLEFTFIRE: (Keys.UP.value, Keys.LEFT.value, Keys.FIRE.value), 57 | Action.UPRIGHTFIRE: (Keys.UP.value, Keys.RIGHT.value, Keys.FIRE.value), 58 | Action.DOWNLEFTFIRE: (Keys.DOWN.value, Keys.LEFT.value, Keys.FIRE.value), 59 | Action.DOWNRIGHTFIRE: (Keys.DOWN.value, Keys.RIGHT.value, Keys.FIRE.value), 60 | } 61 | 62 | full_action_set = [act for act in Action] 63 | 64 | return {tuple(sorted(mapping[act_idx])): act_idx for act_idx in full_action_set} 65 | 66 | 67 | class Keyboard: 68 | def __init__(self, device: ControlDevice, threaded: bool = False): 69 | self.device = device 70 | assert self.device is not None 71 | 72 | self.keys_to_action = get_keys_to_action() 73 | self.relevant_keys = {k for combo in self.keys_to_action for k in combo} 74 | self.pressed_keys = set() 75 | self.input_focus = True 76 | self.exit_requested = False 77 | 78 | self.threaded = threaded 79 | self.running = False 80 | self.action_queue = queue.Queue() 81 | self.thread = None 82 | 83 | self.start() 84 | 85 | help_text = ( 86 | ", ".join(f"{key.name}:{'space' if key.value == ' ' else key.value}" for key in Keys) + ", QUIT: esc" 87 | ) 88 | logger.info(f"Keyboard: {help_text}") 89 | 90 | def shutdown(self): 91 | self.stop() 92 | 93 | self.device.shutdown() 94 | self.device = None 95 | 96 | def _parse_key(self, key) -> str | None: 97 | if key == Key.space: 98 | return ' ' 99 | try: 100 | return key.char 101 | except AttributeError: 102 | return None 103 | 104 | def on_press(self, key): 105 | if not self.input_focus: 106 | return 107 | keycode = self._parse_key(key) 108 | if keycode and keycode in self.relevant_keys and keycode not in self.pressed_keys: 109 | self.pressed_keys.add(keycode) 110 | new_action = self._get_action_from_keys(self.pressed_keys) 111 | self.action_queue.put(('press', new_action)) 112 | 113 | def on_release(self, key): 114 | keycode = self._parse_key(key) 115 | if keycode and keycode in self.relevant_keys and keycode in self.pressed_keys: 116 | self.pressed_keys.remove(keycode) 117 | new_action = self._get_action_from_keys(self.pressed_keys) 118 | self.action_queue.put(('release', new_action)) 119 | 120 | if key == Key.esc: 121 | self.exit_requested = True 122 | 123 | def start(self): 124 | if self.threaded: 125 | self.running = True 126 | self.thread = threading.Thread(target=self._process_actions, daemon=True) 127 | self.thread.start() 128 | 129 | self.listener = Listener(on_press=self.on_press, on_release=self.on_release) 130 | self.listener.start() 131 | 132 | def stop(self): 133 | # stop listening for events, this should join the listener thread 134 | self.listener.stop() 135 | 136 | self.running = False 137 | if self.threaded and self.thread is not None: 138 | self.thread.join() 139 | 140 | def set_input_focus(self, focus: bool): 141 | logger.info(f"Keyboard: input_focus={focus}") 142 | self.input_focus = focus 143 | 144 | def _get_action_from_keys(self, keys): 145 | return self.keys_to_action.get(tuple(sorted(keys)), Action.NOOP) 146 | 147 | # when running non-threaded, expects the calling program 148 | # to update at a regular frequency 149 | def update(self): 150 | if not self.input_focus: 151 | return False, Action.NOOP 152 | 153 | try: 154 | state, action = self.action_queue.get_nowait() 155 | # logger.debug(action) 156 | signal_state = 1 if state == "press" else 0 157 | # print(f"action={action} state={signal_state}") 158 | self.device.apply_action(action, signal_state) 159 | return self.exit_requested, action 160 | except queue.Empty: 161 | return self.exit_requested, Action.NOOP 162 | 163 | def _process_actions(self): 164 | while self.running: 165 | self.update() 166 | time.sleep(0.01) 167 | 168 | def __repr__(self): 169 | return f"" 170 | -------------------------------------------------------------------------------- /train/score_detector/generate_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Keen Technologies, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import random 17 | import shutil 18 | import time 19 | 20 | import numpy as np 21 | from ale_py import Action, ALEInterface, LoggerMode, roms 22 | from ale_ram_injection import GAME_RAM_CONFIG, decode_lives, decode_score_bcd, write_lives, write_score 23 | from PIL import Image 24 | 25 | """ 26 | Generates all valid combinations of score for the game and outputs as {output_dir}/{game}/'img_score_{score:6d}.png' 27 | Generates all valid combinations of lives for the game and outputs as {output_dir}/{game}/'img_lives_{lives}.png' 28 | """ 29 | 30 | 31 | def generate_data(game, data_dir, debug=False): 32 | if os.path.exists(data_dir): 33 | shutil.rmtree(data_dir) 34 | os.makedirs(data_dir, exist_ok=True) 35 | 36 | print(f"Generating data for {game} at {data_dir}") 37 | 38 | config = GAME_RAM_CONFIG[game] 39 | 40 | rom_path = roms.get_rom_path(game) 41 | ale = ALEInterface() 42 | ale.setLoggerMode(LoggerMode.Error) 43 | ale.loadROM(rom_path) 44 | 45 | score_range = range(0, config["max_score"] + 1, config["score_step"][0]) 46 | lives_range = range(1, config["total_lives"] + 1) 47 | 48 | # Generate all valid score combos 49 | dummy_lives = 1 if game == 'defender' or game == 'battle_zone' else 0 50 | for i, score in enumerate(score_range): 51 | if debug: 52 | print(f"--- {game} | Score: {score} | Lives: dummy (0) ---") 53 | 54 | write_score(config, score, ale) 55 | write_lives(config, dummy_lives, ale) 56 | 57 | if game == 'qbert': 58 | ale.reset_game() 59 | for _ in range(5): 60 | ale.act(Action.FIRE) 61 | for _ in range(30): 62 | ale.act(Action.NOOP) 63 | 64 | for i in range(1000): 65 | write_score(config, score, ale) 66 | ale.act(ale.getLegalActionSet()[random.randint(0, 3)]) 67 | obs = ale.getScreenRGB() 68 | 69 | # TODO: get this from the game config instead of hardcoding 70 | score_region = obs[5 : 10 + 5, 26 : 56 + 26] 71 | score_intensity = score_region.mean() 72 | if score_intensity > 10: 73 | if debug: 74 | print(f"[INFO] score frame found at step {i} | score_intensity={score_intensity:.1f}") 75 | break 76 | else: 77 | for _ in range(2): 78 | ale.act(Action.NOOP) 79 | 80 | if debug: 81 | ram = ale.getRAM() 82 | decoded_score = decode_score_bcd(ram, config["score_addr"], config) 83 | print(f"score={score} decoded={decoded_score}") 84 | 85 | score_str = str(score).zfill(config["score_digits"]) 86 | img = Image.fromarray(ale.getScreenRGB()) 87 | img.save(os.path.join(data_dir, f"img_score_{score_str}.png")) 88 | 89 | if i % 100 == 0 or i == len(score_range) - 1: 90 | print(f"[INFO] Generated {i + 1} / {len(score_range)} score images") 91 | 92 | ale.reset_game() 93 | for _ in range(5): 94 | ale.act(Action.NOOP) 95 | 96 | dummy_score = 0 97 | 98 | # generate all valid lives combos 99 | for i, lives in enumerate(lives_range): 100 | if debug: 101 | print(f"--- {game} | Score: dummy (0) | Lives: {lives} ---") 102 | 103 | write_score(config, dummy_score, ale) 104 | write_lives(config, lives, ale) 105 | 106 | if game == 'qbert': 107 | ale.reset_game() 108 | for _ in range(5): 109 | ale.act(Action.FIRE) 110 | for _ in range(30): 111 | ale.act(Action.NOOP) 112 | 113 | for i in range(1000): 114 | write_lives(config, lives, ale) 115 | ale.act(ale.getLegalActionSet()[random.randint(0, 3)]) 116 | obs = ale.getScreenRGB() 117 | 118 | # TODO: get this from the game config instead of hardcoding 119 | lives_region = obs[14 : 16 + 14, 33 : 40 + 33] 120 | lives_intensity = lives_region.mean() 121 | if lives_intensity > 10: 122 | if debug: 123 | print(f"[INFO] lives frame found at step {i} | lives_intensity={lives_intensity:.1f}") 124 | break 125 | else: 126 | for _ in range(2): 127 | ale.act(Action.NOOP) 128 | 129 | if debug: 130 | ram = ale.getRAM() 131 | decoded_lives = decode_lives(ram, config['lives_addr'], config) 132 | print(f"lives={lives} decoded={decoded_lives}") 133 | 134 | img = Image.fromarray(ale.getScreenRGB()) 135 | img.save(os.path.join(data_dir, f"img_lives_{lives}.png")) 136 | 137 | if i % 2 == 0 or i == len(lives_range) - 1: 138 | print(f"[INFO] Generated {i + 1} / {len(lives_range)} lives images") 139 | 140 | print(f"Finished generating data: num_scores={len(score_range)} num_lives={len(lives_range)}.") 141 | 142 | 143 | # python3 generate_dataset.py ms_pacman --output_dir 'frames/ms_pacman' 144 | def get_argument_parser(): 145 | from argparse import ArgumentParser 146 | 147 | parser = ArgumentParser(description="generate_data.py arguments") 148 | parser.add_argument('game', type=str, default=None) 149 | parser.add_argument('--output_dir', type=str, default=os.path.join(os.getcwd(), 'results')) 150 | parser.add_argument('--debug', action='store_true') 151 | return parser 152 | 153 | 154 | if __name__ == '__main__': 155 | arg_parser = get_argument_parser() 156 | args = arg_parser.parse_args() 157 | 158 | try: 159 | generate_data(args.game, args.output_dir, debug=args.debug) 160 | except KeyboardInterrupt: 161 | print("KeyboardInterrupt received") 162 | 163 | exit(0) 164 | -------------------------------------------------------------------------------- /train/score_detector/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Keen Technologies, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | from collections import Counter, defaultdict 17 | 18 | import numpy as np 19 | import torch 20 | from PIL import Image 21 | from torch.utils.data import Dataset 22 | from torchvision import transforms 23 | 24 | 25 | def get_class_weights(dataset, num_classes): 26 | # score labels use per-digit class balancing; lives labels use frequency of the entire label string. 27 | digit_counts = [0] * 10 # (0-9) 28 | lives_label_counts = Counter() 29 | 30 | for label, is_life in zip(dataset.labels, dataset.is_lives): 31 | if is_life: 32 | lives_label_counts[label] += 1 33 | else: 34 | for digit in label: 35 | digit_counts[int(digit)] += 1 36 | 37 | # print("Top 5 most common lives labels:") 38 | # for label, count in lives_label_counts.most_common(5): 39 | # print(f" {label}: {count}") 40 | 41 | total_score_samples = sum(1 for is_life in dataset.is_lives if not is_life) 42 | total_lives_samples = sum(lives_label_counts.values()) 43 | 44 | digit_class_weights = [] 45 | for i, count in enumerate(digit_counts): 46 | if count == 0: 47 | print(f"Digit class {i} has zero samples in training set. Using weight=1.0") 48 | digit_class_weights.append(1.0) 49 | else: 50 | w = total_score_samples / count 51 | digit_class_weights.append(min(w, 100.0)) # clamp to avoid explosion 52 | 53 | sample_weights = [] 54 | for label, is_life in zip(dataset.labels, dataset.is_lives): 55 | if is_life: 56 | weight = total_lives_samples / (lives_label_counts[label] + 1e-6) 57 | else: 58 | weight = sum(digit_class_weights[int(d)] for d in label) 59 | sample_weights.append(weight) 60 | 61 | return sample_weights 62 | 63 | 64 | class CustomNormalize: 65 | def __init__(self, score_stats, lives_stats): 66 | self.norm_map = { 67 | 0: transforms.Normalize((score_stats[0],), (score_stats[1],)), # score subset 68 | 1: transforms.Normalize((lives_stats[0],), (lives_stats[1],)), # lives subset 69 | } 70 | 71 | def __call__(self, image, subset_label): 72 | if subset_label not in self.norm_map: 73 | raise ValueError(f"Unknown subset label: {subset_label}") 74 | return self.norm_map[subset_label](image) 75 | 76 | 77 | class MultiDigitDataset(Dataset): 78 | def __init__( 79 | self, 80 | root_dir, 81 | max_digits, 82 | transform=None, 83 | transform_score=None, 84 | transform_lives=None, 85 | score_meanstd=None, 86 | lives_meanstd=None, 87 | padding_value=-1, 88 | ): 89 | 90 | self.root_dir = root_dir 91 | self.max_digits = max_digits 92 | self.transform = transform 93 | self.transform_score = transform_score 94 | self.transform_lives = transform_lives 95 | self.padding_value = padding_value 96 | self.score_meanstd = score_meanstd 97 | self.lives_meanstd = lives_meanstd 98 | 99 | self.image_paths = [] 100 | self.labels = [] 101 | self.is_lives = [] 102 | 103 | number_counter = Counter() 104 | digit_counter = Counter() 105 | digit_pos_counter = defaultdict(Counter) 106 | 107 | for fname in sorted(os.listdir(root_dir)): 108 | if not fname.lower().endswith((".png", ".jpg", ".jpeg")): 109 | continue 110 | 111 | full_path = os.path.join(root_dir, fname) 112 | label_str = fname.rsplit('_', 1)[-1].split('.')[0] 113 | self.image_paths.append(full_path) 114 | self.labels.append(label_str) 115 | self.is_lives.append(fname.startswith("img_lives")) 116 | 117 | number_counter[label_str] += 1 118 | for pos, d in enumerate(label_str[::-1]): 119 | digit_counter[d] += 1 120 | digit_pos_counter[pos][d] += 1 121 | 122 | if self.score_meanstd and self.lives_meanstd: 123 | self.normalizer = CustomNormalize(self.score_meanstd, self.lives_meanstd) 124 | else: 125 | self.normalizer = None 126 | 127 | # self._print_stats(number_counter, digit_counter, digit_pos_counter) 128 | 129 | def _print_stats(self, number_counter, digit_counter, digit_pos_counter): 130 | print("\n=== Dataset Label Analysis ===") 131 | 132 | print("\n Full Numbers:") 133 | for num, count in sorted(number_counter.items()): 134 | print(f" {num}: {count}") 135 | 136 | print("\n Overall Digit Frequency:") 137 | for d in map(str, range(10)): 138 | print(f" {d}: {digit_counter[d]}") 139 | 140 | print("\n Digit Frequency by Position (right-to-left):") 141 | for pos in sorted(digit_pos_counter): 142 | print(f" Position {pos} (10^{pos} place):") 143 | for d in map(str, range(10)): 144 | print(f" {d}: {digit_pos_counter[pos][d]}") 145 | 146 | def __len__(self): 147 | return len(self.image_paths) 148 | 149 | def __getitem__(self, idx): 150 | img_path = self.image_paths[idx] 151 | label_str = self.labels[idx] 152 | is_life = self.is_lives[idx] 153 | 154 | image = Image.open(img_path).convert('RGB') 155 | image_np = np.array(image) 156 | 157 | augment_transform = self.transform_lives if is_life else self.transform_score 158 | if augment_transform: 159 | if image_np.ndim == 2: 160 | image_np = np.expand_dims(image_np, axis=-1) # H, W -> H, W, 1 161 | try: 162 | image_np = augment_transform(image=image_np)["image"] 163 | except Exception as e: 164 | print(f"Augment failed for idx={idx} with shape={image_np.shape}: {e}") 165 | raise 166 | 167 | img_tensor = torch.from_numpy(image_np).float().permute(2, 0, 1) / 255.0 # CHW, float [0,1] 168 | 169 | if self.transform: 170 | img_tensor = self.transform(img_tensor) 171 | 172 | if self.normalizer: 173 | img_tensor = self.normalizer(img_tensor, int(is_life)) 174 | 175 | padded_label = [int(d) for d in label_str] + [self.padding_value] * (self.max_digits - len(label_str)) 176 | 177 | return img_tensor, torch.tensor(padded_label, dtype=torch.long) 178 | -------------------------------------------------------------------------------- /framework/RoboTroller.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Keen Technologies, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | from dynamixel_sdk import COMM_SUCCESS, PacketHandler, PortHandler 18 | 19 | from framework.Actions import Action 20 | from framework.ControlDevice import ControlDevice 21 | from framework.Logger import logger 22 | 23 | ADDR_TORQUE_ENABLE = 64 24 | ADDR_GOAL_POSITION = 116 25 | ADDR_OPERATING_MODE = 11 26 | ADDR_GOAL_CURRENT = 102 27 | TORQUE_ENABLE = 1 28 | TORQUE_DISABLE = 0 29 | 30 | BUTTON_SERVO_DEFAULT = 2000 # center position 31 | DPAD_SERVO_DEFAULT = 2048 # unpressed position 32 | 33 | 34 | # tuple: (left_right_servo, up_down_servo, button_servo) 35 | def get_positions_from_action_mapping() -> dict[Action, tuple[int, ...]]: 36 | DPAD_SERVO_STRENGTH = 300 # movement strength from reference 37 | BUTTON_DEFLECTION = 2048 - 200 38 | DPAD_SERVO_UP = DPAD_SERVO_DEFAULT + DPAD_SERVO_STRENGTH 39 | DPAD_SERVO_RIGHT = DPAD_SERVO_DEFAULT + DPAD_SERVO_STRENGTH 40 | DPAD_SERVO_DOWN = DPAD_SERVO_DEFAULT - DPAD_SERVO_STRENGTH 41 | DPAD_SERVO_LEFT = DPAD_SERVO_DEFAULT - DPAD_SERVO_STRENGTH 42 | mapping = { 43 | Action.NOOP: ( 44 | DPAD_SERVO_DEFAULT, 45 | DPAD_SERVO_DEFAULT, 46 | BUTTON_SERVO_DEFAULT, 47 | ), 48 | Action.UP: ( 49 | DPAD_SERVO_DEFAULT, 50 | DPAD_SERVO_UP, 51 | BUTTON_SERVO_DEFAULT, 52 | ), 53 | Action.FIRE: ( 54 | DPAD_SERVO_DEFAULT, 55 | DPAD_SERVO_DEFAULT, 56 | BUTTON_DEFLECTION, 57 | ), 58 | Action.DOWN: ( 59 | DPAD_SERVO_DEFAULT, 60 | DPAD_SERVO_DOWN, 61 | BUTTON_SERVO_DEFAULT, 62 | ), 63 | Action.LEFT: ( 64 | DPAD_SERVO_LEFT, 65 | DPAD_SERVO_DEFAULT, 66 | BUTTON_SERVO_DEFAULT, 67 | ), 68 | Action.RIGHT: ( 69 | DPAD_SERVO_RIGHT, 70 | DPAD_SERVO_DEFAULT, 71 | BUTTON_SERVO_DEFAULT, 72 | ), 73 | Action.UPFIRE: (DPAD_SERVO_DEFAULT, DPAD_SERVO_UP, BUTTON_DEFLECTION), 74 | Action.DOWNFIRE: (DPAD_SERVO_DEFAULT, DPAD_SERVO_DOWN, BUTTON_DEFLECTION), 75 | Action.LEFTFIRE: (DPAD_SERVO_LEFT, DPAD_SERVO_DEFAULT, BUTTON_DEFLECTION), 76 | Action.RIGHTFIRE: (DPAD_SERVO_RIGHT, DPAD_SERVO_DEFAULT, BUTTON_DEFLECTION), 77 | Action.UPLEFT: ( 78 | DPAD_SERVO_LEFT, 79 | DPAD_SERVO_UP, 80 | BUTTON_SERVO_DEFAULT, 81 | ), 82 | Action.UPRIGHT: ( 83 | DPAD_SERVO_RIGHT, 84 | DPAD_SERVO_UP, 85 | BUTTON_SERVO_DEFAULT, 86 | ), 87 | Action.DOWNLEFT: ( 88 | DPAD_SERVO_LEFT, 89 | DPAD_SERVO_DOWN, 90 | BUTTON_SERVO_DEFAULT, 91 | ), 92 | Action.DOWNRIGHT: ( 93 | DPAD_SERVO_RIGHT, 94 | DPAD_SERVO_DOWN, 95 | BUTTON_SERVO_DEFAULT, 96 | ), 97 | Action.UPLEFTFIRE: (DPAD_SERVO_LEFT, DPAD_SERVO_UP, BUTTON_DEFLECTION), 98 | Action.UPRIGHTFIRE: (DPAD_SERVO_RIGHT, DPAD_SERVO_UP, BUTTON_DEFLECTION), 99 | Action.DOWNLEFTFIRE: (DPAD_SERVO_LEFT, DPAD_SERVO_DOWN, BUTTON_DEFLECTION), 100 | Action.DOWNRIGHTFIRE: (DPAD_SERVO_RIGHT, DPAD_SERVO_DOWN, BUTTON_DEFLECTION), 101 | } 102 | return mapping 103 | 104 | 105 | class RoboTroller(ControlDevice): 106 | def __init__(self, model_name, vendor_id, product_id, port_name, baud_rate=15200, current_limit=200): 107 | super().__init__() 108 | self.vendor_id = vendor_id 109 | self.product_id = product_id 110 | # TODO: auto-discover port_name via pyudev or serial attributes 111 | if not os.path.exists(port_name): 112 | raise ValueError(f"RoboTroller: {port_name} does not exist. Is Robotroller connected?") 113 | 114 | self.portHandler = PortHandler(port_name) 115 | self.packetHandler = PacketHandler(2.0) 116 | 117 | if not self.portHandler.openPort(): 118 | raise ValueError(f"RoboTroller: Failed to open port {port_name}.") 119 | 120 | if not self.portHandler.setBaudRate(baud_rate): 121 | raise ValueError("RoboTroller: Failed to set baudrate.") 122 | 123 | self.position_from_action_mapping = get_positions_from_action_mapping() 124 | self.current_limit = current_limit 125 | 126 | self.left_right_servo_id = 51 127 | self.up_down_servo_id = 52 128 | self.button_servo_id = 50 129 | self.servo_ids = [self.left_right_servo_id, self.up_down_servo_id, self.button_servo_id] 130 | 131 | self.prev_positions = (DPAD_SERVO_DEFAULT, DPAD_SERVO_DEFAULT, BUTTON_SERVO_DEFAULT) 132 | 133 | for servo_id in self.servo_ids: 134 | self._initialize_servo(servo_id) 135 | 136 | # set to default positions 137 | self.update_positions(self.prev_positions, force=True) 138 | 139 | def _initialize_servo(self, servo_id): 140 | # Disable torque. Mode can only be changed when torque is disabled. 141 | self._write_byte(servo_id, ADDR_TORQUE_ENABLE, TORQUE_DISABLE, "disable torque") 142 | # Change mode to current based position control. This mode allows us to limit the maximum current draw. 143 | self._write_byte(servo_id, ADDR_OPERATING_MODE, 5, "set mode") 144 | # The ADDR_GOAL_CURRENT is treated as a limit on the current draw in mA. 145 | # In current based position control mode, the current is limited to [-current_limit, current_limit]. 146 | self._write_word(servo_id, ADDR_GOAL_CURRENT, self.current_limit, "set current") 147 | self._write_byte(servo_id, ADDR_TORQUE_ENABLE, TORQUE_ENABLE, "enable torque") 148 | 149 | def _write_byte(self, servo_id, addr, val, label): 150 | result, err = self.packetHandler.write1ByteTxRx(self.portHandler, servo_id, addr, val) 151 | if result != COMM_SUCCESS: 152 | logger.warning(f"{label} servo={servo_id}: {self.packetHandler.getTxRxResult(result)}") 153 | elif err != 0: 154 | logger.warning(f"{label} error servo={servo_id}: {self.packetHandler.getRxPacketError(err)}") 155 | 156 | def _write_word(self, servo_id, addr, val, label): 157 | result, err = self.packetHandler.write2ByteTxRx(self.portHandler, servo_id, addr, val) 158 | if result != COMM_SUCCESS: 159 | logger.warning(f"{label} servo={servo_id}: {self.packetHandler.getTxRxResult(result)}") 160 | elif err != 0: 161 | logger.warning(f"{label} error servo={servo_id}: {self.packetHandler.getRxPacketError(err)}") 162 | 163 | def shutdown(self): 164 | for servo_id in self.servo_ids: 165 | self._write_byte(servo_id, ADDR_TORQUE_ENABLE, TORQUE_DISABLE, "disable torque") 166 | self.portHandler.closePort() 167 | 168 | # positions -> (left_right_servo, up_down_servo, button_servo) 169 | def update_positions(self, positions, force=False): 170 | for i, pos in enumerate(positions): 171 | if force or self.prev_positions[i] != pos: 172 | result, err = self.packetHandler.write4ByteTxRx( 173 | self.portHandler, self.servo_ids[i], ADDR_GOAL_POSITION, pos 174 | ) 175 | if result != COMM_SUCCESS: 176 | logger.warning( 177 | f"set position servo={self.servo_ids[i]}: {self.packetHandler.getTxRxResult(result)}" 178 | ) 179 | elif err != 0: 180 | logger.warning( 181 | f"error setting position servo={self.servo_ids[i]}: {self.packetHandler.getRxPacketError(err)}" 182 | ) 183 | 184 | self.prev_positions = positions 185 | 186 | def apply_action(self, action, state): 187 | positions = self.position_from_action_mapping[Action(action)] 188 | self.update_positions(positions, True) 189 | -------------------------------------------------------------------------------- /framework/ScoreDetector.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Keen Technologies, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import contextlib 16 | import math 17 | import os 18 | import time 19 | 20 | import numpy as np 21 | import torch 22 | import torch.nn as nn 23 | import torch.nn.functional as F 24 | from torch.cuda.amp import autocast 25 | 26 | from framework.Logger import logger 27 | 28 | 29 | class CropInfo: 30 | def __init__(self, x, y, w, h, num_digits): 31 | self.x = x 32 | self.y = y 33 | self.w = w 34 | self.h = h 35 | self.num_digits = num_digits 36 | # REVIEW: move to config, or assume atari dims? 37 | self.reference_width = 160 38 | self.reference_height = 210 39 | 40 | def __str__(self): 41 | return f"crop=(x={self.x}, y={self.y}, w={self.w}, h={self.h})" 42 | 43 | 44 | class ScoreDetector: 45 | def __init__( 46 | self, 47 | env_name, 48 | model_type, 49 | total_lives, 50 | checkpoint, 51 | score_crop_info, 52 | lives_crop_info=None, 53 | valid_jumps=[], 54 | score_offsets={}, 55 | lives_offsets={}, 56 | device="cpu", 57 | data_dir=None, 58 | ): 59 | 60 | assert os.path.exists(checkpoint) 61 | 62 | logger.info(f"ScoreDetector: loading score model: {checkpoint} to device={device}") 63 | 64 | self.data_dir = os.getcwd() if data_dir is None else data_dir 65 | self.env_name = env_name 66 | self.device = device 67 | 68 | self.score_crop_info = CropInfo(**score_crop_info) 69 | if score_offsets: 70 | self.score_crop_info.x += score_offsets["offset_x"] 71 | self.score_crop_info.y += score_offsets["offset_y"] 72 | 73 | # lives is optional 74 | if lives_crop_info: 75 | self.lives_crop_info = CropInfo(**lives_crop_info) 76 | if lives_offsets: 77 | self.lives_crop_info.x += lives_offsets["offset_x"] 78 | self.lives_crop_info.y += lives_offsets["offset_y"] 79 | else: 80 | self.lives_crop_info = None 81 | 82 | self.score_validator = None 83 | 84 | self.input_memory_format = torch.contiguous_format 85 | self.use_mixed_precision = False 86 | 87 | self.model_type = model_type 88 | 89 | load_start_time = time.time() 90 | if self.model_type == "crnn_ctc": 91 | from framework.models.score_detector.crnn_ctc import load_model 92 | 93 | self.input_memory_format = torch.channels_last 94 | self.use_mixed_precision = False if self.device == 'cpu' else True 95 | self.model = load_model( 96 | checkpoint, 97 | self.env_name, 98 | device=self.device, 99 | mixed_precision=self.use_mixed_precision, 100 | memory_format=self.input_memory_format, 101 | ) 102 | 103 | from framework.ScoreValidator import ScoreValidator 104 | 105 | self.validator = ScoreValidator( 106 | env_name, 107 | valid_jumps, 108 | displayed_lives=self.lives_crop_info.num_digits if self.lives_crop_info is not None else total_lives, 109 | entropy_threshold=self.model.entropy_threshold, 110 | entropy_ceiling=self.model.entropy_ceiling, 111 | ) 112 | else: 113 | raise ValueError(f"Invalid score model type={self.model_type}") 114 | 115 | print(f"score model load: {(time.time() - load_start_time) * 1000.0:.2f}ms") 116 | # print(self.model) 117 | 118 | # track changes to the regions to avoid unnecessary invocations of the score model 119 | self.last_score_region = None 120 | self.last_lives_region = None 121 | self.region_changed_threshold = 0.8 122 | 123 | self.tmp_score_crop_img = None 124 | 125 | self.ave_time_model = 0 126 | self.ave_time_total = 0 127 | self.frames = 0 128 | 129 | def supports_lives(self): 130 | return self.lives_crop_info is not None 131 | 132 | def get_score_crop_info(self, frame_width, frame_height): 133 | return self.__adjust_crop(frame_width, frame_height, self.score_crop_info) 134 | 135 | def get_lives_crop_info(self, frame_width, frame_height): 136 | if self.lives_crop_info: 137 | return self.__adjust_crop(frame_width, frame_height, self.lives_crop_info) 138 | else: 139 | return None 140 | 141 | # np.ndarray: h,w,c 142 | def get_score_and_lives(self, frame) -> tuple[int, int | None]: 143 | # total_start = time.time() 144 | has_lives = self.lives_crop_info is not None 145 | frame_w, frame_h = frame.shape[1], frame.shape[0] 146 | 147 | score_crop = self.__crop_region(frame, self.get_score_crop_info(frame_w, frame_h), channels_first=False) 148 | lives_crop = ( 149 | self.__crop_region(frame, self.get_lives_crop_info(frame_w, frame_h), channels_first=False) 150 | if has_lives 151 | else None 152 | ) 153 | 154 | self.tmp_score_crop_img = score_crop 155 | 156 | score_crop = self.model.preprocess(score_crop) 157 | if lives_crop is not None: 158 | lives_crop = self.model.preprocess(lives_crop, is_lives=True) 159 | 160 | combined_crop = ( 161 | torch.stack([score_crop, lives_crop], dim=0) if lives_crop is not None else torch.stack([score_crop], dim=0) 162 | ) 163 | combined_crop = combined_crop.contiguous(memory_format=self.input_memory_format) 164 | 165 | # model_start = time.time() 166 | with torch.no_grad(): 167 | autocast_ctx = ( 168 | torch.amp.autocast(device_type=self.device, dtype=torch.float16) 169 | if self.use_mixed_precision 170 | else contextlib.nullcontext() 171 | ) 172 | with autocast_ctx: 173 | decoded, confidences = self.model.predict(combined_crop) 174 | 175 | # self.ave_time_model += (time.time() - model_start) 176 | 177 | decoded = decoded.cpu().numpy() 178 | confidences = confidences.cpu().numpy() 179 | 180 | score, score_confidences, lives, lives_confidences = self.model.convert( 181 | decoded[0], confidences[0], decoded[1] if has_lives else None, confidences[1] if has_lives else None 182 | ) 183 | 184 | if self.validator is not None: 185 | valid_score, valid_lives = self.validator.validate( 186 | score, score_confidences, pred_lives=lives, lives_confidences=lives_confidences 187 | ) 188 | else: 189 | valid_score = score 190 | valid_lives = lives 191 | 192 | # logger.debug(f"{score_preds} -> {valid_score}, {lives_preds} -> {valid_lives}") 193 | 194 | # self.ave_time_total += (time.time()-total_start) 195 | # self.frames += 1 196 | # if (self.frames % 10000) == 0: 197 | # print(f"model ave: {(self.ave_time_model/self.frames)*1000.0:.2f}/ {(self.ave_time_total/self.frames)*1000.0:.2f}ms") 198 | # self.ave_time = 0 199 | # self.ave_time_total = 0 200 | # self.frames = 0 201 | 202 | return valid_score, valid_lives if has_lives else None 203 | 204 | def __adjust_crop(self, width, height, crop): 205 | if width == crop.reference_width and height == crop.reference_height: 206 | return crop 207 | 208 | scale_x = width / crop.reference_width 209 | scale_y = height / crop.reference_height 210 | 211 | x = int(crop.x * scale_x) 212 | y = int(crop.y * scale_y) 213 | x_off = int(crop.w * scale_x) 214 | y_off = int(crop.h * scale_y) 215 | 216 | return CropInfo(x, y, x_off, y_off, crop.num_digits) 217 | 218 | def __crop_region(self, x, crop_info, channels_first=False): 219 | if channels_first: 220 | cropped_image = x[:, crop_info.y : crop_info.y + crop_info.h, crop_info.x : crop_info.x + crop_info.w] 221 | else: 222 | cropped_image = x[crop_info.y : crop_info.y + crop_info.h, crop_info.x : crop_info.x + crop_info.w, :] 223 | return cropped_image 224 | 225 | def _crops_changed(self, score_crop, lives_crop): 226 | # check for changes to the score and lives regions, if no changes (within a threshold) are detected, return 227 | # previous values 228 | score_changed = True 229 | if self.last_score_region is not None: 230 | error = np.mean((score_crop - self.last_score_region) ** 2) 231 | # print(f"score_error={error}") 232 | score_changed = error >= self.region_changed_threshold 233 | 234 | lives_changed = lives_crop is not None 235 | if not score_changed and self.last_lives_region is not None and lives_crop is not None: 236 | error = np.mean((lives_crop - self.last_lives_region) ** 2) 237 | # print(f"lives_error={error}") 238 | lives_changed = error >= self.region_changed_threshold 239 | 240 | self.last_score_region = score_crop 241 | self.last_lives_region = lives_crop 242 | 243 | return score_changed or lives_changed 244 | -------------------------------------------------------------------------------- /train/score_detector/ale_ram_injection.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Keen Technologies, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | 17 | """ 18 | Game Ram Config 19 | -------- 20 | 21 | To add a new game: 22 | 23 | Each ale game in src/games/supported defines a 'step' function 24 | where score and lives handling is found. For score, look for: 25 | 26 | int score = getDecimalScore(ADDR1, ADDR2, ..., &system); 27 | "score_addr": [ADDR1, ADDR2, ...], 28 | "score_type": "packed_bcd", # most common 29 | ** Byte order depends on param order in getDecimalScore(...), often MSB -> LSB, but the function reads **RAM in that order**, so this becomes the "score_addr" order. 30 | 31 | "bcd_order": How **RAM addresses** are ordered (lsb or msb) 32 | "digit_order": How **digits** are packed inside each byte (lsb or msb) 33 | 34 | If ALE has: 35 | getDecimalScore(ADDR_HI, ADDR_MID, ADDR_LO, &system); 36 | 37 | Then: 38 | "score_addr": [ADDR_HI, ADDR_MID, ADDR_LO], 39 | "bcd_order": "lsb", # RAM goes low-to-high 40 | "digit_order": "msb" # digits packed left-to-right (normal) 41 | 42 | If rendered score is off by x10 or x100, "digit_order" is probably wrong. If digits are reversed, "bcd_order" is wrong. 43 | 44 | If ALE has: 45 | int lives = readRam(&system, ADDR); 46 | 47 | "lives_addr": ADDR, 48 | "lives_nibble": "high" or "low", # based on shift 49 | "lives_offset": 1 # if +1 is applied 50 | 51 | if its raw: 52 | "lives_nibble": "low", 53 | "lives_offset": 0 54 | """ 55 | 56 | GAME_RAM_CONFIG = { 57 | "atlantis": { 58 | "score_addr": [0xA1, 0xA3, 0xA2], 59 | "lives_addr": 0xF1, 60 | "lives_nibble": "low", 61 | "score_type": "packed_bcd", 62 | "bcd_order": "msb", 63 | "digit_order": "msb", 64 | "score_multiplier": 100, 65 | "total_lives": 6, 66 | "display_lives": 6, 67 | "score_step": [100], 68 | "max_score": 84300, # capped at 84400 with ram injection (84500 but isn't correct); is this the max? 69 | "score_digits": 6, 70 | }, 71 | "battle_zone": { 72 | "score_addr": [0x9E, 0x9D], 73 | "lives_addr": 0xBA, 74 | "lives_nibble": "low", 75 | "score_type": "custom_battlezone", 76 | "score_multiplier": 1000, 77 | "score_step": [1000], 78 | "max_score": 501000, # capped 79 | "total_lives": 5, 80 | "display_lives": 5, 81 | "score_digits": 6, 82 | }, 83 | "centipede": { 84 | "score_addr": [0xF6, 0xF5, 0xF4], 85 | "lives_addr": 0xED, 86 | "score_type": "packed_bcd", 87 | "bcd_order": "lsb", 88 | "digit_order": "msb", 89 | "score_multiplier": 1, 90 | "lives_nibble": "high", 91 | "total_lives": 3, 92 | "display_lives": 2, 93 | "score_step": [ 94 | 1, 95 | ], 96 | "max_score": 9999, 97 | "score_digits": 6, 98 | }, 99 | "defender": { 100 | "score_addr": [0x9C, 0x9D, 0x9E, 0x9F, 0xA0, 0xA1], 101 | "lives_addr": 0xC2, 102 | "score_type": "custom_defender", 103 | "lives_nibble": "low", 104 | "total_lives": 3, 105 | "display_lives": 3, 106 | "score_step": [50], 107 | "max_score": 73000, # capped at 73000 with ram injection; is this the max? 108 | "score_multiplier": 1, 109 | "score_digits": 6, 110 | }, 111 | "krull": { 112 | "score_addr": [0x9E, 0x9D, 0x9C], 113 | "lives_addr": 0x9F, 114 | "lives_nibble": "low", 115 | "score_type": "packed_bcd", 116 | "bcd_order": "lsb", 117 | "digit_order": "msb", 118 | "total_lives": 3, 119 | "display_lives": 2, 120 | "score_step": [10], 121 | "max_score": 99990, 122 | "score_digits": 6, 123 | }, 124 | "ms_pacman": { 125 | "score_addr": [0xF8, 0xF9, 0xFA], 126 | "lives_addr": 0xFB, 127 | "score_type": "packed_bcd", 128 | "bcd_order": "lsb", 129 | "lives_nibble": "low", 130 | "total_lives": 3, 131 | "display_lives": 2, 132 | "score_step": [ 133 | 10, 134 | ], 135 | "max_score": 99990, 136 | "score_digits": 6, 137 | }, 138 | "qbert": { 139 | "score_addr": [0xD9, 0xDA, 0xDB], 140 | "lives_addr": 0x88, 141 | "score_type": "packed_bcd", 142 | "bcd_order": "msb", 143 | "lives_signed": True, 144 | "total_lives": 4, 145 | "display_lives": 3, 146 | "score_step": [25], 147 | "max_score": 99950, 148 | "score_digits": 5, 149 | }, 150 | "up_n_down": { 151 | "score_addr": [0x82, 0x81, 0x80], 152 | "lives_addr": 0x86, 153 | "score_type": "packed_bcd", 154 | "bcd_order": "lsb", 155 | "digit_order": "msb", 156 | "lives_offset": 1, 157 | "total_lives": 5, 158 | "display_lives": 4, 159 | "score_step": [10], 160 | "max_score": 99990, 161 | "score_digits": 6, 162 | }, 163 | } 164 | 165 | # Atari 2600 has 128 bytes of RAM total, located from $80 to $FF 166 | REAL_RAM_START = 0x80 167 | 168 | 169 | def encode_bcd(score: int, byte_count: int, bcd_order: str = "msb", digit_order: str = "msb") -> list[int]: 170 | digits = [int(d) for d in f"{score:0{byte_count * 2}d}"] 171 | 172 | if digit_order == "lsb": 173 | digits = digits[::-1] 174 | 175 | bcd = [] 176 | for i in range(0, len(digits), 2): 177 | hi, lo = digits[i], digits[i + 1] 178 | byte = (hi << 4) | lo 179 | bcd.append(byte) 180 | 181 | if bcd_order == "lsb": 182 | bcd = bcd[::-1] 183 | 184 | return bcd 185 | 186 | 187 | def decode_score_bcd(ram, addr_list, config): 188 | base = REAL_RAM_START 189 | score_type = config["score_type"] 190 | bcd_order = config.get("bcd_order", "msb") 191 | digit_order = config.get("digit_order", "msb") 192 | multiplier = config.get("score_multiplier", 1) 193 | 194 | if score_type == "custom_battlezone": 195 | val1 = ram[0x9D - REAL_RAM_START] 196 | val2 = ram[0x9E - REAL_RAM_START] 197 | 198 | ones = (val1 >> 4) & 0xF # left nibble of 0x9D 199 | tens = val2 & 0xF # right nibble of 0x9E 200 | hundreds = (val2 >> 4) & 0xF # left nibble of 0x9E 201 | 202 | for digit in (ones, tens, hundreds): 203 | if digit == 0xA: 204 | digit = 0 # unused 205 | 206 | score = (hundreds * 100 + tens * 10 + ones) * 1000 207 | return score 208 | 209 | if score_type == "custom_defender": 210 | # each address holds one digit, LSB first 211 | score = 0 212 | mult = 1 213 | for addr in addr_list: 214 | val = ram[addr - base] & 0xF 215 | if val == 0xA: 216 | val = 0 217 | score += val * mult 218 | mult *= 10 219 | return score * multiplier 220 | 221 | # decode BCD packed bytes 222 | ram_bytes = addr_list if bcd_order == "msb" else list(reversed(addr_list)) 223 | digits = [] 224 | for addr in ram_bytes: 225 | byte_val = ram[addr - base] 226 | digits.append((byte_val >> 4) & 0xF) # left 227 | digits.append(byte_val & 0xF) # right 228 | 229 | if digit_order == "lsb": 230 | digits = digits[::-1] 231 | 232 | score = sum(d * (10**i) for i, d in enumerate(reversed(digits))) 233 | return score * multiplier 234 | 235 | 236 | def decode_lives(ram, addr, config): 237 | idx = addr - REAL_RAM_START 238 | # print(f"RAM[0x{addr:X}] = {ram[addr] - REAL_RAM_START]}") 239 | return ram[idx] 240 | 241 | 242 | def write_score(config, score, ale): 243 | addr = config["score_addr"] 244 | score_type = config["score_type"] 245 | multiplier = config.get("score_multiplier", 1) 246 | 247 | if score_type == "packed_bcd": 248 | score = score // multiplier 249 | byte_count = len(addr) 250 | bcd = encode_bcd(score, byte_count, config.get("bcd_order", "msb"), config.get("digit_order", "msb")) 251 | 252 | for a, val in zip(addr, bcd): 253 | ale.setRAM(a - REAL_RAM_START, val) 254 | 255 | elif score_type == "custom_battlezone": 256 | digits = [int(d) for d in f"{score // multiplier:03d}"] # 3-digit string 257 | hundreds, tens, ones = digits 258 | 259 | byte1 = ones << 4 # 0x9D = O0 (left nibble = ones) 260 | byte2 = (hundreds << 4) | tens # 0x9E = HT 261 | 262 | ale.setRAM(0x9D - REAL_RAM_START, byte1) 263 | ale.setRAM(0x9E - REAL_RAM_START, byte2) 264 | 265 | elif score_type == "custom_defender": 266 | digits = [int(d) for d in f"{score:06d}"][::-1] # reverse for LSB-first 267 | for a, d in zip(addr, digits): 268 | val = d & 0xF # ensure upper nibble is 0 269 | ale.setRAM(a - REAL_RAM_START, val) 270 | 271 | # for addr in config["score_addr"]: 272 | # print(f"RAM[{hex(addr)}] = {ale.getRAM()[addr - REAL_RAM_START]}") 273 | 274 | elif score_type == "nibble_lsb_first": 275 | # each digit is one nibble, LSB at lowest addr 276 | digits = [int(d) for d in f"{score:06d}"] 277 | for i, a in enumerate(addr): 278 | ale.setRAM(a - REAL_RAM_START, digits[i]) 279 | 280 | else: 281 | raise NotImplementedError(f"Unsupported score_type: {score_type}") 282 | 283 | 284 | def write_lives(config, lives, ale): 285 | addr = config["lives_addr"] 286 | idx = addr - REAL_RAM_START 287 | ram = ale.getRAM() 288 | total = config.get("total_lives", lives) 289 | display = config.get("display_lives", total) 290 | offset = config.get("lives_offset", 1 if display < total else 0) 291 | 292 | if config.get("lives_signed"): 293 | val = np.uint8(np.int8(lives - offset)) 294 | ale.setRAM(idx, val) 295 | return 296 | 297 | val = (lives - offset) & 0xF 298 | 299 | nibble = config.get("lives_nibble", "low") 300 | current_val = ram[idx] 301 | 302 | if nibble == "high": 303 | new_val = (val << 4) | (current_val & 0x0F) 304 | ale.setRAM(idx, new_val) 305 | elif nibble == "low": 306 | new_val = (current_val & 0xF0) | val 307 | ale.setRAM(idx, new_val) 308 | else: 309 | raise ValueError(f"Unknown lives_nibble value: {nibble}") 310 | -------------------------------------------------------------------------------- /framework/MCCDAQDevice.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Keen Technologies, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from enum import Enum, IntEnum 16 | 17 | from framework.Actions import Action 18 | from framework.ControlDevice import ControlDevice 19 | from framework.Logger import logger 20 | 21 | 22 | class Signal(Enum): 23 | LOW = 0 24 | HIGH = 1 25 | 26 | 27 | # Layout for the MCC USB-1024LS 28 | class MCCDAQ1024LSLayout: 29 | """ 30 | The MCC DAQ has 4 ports: 31 | - PORTA - A0-A7 mapped to pins 21-28 32 | - PORTB - B0-B7 mapped to pins 32-39 33 | - PORTCL - C0-C3 mapped to pins 1-4 34 | - PORTCH - C4-C7 mapped to pins 5-8 35 | """ 36 | 37 | # uldaq/ul_enums.py 38 | class ULDAQPortEnum(IntEnum): 39 | FIRSTPORTA = 10 40 | FIRSTPORTB = 11 41 | FIRSTPORTCL = 12 42 | FIRSTPORTCH = 13 43 | 44 | # from: https://github.com/mccdaq/uldaq/blob/1d8404159c0fb6d2665461b80acca5bbef5c610a/src/hid/dio/DioUsbDio24.cpp#L177 45 | class HIDPortEnum(IntEnum): 46 | FIRSTPORTA = 1 47 | FIRSTPORTB = 4 48 | FIRSTPORTCL = 8 49 | FIRSTPORTCH = 2 50 | 51 | # https://github.com/mccdaq/uldaq/blob/1d8404159c0fb6d2665461b80acca5bbef5c610a/src/hid/dio/DioUsbDio24.h 52 | class ReportCommand(IntEnum): 53 | DIN = 0x00 # Read all pins on a port 54 | DOUT = 0x01 # Write to all pins on a port 55 | BITIN = 0x02 # Read a single pin 56 | BITOUT = 0x03 # Write a single pin 57 | DCONFIG = 0x0D # Configure direction of a port 58 | 59 | def __init__(self, use_hid: bool): 60 | self.use_hid = use_hid 61 | self.PortEnum = self.HIDPortEnum if use_hid else self.ULDAQPortEnum 62 | 63 | self.port_ranges = { 64 | self.PortEnum.FIRSTPORTA: (21, 29), 65 | self.PortEnum.FIRSTPORTB: (32, 40), 66 | self.PortEnum.FIRSTPORTCL: (1, 5), 67 | self.PortEnum.FIRSTPORTCH: (5, 9), 68 | } 69 | 70 | def get_port_for_pin(self, pin): 71 | for port, (start, end) in self.port_ranges.items(): 72 | if start <= pin <= end: 73 | return port 74 | return None 75 | 76 | # Convert the pin to the corresponding bit value for the port. 77 | def get_bit_for_pin(self, port, pin): 78 | # Pins are 1-based 79 | start_range, _ = self.port_ranges[port] 80 | 81 | # Find the port the pin belongs to, and offset the port range 82 | # such that it is in the range [1,number_of_bits]. 83 | # CL = (1,4) 84 | # CH = (1,4) 85 | # A = (1,8) 86 | # B = (1,8) 87 | # And, then offset by 1 to convert to 0-based range for bit. 88 | bit = pin - (start_range - 1) - 1 89 | return bit 90 | 91 | 92 | def get_pins_from_action_mapping(action_to_pin_map) -> dict[Action, tuple[int, ...]]: 93 | PIN_UP = action_to_pin_map[Action.UP] 94 | PIN_DOWN = action_to_pin_map[Action.DOWN] 95 | PIN_RIGHT = action_to_pin_map[Action.RIGHT] 96 | PIN_LEFT = action_to_pin_map[Action.LEFT] 97 | PIN_FIRE = action_to_pin_map[Action.FIRE] 98 | 99 | mapping = { 100 | Action.NOOP: (-1,), 101 | Action.UP: (PIN_UP,), 102 | Action.FIRE: (PIN_FIRE,), 103 | Action.DOWN: (PIN_DOWN,), 104 | Action.LEFT: (PIN_LEFT,), 105 | Action.RIGHT: (PIN_RIGHT,), 106 | Action.UPFIRE: (PIN_UP, PIN_FIRE), 107 | Action.DOWNFIRE: (PIN_DOWN, PIN_FIRE), 108 | Action.LEFTFIRE: (PIN_LEFT, PIN_FIRE), 109 | Action.RIGHTFIRE: (PIN_RIGHT, PIN_FIRE), 110 | Action.UPLEFT: (PIN_UP, PIN_LEFT), 111 | Action.UPRIGHT: (PIN_UP, PIN_RIGHT), 112 | Action.DOWNLEFT: (PIN_DOWN, PIN_LEFT), 113 | Action.DOWNRIGHT: (PIN_DOWN, PIN_RIGHT), 114 | Action.UPLEFTFIRE: (PIN_UP, PIN_LEFT, PIN_FIRE), 115 | Action.UPRIGHTFIRE: (PIN_UP, PIN_RIGHT, PIN_FIRE), 116 | Action.DOWNLEFTFIRE: (PIN_DOWN, PIN_LEFT, PIN_FIRE), 117 | Action.DOWNRIGHTFIRE: (PIN_DOWN, PIN_RIGHT, PIN_FIRE), 118 | } 119 | return mapping 120 | 121 | 122 | # NOTE: Implementation is for MCCDAQ USB-1024LS I/O device board 123 | # From https://forums.atariage.com/topic/266868-joystick-pinout-question/#comment-3788375: 124 | # 'SIGNAL_LOW' will trigger the action on Atari, and 'SIGNAL_HIGH' is off. 125 | class DAQDevice(ControlDevice): 126 | def __init__( 127 | self, 128 | model_name: str, 129 | vendor_id: str, 130 | product_id: str, 131 | pin_to_action_str_map, 132 | active_low=True, 133 | use_hid_backend=False, 134 | ): 135 | super().__init__() 136 | self.vendor_id = vendor_id 137 | self.product_id = product_id 138 | self.use_hid = use_hid_backend 139 | self.signal_active = Signal.LOW.value if active_low else Signal.HIGH.value 140 | self.signal_inactive = Signal.HIGH.value if active_low else Signal.LOW.value 141 | 142 | self.layout = MCCDAQ1024LSLayout(self.use_hid) 143 | self.action_to_pin_map = {} 144 | for pin, action_str in pin_to_action_str_map.items(): 145 | if Action.has_key(action_str): 146 | self.action_to_pin_map[Action[action_str]] = int(pin) 147 | else: 148 | logger.warning(f"'{action_str}' is not a valid action.") 149 | self.action_pins = tuple(int(pin) for pin in pin_to_action_str_map.keys()) 150 | self.pins_from_action_map = get_pins_from_action_mapping(self.action_to_pin_map) 151 | 152 | ports = [self.layout.get_port_for_pin(p) for p in self.action_pins] 153 | assert all(p == ports[0] for p in ports), "Only one port supported at a time" 154 | 155 | self.port = ports[0] 156 | self.port_state = 0x00 157 | 158 | if self.use_hid: 159 | from framework.HIDDevice import HIDDevice 160 | 161 | # from: https://github.com/mccdaq/uldaq/blob/1d8404159c0fb6d2665461b80acca5bbef5c610a/src/uldaq.h#L945 162 | class DigitalDirection(IntEnum): 163 | INPUT = 1 164 | OUTPUT = 2 165 | 166 | self._ReportCommand = MCCDAQ1024LSLayout.ReportCommand 167 | 168 | # When using split-port C, we must write all 8 bits at once, but we want to update CL and CH 169 | # individually. Cache the last known values and combine before writing to device. 170 | self.port_cl_val = 0 171 | self.port_ch_val = 0 172 | self.backend = HIDDevice(vendor_id, product_id) 173 | 174 | # configure the port 175 | report = bytearray([self._ReportCommand.DCONFIG, self.port, DigitalDirection.OUTPUT]) 176 | self.backend.write_sync(report) 177 | else: 178 | import uldaq 179 | 180 | assert self.layout.PortEnum.FIRSTPORTA == uldaq.DigitalPortType.FIRSTPORTA 181 | assert self.layout.PortEnum.FIRSTPORTB == uldaq.DigitalPortType.FIRSTPORTB 182 | assert self.layout.PortEnum.FIRSTPORTCL == uldaq.DigitalPortType.FIRSTPORTCL 183 | assert self.layout.PortEnum.FIRSTPORTCH == uldaq.DigitalPortType.FIRSTPORTCH 184 | 185 | devices = uldaq.get_daq_device_inventory(uldaq.InterfaceType.USB) 186 | if not devices: 187 | raise RuntimeError("No DAQ device found.") 188 | 189 | self.device = uldaq.DaqDevice(devices[0]) # TODO: match model_name properly 190 | assert self.device is not None, "Failed to get first DAQ device" 191 | self.device.connect() 192 | 193 | self.dio_device = self.device.get_dio_device() 194 | assert self.dio_device is not None, "Failed to get DIO device" 195 | 196 | self.dio_device.d_config_port(self.port, uldaq.DigitalDirection.OUTPUT) 197 | 198 | self.default_port_state = self._build_default_state() 199 | 200 | # Initialize action pin values to off. 201 | # From https://forums.atariage.com/topic/266868-joystick-pinout-question/#comment-3788375: 202 | # 'SIGNAL_LOW' will trigger the action on Atari, and 'SIGNAL_HIGH' is off. 203 | self._set_pins(self.port, self.action_pins, self.signal_inactive) 204 | 205 | def shutdown(self): 206 | # Set the action pins to off 207 | # From https://forums.atariage.com/topic/266868-joystick-pinout-question/#comment-3788375: 208 | # 'SIGNAL_LOW' will trigger the action on Atari, and 'SIGNAL_HIGH' is off. 209 | self._set_pins(self.port, self.action_pins, self.signal_inactive) 210 | 211 | if self.use_hid: 212 | if self.backend is not None: 213 | self.backend.shutdown() 214 | else: 215 | if self.device is not None: 216 | self.device.disconnect() 217 | 218 | def _build_default_state(self): 219 | bits = self._get_bits_for_pins(self.port, self.action_pins) 220 | state = 0 221 | for bit in bits: 222 | state |= 1 << bit 223 | return state 224 | 225 | def _send_signal(self, port, state): 226 | if self.use_hid: 227 | signal_val = state & 0xFF 228 | # when using the split port-C, only update the specified nibble 229 | if port == self.layout.PortEnum.FIRSTPORTCL: 230 | self.port_cl_val = signal_val & 0x0F 231 | signal_val = signal_val | (self.port_ch_val << 4) 232 | elif port == self.layout.PortEnum.FIRSTPORTCH: 233 | self.port_ch_val = signal_val & 0x0F 234 | signal_val = (signal_val << 4) | self.port_cl_val 235 | 236 | signal_data = bytearray([self._ReportCommand.DOUT, port, signal_val]) 237 | self.backend.write_sync(signal_data) 238 | else: 239 | self.dio_device.d_out(port, state) 240 | 241 | def _set_pins(self, port, pins, value, force=True): 242 | bits = self._get_bits_for_pins(port, pins) 243 | updated_state = self.default_port_state 244 | 245 | # construct the data mask for the new set of pins 246 | # d_in should only be used for debug; ports are configured for OUTPUT, 247 | # so reads are slow, up to ~15ms. 248 | # data_mask = self.dio_device.d_in(ports[0]) 249 | 250 | for bit in bits: 251 | if value == self.signal_active: 252 | updated_state &= ~(1 << bit) 253 | else: 254 | updated_state |= 1 << bit 255 | 256 | pins_low = self.port_state & ~updated_state 257 | pins_high = ~self.port_state & updated_state 258 | port_state = self.port_state & ~pins_low | pins_high 259 | 260 | if force or self.port_state != port_state: 261 | self.port_state = port_state 262 | self._send_signal(port, self.port_state) 263 | 264 | # logger.debug(f"port_state = {bin(self.port_state):08}") 265 | 266 | def _get_bits_for_pins(self, port_enum, pins): 267 | return [self.layout.get_bit_for_pin(port_enum, p) for p in pins] 268 | 269 | def apply_action(self, action, state): 270 | pins = self.pins_from_action_map.get(Action(action), (-1,)) 271 | # handle NOOP 272 | if pins == (-1,): 273 | pins = self.action_pins 274 | state = 0 275 | 276 | # To simulate press, pull it low; to simulate release pull it high. 277 | signal = self.signal_active if state else self.signal_inactive 278 | self._set_pins(self.port, pins, signal) 279 | 280 | def get_pins(self) -> list[int]: 281 | return list(self.action_pins) 282 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /framework/v4l2_defs.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Keen Technologies, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import ctypes 16 | 17 | # Excerpts from v4l2.py and https://www.kernel.org/doc/html/v4.9/media/uapi/v4l/videodev.html 18 | 19 | # =================== 20 | # ioctl utils 21 | # =================== 22 | 23 | _IOC_NONE = 0 24 | _IOC_WRITE = 1 25 | _IOC_READ = 2 26 | 27 | _IOC_NRBITS = 8 28 | _IOC_TYPEBITS = 8 29 | _IOC_SIZEBITS = 14 30 | _IOC_DIRBITS = 2 31 | 32 | _IOC_NRSHIFT = 0 33 | _IOC_TYPESHIFT = _IOC_NRSHIFT + _IOC_NRBITS 34 | _IOC_SIZESHIFT = _IOC_TYPESHIFT + _IOC_TYPEBITS 35 | _IOC_DIRSHIFT = _IOC_SIZESHIFT + _IOC_SIZEBITS 36 | 37 | 38 | def _IOC(dir_, type_, nr, size): 39 | return ( 40 | ctypes.c_int32(dir_ << _IOC_DIRSHIFT).value 41 | | ctypes.c_int32(ord(type_) << _IOC_TYPESHIFT).value 42 | | ctypes.c_int32(nr << _IOC_NRSHIFT).value 43 | | ctypes.c_int32(size << _IOC_SIZESHIFT).value 44 | ) 45 | 46 | 47 | def _IOC_TYPECHECK(t): 48 | return ctypes.sizeof(t) 49 | 50 | 51 | def _IO(type_, nr): 52 | return _IOC(_IOC_NONE, type_, nr, 0) 53 | 54 | 55 | def _IOW(type_, nr, size): 56 | return _IOC(_IOC_WRITE, type_, nr, _IOC_TYPECHECK(size)) 57 | 58 | 59 | def _IOR(type_, nr, size): 60 | return _IOC(_IOC_READ, type_, nr, _IOC_TYPECHECK(size)) 61 | 62 | 63 | def _IOWR(type_, nr, size): 64 | return _IOC(_IOC_READ | _IOC_WRITE, type_, nr, _IOC_TYPECHECK(size)) 65 | 66 | 67 | # =================== 68 | # v4l2 constants 69 | # =================== 70 | 71 | V4L2_MEMORY_MMAP = 1 72 | 73 | V4L2_CAP_VIDEO_CAPTURE = 0x00000001 74 | 75 | 76 | def fourcc(a, b, c, d): 77 | return (ord(a)) | (ord(b) << 8) | (ord(c) << 16) | (ord(d) << 24) 78 | 79 | 80 | def decode_fourcc(fmt): 81 | return ''.join([chr((fmt >> 8 * i) & 0xFF) for i in range(4)]) 82 | 83 | 84 | PIXEL_FORMATS = { 85 | "YUYV": fourcc('Y', 'U', 'Y', 'V'), 86 | "NV12": fourcc('N', 'V', '1', '2'), 87 | } 88 | 89 | # =================== 90 | # v4l2 structs 91 | # 92 | # NOTE: These haven't been tested for ABI compatibiity. 93 | # =================== 94 | 95 | # https://www.kernel.org/doc/html/v4.9/media/uapi/v4l/vidioc-querycap.html#c.v4l2_capability 96 | 97 | 98 | class v4l2_capability(ctypes.Structure): 99 | _fields_ = [ 100 | ('driver', ctypes.c_char * 16), 101 | ('card', ctypes.c_char * 32), 102 | ('bus_info', ctypes.c_char * 32), 103 | ('version', ctypes.c_uint32), 104 | ('capabilities', ctypes.c_uint32), 105 | ('reserved', ctypes.c_uint32 * 4), 106 | ] 107 | 108 | 109 | # https://www.kernel.org/doc/html/v4.9/media/uapi/v4l/pixfmt-002.html#c.v4l2_pix_format 110 | 111 | v4l2_field = ctypes.c_uint 112 | ( 113 | V4L2_FIELD_ANY, 114 | V4L2_FIELD_NONE, 115 | V4L2_FIELD_TOP, 116 | V4L2_FIELD_BOTTOM, 117 | V4L2_FIELD_INTERLACED, 118 | V4L2_FIELD_SEQ_TB, 119 | V4L2_FIELD_SEQ_BT, 120 | V4L2_FIELD_ALTERNATE, 121 | V4L2_FIELD_INTERLACED_TB, 122 | V4L2_FIELD_INTERLACED_BT, 123 | ) = range(10) 124 | 125 | v4l2_colorspace = ctypes.c_uint 126 | ( 127 | V4L2_COLORSPACE_SMPTE170M, 128 | V4L2_COLORSPACE_SMPTE240M, 129 | V4L2_COLORSPACE_REC709, 130 | V4L2_COLORSPACE_BT878, 131 | V4L2_COLORSPACE_470_SYSTEM_M, 132 | V4L2_COLORSPACE_470_SYSTEM_BG, 133 | V4L2_COLORSPACE_JPEG, 134 | V4L2_COLORSPACE_SRGB, 135 | ) = range(1, 9) 136 | 137 | 138 | class v4l2_pix_format(ctypes.Structure): 139 | _fields_ = [ 140 | ('width', ctypes.c_uint32), 141 | ('height', ctypes.c_uint32), 142 | ('pixelformat', ctypes.c_uint32), 143 | ('field', v4l2_field), 144 | ('bytesperline', ctypes.c_uint32), 145 | ('sizeimage', ctypes.c_uint32), 146 | ('colorspace', v4l2_colorspace), 147 | ('flags', ctypes.c_uint32), 148 | ('priv', ctypes.c_uint32), 149 | ] 150 | 151 | 152 | # https://www.kernel.org/doc/html/v4.9/media/uapi/v4l/vidioc-g-fmt.html#c.v4l2_format 153 | 154 | v4l2_buf_type = ctypes.c_uint 155 | ( 156 | V4L2_BUF_TYPE_VIDEO_CAPTURE, 157 | V4L2_BUF_TYPE_VIDEO_OUTPUT, 158 | V4L2_BUF_TYPE_VIDEO_OVERLAY, 159 | V4L2_BUF_TYPE_VBI_CAPTURE, 160 | V4L2_BUF_TYPE_VBI_OUTPUT, 161 | V4L2_BUF_TYPE_SLICED_VBI_CAPTURE, 162 | V4L2_BUF_TYPE_SLICED_VBI_OUTPUT, 163 | V4L2_BUF_TYPE_VIDEO_OUTPUT_OVERLAY, 164 | V4L2_BUF_TYPE_PRIVATE, 165 | ) = list(range(1, 9)) + [0x80] 166 | 167 | 168 | class v4l2_rect(ctypes.Structure): 169 | _fields_ = [ 170 | ('left', ctypes.c_int32), 171 | ('top', ctypes.c_int32), 172 | ('width', ctypes.c_int32), 173 | ('height', ctypes.c_int32), 174 | ] 175 | 176 | 177 | class v4l2_clip(ctypes.Structure): 178 | pass 179 | 180 | 181 | v4l2_clip._fields_ = [ 182 | ('c', v4l2_rect), 183 | ('next', ctypes.POINTER(v4l2_clip)), 184 | ] 185 | 186 | 187 | class v4l2_window(ctypes.Structure): 188 | _fields_ = [ 189 | ('w', v4l2_rect), 190 | ('field', v4l2_field), 191 | ('chromakey', ctypes.c_uint32), 192 | ('clips', ctypes.POINTER(v4l2_clip)), 193 | ('clipcount', ctypes.c_uint32), 194 | ('bitmap', ctypes.c_void_p), 195 | ('global_alpha', ctypes.c_uint8), 196 | ] 197 | 198 | 199 | class v4l2_vbi_format(ctypes.Structure): 200 | _fields_ = [ 201 | ('sampling_rate', ctypes.c_uint32), 202 | ('offset', ctypes.c_uint32), 203 | ('samples_per_line', ctypes.c_uint32), 204 | ('sample_format', ctypes.c_uint32), 205 | ('start', ctypes.c_int32 * 2), 206 | ('count', ctypes.c_uint32 * 2), 207 | ('flags', ctypes.c_uint32), 208 | ('reserved', ctypes.c_uint32 * 2), 209 | ] 210 | 211 | 212 | class v4l2_sliced_vbi_format(ctypes.Structure): 213 | _fields_ = [ 214 | ('service_set', ctypes.c_uint16), 215 | ('service_lines', ctypes.c_uint16 * 2 * 24), 216 | ('io_size', ctypes.c_uint32), 217 | ('reserved', ctypes.c_uint32 * 2), 218 | ] 219 | 220 | 221 | class v4l2_format(ctypes.Structure): 222 | class _u(ctypes.Union): 223 | _fields_ = [ 224 | ('pix', v4l2_pix_format), 225 | ('win', v4l2_window), 226 | ('vbi', v4l2_vbi_format), 227 | ('sliced', v4l2_sliced_vbi_format), 228 | ('raw_data', ctypes.c_char * 200), 229 | ] 230 | 231 | _fields_ = [ 232 | ('type', v4l2_buf_type), 233 | ('fmt', _u), 234 | ] 235 | 236 | 237 | # https://www.kernel.org/doc/html/v4.9/media/uapi/v4l/vidioc-reqbufs.html#c.v4l2_requestbuffers 238 | 239 | v4l2_memory = ctypes.c_uint 240 | ( 241 | V4L2_MEMORY_MMAP, 242 | V4L2_MEMORY_USERPTR, 243 | V4L2_MEMORY_OVERLAY, 244 | ) = range(1, 4) 245 | 246 | 247 | class v4l2_requestbuffers(ctypes.Structure): 248 | _fields_ = [ 249 | ('count', ctypes.c_uint32), 250 | ('type', v4l2_buf_type), 251 | ('memory', v4l2_memory), 252 | ('reserved', ctypes.c_uint32 * 2), 253 | ] 254 | # NOTE: For V4L2_MEMORY_DMABUF, we'd also need: 255 | # - v4l2_exportbuffer for sharing buffers via fd 256 | # - v4l2_plane structures for multi-plane formats 257 | 258 | 259 | # https://www.kernel.org/doc/html/v4.9/media/uapi/v4l/buffer.html#c.v4l2_buffer 260 | 261 | 262 | class timeval(ctypes.Structure): 263 | _fields_ = [ 264 | ('secs', ctypes.c_long), 265 | ('usecs', ctypes.c_long), 266 | ] 267 | 268 | 269 | class v4l2_timecode(ctypes.Structure): 270 | _fields_ = [ 271 | ('type', ctypes.c_uint32), 272 | ('flags', ctypes.c_uint32), 273 | ('frames', ctypes.c_uint8), 274 | ('seconds', ctypes.c_uint8), 275 | ('minutes', ctypes.c_uint8), 276 | ('hours', ctypes.c_uint8), 277 | ('userbits', ctypes.c_uint8 * 4), 278 | ] 279 | 280 | 281 | class v4l2_buffer(ctypes.Structure): 282 | class _u(ctypes.Union): 283 | _fields_ = [ 284 | ('offset', ctypes.c_uint32), # for V4L2_MEMORY_MMAP 285 | ('userptr', ctypes.c_ulong), # for V4L2_MEMORY_USERPTR 286 | # NOTE: For V4L2_MEMORY_DMABUF, an explicit fd field is required. 287 | # If needed, expand this union to include: ('fd', ctypes.c_int) 288 | ] 289 | 290 | _fields_ = [ 291 | ('index', ctypes.c_uint32), 292 | ('type', v4l2_buf_type), 293 | ('bytesused', ctypes.c_uint32), 294 | ('flags', ctypes.c_uint32), 295 | ('field', v4l2_field), 296 | ('timestamp', timeval), 297 | ('timecode', v4l2_timecode), 298 | ('sequence', ctypes.c_uint32), 299 | ('memory', v4l2_memory), 300 | ('m', _u), 301 | ('length', ctypes.c_uint32), 302 | ('input', ctypes.c_uint32), 303 | ('reserved', ctypes.c_uint32), 304 | ] 305 | 306 | 307 | # https://www.kernel.org/doc/html/v4.9/media/uapi/v4l/vidioc-g-parm.html#c.v4l2_captureparm 308 | 309 | 310 | class v4l2_fract(ctypes.Structure): 311 | _fields_ = [ 312 | ('numerator', ctypes.c_uint32), 313 | ('denominator', ctypes.c_uint32), 314 | ] 315 | 316 | 317 | class v4l2_captureparm(ctypes.Structure): 318 | _fields_ = [ 319 | ('capability', ctypes.c_uint32), 320 | ('capturemode', ctypes.c_uint32), 321 | ('timeperframe', v4l2_fract), 322 | ('extendedmode', ctypes.c_uint32), 323 | ('readbuffers', ctypes.c_uint32), 324 | ('reserved', ctypes.c_uint32 * 4), 325 | ] 326 | 327 | 328 | # https://www.kernel.org/doc/html/v4.9/media/uapi/v4l/vidioc-g-parm.html#c.v4l2_streamparm 329 | 330 | 331 | class v4l2_outputparm(ctypes.Structure): 332 | _fields_ = [ 333 | ('capability', ctypes.c_uint32), 334 | ('outputmode', ctypes.c_uint32), 335 | ('timeperframe', v4l2_fract), 336 | ('extendedmode', ctypes.c_uint32), 337 | ('writebuffers', ctypes.c_uint32), 338 | ('reserved', ctypes.c_uint32 * 4), 339 | ] 340 | 341 | 342 | class v4l2_streamparm(ctypes.Structure): 343 | class _u(ctypes.Union): 344 | _fields_ = [ 345 | ('capture', v4l2_captureparm), 346 | ('output', v4l2_outputparm), 347 | ('raw_data', ctypes.c_char * 200), 348 | ] 349 | 350 | _fields_ = [('type', v4l2_buf_type), ('parm', _u)] 351 | 352 | 353 | # =================== 354 | # ioctl commands 355 | # =================== 356 | 357 | VIDIOC_QUERYCAP = _IOR('V', 0, v4l2_capability) 358 | VIDIOC_G_FMT = _IOWR('V', 4, v4l2_format) 359 | VIDIOC_S_FMT = _IOWR('V', 5, v4l2_format) 360 | VIDIOC_REQBUFS = _IOWR('V', 8, v4l2_requestbuffers) 361 | VIDIOC_QUERYBUF = _IOWR('V', 9, v4l2_buffer) 362 | VIDIOC_QBUF = _IOWR('V', 15, v4l2_buffer) 363 | VIDIOC_DQBUF = _IOWR('V', 17, v4l2_buffer) 364 | VIDIOC_STREAMON = _IOW('V', 18, ctypes.c_int) 365 | VIDIOC_STREAMOFF = _IOW('V', 19, ctypes.c_int) 366 | VIDIOC_G_PARM = _IOWR('V', 21, v4l2_streamparm) 367 | VIDIOC_S_PARM = _IOWR('V', 22, v4l2_streamparm) 368 | 369 | # =================== 370 | # export list 371 | # =================== 372 | 373 | __all__ = [ 374 | 'v4l2_capability', 375 | 'v4l2_format', 376 | 'v4l2_pix_format', 377 | 'v4l2_buffer', 378 | 'v4l2_requestbuffers', 379 | 'v4l2_streamparm', 380 | 'v4l2_captureparm', 381 | 'v4l2_fract', 382 | 'PIXEL_FORMATS', 383 | 'VIDIOC_QUERYCAP', 384 | 'VIDIOC_G_FMT', 385 | 'VIDIOC_S_FMT', 386 | 'VIDIOC_REQBUFS', 387 | 'VIDIOC_QUERYBUF', 388 | 'VIDIOC_QBUF', 389 | 'VIDIOC_DQBUF', 390 | 'VIDIOC_STREAMON', 391 | 'VIDIOC_STREAMOFF', 392 | 'VIDIOC_G_PARM', 393 | 'VIDIOC_S_PARM', 394 | 'fourcc', 395 | 'decode_fourcc', 396 | ] 397 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Physical Atari 2 | 3 | **Physical Atari** is a platform for evaluating reinforcement learning (RL) algorithms. 4 | 5 | Many RL algorithms are evaluated primarily with simulations, driven by the ease of running experiments that can be replicated by other researchers. Although this is a successful research approach, it is widely recognized in science and engineering that a simulation can only capture part of the complexity of the real system. With so much RL research using simulators, there is the danger that improvements observed in a simulator does not translate to their real-world counterparts. This is especially true for RL research developed with the Atari Learning Environment (ALE), as the corresponding physical systems were not easily accessible. Some of algorithms developed with the ALE have been deployed to other real-world settings, but many RL algorithms have only been tested in the ALE. 6 | 7 | The **Physical Atari** platform provides a software and hardware interface between reinforcement learning agents and a modern version of a physical Atari. This interface enables the evaluation of RL algorithms developed for the ALE with a real-world instantiation. The physical platform exposes several timing concerns that are not present in the ALE. The physical Atari system operates in real-time and is not turn based (it does not wait for an agent action). Physical systems have non-negligible latency and physical systems have unmodelled dynamics (sensor and actuation noise). Unlike traditional environments that use pixel-perfect emulators (e.g., ALE), this setup integrates a physical **Atari 2600+** console, a camera-based observation pipeline, and a real-time control system using physical actuators. 8 | 9 | 10 | This platform provides three contributions for the RL research community. 11 | - A physical platform for evaluating RL algorithms that have been primarily developed with the ALE 12 | - An implementation of a game-independent RL algorithm on this platform that learns under real-time constraints to reliably surpass standard benchmark performance within five hours (1 million frames) on multiple Atari games. 13 | - The platform provides insight into the limitations of our simulators. Discrepancies in performance between the simulator and reality suggests changes to our simulated environments and the metrics for evaluating RL algorithms. 14 | 15 | --- 16 | 17 | ## Overview 18 | 19 | The system consists of three main components: 20 | 21 | - **Environment**: A modern [Atari 2600+](https://www.amazon.com/Atari-2600/dp/B0CG7LMFKY) console, outputting real 4:3 video over HDMI. The console is pin-compatible with original Atari game cartridges and joysticks. 22 | - **Agent**: The learning algorithms and supporting control logic run on a gaming laptop or workstation. 23 | - **Interface**: 24 | - **Observation**: Video is captured by a USB camera at 60 frames per second 25 | - **Action**: Agent selected actions are sent to the console by one of the following. 26 | - A **mechanical actuator** (RoboTroller) that physically moves the CX40+ joystick 27 | - A **digital I/O module** that bypasses the joystick and sends signals directly to the controller port via the DB9 cable 28 | 29 | This setup enables the study of RL algorithms in the physical world, in the presence of many real-world concerns (domain shifts, latency, and noise). 30 | 31 | --- 32 | 33 | ## System Setup 34 | 35 | A complete hardware/software setup guide is available here: 36 | [**System Setup**](docs/setup.md) 37 | 38 | --- 39 | 40 | ## Components 41 | 42 | | Component | Description | 43 | |---------------------|-------------| 44 | | **Console** | [Atari 2600+](https://www.amazon.com/Atari-2600/dp/B0CG7LMFKY) with CX40+ joystick | 45 | | **Monitor** | Any gaming monitor with native 16:9 resolution and 60Hz refresh rate | 46 | | **Camera** | [Razer Kiyo Pro (1080p)](https://www.amazon.com/dp/B08T1MWX6J) — supports 60FPS uncompressed | 47 | | **Control (Option 1)** | Mechanical joystick control via servo-based actuator [RoboTroller](https://robotroller.keenagi.com) | 48 | | **Control (Option 2)** | Digital I/O module — e.g., [MCC USB-1024LS](https://microdaq.com/usb-1024ls-24-bit-digital-input-output-i-o-module.php) | 49 | 50 | > See [setup.md](docs/setup.md) for placement, lighting, USB bandwidth, tag positioning, and system setup. 51 | 52 | --- 53 | 54 | ## Design of the RL software interface 55 | 56 | In the textbook picture, an RL agent interacts with its environment by exchanging signals for reward, observation, and action. In the episodic domains there is also a signal for the end of episode. In most Atari/Gym interfaces this is extended to support signals for end-of-life, sequence truncation, and supporting a minimal action set in a game. These additional signals are useful for accelerating early performance in Atari games, and ease of experimentation is a significant factor for the physical Atari platform. We have chosen to expose the additional signals to our learning agents. 57 | 58 | We want the RL algorithms to have an interface that supports real-time interaction with the real world. We have changed the agent/environment calling conventions a common choice (where the agent directly calls the environment) to an interface where the experiment infrastructure sends the signals to the agent and to the environment. 59 | 60 | The primary agent/environment interface operates at 60fps. Observations are received from the video camera (with some internal buffers) and sent to the agent. Actions selected by the agent are converted into commands for the robotroller to move the joystick (which also has latencies). More effort is required to extract signals for rewards, lives, and the end of episode from the observed video, and these are described below. 61 | 62 | ## Games 63 | 64 | We restricted our attention to console games that only require a **fire button press** to start gameplay. Many Atari games require toggling the physical **reset switch** on the Atari console to restart the game, and so were not used. 65 | 66 | 67 | The following games are known to work with this setup and cover a range of visual styles and control demands: 68 | 69 | _Recommended_ 70 | 71 | - **Ms. Pac-Man** 72 | - **Centipede** 73 | - **Up 'n Down** 74 | - **Krull** 75 | 76 | _Less Tested_ 77 | - **Q*Bert** 78 | - **Battle Zone** 79 | - **Atlantis** 80 | - **Defender** 81 | 82 | ## Detecting Score, Lives, and the End of Game 83 | 84 | 85 | Detection of the game score, lives, and the end of game requires custom logic for the Physical Atari. In the ALE, bits from the internal game state are used to compute these signals in a custom manner for each game. For the physical Atari, these signals must be computed from the video screen. Multiple steps are required for extracting these signals, and it is the most brittle part of the physical atari platform. 86 | 87 | For the first step, camera images are rectified by identifying the corners of the Atari game screen in the camera image and applying a standard linear transformation. We have tried multiple approaches here. One reliable approach is to manually identify the four screen corners, as these do not vary by Atari game. This approach breaks down if the system is subject to jostling or vibration. Two other approaches we have examined are the use of April tags, and whole screen recognition. 88 | 89 | For the second step we manually identify boxes around the score and lives for each game. 90 | 91 | For the third step, the score is read from the video. The digits used in each atari game differs substantially, and they are also distinct from digits in standard online datasets such as MNIST. We again tried multiple approaches. The most reliable was to collect a dataset for each game of images and numerical scores, and then train a supervised learning classifier for each. For training the classifier, the captured images are augmented with several transformations, to account for small variations in calibration, lighting, and geometry. A similar process is used to detect the number of lives in the game. 92 | 93 | Some additional custom logic is used on top of the neural net classifiers to extract reliable signals. Score differences are validated for consistency with the game scores observed in the simulator, so scores can't change by an unrealizable amount between frames. Additional logic is present to recover from transient errors, and to detect a plausible end of the game. When the game is presumed to be over, a FIRE action is sent to restart, and the end of game is sent to the learning agent. 94 | 95 | // A per-game CRNN model is used to extract the score directly from screen pixels. These models are trained on ALE-rendered frames, using known score regions for supervision. 96 | 97 | // For some games, a more targeted model may be used instead. 98 | 99 | ### Additional Considerations for ROM and Hardware Variability 100 | 101 | Several additional ROM and hardware factors can impact signal extraction accuracy. Different ROM revisions may introduce subtle or significant variations, including difficulty adjustments, bug fixes, or changes in visual indicators such as the displayed number of lives. Region-specific differences between NTSC and PAL cartridges could also potentially result in rendering variations, timing discrepancies, or different memory layouts, although the extent to which these factors necessitate separate processing pipelines remains to be fully tested. 102 | 103 | Additionally, prototype versions of games may differ notably from retail releases, leading to mismatches if the ALE emulator is based on a retail ROM and the physical cartridge represents an earlier or alternative version. Furthermore, variations in bankswitching techniques, used by certain cartridges to extend memory addressing, could also explain inconsistencies between the simulated ALE environment and physical hardware, particularly if RAM addresses differ for critical indicators like lives or score. These factors collectively underscore the importance of careful validation and customization when transitioning from simulation to the physical Atari platform. 104 | 105 | --- 106 | 107 | ## Research Challenges 108 | 109 | This project focuses on bridging the **reality gap** in empirical reinforcement learning research. 110 | 111 | **Differences between the emulator and reality include:** 112 | - No turn-taking in a real-time environment 113 | - Visual statistics (emulator vs real camera) 114 | - Latency in video capture and actuator response 115 | - Imperfect lighting, reflections, and image noise 116 | - Score detection errors under variable lighting and resolution 117 | 118 | Both trained policies and reinforcement learning algorithms can degrade significantly when exposed to these real-world conditions, even if they perform well in simulation. 119 | 120 | --- 121 | 122 | ## Launching 123 | 124 | To run the physical setup, build the docker environment with: 125 | 126 | ``` 127 | ./docker_build.sh 128 | ``` 129 | 130 | Run the docker environment with: 131 | 132 | ``` 133 | ./docker_run.sh 134 | ``` 135 | 136 | Launch the physical harness: 137 | 138 | ``` 139 | python physical_harness.py 140 | ``` 141 | 142 | An example run for the Ms Pacman game, for a custom configuration 143 | 144 | ``` 145 | python3 harness_physical.py \ 146 | --detection_config=configs/screen_detection/fixed.json \ 147 | --game_config=configs/games/ms_pacman.json \ 148 | --agent_type=agent_delay_target \ 149 | --reduce_action_set=2 --gpu=0 \ 150 | --joystick_config=configs/controllers/robotroller.json \ 151 | --total_frames=1_000_000 152 | ``` 153 | 154 | 155 | --- 156 | 157 | ## System Performance and Profiling 158 | 159 | Running the physical setup reliably requires that the system meets strict performance requirements — especially with regard to CPU and GPU power settings, thermal limits, and scheduling behavior. 160 | 161 | Modern systems often default to power-saving configurations that can cause unexpected latency, frame delays, or jitter. These issues are especially problematic in real-time or hardware-in-the-loop setups. 162 | 163 | - See [Performance Setup](./docs/setup.md#system-performance-validation) for validating system configuration and fixing system-level performance issues. 164 | - See [Profiling Guide](./docs/profiling.md) for details on collecting and analyzing performance data using NVIDIA Nsight Systems and NVTX annotations. 165 | 166 | --- 167 | 168 | ## License 169 | 170 | This project is licensed under the Apache 2.0 License. 171 | 172 | Unless otherwise noted, this license applies to all source code and pre-trained model files included in the repository. 173 | -------------------------------------------------------------------------------- /framework/models/score_detector/crnn_ctc.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Keen Technologies, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import math 17 | 18 | import numpy as np 19 | import torch 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | from torch.cuda.amp import autocast 23 | 24 | from framework.Logger import logger 25 | 26 | # Based on the CRNN described here: https://arxiv.org/pdf/1507.05717 27 | # CTC-based (need to collapse repeats and filter blank tokens) where 28 | # the output is decoded as a sequence of class predictions up to a 29 | # max length (or until padding is encountered) 30 | 31 | 32 | class CRNN(nn.Module): 33 | def __init__(self, num_classes=11, hidden_size=128, cnn_dropout_rate=0.3, rnn_dropout_rate=0.5): 34 | super().__init__() 35 | self.cnn = nn.Sequential( 36 | nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1), 37 | nn.ReLU(inplace=True), 38 | nn.MaxPool2d(2, 2), 39 | nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), 40 | nn.ReLU(inplace=True), 41 | nn.MaxPool2d(2, 2), 42 | nn.Conv2d(64, 96, kernel_size=3, stride=1, padding=1), 43 | nn.ReLU(inplace=True), 44 | nn.Conv2d(96, 96, kernel_size=3, stride=1, padding=1), 45 | nn.ReLU(inplace=True), 46 | nn.MaxPool2d(kernel_size=(1, 2), stride=2), 47 | nn.Conv2d(96, 128, kernel_size=3, stride=1, padding=1), 48 | nn.ReLU(inplace=True), 49 | nn.Conv2d(128, 128, kernel_size=2, stride=1), 50 | nn.ReLU(inplace=True), 51 | nn.Dropout(cnn_dropout_rate), 52 | ) 53 | 54 | self.rnn = nn.GRU(input_size=128, hidden_size=hidden_size, num_layers=1, bidirectional=True, batch_first=True) 55 | 56 | self.dropout = nn.Dropout(rnn_dropout_rate) 57 | self.fc = nn.Linear(hidden_size * 2, num_classes) 58 | 59 | def forward(self, x): 60 | x = self.cnn(x) # (B, C, H, W) 61 | x = self.dropout(x) 62 | 63 | B, C, H, W = x.size() 64 | x = x.view(B, C, H * W).permute(0, 2, 1) # (B, S, C) 65 | 66 | x, _ = self.rnn(x) # (B, S, 2*hidden) 67 | x = self.fc(x) # (B, S, num_classes) 68 | return x.permute(0, 2, 1) # (B, num_classes, S) 69 | 70 | 71 | def greedy_decode_ctc( 72 | logits: torch.Tensor, 73 | max_length: int, 74 | padding_index=-1, 75 | blank_index=1, 76 | temperature=1.0, 77 | decoded_buf=None, 78 | confidences_buf=None, 79 | ): 80 | 81 | # Apply temperature scaling before softmax 82 | logits = logits / temperature 83 | 84 | # logits: [B, num_classes, S] 85 | logprobs = F.log_softmax(logits, dim=1) 86 | preds = torch.argmax(logprobs, dim=1) # [B, S] 87 | 88 | # remove repeats and blanks 89 | prev_preds = torch.cat( 90 | [torch.full((preds.size(0), 1), fill_value=-1, dtype=preds.dtype, device=preds.device), preds[:, :-1]], dim=1 91 | ) 92 | mask = (preds != prev_preds) & (preds != blank_index) # [B, S] 93 | 94 | if decoded_buf is not None: 95 | decoded = decoded_buf 96 | else: 97 | decoded = torch.full((preds.shape[0], max_length), padding_index, dtype=torch.int32, device=logprobs.device) 98 | 99 | probs = F.softmax(logits, dim=1) # [B, num_classes, S] 100 | entropy = -torch.sum(probs * torch.log(probs + 1e-6), dim=1) # [B, S] 101 | 102 | if confidences_buf is not None: 103 | confidences = confidences_buf 104 | else: 105 | confidences = torch.zeros((preds.shape[0], max_length), dtype=torch.float32, device=logprobs.device) 106 | 107 | for b in range(preds.size(0)): 108 | valid_indices = torch.nonzero(mask[b], as_tuple=False).squeeze(1) 109 | if valid_indices.numel() == 0: 110 | continue 111 | seq = preds[b, valid_indices] 112 | n = min(seq.numel(), max_length) 113 | decoded[b, :n] = seq[:n] 114 | confidences[b, :n] = entropy[b, valid_indices[:n]] 115 | 116 | return decoded, confidences 117 | 118 | 119 | class CRNNDecoder(CRNN): 120 | def __init__( 121 | self, 122 | num_classes=12, 123 | hidden_size=128, 124 | max_digits=6, 125 | score_mean=0.1307, 126 | score_std=0.113, 127 | lives_mean=0.1307, 128 | lives_std=0.113, 129 | blank_index=11, 130 | padding_index=-1, 131 | life_symbol=10, 132 | image_size=32, 133 | entropy_threshold=1.0, 134 | entropy_ceiling=1.0, 135 | temperature=1.0, 136 | device='cpu', 137 | use_mixed_precision=False, 138 | ): 139 | super().__init__(num_classes=num_classes, hidden_size=hidden_size) 140 | 141 | self.use_mixed_precision = use_mixed_precision 142 | self.device = device 143 | 144 | self.score_mean = torch.tensor([score_mean], dtype=torch.float32, device=self.device) 145 | self.score_std = torch.tensor([score_std], dtype=torch.float32, device=self.device) 146 | self.lives_mean = torch.tensor([lives_mean], dtype=torch.float32, device=self.device) 147 | self.lives_std = torch.tensor([lives_std], dtype=torch.float32, device=self.device) 148 | 149 | self.blank_idx = blank_index 150 | self.padding_idx = padding_index 151 | self.life_symbol = life_symbol 152 | self.life_symbol_str = str(self.life_symbol) 153 | 154 | self.entropy_threshold = entropy_threshold 155 | self.entropy_ceiling = entropy_ceiling 156 | self.temperature = temperature 157 | 158 | self.image_size = image_size 159 | self.decode_max_length = max_digits 160 | self.input_dims = (self.image_size, self.image_size * self.decode_max_length) 161 | 162 | self._decoded_buf = None 163 | self._confidences_buf = None 164 | 165 | def preprocess(self, x: np.ndarray, is_lives=False): 166 | crop_t = torch.from_numpy(x).permute(2, 0, 1).float().div(255.0) 167 | 168 | if crop_t.shape[0] == 3: 169 | crop_t = 0.2989 * crop_t[0] + 0.5870 * crop_t[1] + 0.1140 * crop_t[2] 170 | crop_t = crop_t.unsqueeze(0) 171 | 172 | crop_t = F.interpolate(crop_t.unsqueeze(0), size=self.input_dims, mode='bilinear', align_corners=False).squeeze( 173 | 0 174 | ) 175 | 176 | mean = self.score_mean if not is_lives else self.lives_mean 177 | std = self.score_std if not is_lives else self.lives_std 178 | crop_t = crop_t.to(self.device) 179 | crop_t = (crop_t - mean) / std 180 | 181 | if self.use_mixed_precision: 182 | crop_t = crop_t.half() 183 | 184 | return crop_t 185 | 186 | def decode_ctc(self, logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: 187 | max_length = self.decode_max_length 188 | padding_index = self.padding_idx 189 | blank_index = self.blank_idx 190 | 191 | if self._decoded_buf is None: 192 | self._decoded_buf = torch.empty((logits.shape[0], max_length), dtype=torch.int32, device=logits.device) 193 | 194 | decoded = self._decoded_buf 195 | decoded.fill_(padding_index) 196 | 197 | if self._confidences_buf is None: 198 | self._confidences_buf = torch.empty( 199 | (logits.shape[0], max_length), dtype=torch.float32, device=logits.device 200 | ) 201 | 202 | confidences = self._confidences_buf 203 | confidences.zero_() 204 | 205 | return greedy_decode_ctc( 206 | logits, 207 | max_length, 208 | padding_index, 209 | blank_index, 210 | decoded_buf=decoded, 211 | confidences_buf=confidences, 212 | temperature=self.temperature, 213 | ) 214 | 215 | def predict( 216 | self, 217 | x: torch.Tensor, 218 | ) -> tuple[torch.Tensor, torch.Tensor]: 219 | logits = self.forward(x) 220 | return self.decode_ctc(logits) 221 | 222 | def convert(self, score_pred, score_confidence, lives_pred=None, lives_confidence=None): 223 | """ 224 | Some games may not display score or lives for periods of time, so we need to distinguish 225 | between true-zero and not-visible states. True zero is mostly an issue with lives, where 226 | the absence of life symbols means no lives remain in several supported games. 227 | This distinction is more complex in games like Q*bert, where both cases occur. 228 | 229 | With CTC decoding, a blank prediction should indicate no score or lives displayed. However, 230 | the model may sometimes be overconfident and predict noise instead. 231 | 232 | Game-specific quirks: 233 | Centipede: The score strobe intermittently; when off, a blank prediction means not-visible. 234 | Q*bert: Both lives and score flash off together for extended periods; blank predictions for both indicate not-visible. 235 | Q*bert (and others): Absence of life symbols means true zero; if the score prediction is not blank, return zero. 236 | Atlantis: Lives detection is non-trivial; lives prediction is expected to be None. 237 | 238 | To handle both cases for lives, return None for not-visible and zero for true zero. 239 | """ 240 | score_pred = score_pred[score_pred != self.padding_idx] 241 | blank_score = len(score_pred) == 0 242 | score = None if blank_score else int(''.join(map(str, score_pred))) 243 | score_conf = score_confidence[score_confidence != self.padding_idx] 244 | 245 | lives = None 246 | if lives_pred is not None: 247 | lives_pred = lives_pred[lives_pred != self.padding_idx] 248 | blank_lives = len(lives_pred) == 0 249 | 250 | if blank_lives and blank_score: 251 | lives = None 252 | else: 253 | # lives prediction will be 1 symbol = 10; 2 symbols = 1010; and so on. 254 | # find the number of times the symbol is repeated 255 | lives_str = ''.join(lives_pred.astype(str)) 256 | lives = lives_str.count(self.life_symbol_str) 257 | 258 | lives_conf = None 259 | if lives_confidence is not None: 260 | lives_conf = lives_confidence[lives_confidence != self.padding_idx] 261 | return score, score_conf, lives, lives_conf 262 | 263 | 264 | def fuse_cnn_layers(model): 265 | fused = [] 266 | for i in range(0, len(model.cnn) - 1): 267 | m1 = model.cnn[i] 268 | m2 = model.cnn[i + 1] 269 | if isinstance(m1, nn.Conv2d) and isinstance(m2, nn.ReLU): 270 | fused.append([str(i), str(i + 1)]) 271 | 272 | torch.quantization.fuse_modules(model.cnn, fused, inplace=True) 273 | return model 274 | 275 | 276 | def load_model( 277 | checkpoint: str, 278 | env_name: str, 279 | mixed_precision: bool = True, 280 | device: str = 'cpu', 281 | memory_format=torch.contiguous_format, 282 | ) -> CRNNDecoder: 283 | 284 | checkpoint_data = torch.load(checkpoint, map_location='cpu', weights_only=False) 285 | weights = checkpoint_data['state_dict'] 286 | model_config = checkpoint_data['model_config'] 287 | game_config = checkpoint_data['game_config'] 288 | train_config = checkpoint_data['train_config'] 289 | 290 | score_mean, score_std = game_config["score_mean_std"] 291 | lives_mean, lives_std = game_config["lives_mean_std"] 292 | max_digits = game_config["max_digits"] 293 | 294 | num_classes = model_config["num_classes"] 295 | hidden_size = model_config["hidden_size"] 296 | blank_idx = model_config["blank_idx"] 297 | padding_idx = model_config["padding_idx"] 298 | life_symbol = model_config["life_symbol"] 299 | image_size = model_config["image_size"] 300 | # entropy_threshold = model_config["balanced_entropy_threshold"] 301 | entropy_threshold = model_config["entropy_threshold"] 302 | entropy_ceiling = ( 303 | model_config["entropy_ceiling"] if "entropy_ceiling" in model_config else math.log(num_classes) * 0.75 304 | ) 305 | temperature = train_config["temperature"] 306 | 307 | model = CRNNDecoder( 308 | num_classes=num_classes, 309 | hidden_size=hidden_size, 310 | max_digits=max_digits, 311 | blank_index=blank_idx, 312 | padding_index=padding_idx, 313 | image_size=image_size, 314 | life_symbol=life_symbol, 315 | score_mean=score_mean, 316 | score_std=score_std, 317 | lives_mean=lives_mean, 318 | lives_std=lives_std, 319 | entropy_threshold=entropy_threshold, 320 | entropy_ceiling=entropy_ceiling, 321 | temperature=temperature, 322 | device=device, 323 | use_mixed_precision=mixed_precision, 324 | ) 325 | 326 | model.load_state_dict(weights) 327 | model = model.to(memory_format=memory_format).to(device) 328 | model = fuse_cnn_layers(model) 329 | model.eval() 330 | if mixed_precision: 331 | model = model.half() 332 | model = torch.compile(model, mode='reduce-overhead', fullgraph=False, dynamic=False) 333 | # print(model) 334 | 335 | model_param_count = 0 336 | for p in model.parameters(): 337 | model_param_count += p.numel() 338 | logger.debug(f"CRNN_CTC: model param count = {model_param_count}") 339 | 340 | return model 341 | --------------------------------------------------------------------------------