├── playground ├── games │ ├── chess │ │ ├── common │ │ │ ├── __init__.py │ │ │ ├── consts.py │ │ │ ├── flood_fill.py │ │ │ ├── attack_tables.py │ │ │ └── movegen.py │ │ ├── assets │ │ │ ├── pieces │ │ │ │ ├── bb.png │ │ │ │ ├── bk.png │ │ │ │ ├── bn.png │ │ │ │ ├── bp.png │ │ │ │ ├── bq.png │ │ │ │ ├── br.png │ │ │ │ ├── wb.png │ │ │ │ ├── wk.png │ │ │ │ ├── wn.png │ │ │ │ ├── wp.png │ │ │ │ ├── wq.png │ │ │ │ └── wr.png │ │ │ └── icons │ │ │ │ ├── crown.png │ │ │ │ ├── reset.png │ │ │ │ ├── computer.png │ │ │ │ ├── player.png │ │ │ │ ├── left_arrow.png │ │ │ │ ├── lightbulb.png │ │ │ │ ├── pawn_icon.png │ │ │ │ ├── right_arrow.png │ │ │ │ ├── checkbox_checked.png │ │ │ │ └── checkbox_unchecked.png │ │ ├── __init__.py │ │ ├── chess_ui.py │ │ └── chess.py │ ├── sudoku │ │ ├── images │ │ │ ├── icon.png │ │ │ ├── menu.png │ │ │ ├── pause.png │ │ │ ├── title.PNG │ │ │ └── resume.jpg │ │ ├── __init__.py │ │ ├── sudoku_generator.py │ │ ├── sudoku_ui.py │ │ └── sudoku.py │ ├── minesweeper │ │ ├── images │ │ │ ├── bomb.png │ │ │ ├── cross.png │ │ │ ├── flag.png │ │ │ ├── rocket.png │ │ │ ├── smiley.png │ │ │ └── clock-select.png │ │ ├── __init__.py │ │ ├── game_cfg.py │ │ ├── minesweeper_ui.py │ │ └── minesweeper.py │ ├── gomoku │ │ ├── designer │ │ │ ├── image │ │ │ │ ├── black.png │ │ │ │ ├── white.png │ │ │ │ ├── chessboard.png │ │ │ │ ├── win-removebg-preview.png │ │ │ │ └── lost-removebg-preview.png │ │ │ ├── gobang_qrc.qrc │ │ │ └── gobang_ui.ui │ │ ├── __init__.py │ │ ├── gomoku_ui.py │ │ └── AI.py │ ├── reversi │ │ ├── __init__.py │ │ ├── reversi_ui.py │ │ ├── AI.py │ │ └── reversi.py │ ├── tictactoe │ │ ├── __init__.py │ │ ├── tictactoe_ui.py │ │ ├── AI.py │ │ └── tictactoe.py │ ├── __init__.py │ └── base.py ├── experiment │ ├── __init__.py │ └── recipe.py ├── simulator │ └── __init__.py ├── utils │ ├── __init__.py │ └── utils.py ├── registry.py ├── __init__.py ├── evaluator │ ├── __init__.py │ ├── base_qa.py │ └── evaluator.py ├── agents │ ├── __init__.py │ ├── base.py │ └── single_step_agents.py ├── state_code.py └── benchmark.py ├── assets ├── radar_chart.jpg └── LVLM-Playground.jpg ├── configs ├── agents │ ├── google │ │ ├── gemini-1.5-pro.py │ │ ├── gemini-1.5-flash.py │ │ └── gemini-1.0-pro-vision.py │ ├── 01-ai │ │ ├── yi-vl-34b.py │ │ └── yi-vl-6b.py │ ├── internvl │ │ ├── internvl2-1b.py │ │ ├── internvl2-8b.py │ │ ├── internvl2-26b.py │ │ ├── internvl2-40b.py │ │ ├── internvl2-4b.py │ │ └── internvl2-2b.py │ ├── anhthropic │ │ ├── claude-3-opus.py │ │ ├── claude-3-haiku.py │ │ ├── claude-3-sonnet.py │ │ └── claude-3.5-sonnet.py │ ├── llava │ │ ├── llava1.6-vicuna34b.py │ │ ├── llava1.6-mistral7b.py │ │ ├── llava1.6-vicuna13b.py │ │ └── llava1.6-vicuna7b.py │ ├── openai │ │ ├── gpt-4-turbo-240409.py │ │ ├── gpt-4o-mini-240718.py │ │ └── gpt-4o-240806.py │ ├── qwen │ │ └── qwen2-vl-7b.py │ ├── microsoft │ │ └── phi3.5-vl.py │ └── deepseek │ │ └── deepseek-vl-7b.py ├── recipe │ └── base.py ├── base.py └── games │ ├── tictactoe.py │ ├── sudoku.py │ ├── minesweeper.py │ ├── gomoku.py │ ├── chess.py │ └── reversi.py ├── setup.py ├── requirements.txt ├── generate_benchmark.py ├── run.py ├── .pre-commit-config.yaml ├── evaluate.py ├── plot_radar.py ├── .gitignore └── README.md /playground/games/chess/common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /playground/experiment/__init__.py: -------------------------------------------------------------------------------- 1 | from .recipe import Recipe 2 | 3 | __all__ = ['Recipe'] 4 | -------------------------------------------------------------------------------- /assets/radar_chart.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinke-wang/LVLM-Playground/HEAD/assets/radar_chart.jpg -------------------------------------------------------------------------------- /playground/simulator/__init__.py: -------------------------------------------------------------------------------- 1 | from .simulator import GameSimulator 2 | 3 | __all__ = ['GameSimulator'] 4 | -------------------------------------------------------------------------------- /assets/LVLM-Playground.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinke-wang/LVLM-Playground/HEAD/assets/LVLM-Playground.jpg -------------------------------------------------------------------------------- /playground/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import encode_image, set_random_seed 2 | 3 | __all__ = ['set_random_seed', 'encode_image'] 4 | -------------------------------------------------------------------------------- /playground/games/sudoku/images/icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinke-wang/LVLM-Playground/HEAD/playground/games/sudoku/images/icon.png -------------------------------------------------------------------------------- /playground/games/sudoku/images/menu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinke-wang/LVLM-Playground/HEAD/playground/games/sudoku/images/menu.png -------------------------------------------------------------------------------- /playground/games/sudoku/images/pause.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinke-wang/LVLM-Playground/HEAD/playground/games/sudoku/images/pause.png -------------------------------------------------------------------------------- /playground/games/sudoku/images/title.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinke-wang/LVLM-Playground/HEAD/playground/games/sudoku/images/title.PNG -------------------------------------------------------------------------------- /playground/games/chess/assets/pieces/bb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinke-wang/LVLM-Playground/HEAD/playground/games/chess/assets/pieces/bb.png -------------------------------------------------------------------------------- /playground/games/chess/assets/pieces/bk.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinke-wang/LVLM-Playground/HEAD/playground/games/chess/assets/pieces/bk.png -------------------------------------------------------------------------------- /playground/games/chess/assets/pieces/bn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinke-wang/LVLM-Playground/HEAD/playground/games/chess/assets/pieces/bn.png -------------------------------------------------------------------------------- /playground/games/chess/assets/pieces/bp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinke-wang/LVLM-Playground/HEAD/playground/games/chess/assets/pieces/bp.png -------------------------------------------------------------------------------- /playground/games/chess/assets/pieces/bq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinke-wang/LVLM-Playground/HEAD/playground/games/chess/assets/pieces/bq.png -------------------------------------------------------------------------------- /playground/games/chess/assets/pieces/br.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinke-wang/LVLM-Playground/HEAD/playground/games/chess/assets/pieces/br.png -------------------------------------------------------------------------------- /playground/games/chess/assets/pieces/wb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinke-wang/LVLM-Playground/HEAD/playground/games/chess/assets/pieces/wb.png -------------------------------------------------------------------------------- /playground/games/chess/assets/pieces/wk.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinke-wang/LVLM-Playground/HEAD/playground/games/chess/assets/pieces/wk.png -------------------------------------------------------------------------------- /playground/games/chess/assets/pieces/wn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinke-wang/LVLM-Playground/HEAD/playground/games/chess/assets/pieces/wn.png -------------------------------------------------------------------------------- /playground/games/chess/assets/pieces/wp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinke-wang/LVLM-Playground/HEAD/playground/games/chess/assets/pieces/wp.png -------------------------------------------------------------------------------- /playground/games/chess/assets/pieces/wq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinke-wang/LVLM-Playground/HEAD/playground/games/chess/assets/pieces/wq.png -------------------------------------------------------------------------------- /playground/games/chess/assets/pieces/wr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinke-wang/LVLM-Playground/HEAD/playground/games/chess/assets/pieces/wr.png -------------------------------------------------------------------------------- /playground/games/sudoku/images/resume.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinke-wang/LVLM-Playground/HEAD/playground/games/sudoku/images/resume.jpg -------------------------------------------------------------------------------- /playground/registry.py: -------------------------------------------------------------------------------- 1 | from pjtools.registry import Registry 2 | 3 | AGENT_REGISTRY = Registry('agent') 4 | GAME_REGISTRY = Registry('game') 5 | -------------------------------------------------------------------------------- /playground/games/chess/assets/icons/crown.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinke-wang/LVLM-Playground/HEAD/playground/games/chess/assets/icons/crown.png -------------------------------------------------------------------------------- /playground/games/chess/assets/icons/reset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinke-wang/LVLM-Playground/HEAD/playground/games/chess/assets/icons/reset.png -------------------------------------------------------------------------------- /playground/games/minesweeper/images/bomb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinke-wang/LVLM-Playground/HEAD/playground/games/minesweeper/images/bomb.png -------------------------------------------------------------------------------- /playground/games/minesweeper/images/cross.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinke-wang/LVLM-Playground/HEAD/playground/games/minesweeper/images/cross.png -------------------------------------------------------------------------------- /playground/games/minesweeper/images/flag.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinke-wang/LVLM-Playground/HEAD/playground/games/minesweeper/images/flag.png -------------------------------------------------------------------------------- /playground/__init__.py: -------------------------------------------------------------------------------- 1 | from playground.agents import * # noqa 2 | from playground.experiment import * # noqa 3 | from playground.games import * # noqa 4 | -------------------------------------------------------------------------------- /playground/games/chess/assets/icons/computer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinke-wang/LVLM-Playground/HEAD/playground/games/chess/assets/icons/computer.png -------------------------------------------------------------------------------- /playground/games/chess/assets/icons/player.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinke-wang/LVLM-Playground/HEAD/playground/games/chess/assets/icons/player.png -------------------------------------------------------------------------------- /playground/games/gomoku/designer/image/black.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinke-wang/LVLM-Playground/HEAD/playground/games/gomoku/designer/image/black.png -------------------------------------------------------------------------------- /playground/games/gomoku/designer/image/white.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinke-wang/LVLM-Playground/HEAD/playground/games/gomoku/designer/image/white.png -------------------------------------------------------------------------------- /playground/games/minesweeper/images/rocket.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinke-wang/LVLM-Playground/HEAD/playground/games/minesweeper/images/rocket.png -------------------------------------------------------------------------------- /playground/games/minesweeper/images/smiley.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinke-wang/LVLM-Playground/HEAD/playground/games/minesweeper/images/smiley.png -------------------------------------------------------------------------------- /configs/agents/google/gemini-1.5-pro.py: -------------------------------------------------------------------------------- 1 | lmm_agent = dict( 2 | name='gemini1.5-pro', 3 | agent='google_single', 4 | model='gemini-1.5-pro' 5 | ) 6 | -------------------------------------------------------------------------------- /playground/games/chess/assets/icons/left_arrow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinke-wang/LVLM-Playground/HEAD/playground/games/chess/assets/icons/left_arrow.png -------------------------------------------------------------------------------- /playground/games/chess/assets/icons/lightbulb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinke-wang/LVLM-Playground/HEAD/playground/games/chess/assets/icons/lightbulb.png -------------------------------------------------------------------------------- /playground/games/chess/assets/icons/pawn_icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinke-wang/LVLM-Playground/HEAD/playground/games/chess/assets/icons/pawn_icon.png -------------------------------------------------------------------------------- /playground/games/chess/__init__.py: -------------------------------------------------------------------------------- 1 | from .chess import Chess 2 | from .chess_qa import ChessQuestionAnswering 3 | 4 | __all__ = ['Chess', 'ChessQuestionAnswering'] 5 | -------------------------------------------------------------------------------- /playground/games/chess/assets/icons/right_arrow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinke-wang/LVLM-Playground/HEAD/playground/games/chess/assets/icons/right_arrow.png -------------------------------------------------------------------------------- /playground/games/gomoku/designer/image/chessboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinke-wang/LVLM-Playground/HEAD/playground/games/gomoku/designer/image/chessboard.png -------------------------------------------------------------------------------- /playground/games/minesweeper/images/clock-select.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinke-wang/LVLM-Playground/HEAD/playground/games/minesweeper/images/clock-select.png -------------------------------------------------------------------------------- /configs/agents/01-ai/yi-vl-34b.py: -------------------------------------------------------------------------------- 1 | _base_ = ['configs/agents/01-ai/yi-vl-6b.py'] 2 | 3 | lmm_agent = dict( 4 | name='yi-vl-34b', 5 | model='01-ai/Yi-VL-34B', 6 | ) 7 | -------------------------------------------------------------------------------- /playground/games/chess/assets/icons/checkbox_checked.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinke-wang/LVLM-Playground/HEAD/playground/games/chess/assets/icons/checkbox_checked.png -------------------------------------------------------------------------------- /playground/games/chess/assets/icons/checkbox_unchecked.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinke-wang/LVLM-Playground/HEAD/playground/games/chess/assets/icons/checkbox_unchecked.png -------------------------------------------------------------------------------- /playground/games/gomoku/__init__.py: -------------------------------------------------------------------------------- 1 | from .gomoku import Gomoku 2 | from .gomoku_qa import GomokuQuestionAnswering 3 | 4 | __all__ = ['Gomoku', 'GomokuQuestionAnswering'] 5 | -------------------------------------------------------------------------------- /playground/games/sudoku/__init__.py: -------------------------------------------------------------------------------- 1 | from .sudoku import Sudoku 2 | from .sudoku_qa import SudokuQuestionAnswering 3 | 4 | __all__ = ['Sudoku', 'SudokuQuestionAnswering'] 5 | -------------------------------------------------------------------------------- /playground/games/reversi/__init__.py: -------------------------------------------------------------------------------- 1 | from .reversi import Reversi 2 | from .reversi_qa import ReversiQuestionAnswering 3 | 4 | __all__ = ['Reversi', 'ReversiQuestionAnswering'] 5 | -------------------------------------------------------------------------------- /playground/games/gomoku/designer/image/win-removebg-preview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinke-wang/LVLM-Playground/HEAD/playground/games/gomoku/designer/image/win-removebg-preview.png -------------------------------------------------------------------------------- /playground/games/gomoku/designer/image/lost-removebg-preview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinke-wang/LVLM-Playground/HEAD/playground/games/gomoku/designer/image/lost-removebg-preview.png -------------------------------------------------------------------------------- /configs/agents/google/gemini-1.5-flash.py: -------------------------------------------------------------------------------- 1 | _base_ = ['configs/agents/google/gemini-1.5-pro.py'] 2 | 3 | lmm_agent = dict( 4 | name='gemini1.5-flash', 5 | model='gemini-1.5-flash', 6 | ) 7 | -------------------------------------------------------------------------------- /configs/agents/internvl/internvl2-1b.py: -------------------------------------------------------------------------------- 1 | _base_ = ['configs/agents/internvl/internvl2-2b.py'] 2 | 3 | lmm_agent = dict( 4 | name='internvl2-1b', 5 | model='OpenGVLab/InternVL2-1B', 6 | ) 7 | -------------------------------------------------------------------------------- /configs/agents/internvl/internvl2-8b.py: -------------------------------------------------------------------------------- 1 | _base_ = ['configs/agents/internvl/internvl2-2b.py'] 2 | 3 | lmm_agent = dict( 4 | name='internvl2-8b', 5 | model='OpenGVLab/InternVL2-8B', 6 | ) 7 | -------------------------------------------------------------------------------- /playground/games/tictactoe/__init__.py: -------------------------------------------------------------------------------- 1 | from .tictactoe import TicTacToe 2 | from .tictactoe_qa import TicTacToeQuestionAnswering 3 | 4 | __all__ = ['TicTacToeQuestionAnswering', 'TicTacToe'] 5 | -------------------------------------------------------------------------------- /configs/agents/internvl/internvl2-26b.py: -------------------------------------------------------------------------------- 1 | _base_ = ['configs/agents/internvl/internvl2-2b.py'] 2 | 3 | lmm_agent = dict( 4 | name='internvl2-26b', 5 | model='OpenGVLab/InternVL2-26B', 6 | ) 7 | -------------------------------------------------------------------------------- /configs/recipe/base.py: -------------------------------------------------------------------------------- 1 | name = 'standard' 2 | save_path = 'experiments' 3 | tasks = ['perceive', 'qa', 'rule', 'e2e'] 4 | games = ['tictactoe', 'reversi', 'gomoku', 'minesweeper', 'sudoku', 'chess'] 5 | -------------------------------------------------------------------------------- /configs/agents/anhthropic/claude-3-opus.py: -------------------------------------------------------------------------------- 1 | _base_ = ['configs/agents/anhthropic/claude-3.5-sonnet.py'] 2 | 3 | lmm_agent = dict( 4 | name='claude3-opus', 5 | model='claude-3-opus-20240229', 6 | ) 7 | -------------------------------------------------------------------------------- /configs/agents/google/gemini-1.0-pro-vision.py: -------------------------------------------------------------------------------- 1 | _base_ = ['configs/agents/google/gemini-1.5-pro.py'] 2 | 3 | lmm_agent = dict( 4 | name='gemini1.0-pro-vision', 5 | model='gemini-pro-vision', 6 | ) 7 | -------------------------------------------------------------------------------- /configs/agents/llava/llava1.6-vicuna34b.py: -------------------------------------------------------------------------------- 1 | _base_ = ['configs/agents/llava/llava1.6-vicuna7b.py'] 2 | 3 | lmm_agent = dict( 4 | name='llava1.6-yi34b', 5 | model='liuhaotian/llava-v1.6-34b', 6 | ) 7 | -------------------------------------------------------------------------------- /configs/agents/openai/gpt-4-turbo-240409.py: -------------------------------------------------------------------------------- 1 | _base_ = ['configs/agents/openai/gpt-4o-240513.py'] 2 | 3 | lmm_agent = dict( 4 | name='gpt4turbo-240409', 5 | model='gpt-4-turbo-2024-04-09', 6 | ) 7 | -------------------------------------------------------------------------------- /configs/agents/openai/gpt-4o-mini-240718.py: -------------------------------------------------------------------------------- 1 | _base_ = ['configs/agents/openai/gpt-4o-240513.py'] 2 | 3 | lmm_agent = dict( 4 | name='gpt4o-mini-240718', 5 | model='gpt-4o-mini-2024-07-18', 6 | ) 7 | -------------------------------------------------------------------------------- /configs/agents/anhthropic/claude-3-haiku.py: -------------------------------------------------------------------------------- 1 | _base_ = ['configs/agents/anhthropic/claude-3.5-sonnet.py'] 2 | 3 | lmm_agent = dict( 4 | name='claude3-haiku', 5 | model='claude-3-haiku-20240307', 6 | ) 7 | -------------------------------------------------------------------------------- /configs/agents/openai/gpt-4o-240806.py: -------------------------------------------------------------------------------- 1 | lmm_agent = dict( 2 | name='gpt4o', 3 | agent='openai_single', 4 | model='gpt-4o-2024-08-06', 5 | max_tokens=812, 6 | image_size=(1000, 1000), 7 | ) 8 | -------------------------------------------------------------------------------- /playground/games/minesweeper/__init__.py: -------------------------------------------------------------------------------- 1 | from .minesweeper import MineSweeper 2 | from .minesweeper_qa import MinesweeperQuestionAnswering 3 | 4 | __all__ = ['MineSweeper', 'MinesweeperQuestionAnswering'] 5 | -------------------------------------------------------------------------------- /configs/agents/anhthropic/claude-3-sonnet.py: -------------------------------------------------------------------------------- 1 | _base_ = ['configs/agents/anhthropic/claude-3.5-sonnet.py'] 2 | 3 | lmm_agent = dict( 4 | name='claude3-sonnet', 5 | model='claude-3-sonnet-20240229', 6 | ) 7 | -------------------------------------------------------------------------------- /playground/evaluator/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_qa import BaseQuestionAnswering 2 | from .evaluator import Evaluator 3 | from .metric import Metric 4 | 5 | __all__ = ['Evaluator', 'BaseQuestionAnswering', 'Metric'] 6 | -------------------------------------------------------------------------------- /configs/agents/llava/llava1.6-mistral7b.py: -------------------------------------------------------------------------------- 1 | _base_ = ['configs/agents/llava/llava1.6-vicuna7b.py'] 2 | 3 | lmm_agent = dict( 4 | name='llava1.6-mistral7b', 5 | model='liuhaotian/llava-v1.6-mistral-7b', 6 | ) 7 | -------------------------------------------------------------------------------- /configs/agents/llava/llava1.6-vicuna13b.py: -------------------------------------------------------------------------------- 1 | _base_ = ['configs/agents/llava/llava1.6-vicuna7b.py'] 2 | 3 | lmm_agent = dict( 4 | name='llava1.6-vicuna13b', 5 | model='liuhaotian/llava-v1.6-vicuna-13b', 6 | ) 7 | -------------------------------------------------------------------------------- /configs/agents/anhthropic/claude-3.5-sonnet.py: -------------------------------------------------------------------------------- 1 | lmm_agent = dict( 2 | name='claude-3.5-sonnet', 3 | agent='anhthropic_single', 4 | model='claude-3-5-sonnet-20240620', 5 | max_tokens=812, 6 | image_size=(1000, 1000), 7 | ) 8 | -------------------------------------------------------------------------------- /configs/agents/internvl/internvl2-40b.py: -------------------------------------------------------------------------------- 1 | _base_ = ['configs/agents/internvl/internvl2-2b.py'] 2 | 3 | from lmdeploy import ChatTemplateConfig 4 | 5 | lmm_agent = dict( 6 | name='internvl2-40b', 7 | model='OpenGVLab/InternVL2-40B', 8 | chat_template=ChatTemplateConfig('internvl-zh-hermes2'), 9 | ) 10 | -------------------------------------------------------------------------------- /playground/games/gomoku/designer/gobang_qrc.qrc: -------------------------------------------------------------------------------- 1 | 2 | 3 | image/lost-removebg-preview.png 4 | image/win-removebg-preview.png 5 | image/black.png 6 | image/white.png 7 | image/chessboard.png 8 | 9 | 10 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | with open('requirements.txt') as f: 4 | required = f.read().splitlines() 5 | 6 | setup(name='playground', 7 | version='0.0.1', 8 | author='Xinyu Wang', 9 | author_email='xinyu.wang02@adelaide.edu.au', 10 | packages=find_packages(), 11 | install_requires=required) 12 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | anthropic 3 | attrdict 4 | chess 5 | gmpy2 6 | google-generativeai 7 | imageio 8 | imageio-ffmpeg 9 | lmdeploy 10 | matplotlib 11 | openai 12 | pillow 13 | pjtools 14 | protobuf 15 | pyqt5 16 | pyqt5-tools 17 | qwen_vl_utils 18 | sentencepiece 19 | timm 20 | torch 21 | torchvision 22 | transformers 23 | transformers_stream_generator 24 | -------------------------------------------------------------------------------- /configs/agents/internvl/internvl2-4b.py: -------------------------------------------------------------------------------- 1 | from lmdeploy import ChatTemplateConfig, PytorchEngineConfig 2 | 3 | _base_ = ['configs/agents/internvl/internvl2-2b.py'] 4 | 5 | lmm_agent = dict( 6 | name='internvl2-4b', 7 | chat_template=ChatTemplateConfig('internvl-phi3'), 8 | model='OpenGVLab/InternVL2-4B', 9 | backend_config=PytorchEngineConfig(session_len=8192) 10 | ) 11 | -------------------------------------------------------------------------------- /configs/base.py: -------------------------------------------------------------------------------- 1 | display = False 2 | save_path = 'experiments/game_history' 3 | maximum_trials = 3 4 | device = 'cuda:0' 5 | make_video = True 6 | 7 | benchmark_setting = dict( 8 | games=['tictactoe', 'gomoku', 'minesweeper', 'reversi', 'sudoku', 'chess'], 9 | sample_size=2000, 10 | e2e_round=100, 11 | offline_task=['perceive', 'qa', 'rule'], 12 | benchmark_path='benchmark' 13 | ) 14 | -------------------------------------------------------------------------------- /configs/agents/01-ai/yi-vl-6b.py: -------------------------------------------------------------------------------- 1 | from lmdeploy import (ChatTemplateConfig, GenerationConfig, 2 | TurbomindEngineConfig) 3 | 4 | lmm_agent = dict( 5 | name='yi-vl-6b', 6 | agent='lmdeploy_single', 7 | chat_template=ChatTemplateConfig('yi-vl'), 8 | model='01-ai/Yi-VL-6B', 9 | backend_config=TurbomindEngineConfig(session_len=8192), 10 | general_config=GenerationConfig(max_new_tokens=1024) 11 | ) 12 | -------------------------------------------------------------------------------- /configs/agents/qwen/qwen2-vl-7b.py: -------------------------------------------------------------------------------- 1 | from lmdeploy import (ChatTemplateConfig, GenerationConfig, 2 | TurbomindEngineConfig) 3 | 4 | lmm_agent = dict( 5 | name='qwen7b', 6 | agent='lmdeploy_single', 7 | chat_template=ChatTemplateConfig('qwen-7b'), 8 | model='Qwen/Qwen2-VL-7B-Instruct', 9 | backend_config=TurbomindEngineConfig(session_len=8192), 10 | general_config=GenerationConfig(max_new_tokens=1024) 11 | ) 12 | -------------------------------------------------------------------------------- /configs/agents/microsoft/phi3.5-vl.py: -------------------------------------------------------------------------------- 1 | from lmdeploy import (ChatTemplateConfig, GenerationConfig, 2 | TurbomindEngineConfig) 3 | 4 | lmm_agent = dict( 5 | name='phi3.5-vl', 6 | agent='lmdeploy_single', 7 | model='microsoft/Phi-3.5-vision-instruct', 8 | hat_template=ChatTemplateConfig('phi-3'), 9 | backend_config=TurbomindEngineConfig(session_len=8192), 10 | general_config=GenerationConfig(max_new_tokens=1024) 11 | ) 12 | -------------------------------------------------------------------------------- /configs/agents/llava/llava1.6-vicuna7b.py: -------------------------------------------------------------------------------- 1 | from lmdeploy import (ChatTemplateConfig, GenerationConfig, 2 | TurbomindEngineConfig) 3 | 4 | lmm_agent = dict( 5 | name='llava1.6-vicuna7b', 6 | agent='lmdeploy_single', 7 | model='liuhaotian/llava-v1.6-vicuna-7b', 8 | hat_template=ChatTemplateConfig('llava-v1'), 9 | backend_config=TurbomindEngineConfig(session_len=8192), 10 | general_config=GenerationConfig(max_new_tokens=1024) 11 | ) 12 | -------------------------------------------------------------------------------- /configs/agents/deepseek/deepseek-vl-7b.py: -------------------------------------------------------------------------------- 1 | from lmdeploy import (ChatTemplateConfig, GenerationConfig, 2 | TurbomindEngineConfig) 3 | 4 | lmm_agent = dict( 5 | name='deepseek-vl-7b', 6 | agent='lmdeploy_single', 7 | chat_template=ChatTemplateConfig('deepseek-chat'), 8 | model='deepseek-ai/deepseek-vl-7b-chat', 9 | backend_config=TurbomindEngineConfig(session_len=8192), 10 | general_config=GenerationConfig(max_new_tokens=1024) 11 | ) 12 | -------------------------------------------------------------------------------- /configs/agents/internvl/internvl2-2b.py: -------------------------------------------------------------------------------- 1 | from lmdeploy import (ChatTemplateConfig, GenerationConfig, 2 | TurbomindEngineConfig) 3 | 4 | lmm_agent = dict( 5 | name='internvl2-2b', 6 | agent='lmdeploy_single', 7 | chat_template=ChatTemplateConfig('internvl-internlm2'), 8 | model='OpenGVLab/InternVL2-2B', 9 | backend_config=TurbomindEngineConfig(session_len=8192), 10 | general_config=GenerationConfig(max_new_tokens=1024, top_p=0.8), 11 | ) 12 | -------------------------------------------------------------------------------- /playground/agents/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseAgent 2 | from .single_step_agents import (AnthropicAgentSingleStep, 3 | GoogleAIAgentSingleStep, 4 | LMDeployAgentSingleStep, 5 | OpenAIAgentSingleStep) 6 | 7 | __all__ = [ 8 | 'BaseAgent', 9 | 'OpenAIAgentSingleStep', 10 | 'LMDeployAgentSingleStep', 11 | 'GoogleAIAgentSingleStep', 12 | 'AnthropicAgentSingleStep', 13 | ] 14 | -------------------------------------------------------------------------------- /playground/agents/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | 4 | class BaseAgent(ABC): 5 | 6 | def __init__(self, agent_cfg): 7 | self.agent_cfg = agent_cfg 8 | 9 | @abstractmethod 10 | def get_decision(self, screenshot_path: str, prompt: str): 11 | """ 12 | Given the path to a screenshot of the current game state and a prompt, 13 | this method should return a decision on the next move or action. 14 | """ 15 | raise NotImplementedError('The method not implemented') 16 | -------------------------------------------------------------------------------- /playground/games/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseGame, BaseGameLogic 2 | from .chess import Chess, ChessQuestionAnswering 3 | from .gomoku import Gomoku, GomokuQuestionAnswering 4 | from .minesweeper import MineSweeper, MinesweeperQuestionAnswering 5 | from .reversi import Reversi, ReversiQuestionAnswering 6 | from .sudoku import Sudoku, SudokuQuestionAnswering 7 | from .tictactoe import TicTacToe, TicTacToeQuestionAnswering 8 | 9 | __all__ = [ 10 | 'BaseGame', 'BaseGameLogic', 'Gomoku', 'TicTacToe', 'MineSweeper', 11 | 'Sudoku', 'Reversi', 'Chess', 'TicTacToeQuestionAnswering', 12 | 'SudokuQuestionAnswering', 'ReversiQuestionAnswering', 13 | 'MinesweeperQuestionAnswering', 'GomokuQuestionAnswering', 14 | 'ChessQuestionAnswering' 15 | ] 16 | -------------------------------------------------------------------------------- /playground/state_code.py: -------------------------------------------------------------------------------- 1 | import json 2 | from enum import Enum 3 | 4 | 5 | class JSONSerializableEnum(Enum): 6 | 7 | def __json__(self): 8 | return self.name 9 | 10 | @classmethod 11 | def to_json(cls, obj): 12 | if isinstance(obj, cls): 13 | return obj.name 14 | return obj 15 | 16 | 17 | class GameStatus(JSONSerializableEnum): 18 | WIN = 101 19 | LOSE = 102 20 | TIE = 103 21 | INVALID_MOVE = 104 22 | IN_PROGRESS = 105 23 | MAX_TRIAL_REACHED = 106 24 | ERROR = 107 25 | 26 | 27 | class GameStatusEncoder(json.JSONEncoder): 28 | 29 | def default(self, obj): 30 | if isinstance(obj, GameStatus): 31 | return obj.name 32 | return super().default(obj) 33 | -------------------------------------------------------------------------------- /generate_benchmark.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | from playground.benchmark import Generator 5 | 6 | os.environ['QT_QPA_PLATFORM'] = 'offscreen' 7 | 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser( 11 | description='Generate offline benchmark for LVLM-Playground') 12 | parser.add_argument('--benchmark-setting', 13 | type=str, 14 | help='Path to the benchmark setting config.', 15 | default='configs/base.py') 16 | return parser.parse_args() 17 | 18 | 19 | def main(): 20 | args = parse_args() 21 | generator = Generator(args.benchmark_setting) 22 | generator.generate_benchmark() 23 | 24 | 25 | if __name__ == '__main__': 26 | main() 27 | -------------------------------------------------------------------------------- /playground/evaluator/base_qa.py: -------------------------------------------------------------------------------- 1 | 2 | class BaseQuestionAnswering: 3 | 4 | def __init__(self, general_prompt, shot=3): 5 | self.general_prompt = general_prompt 6 | self.question_pool = [] 7 | self.shot = shot 8 | 9 | def get_qa_pair(self, game_state): 10 | raise NotImplementedError('Subclasses should implement this method.') 11 | 12 | def get_qa_pairs(self, game_state): 13 | qa_pairs = set() 14 | while len(qa_pairs) < self.shot + 1: 15 | question, answer = self.get_qa_pair(game_state) 16 | qa_pair = (question, answer) 17 | if qa_pair not in qa_pairs: 18 | qa_pairs.add(qa_pair) 19 | 20 | return list(qa_pairs) 21 | 22 | def get_answer(self, game_state, question): 23 | raise NotImplementedError('Subclasses should implement this method.') 24 | -------------------------------------------------------------------------------- /playground/games/base.py: -------------------------------------------------------------------------------- 1 | from playground.state_code import GameStatus 2 | 3 | 4 | class BaseGame: 5 | AI_component = False 6 | 7 | def __init__(self, game_cfg) -> None: 8 | self.status = GameStatus.IN_PROGRESS 9 | self.game_cfg = game_cfg 10 | 11 | def get_screenshot(self): 12 | raise NotImplementedError 13 | 14 | def input_move(self, move): 15 | raise NotImplementedError 16 | 17 | def get_game_status(self): 18 | raise NotImplementedError 19 | 20 | def get_random_state(self): 21 | raise NotImplementedError 22 | 23 | def get_rule_state(self): 24 | raise NotImplementedError 25 | 26 | def calculate_score(self): 27 | """Calculate score based on current game state.""" 28 | raise NotImplementedError 29 | 30 | 31 | class BaseGameLogic: 32 | 33 | def parse_e2e(self, lmm_output): 34 | """Parse e2e output to a move.""" 35 | raise NotImplementedError('Subclasses must implement parse_e2e') 36 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | 5 | from PyQt5.QtWidgets import QApplication 6 | 7 | from playground import Recipe 8 | 9 | os.environ['QT_QPA_PLATFORM'] = 'offscreen' 10 | 11 | 12 | def parse_args(): 13 | parser = argparse.ArgumentParser( 14 | description='Large Multi-modal Model Playground') 15 | parser.add_argument('--exp-recipe', 16 | type=str, 17 | help='Path to the game config.', 18 | default='configs/recipe/base.py') 19 | parser.add_argument('--agent-cfg', 20 | type=str, 21 | help='Path to the agent config.', 22 | default='configs/agents/internvl/internvl2-1b.py') 23 | return parser.parse_args() 24 | 25 | 26 | def main(): 27 | app = QApplication(sys.argv) # noqa 28 | 29 | args = parse_args() 30 | recipe = Recipe(args) 31 | recipe.run_experiments() 32 | 33 | 34 | if __name__ == '__main__': 35 | main() 36 | -------------------------------------------------------------------------------- /playground/utils/utils.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import random 3 | from io import BytesIO 4 | 5 | import numpy as np 6 | import torch 7 | from PIL import Image 8 | 9 | 10 | def set_random_seed(): 11 | """Set the random seed for reproducibility.""" 12 | seed = random.randint(0, 2**32 - 1) 13 | random.seed(seed) 14 | np.random.seed(seed) 15 | torch.manual_seed(seed) 16 | torch.random.manual_seed(seed) 17 | torch.cuda.manual_seed(seed) 18 | torch.cuda.manual_seed_all(seed) 19 | 20 | return seed 21 | 22 | 23 | def encode_image(image_path, size=None): 24 | """Encode an image to a base64 string. Optionally resize the image before 25 | encoding. 26 | """ 27 | with open(image_path, 'rb') as image_file: 28 | image = Image.open(image_file) 29 | 30 | if size: 31 | image = image.resize(size, Image.Resampling.LANCZOS) 32 | 33 | buffered = BytesIO() 34 | image.save(buffered, format='PNG') 35 | return base64.b64encode(buffered.getvalue()).decode('utf-8') 36 | -------------------------------------------------------------------------------- /playground/games/minesweeper/game_cfg.py: -------------------------------------------------------------------------------- 1 | from PyQt5.QtGui import QColor, QImage 2 | 3 | from playground.state_code import GameStatus 4 | 5 | IMG_BOMB = QImage('playground/games/minesweeper/images/bomb.png') 6 | IMG_FLAG = QImage('playground/games/minesweeper/images/flag.png') 7 | IMG_START = QImage('playground/games/minesweeper/images/rocket.png') 8 | IMG_CLOCK = QImage('playground/games/minesweeper/images/clock-select.png') 9 | 10 | STATUS_ICONS = { 11 | GameStatus.INVALID_MOVE: 'playground/games/minesweeper/images/plus.png', 12 | GameStatus.ERROR: 'playground/games/minesweeper/images/plus.png', 13 | GameStatus.IN_PROGRESS: 'playground/games/minesweeper/images/smiley.png', 14 | GameStatus.LOSE: 'playground/games/minesweeper/images/cross.png', 15 | GameStatus.WIN: 'playground/games/minesweeper/images/smiley-lol.png', 16 | } 17 | 18 | NUM_COLORS = { 19 | 1: QColor('#f44336'), 20 | 2: QColor('#9C27B0'), 21 | 3: QColor('#3F51B5'), 22 | 4: QColor('#03A9F4'), 23 | 5: QColor('#00BCD4'), 24 | 6: QColor('#4CAF50'), 25 | 7: QColor('#E91E63'), 26 | 8: QColor('#FF9800') 27 | } 28 | 29 | LEVELS = {'easy': (8, 10), 'middle': (12, 30), 'hard': (16, 40)} 30 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/PyCQA/flake8 3 | rev: 5.0.4 4 | hooks: 5 | - id: flake8 6 | args: ["--exclude=tests/*,configs/*"] 7 | - repo: https://github.com/pycqa/isort 8 | rev: 5.12.0 9 | hooks: 10 | - id: isort 11 | - repo: https://github.com/pre-commit/mirrors-yapf 12 | rev: v0.32.0 13 | hooks: 14 | - id: yapf 15 | exclude: ^configs/.* 16 | - repo: https://github.com/codespell-project/codespell 17 | rev: v2.2.1 18 | hooks: 19 | - id: codespell 20 | args: ["--ignore-words-list=ans,ques"] 21 | - repo: https://github.com/pre-commit/pre-commit-hooks 22 | rev: v4.3.0 23 | hooks: 24 | - id: trailing-whitespace 25 | - id: check-yaml 26 | - id: end-of-file-fixer 27 | exclude: ^assets/prompts/.*$ 28 | - id: requirements-txt-fixer 29 | - id: double-quote-string-fixer 30 | - id: check-merge-conflict 31 | - id: fix-encoding-pragma 32 | args: ["--remove"] 33 | - id: mixed-line-ending 34 | args: ["--fix=lf"] 35 | - id: mixed-line-ending 36 | args: ["--fix=lf"] 37 | - repo: https://github.com/executablebooks/mdformat 38 | rev: 0.7.9 39 | hooks: 40 | - id: mdformat 41 | args: ["--number"] 42 | additional_dependencies: 43 | - mdformat_frontmatter 44 | - linkify-it-py 45 | -------------------------------------------------------------------------------- /playground/games/gomoku/gomoku_ui.py: -------------------------------------------------------------------------------- 1 | from PyQt5 import QtCore, QtGui, QtWidgets 2 | 3 | 4 | class Ui_MainWindow(object): 5 | 6 | def setupUi(self, MainWindow): 7 | MainWindow.setObjectName('MainWindow') 8 | MainWindow.resize(1000, 1000) 9 | MainWindow.setMinimumSize(QtCore.QSize(1000, 1000)) 10 | MainWindow.setMaximumSize(QtCore.QSize(1000, 1000)) 11 | MainWindow.setStyleSheet('') 12 | self.centralwidget = QtWidgets.QWidget(MainWindow) 13 | self.centralwidget.setObjectName('centralwidget') 14 | self.chessboard = QtWidgets.QWidget(self.centralwidget) 15 | self.chessboard.setGeometry(QtCore.QRect(0, 0, 1000, 1000)) 16 | self.chessboard.setStyleSheet( 17 | 'border-image: url(:/bg/image/chessboard.png);') 18 | self.chessboard.setObjectName('chessboard') 19 | self.result_label = QtWidgets.QLabel(self.centralwidget) 20 | self.result_label.setGeometry(QtCore.QRect(240, 239, 521, 191)) 21 | font = QtGui.QFont() 22 | font.setPointSize(90) 23 | self.result_label.setFont(font) 24 | self.result_label.setText('') 25 | self.result_label.setObjectName('result_label') 26 | self.chessboard.raise_() 27 | self.result_label.raise_() 28 | MainWindow.setCentralWidget(self.centralwidget) 29 | 30 | self.retranslateUi(MainWindow) 31 | QtCore.QMetaObject.connectSlotsByName(MainWindow) 32 | 33 | def retranslateUi(self, MainWindow): 34 | _translate = QtCore.QCoreApplication.translate 35 | MainWindow.setWindowTitle(_translate('MainWindow', 'MainWindow')) 36 | 37 | 38 | import playground.games.gomoku.gomoku_qrc_rc # noqa 39 | -------------------------------------------------------------------------------- /playground/games/gomoku/designer/gobang_ui.ui: -------------------------------------------------------------------------------- 1 | 2 | 3 | MainWindow 4 | 5 | 6 | 7 | 0 8 | 0 9 | 1000 10 | 1000 11 | 12 | 13 | 14 | 15 | 1000 16 | 1000 17 | 18 | 19 | 20 | 21 | 1000 22 | 1000 23 | 24 | 25 | 26 | MainWindow 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 0 36 | 0 37 | 1000 38 | 1000 39 | 40 | 41 | 42 | border-image: url(:/bg/image/chessboard.png); 43 | 44 | 45 | 46 | 47 | 48 | 240 49 | 239 50 | 521 51 | 191 52 | 53 | 54 | 55 | 56 | 90 57 | 58 | 59 | 60 | 61 | 62 | 63 | chessboard 64 | result_label 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | from playground.evaluator import Metric 6 | 7 | 8 | def parse_args(): 9 | parser = argparse.ArgumentParser(description='LVLM-Playground Evaluation') 10 | parser.add_argument('record_path', 11 | type=str, 12 | help='Path to the experiment results JSON file.') 13 | parser.add_argument('--annotation_dir', 14 | type=str, 15 | default='./benchmark', 16 | help='Directory containing annotation JSON files') 17 | parser.add_argument('--output_path', 18 | type=str, 19 | default=None, 20 | help='Path to save the evaluation results JSON file') 21 | return parser.parse_args() 22 | 23 | 24 | def main(): 25 | args = parse_args() 26 | 27 | if not os.path.exists(args.record_path): 28 | print(f"Error: Record file '{args.record_path}' does not exist.") 29 | return 30 | if not os.path.exists(args.annotation_dir): 31 | print(f"Error: Annotation '{args.annotation_dir}' does not exist.") 32 | return 33 | 34 | if args.output_path is None: 35 | record_filename = os.path.splitext(os.path.basename( 36 | args.record_path))[0] 37 | output_path = os.path.join('./evaluation_results', 38 | f'{record_filename}_results.json') 39 | else: 40 | output_path = args.output_path 41 | 42 | print('Starting evaluation...') 43 | metric = Metric(args.record_path, args.annotation_dir) 44 | scores = metric.evaluate_all() 45 | 46 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 47 | metric.save_evaluation(output_path) 48 | print(f"Evaluation results saved to '{output_path}'") 49 | 50 | print('\nDetailed Evaluation Results:') 51 | print(json.dumps(scores, indent=4)) 52 | 53 | print('\nGame Difficulty Weighted Summary:') 54 | for task in metric.weighted_summary: 55 | weighted_avg = metric.weighted_summary[task]['weighted_average'] 56 | print(f'{task}: Weighted Average Score = {weighted_avg:.4f}') 57 | 58 | 59 | if __name__ == '__main__': 60 | main() 61 | -------------------------------------------------------------------------------- /playground/games/sudoku/sudoku_generator.py: -------------------------------------------------------------------------------- 1 | from random import randint, shuffle 2 | 3 | 4 | def checkGrid(grid): 5 | for row in range(9): 6 | for col in range(9): 7 | if grid[row][col] == 0: 8 | return False 9 | return True 10 | 11 | 12 | def fillGrid(grid): 13 | numberList = [1, 2, 3, 4, 5, 6, 7, 8, 9] 14 | 15 | for i in range(81): 16 | row = i // 9 17 | col = i % 9 18 | if grid[row][col] == 0: 19 | shuffle(numberList) 20 | 21 | for value in numberList: 22 | if not (value in grid[row]): 23 | if value not in [grid[r][col] for r in range(9)]: 24 | square = [ 25 | grid[r][c] 26 | for r in range(row // 3 * 3, row // 3 * 3 + 3) 27 | for c in range(col // 3 * 3, col // 3 * 3 + 3) 28 | ] 29 | if value not in square: 30 | grid[row][col] = value 31 | if checkGrid(grid): 32 | return True 33 | elif fillGrid(grid): 34 | return True 35 | break 36 | grid[row][col] = 0 37 | return False 38 | 39 | 40 | def solveGrid(grid, counter): 41 | for i in range(81): 42 | row = i // 9 43 | col = i % 9 44 | if grid[row][col] == 0: 45 | for value in range(1, 10): 46 | if not (value in grid[row]): 47 | if value not in [grid[r][col] for r in range(9)]: 48 | square = [ 49 | grid[r][c] 50 | for r in range(row // 3 * 3, row // 3 * 3 + 3) 51 | for c in range(col // 3 * 3, col // 3 * 3 + 3) 52 | ] 53 | if value not in square: 54 | grid[row][col] = value 55 | if checkGrid(grid): 56 | counter[0] += 1 57 | break 58 | elif solveGrid(grid, counter): 59 | return True 60 | break 61 | grid[row][col] = 0 62 | return False 63 | 64 | 65 | def generate_puzzle(grid, attempts): 66 | counter = [0] 67 | 68 | while attempts > 0: 69 | row, col = randint(0, 8), randint(0, 8) 70 | while grid[row][col] == 0: 71 | row, col = randint(0, 8), randint(0, 8) 72 | 73 | backup = grid[row][col] 74 | grid[row][col] = 0 75 | 76 | copyGrid = [[grid[r][c] for c in range(9)] for r in range(9)] 77 | counter[0] = 0 78 | 79 | solveGrid(copyGrid, counter) 80 | 81 | if counter[0] != 1: 82 | grid[row][col] = backup 83 | attempts -= 1 84 | 85 | return grid 86 | -------------------------------------------------------------------------------- /playground/games/reversi/reversi_ui.py: -------------------------------------------------------------------------------- 1 | from PyQt5 import QtCore, QtGui, QtWidgets 2 | from PyQt5.QtCore import QRect, Qt 3 | from PyQt5.QtGui import QBrush, QColor, QFont, QPen 4 | 5 | WHITE = Qt.white 6 | BLACK = Qt.black 7 | GREEN = QColor(0, 128, 0) 8 | CELL_SIZE = 400 // 8 9 | 10 | 11 | class Ui_MainWindow(object): 12 | 13 | def setupUi(self, MainWindow): 14 | MainWindow.setObjectName('MainWindow') 15 | MainWindow.resize(500, 600) 16 | MainWindow.setMinimumSize(QtCore.QSize(500, 600)) 17 | MainWindow.setMaximumSize(QtCore.QSize(500, 600)) 18 | self.centralwidget = QtWidgets.QWidget(MainWindow) 19 | self.centralwidget.setObjectName('centralwidget') 20 | 21 | font = QtGui.QFont() 22 | font.setPointSize(12) 23 | 24 | self.player_label = QtWidgets.QLabel(self.centralwidget) 25 | self.player_label.setGeometry(QtCore.QRect(100, 450, 300, 50)) 26 | self.player_label.setFont(font) 27 | self.player_label.setAlignment(QtCore.Qt.AlignCenter) 28 | 29 | self.restart_button = QtWidgets.QPushButton(self.centralwidget) 30 | self.restart_button.setGeometry(QtCore.QRect(200, 500, 100, 40)) 31 | self.restart_button.setText('Restart') 32 | self.restart_button.setFont(font) 33 | 34 | MainWindow.setCentralWidget(self.centralwidget) 35 | 36 | def draw_board(self, qp, board): 37 | for y in range(8): 38 | for x in range(8): 39 | qp.setBrush(QBrush(GREEN)) 40 | qp.drawRect(x * CELL_SIZE + 60, y * CELL_SIZE + 40, CELL_SIZE, 41 | CELL_SIZE) 42 | qp.setPen(QPen(BLACK)) 43 | qp.drawRect(x * CELL_SIZE + 60, y * CELL_SIZE + 40, CELL_SIZE, 44 | CELL_SIZE) 45 | if board[y][x] == 2: 46 | qp.setBrush(QBrush(WHITE)) 47 | qp.drawEllipse( 48 | QRect(int(x * CELL_SIZE + CELL_SIZE / 6 + 60), 49 | int(y * CELL_SIZE + CELL_SIZE / 6 + 40), 50 | CELL_SIZE * 2 // 3, CELL_SIZE * 2 // 3)) 51 | elif board[y][x] == 1: 52 | qp.setBrush(QBrush(BLACK)) 53 | qp.drawEllipse( 54 | QRect(int(x * CELL_SIZE + CELL_SIZE / 6 + 60), 55 | int(y * CELL_SIZE + CELL_SIZE / 6 + 40), 56 | CELL_SIZE * 2 // 3, CELL_SIZE * 2 // 3)) 57 | 58 | def draw_labels(self, qp): 59 | qp.setPen(Qt.black) 60 | label_font = QFont('Arial', 14, QFont.Bold) 61 | qp.setFont(label_font) 62 | 63 | for col in range(8): 64 | label = str(col + 1) 65 | qp.drawText(col * CELL_SIZE + 60 + CELL_SIZE // 2 - 5, 35, label) 66 | 67 | for row in range(8): 68 | label = chr(ord('A') + row) 69 | qp.drawText(35, row * CELL_SIZE + 40 + CELL_SIZE // 2 + 5, label) 70 | -------------------------------------------------------------------------------- /playground/games/tictactoe/tictactoe_ui.py: -------------------------------------------------------------------------------- 1 | from PyQt5 import QtCore, QtGui, QtWidgets 2 | 3 | 4 | class Ui_MainWindow(object): 5 | 6 | def setupUi(self, MainWindow): 7 | MainWindow.setObjectName('MainWindow') 8 | MainWindow.resize(500, 600) 9 | MainWindow.setMinimumSize(QtCore.QSize(500, 600)) 10 | MainWindow.setMaximumSize(QtCore.QSize(500, 600)) 11 | MainWindow.setStyleSheet('') 12 | self.centralwidget = QtWidgets.QWidget(MainWindow) 13 | self.centralwidget.setObjectName('centralwidget') 14 | 15 | font = QtGui.QFont() 16 | font.setPointSize(1) 17 | 18 | self.buttons = [] 19 | positions = [(i, j) for i in range(3) for j in range(3)] 20 | for idx, pos in enumerate(positions): 21 | button = QtWidgets.QPushButton(self.centralwidget) 22 | button.setGeometry( 23 | QtCore.QRect(50 + pos[1] * 140, 100 + pos[0] * 140, 140, 140)) 24 | button.setFont(font) 25 | button.setText(str(idx + 1)) 26 | button.setObjectName(f'button_{idx + 1}') 27 | self.buttons.append(button) 28 | 29 | self.row_labels = [] 30 | rows = ['A', 'B', 'C'] 31 | for i, row in enumerate(rows): 32 | label = QtWidgets.QLabel(self.centralwidget) 33 | label.setGeometry(QtCore.QRect(30, 170 + i * 140, 20, 20)) 34 | label.setText(row) 35 | label.setAlignment(QtCore.Qt.AlignCenter) 36 | label.setFont(QtGui.QFont('Arial', 16, QtGui.QFont.Bold)) 37 | self.row_labels.append(label) 38 | 39 | self.column_labels = [] 40 | columns = ['1', '2', '3'] 41 | for i, column in enumerate(columns): 42 | label = QtWidgets.QLabel(self.centralwidget) 43 | label.setGeometry(QtCore.QRect(105 + i * 140, 70, 20, 20)) 44 | label.setText(column) 45 | label.setAlignment(QtCore.Qt.AlignCenter) 46 | label.setFont(QtGui.QFont('Arial', 16, QtGui.QFont.Bold)) 47 | self.column_labels.append(label) 48 | 49 | self.label = QtWidgets.QLabel(self.centralwidget) 50 | self.label.setGeometry(QtCore.QRect(50, 10, 351, 71)) 51 | font = QtGui.QFont() 52 | font.setFamily('URW Gothic') 53 | font.setPointSize(20) 54 | self.label.setFont(font) 55 | 56 | self.label_2 = QtWidgets.QLabel(self.centralwidget) 57 | self.label_2.setGeometry(QtCore.QRect(50, 30, 351, 51)) 58 | font2 = QtGui.QFont() 59 | font2.setFamily('URW Gothic') 60 | font2.setPointSize(20) 61 | font2.setWeight(50) 62 | self.label_2.setFont(font2) 63 | 64 | MainWindow.setCentralWidget(self.centralwidget) 65 | 66 | self.retranslateUi(MainWindow) 67 | QtCore.QMetaObject.connectSlotsByName(MainWindow) 68 | 69 | def retranslateUi(self, MainWindow): 70 | _translate = QtCore.QCoreApplication.translate 71 | MainWindow.setWindowTitle(_translate('MainWindow', 'Tic Tac Toe')) 72 | -------------------------------------------------------------------------------- /playground/evaluator/evaluator.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import time 3 | 4 | import torch 5 | 6 | from playground.simulator import GameSimulator 7 | from playground.utils import set_random_seed 8 | 9 | 10 | class Evaluator: 11 | """Evaluator class to run the game with the given agent.""" 12 | 13 | def __init__(self, game_cfg, agent, task, log_file, save_path): 14 | self.game_cfg = game_cfg 15 | self.agent = agent 16 | self.task = task 17 | self.save_path = osp.join(save_path, self.game_cfg.game_name, 18 | self.task, 19 | self.agent.agent_cfg.lmm_agent.name) 20 | self.seed = set_random_seed() 21 | self.log_file = log_file 22 | 23 | def run(self, batch): 24 | if self.task == 'e2e': 25 | return self.run_e2e_game(batch) 26 | elif self.task == 'perceive': 27 | return self.run_perceive(batch) 28 | elif self.task == 'rule': 29 | return self.run_rule(batch) 30 | elif self.task == 'qa': 31 | return self.run_qa(batch) 32 | else: 33 | raise ValueError(f'Invalid task type: {self.task}') 34 | 35 | def run_e2e_game(self, batch): 36 | crt_save_path = osp.join(self.save_path, f'round_{int(time.time())}') 37 | simulator = GameSimulator(self.game_cfg, self.agent, self.seed, 38 | crt_save_path, self.task) 39 | 40 | result = simulator.run_e2e(batch) 41 | 42 | # if self.game_cfg.make_video: 43 | # simulator.make_video() 44 | 45 | return result, simulator 46 | 47 | def run_perceive(self, batch): 48 | crt_save_path = osp.join(self.save_path) 49 | simulator = GameSimulator(self.game_cfg, 50 | self.agent, 51 | self.seed, 52 | crt_save_path, 53 | self.task, 54 | log_file=self.log_file) 55 | result = simulator.perceive(batch) 56 | 57 | return result, simulator 58 | 59 | def run_rule(self, batch): 60 | crt_save_path = osp.join(self.save_path) 61 | simulator = GameSimulator(self.game_cfg, 62 | self.agent, 63 | self.seed, 64 | crt_save_path, 65 | self.task, 66 | log_file=self.log_file) 67 | result = simulator.rule(batch) 68 | 69 | return result, simulator 70 | 71 | def run_qa(self, batch): 72 | crt_save_path = osp.join(self.save_path) 73 | simulator = GameSimulator(self.game_cfg, 74 | self.agent, 75 | self.seed, 76 | crt_save_path, 77 | self.task, 78 | log_file=self.log_file) 79 | result = simulator.qa(batch) 80 | 81 | return result, simulator 82 | 83 | def cleanup(self): 84 | torch.cuda.empty_cache() 85 | -------------------------------------------------------------------------------- /configs/games/tictactoe.py: -------------------------------------------------------------------------------- 1 | from playground.games import TicTacToeQuestionAnswering 2 | 3 | _base_ = ['configs/base.py'] 4 | 5 | game_name = 'tictactoe' 6 | game_description = dict( 7 | e2e=('Tic Tac Toe is played on a 3x3 grid. Players take turns placing X ' 8 | 'or O in the cells. The goal is to be the first to form an unbroken ' 9 | 'line of three marks horizontally, vertically, or diagonally. The ' 10 | 'game starts with an empty board, and the O player goes first. The ' 11 | 'grid is labeled with rows A to C and columns 1 to 3. You are ' 12 | 'playing as O, aiming to win by placing marks strategically. Each ' 13 | 'position can only be occupied by one mark, so do not choose a spot ' 14 | 'that is already taken. Based on the board state screenshots, please ' 15 | 'first observe the current situation, then carefully think and ' 16 | 'explain your strategy briefly, and finally output your movement for ' 17 | 'this status. Please strictly follow the following format:\n' 18 | 'Observation: \n' 19 | 'Strategy: \n' 20 | 'Movement: \n' 21 | 'where the observation should briefly summarize the current ' 22 | 'situation, the strategy is a brief explanation of how you plan to ' 23 | 'win the game, and the position can be any combination of rows A to ' 24 | 'C and columns 1 to 3, for example, A1, 2B, or c3.'), 25 | perceive=( 26 | 'Tic Tac Toe is a game played on a 3x3 grid where players take turns ' 27 | 'placing X or O in the cells. Given a screenshot of the game board, ' 28 | 'please determine the current game state using a 3x3 matrix. In this ' 29 | 'matrix, an empty cell should be represented by -1, X should be ' 30 | 'represented by 1, and O should be represented by 0. Please strictly ' 31 | 'follow the format:\n' 32 | 'Game State: \n' 33 | 'where is a 3x3 matrix. For example,\n' 34 | 'Game State: [[-1, -1, -1], [-1, -1, -1], [-1, -1, -1]]\n' 35 | 'represents an empty board.'), 36 | rule=( 37 | 'Tic Tac Toe is played on a 3x3 grid. Players take turns placing X or ' 38 | 'O in the cells. The game starts with an empty board. The grid is ' 39 | 'labeled with rows A to C and columns 1 to 3. Each position can only ' 40 | 'be occupied by one mark. Based on the board state, ' 41 | 'please find an empty cell where you can place your next stone.\n' 42 | 'Please strictly follow the following format:\n' 43 | 'Movement: \n' 44 | 'where the position can be any combination of rows A to C and columns ' 45 | '1 to 3, for example, A1, B2, or C3.' 46 | ), 47 | qa=( 48 | 'Tic Tac Toe is a game played on a 3x3 grid where two players take ' 49 | 'turns placing X or O in the cells. The goal is to form a horizontal, ' 50 | 'vertical, or diagonal line with three of your own marks. The grid is ' 51 | 'labeled with rows A to C and columns 1 to 3. Please answer the ' 52 | 'following question based on the provided screenshot of the current ' 53 | 'game state:\n' 54 | '{question}\n' 55 | 'Answer: \n' 56 | 'where should be one of A, B, C, or D.' 57 | ) 58 | ) 59 | 60 | player_first = True 61 | qa = TicTacToeQuestionAnswering 62 | -------------------------------------------------------------------------------- /playground/games/tictactoe/AI.py: -------------------------------------------------------------------------------- 1 | class Minimax: 2 | 3 | def __init__(self, bot, opponent): 4 | self.bot = bot 5 | self.opponent = opponent 6 | 7 | @staticmethod 8 | def generate_plugin(lst, cols): 9 | return [lst[i:i + cols] for i in range(0, len(lst), cols)] 10 | 11 | @staticmethod 12 | def generate_1d(row, col): 13 | return row * 3 + col 14 | 15 | def generate_2d(self, board): 16 | return self.generate_plugin([ 17 | '_' if cell not in [self.bot, self.opponent] else cell 18 | for cell in board 19 | ], 3) 20 | 21 | @staticmethod 22 | def is_moves_left(board): 23 | return any(cell == '_' for row in board for cell in row) 24 | 25 | def reset(self, bot, opponent): 26 | self.bot = bot 27 | self.opponent = opponent 28 | 29 | def evaluate(self, board): 30 | for row in board: 31 | if row[0] == row[1] == row[2]: 32 | if row[0] == self.bot: 33 | return 10 34 | elif row[0] == self.opponent: 35 | return -10 36 | 37 | for col in range(3): 38 | if board[0][col] == board[1][col] == board[2][col]: 39 | if board[0][col] == self.bot: 40 | return 10 41 | elif board[0][col] == self.opponent: 42 | return -10 43 | 44 | if board[0][0] == board[1][1] == board[2][2]: 45 | if board[0][0] == self.bot: 46 | return 10 47 | elif board[0][0] == self.opponent: 48 | return -10 49 | 50 | if board[0][2] == board[1][1] == board[2][0]: 51 | if board[0][2] == self.bot: 52 | return 10 53 | elif board[0][2] == self.opponent: 54 | return -10 55 | 56 | return 0 57 | 58 | def minimax(self, board, depth, is_max): 59 | score = self.evaluate(board) 60 | 61 | if score == 10 or score == -10: 62 | return score 63 | 64 | if not self.is_moves_left(board): 65 | return 0 66 | 67 | if is_max: 68 | best = -1000 69 | for i in range(3): 70 | for j in range(3): 71 | if board[i][j] == '_': 72 | board[i][j] = self.bot 73 | best = max(best, 74 | self.minimax(board, depth + 1, not is_max)) 75 | board[i][j] = '_' 76 | return best 77 | else: 78 | best = 1000 79 | for i in range(3): 80 | for j in range(3): 81 | if board[i][j] == '_': 82 | board[i][j] = self.opponent 83 | best = min(best, 84 | self.minimax(board, depth + 1, not is_max)) 85 | board[i][j] = '_' 86 | return best 87 | 88 | def find_best_move(self, board): 89 | best_val = -1000 90 | best_move = (-1, -1) 91 | for i in range(3): 92 | for j in range(3): 93 | if board[i][j] == '_': 94 | board[i][j] = self.bot 95 | move_val = self.minimax(board, 1, False) 96 | board[i][j] = '_' 97 | if move_val > best_val: 98 | best_move = (i, j) 99 | best_val = move_val 100 | return self.generate_1d(best_move[0], best_move[1]) 101 | -------------------------------------------------------------------------------- /plot_radar.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import os 4 | from math import pi 5 | 6 | import matplotlib.pyplot as plt 7 | 8 | 9 | def load_evaluation_results(results_dir='evaluation_results'): 10 | """Load all evaluation result JSON files from the specified directory.""" 11 | print(f'Scanning directory: {results_dir}') 12 | result_files = glob.glob(os.path.join(results_dir, '*_results.json')) 13 | print(f'Found files: {result_files}') 14 | 15 | models_data = {} 16 | 17 | if not result_files: 18 | print(f"No evaluation result files found in '{results_dir}'.") 19 | return None, {} 20 | 21 | print(f'Loading first file to determine categories: {result_files[0]}') 22 | with open(result_files[0], 'r') as f: 23 | first_data = json.load(f) 24 | weighted_summary = first_data.get('weighted_summary', {}) 25 | categories = list(weighted_summary.keys()) 26 | 27 | for file_path in result_files: 28 | model_name = os.path.splitext(os.path.basename(file_path))[0].replace( 29 | '_results', '') 30 | print(f'Processing file: {file_path} (Model: {model_name})') 31 | with open(file_path, 'r') as f: 32 | data = json.load(f) 33 | weighted_summary = data.get('weighted_summary', {}) 34 | models_data[model_name] = { 35 | task: weighted_summary.get(task, 36 | {}).get('weighted_average', 0) 37 | for task in categories 38 | } 39 | return categories, models_data 40 | 41 | 42 | def create_radar_chart(categories, models_data, output_file='radar_chart.pdf'): 43 | """Create a single radar chart with task-specific normalization.""" 44 | if not categories or not models_data: 45 | print('No data to plot.') 46 | return 47 | 48 | print(f'Categories for radar chart: {categories}') 49 | print(f'Models data for plotting: {json.dumps(models_data, indent=4)}') 50 | 51 | N = len(categories) 52 | angles = [n / float(N) * 2 * pi for n in range(N)] 53 | angles += angles[:1] 54 | 55 | fig, ax = plt.subplots(figsize=(12, 9), subplot_kw=dict(polar=True)) 56 | task_ranges = {} 57 | for task in categories: 58 | task_values = [scores[task] for scores in models_data.values()] 59 | min_val = min(task_values) 60 | max_val = max(task_values) if max(task_values) > min( 61 | task_values) else min(task_values) + 1 62 | task_ranges[task] = (min_val, max_val) 63 | print(f'Task: {task}, Min: {min_val}, Max: {max_val}') 64 | 65 | for model, scores in models_data.items(): 66 | values = [(scores[task] - task_ranges[task][0]) / 67 | (task_ranges[task][1] - task_ranges[task][0]) 68 | if task_ranges[task][1] > task_ranges[task][0] else 0 69 | for task in categories] 70 | values += values[:1] 71 | print(f'Normalized values for {model}: {values[:-1]}') 72 | 73 | ax.plot(angles, values, linewidth=1, linestyle='solid', label=model) 74 | ax.fill(angles, values, alpha=0.1) 75 | 76 | plt.xticks(angles[:-1], categories, fontsize=18, rotation=90) 77 | ax.set_yticklabels([f'{tick:.2f}' for tick in ax.get_yticks()], 78 | fontsize=10) 79 | plt.legend(loc='upper right', bbox_to_anchor=(1.3, 1)) 80 | plt.savefig(output_file, format='pdf', bbox_inches='tight') 81 | print(f"Radar chart saved to '{output_file}'") 82 | plt.close() 83 | 84 | 85 | def main(): 86 | results_dir = 'evaluation_results' 87 | categories, models_data = load_evaluation_results(results_dir) 88 | if models_data: 89 | create_radar_chart(categories, models_data) 90 | else: 91 | print('Failed to load data; radar chart not generated.') 92 | 93 | 94 | if __name__ == '__main__': 95 | main() 96 | -------------------------------------------------------------------------------- /playground/games/minesweeper/minesweeper_ui.py: -------------------------------------------------------------------------------- 1 | import string 2 | 3 | from PyQt5.QtCore import QSize, Qt 4 | from PyQt5.QtGui import QBrush, QIcon, QPainter, QPalette, QPen, QPixmap 5 | from PyQt5.QtWidgets import (QGridLayout, QHBoxLayout, QLabel, QPushButton, 6 | QWidget) 7 | 8 | from playground.games.minesweeper.game_cfg import (IMG_BOMB, IMG_CLOCK, 9 | NUM_COLORS, STATUS_ICONS) 10 | from playground.state_code import GameStatus 11 | 12 | 13 | class MinesweeperUI: 14 | 15 | def __init__(self, parent, b_size): 16 | self.centralwidget = QWidget(parent) 17 | self.centralwidget.setObjectName('centralwidget') 18 | self.b_size = b_size 19 | 20 | self.gridLayout = QGridLayout(self.centralwidget) 21 | self.headerLayout = QHBoxLayout() 22 | self.minesLabel = QLabel(self.centralwidget) 23 | self.clockLabel = QLabel(self.centralwidget) 24 | self.statusButton = QPushButton(self.centralwidget) 25 | 26 | bomb_icon = QLabel(self.centralwidget) 27 | bomb_icon.setPixmap(QPixmap.fromImage(IMG_BOMB)) 28 | bomb_icon.setFixedSize(QSize(32, 32)) 29 | bomb_icon.setScaledContents(True) 30 | 31 | clock_icon = QLabel(self.centralwidget) 32 | clock_icon.setPixmap(QPixmap.fromImage(IMG_CLOCK)) 33 | clock_icon.setFixedSize(QSize(32, 32)) 34 | clock_icon.setScaledContents(True) 35 | 36 | self.headerLayout.addWidget(bomb_icon) 37 | self.headerLayout.addWidget(self.minesLabel) 38 | self.headerLayout.addWidget(self.statusButton) 39 | self.headerLayout.addWidget(self.clockLabel) 40 | self.headerLayout.addWidget(clock_icon) 41 | 42 | self.gridLayout.addLayout(self.headerLayout, 0, 0, 1, 1) 43 | self.gameGrid = QGridLayout() 44 | self.gridLayout.addLayout(self.gameGrid, 1, 0, 1, 1) 45 | 46 | self.statusButton.setIcon(QIcon(STATUS_ICONS[GameStatus.IN_PROGRESS])) 47 | self.statusButton.setIconSize(QSize(32, 32)) 48 | self.statusButton.setFlat(True) 49 | 50 | for x in range(1, self.b_size + 1): 51 | label = QLabel(str(x)) 52 | label.setAlignment(Qt.AlignHCenter | Qt.AlignVCenter) 53 | self.gameGrid.addWidget(label, 0, x) 54 | for y in range(1, self.b_size + 1): 55 | label = QLabel(string.ascii_lowercase[y - 1]) 56 | label.setAlignment(Qt.AlignHCenter | Qt.AlignVCenter) 57 | self.gameGrid.addWidget(label, y, 0) 58 | for x in range(1, self.b_size + 1): 59 | for y in range(1, self.b_size + 1): 60 | self.gameGrid.addWidget(Pos(x - 1, y - 1), y, x) 61 | 62 | 63 | class Pos(QWidget): 64 | 65 | def __init__(self, x, y): 66 | super().__init__() 67 | self.setFixedSize(QSize(20, 20)) 68 | self.x = x 69 | self.y = y 70 | self.is_mine = False 71 | self.adjacent_n = 0 72 | self.is_revealed = False 73 | 74 | def paintEvent(self, event): 75 | p = QPainter(self) 76 | p.setRenderHint(QPainter.Antialiasing) 77 | r = event.rect() 78 | if self.is_revealed: 79 | color = self.palette().color(QPalette.Background) 80 | outer, inner = color, color 81 | else: 82 | outer, inner = Qt.gray, Qt.lightGray 83 | p.fillRect(r, QBrush(inner)) 84 | pen = QPen(outer) 85 | pen.setWidth(1) 86 | p.setPen(pen) 87 | p.drawRect(r) 88 | if self.is_revealed: 89 | if self.is_mine: 90 | p.drawPixmap(r, QPixmap(IMG_BOMB)) 91 | elif self.adjacent_n > 0: 92 | pen = QPen(NUM_COLORS[self.adjacent_n]) 93 | p.setPen(pen) 94 | f = p.font() 95 | f.setBold(True) 96 | p.setFont(f) 97 | p.drawText(r, Qt.AlignHCenter | Qt.AlignVCenter, 98 | str(self.adjacent_n)) 99 | -------------------------------------------------------------------------------- /configs/games/sudoku.py: -------------------------------------------------------------------------------- 1 | from playground.games import SudokuQuestionAnswering 2 | 3 | _base_ = ['configs/base.py'] 4 | 5 | game_name = 'sudoku' 6 | game_description = dict( 7 | e2e=( 8 | 'Sudoku is a logic-based puzzle played on a 9x9 grid, subdivided into ' 9 | 'nine 3x3 subgrids. The goal is to fill the grid so that each row, ' 10 | 'column, and 3x3 subgrid contains all digits from 1 to 9 without ' 11 | 'repetition. The grid starts with some cells filled with numbers ' 12 | '(clues), and you must fill in the remaining empty cells one at a ' 13 | 'time. Rows are labeled A to I (top to bottom), and columns are ' 14 | 'numbered 1 to 9 (left to right). Each move involves placing a digit ' 15 | '(1-9) in an empty cell, ensuring no repetition in its row, column, ' 16 | 'or 3x3 subgrid. Based on the current board state screenshot, observe ' 17 | 'the situation, formulate a strategy, and output a single valid move ' 18 | 'to place a digit.\nPlease strictly follow the format:\n' 19 | 'Observation: \n' 20 | 'Strategy: \n' 21 | 'Movement: \n' 22 | 'where is A to I, is 1 to 9, and is 1 to 9. ' 23 | 'For example, "A1 5" places 5 in the top-left cell.' 24 | ), 25 | perceive=( 26 | 'Sudoku is a logic-based puzzle played on a 9x9 grid, where the grid ' 27 | 'is subdivided into nine 3x3 subgrids. The goal is to fill the grid ' 28 | 'so that each row, each column, and each 3x3 subgrid contains all ' 29 | 'digits from 1 to 9 without repetition. Given a screenshot of the ' 30 | 'Sudoku grid, please represent the current state of the puzzle using ' 31 | 'a 9x9 matrix. In this matrix, an empty cell should be represented by ' 32 | '0, and filled cells should contain their respective numbers (1-9). ' 33 | 'Please strictly follow the format:\n' 34 | 'Game State: \n' 35 | 'where is a 9x9 matrix. For example,\n' 36 | 'Game State: [[5, 3, 0, 0, 7, 0, 0, 0, 0], ' 37 | '[6, 0, 0, 1, 9, 5, 0, 0, 0], [0, 9, 8, 0, 0, 0, 0, 6, 0], ' 38 | '[8, 0, 0, 0, 6, 0, 0, 0, 3], [4, 0, 0, 8, 0, 3, 0, 0, 1], ' 39 | '[7, 0, 0, 0, 2, 0, 0, 0, 6], [0, 6, 0, 0, 0, 0, 2, 8, 0], ' 40 | '[0, 0, 0, 4, 1, 9, 0, 0, 5], [0, 0, 0, 0, 8, 0, 0, 7, 9]]\n' 41 | 'represents a partially filled Sudoku grid with some cells empty ' 42 | '(represented by 0).' 43 | ), 44 | qa=( 45 | 'Sudoku is played on a 9x9 grid, where each row, column, and 3x3 ' 46 | 'subgrid must contain the numbers 1 to 9 exactly once. Please answer ' 47 | 'the following question based on the provided screenshot of the ' 48 | 'current game state:\n' 49 | '{question}\n' 50 | 'Answer: \n' 51 | 'where should be one of A, B, C, or D.' 52 | ), 53 | rule=( 54 | 'Sudoku is played on a 9x9 grid, where each row, column, and 3x3 ' 55 | 'subgrid must contain the numbers 1 to 9 exactly once. The grid ' 56 | 'starts with the top-left corner as A1, where rows are labeled from ' 57 | 'A to I and columns are numbered from 1 to 9. A valid move involves ' 58 | 'placing a digit from 1 to 9 in an empty cell, ensuring that the ' 59 | 'number does not already appear in the same row, column, or 3x3 ' 60 | 'subgrid. Based on the current state of the Sudoku grid, please find ' 61 | 'a valid empty cell where you can place a digit and make a valid ' 62 | 'move.\nPlease strictly follow the format:\n' 63 | 'Movement: \n' 64 | 'where is A to I (representing rows), is 1 to 9 ' 65 | '(representing columns), and is a number between 1 to 9. For ' 66 | 'example, A1 5 means placing the digit 5 in the top-left corner of ' 67 | 'the grid.' 68 | ) 69 | ) 70 | 71 | player_first = True 72 | qa = SudokuQuestionAnswering 73 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | game_history 164 | helpers 165 | experiments 166 | archival_experiment_results 167 | rebuttal-exp 168 | benchmark/ 169 | evaluation_results 170 | .vscode 171 | -------------------------------------------------------------------------------- /configs/games/minesweeper.py: -------------------------------------------------------------------------------- 1 | from playground.games import MinesweeperQuestionAnswering 2 | 3 | _base_ = ['configs/base.py'] 4 | 5 | game_name = 'minesweeper' 6 | game_description = dict( 7 | e2e=( 8 | 'Minesweeper is a logic-based puzzle game played on an 8x8 grid with ' 9 | 'exactly 10 mines. Each cell can either contain a mine, be ' 10 | 'unrevealed, or show the number of adjacent mines (from 0 to 8). The ' 11 | 'goal is to reveal all cells that do not contain mines without ' 12 | 'triggering any mines. You win when only the 10 mine cells remain ' 13 | 'unrevealed. The grid is labeled with rows A to H (top to bottom) ' 14 | 'and columns 1 to 8 (left to right). You can only reveal cells (e.g., ' 15 | '"A1"), and flagging is not allowed. Based on the current board state ' 16 | 'screenshot, please observe the situation, formulate a strategy, and ' 17 | 'output a move to reveal a cell. Please strictly follow the format:\n' 18 | 'Observation: \n' 19 | 'Strategy: \n' 20 | 'Movement: \n' 21 | 'where summarizes the current board state, ' 22 | ' explains your reasoning, and is a move in the ' 23 | 'format "A1".' 24 | ), 25 | perceive=( 26 | 'Minesweeper is a logic-based puzzle game played on an 8x8 grid with ' 27 | 'exactly 10 mines. Each cell can either contain a mine (represented ' 28 | 'by 9), or it can be empty. Unrevealed cells should be represented by ' 29 | '-1. Cells that are revealed and contain no adjacent mines are ' 30 | 'represented by 0, while cells that are revealed and show a number ' 31 | 'from 1 to 8 indicate how many adjacent mines surround the cell. ' 32 | 'Mines are represented by the number 9.\nPlease strictly follow the ' 33 | 'format:\n' 34 | 'Game State: \n' 35 | 'where is an 8x8 matrix representing the game grid, ' 36 | 'with unrevealed cells as -1, mines as 9, and numbers from 0 to 8 ' 37 | 'indicating the number of adjacent mines. For example,\n' 38 | 'Game State: [[-1, 1, 1, 2, -1, -1, -1, 1], ' 39 | '[1, 2, 3, 2, -1, 1, 1, 2], ' 40 | '[2, 3, 4, 3, 1, 1, 1, 2], ' 41 | '[1, 2, 2, 2, 2, 2, 2, 1], ' 42 | '[1, 2, 2, 2, 2, 9, 1, 0], ' 43 | '[-1, 1, 2, 3, 2, 1, 0, -1], ' 44 | '[-1, -1, 1, 2, 3, 1, -1, -1], ' 45 | '[1, 2, 9, 1, 1, -1, -1, -1]]\n' 46 | 'This example represents a grid where some cells have been uncovered, ' 47 | 'showing numbers indicating adjacent mines, unrevealed cells are ' 48 | 'represented by -1, and mines are represented by the number 9.' 49 | ), 50 | qa=( 51 | 'Minesweeper is played on an 8x8 grid with exactly 10 mines. Each ' 52 | 'cell can either contain a mine, be unrevealed, or show the number of ' 53 | 'adjacent mines (from 0 to 8). The grid is labeled with rows A to H ' 54 | '(top to bottom) and columns 1 to 8 (left to right). Please answer ' 55 | 'the following question based on the provided screenshot of the ' 56 | 'current game state:\n' 57 | '{question}\n' 58 | 'Answer: \n' 59 | 'where should be one of A, B, C, or D.' 60 | ), 61 | rule=( 62 | 'Minesweeper is played on an 8x8 grid with exactly 10 mines. Each ' 63 | 'cell can either contain a mine, be unrevealed, or show the number of ' 64 | 'adjacent mines (from 0 to 8). The goal is to reveal all cells that ' 65 | 'do not contain mines, without triggering any mines. You win when ' 66 | 'only the 10 mine cells remain unrevealed. The grid is labeled with ' 67 | 'rows A to H (top to bottom) and columns 1 to 8 (left to right). Each ' 68 | 'cell can only be revealed once, and flagging is not allowed. Based ' 69 | 'on the current board state image, please find a cell to reveal ' 70 | 'next.\nPlease strictly follow the following format:\n' 71 | 'Movement: \n' 72 | 'where the position can be any combination of rows A to H and columns ' 73 | '1 to 8, for example, A1, B5, or H8.' 74 | ) 75 | ) 76 | 77 | level = 'easy' 78 | qa = MinesweeperQuestionAnswering 79 | -------------------------------------------------------------------------------- /configs/games/gomoku.py: -------------------------------------------------------------------------------- 1 | from playground.games import GomokuQuestionAnswering 2 | 3 | _base_ = ['configs/base.py'] 4 | 5 | game_name = 'gomoku' 6 | game_description = dict( 7 | e2e=( 8 | 'Gomoku is played on a 15x15 grid, where players take turns placing ' 9 | 'black or white stones on the intersections. The goal is to be the ' 10 | 'first to form an unbroken line of five stones horizontally, ' 11 | 'vertically, or diagonally. The game starts with an empty board, and ' 12 | 'the black player (you) goes first, followed by the white player ' 13 | '(AI). The grid is labeled with columns A to O (left to right) and ' 14 | 'rows 1 to 15 (top to bottom). You are playing as black, aiming to ' 15 | 'win by placing stones strategically. Each intersection can only be ' 16 | 'occupied by one stone, so do not choose a spot that is already ' 17 | 'taken. Based on the board state screenshots, please first observe ' 18 | 'the current situation, then carefully think and explain your ' 19 | 'strategy briefly, and finally output your movement for this status. ' 20 | 'Please strictly follow the following format:\n' 21 | 'Observation: \n' 22 | 'Strategy: \n' 23 | 'Movement: \n' 24 | 'where the observation should briefly summarize the current ' 25 | 'situation, the strategy is a brief explanation of how you plan to ' 26 | 'win the game, and the position can be any combination of columns A ' 27 | 'to O and rows 1 to 15, for example, A1, H8, or O15.' 28 | ), 29 | perceive=( 30 | 'Gomoku is a game played on a 15x15 grid where players take turns ' 31 | 'placing black or white stones on the intersections. Given a ' 32 | 'screenshot of the Gomoku board, please determine the current game ' 33 | 'state using a 15x15 matrix. In this matrix, an empty intersection ' 34 | 'should be represented by 0, a black stone by 1, and a white stone by ' 35 | '2. Please strictly follow the format:\n' 36 | 'Game State: \n' 37 | 'where is a 15x15 matrix. For example,\n' 38 | 'Game State: [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], ' 39 | '[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], ' 40 | '[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], ' 41 | '[0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0], ' 42 | '[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], ' 43 | '[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], ' 44 | '[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], ' 45 | '[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], ' 46 | '[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], ' 47 | '[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], ' 48 | '[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], ' 49 | '[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], ' 50 | '[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], ' 51 | '[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], ' 52 | '[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]\n' 53 | 'represents a partially filled Gomoku board.' 54 | ), 55 | qa=( 56 | 'Gomoku is played on a 15x15 grid, where black and white stones are ' 57 | 'placed in turns. The goal is to place five consecutive stones in a ' 58 | 'horizontal, vertical, or diagonal line. Please answer the following ' 59 | 'question based on the provided screenshot of the current game ' 60 | 'state:\n' 61 | '{question}\n' 62 | 'Answer: \n' 63 | 'where should be one of A, B, C, or D.' 64 | ), 65 | rule=( 66 | 'Gomoku is played on a 15x15 grid, where black and white stones are ' 67 | 'placed on the intersections of the grid. The objective is to place ' 68 | 'five consecutive stones in a horizontal, vertical, or diagonal line. ' 69 | 'The game starts with an empty board. The grid is labeled with ' 70 | 'columns A to O (left to right) and rows 1 to 15 (top to bottom). ' 71 | 'Each intersection can only be occupied by one stone, either black or ' 72 | 'white. Based on the board state, please find an empty intersection ' 73 | 'where you can place your next stone.\n' 74 | 'Please strictly follow the following format:\n' 75 | 'Movement: \n' 76 | 'where the position can be any combination of columns A to O and rows ' 77 | '1 to 15, for example, A1, H8, or O15.' 78 | ) 79 | ) 80 | 81 | chessboard_size = 15 82 | player_first = True 83 | qa = GomokuQuestionAnswering 84 | -------------------------------------------------------------------------------- /playground/agents/single_step_agents.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | 4 | import anthropic 5 | import google.generativeai as genai 6 | import requests 7 | from lmdeploy import pipeline 8 | from lmdeploy.vl import load_image 9 | 10 | from playground.agents import BaseAgent 11 | from playground.registry import AGENT_REGISTRY 12 | from playground.utils import encode_image 13 | 14 | 15 | @AGENT_REGISTRY.register('openai_single') 16 | class OpenAIAgentSingleStep(BaseAgent): 17 | 18 | def __init__(self, agent_cfg): 19 | super().__init__(agent_cfg) 20 | api_key = os.getenv('OPENAI_API_KEY') 21 | self.headers = { 22 | 'Content-Type': 'application/json', 23 | 'Authorization': f'Bearer {api_key}' 24 | } 25 | self.base_payload = { 26 | 'model': agent_cfg.lmm_agent.model, 27 | 'max_tokens': agent_cfg.lmm_agent.max_tokens 28 | } 29 | self.input_sz = agent_cfg.lmm_agent.image_size 30 | 31 | def get_decision(self, screenshot_path: str, prompt: str): 32 | base64_image = encode_image(screenshot_path, self.input_sz) 33 | payload = self.base_payload.copy() 34 | payload['messages'] = [{ 35 | 'role': 36 | 'user', 37 | 'content': [{ 38 | 'type': 'text', 39 | 'text': prompt 40 | }, { 41 | 'type': 'image_url', 42 | 'image_url': { 43 | 'url': f'data:image/jpeg;base64,{base64_image}' 44 | } 45 | }] 46 | }] 47 | outputs = requests.post('https://api.openai.com/v1/chat/completions', 48 | headers=self.headers, 49 | json=payload) 50 | outputs = outputs.json() 51 | return outputs['choices'][0]['message']['content'] 52 | 53 | 54 | @AGENT_REGISTRY.register('google_single') 55 | class GoogleAIAgentSingleStep(BaseAgent): 56 | 57 | def __init__(self, agent_cfg): 58 | super().__init__(agent_cfg) 59 | self.api_key = os.getenv('GOOGLE_API_KEY') 60 | genai.configure(api_key=self.api_key) 61 | self.model = genai.GenerativeModel( 62 | model_name=agent_cfg.lmm_agent.model) 63 | 64 | def get_decision(self, screenshot_path: str, prompt: str): 65 | image = { 66 | 'mime_type': 'image/png', 67 | 'data': pathlib.Path(screenshot_path).read_bytes() 68 | } 69 | outputs = self.model.generate_content([prompt, image]) 70 | return outputs.text 71 | 72 | 73 | @AGENT_REGISTRY.register('anhthropic_single') 74 | class AnthropicAgentSingleStep(BaseAgent): 75 | 76 | def __init__(self, agent_cfg): 77 | super().__init__(agent_cfg) 78 | self.base_payload = { 79 | 'model': agent_cfg.lmm_agent.model, 80 | 'max_tokens': agent_cfg.lmm_agent.max_tokens 81 | } 82 | self.input_sz = agent_cfg.lmm_agent.image_size 83 | self.model = anthropic.Anthropic() 84 | 85 | def get_decision(self, screenshot_path: str, prompt: str): 86 | base64_image = encode_image(screenshot_path, self.input_sz) 87 | payload = self.base_payload.copy() 88 | payload['messages'] = [{ 89 | 'role': 90 | 'user', 91 | 'content': [{ 92 | 'type': 'image', 93 | 'source': { 94 | 'type': 'base64', 95 | 'media_type': 'image/png', 96 | 'data': base64_image 97 | } 98 | }, { 99 | 'type': 'text', 100 | 'text': prompt 101 | }] 102 | }] 103 | outputs = self.model.messages.create(**payload) 104 | return outputs.content[0].text 105 | 106 | 107 | @AGENT_REGISTRY.register('lmdeploy_single') 108 | class LMDeployAgentSingleStep(BaseAgent): 109 | 110 | def __init__(self, agent_cfg): 111 | super().__init__(agent_cfg) 112 | if agent_cfg.lmm_agent.name == 'deepseek-vl-7b': 113 | self.is_deepseek_vl = True 114 | else: 115 | self.is_deepseek_vl = False 116 | self.model = pipeline( 117 | agent_cfg.lmm_agent.model, 118 | backend_config=agent_cfg.lmm_agent.backend_config) 119 | self.gen_config = agent_cfg.lmm_agent.general_config 120 | 121 | def get_decision(self, screenshot_path: str, prompt: str): 122 | image = load_image(screenshot_path) 123 | if self.is_deepseek_vl: 124 | prompt = '' + prompt 125 | outputs = self.model((prompt, image), gen_config=self.gen_config) 126 | return outputs.text 127 | -------------------------------------------------------------------------------- /configs/games/chess.py: -------------------------------------------------------------------------------- 1 | from playground.games import ChessQuestionAnswering 2 | 3 | _base_ = ['configs/base.py'] 4 | 5 | game_name = 'chess' 6 | game_description = dict( 7 | e2e=( 8 | 'Chess is a strategy game played on an 8x8 board with 64 squares, ' 9 | 'using six types of pieces: pawns, knights, bishops, rooks, queens, ' 10 | 'and kings, for both white and black players. The game starts with a ' 11 | 'standard initial position: white pieces on ranks 1 and 2, black ' 12 | 'pieces on ranks 7 and 8. The board uses algebraic coordinates with ' 13 | 'files labeled "a" through "h" from left to right and ranks labeled ' 14 | '"1" through "8" from bottom to top (a1 at bottom-left, h8 at ' 15 | 'top-right). White moves first, followed by Black. You are playing as ' 16 | 'White, aiming to checkmate the Black king or achieve a favorable ' 17 | 'position. Each move must follow standard chess rules and be ' 18 | 'expressed in Standard Algebraic Notation (SAN), such as "e4" (pawn ' 19 | 'to e4), "Nf3" (knight to f3), or "O-O" (kingside castling). Based on ' 20 | 'the board state screenshots, please first observe the current ' 21 | 'situation, then carefully think and explain your strategy briefly, ' 22 | 'and finally output your movement for this status. Please strictly ' 23 | 'follow the following format:\n' 24 | 'Observation: \n' 25 | 'Strategy: \n' 26 | 'Movement: \n' 27 | 'where the observation should briefly summarize the current ' 28 | 'situation, the strategy is a brief explanation of how you plan to ' 29 | 'win or improve your position, and the position is a legal move in ' 30 | 'SAN, for example, "e4", "Nf3", or "O-O".' 31 | ), 32 | perceive=( 33 | 'Chess is a strategy game played on an 8x8 board with 64 squares, ' 34 | 'using six types of pieces: pawns, knights, bishops, rooks, queens, ' 35 | 'and kings, for both white and black players. You are provided with ' 36 | 'an image of a chessboard, and your task is to represent the current ' 37 | 'state of the game as an 8x8 matrix using the specified numerical ' 38 | 'format. Each type of chess piece, both black and white, is ' 39 | 'represented by a unique number:\n- Empty squares: 0\n' 40 | '- White pieces: Pawn=1, Knight=2, Bishop=3, Rook=4, Queen=5, King=6\n' 41 | '- Black pieces: Pawn=-1, Knight=-2, Bishop=-3, Rook=-4, Queen=-5, ' 42 | 'King=-6\n\nFrom the provided chessboard image, convert the visible ' 43 | 'board into this 8x8 matrix format. For example, the initial chess ' 44 | 'position would be represented as:\n' 45 | 'Game State: [[-4, -2, -3, -5, -6, -3, -2, -4],\n' 46 | '[-1, -1, -1, -1, -1, -1, -1, -1],\n' 47 | '[0, 0, 0, 0, 0, 0, 0, 0],\n' 48 | '[0, 0, 0, 0, 0, 0, 0, 0],\n' 49 | '[0, 0, 0, 0, 0, 0, 0, 0],\n' 50 | '[0, 0, 0, 0, 0, 0, 0, 0],\n' 51 | '[1, 1, 1, 1, 1, 1, 1, 1],\n' 52 | '[4, 2, 3, 5, 6, 3, 2, 4]]\n\n' 53 | 'Ensure that your output strictly follows this matrix format with no ' 54 | 'deviations, based on the pieces shown in the image.' 55 | ), 56 | qa=( 57 | 'Chess is a strategy game played on an 8x8 board with 64 squares, ' 58 | 'using six types of pieces: pawns, knights, bishops, rooks, queens, ' 59 | 'and kings, for both white and black players. The board uses a ' 60 | 'coordinate system where columns are labeled "a" through "h" from ' 61 | 'left to right, and rows are labeled "1" through "8" from bottom to ' 62 | 'top (a1 at bottom-left, h8 at top-right). Please answer the ' 63 | 'following question based on the provided screenshot of the current ' 64 | 'game state:\n' 65 | '{question}\n' 66 | 'Answer: \n' 67 | 'where should be one of A, B, C, or D.' 68 | ), 69 | rule=( 70 | 'Chess is played on an 8x8 board following standard chess rules. Each ' 71 | 'piece moves according to its unique capabilities. The board uses ' 72 | 'algebraic coordinates with files labeled "a" through "h" from left ' 73 | 'to right and ranks labeled "1" through "8" from bottom to top (a1 at ' 74 | 'bottom-left, h8 at top-right). White moves first, followed by Black. ' 75 | 'Based on the current board state image, please choose one legal move ' 76 | 'for White and output it using Standard Algebraic Notation (SAN). For ' 77 | 'example, if White’s pawn on e2 can move to e4, your answer should be ' 78 | '"e4"; if White’s knight on g1 can move to f3, your answer should be ' 79 | '"Nf3".\nPlease strictly follow the following format:\n' 80 | 'Movement: \n' 81 | 'where is the move in SAN (e.g., "e4", "Nf3", "O-O").' 82 | ) 83 | ) 84 | 85 | player_first = True 86 | user_is_white = True 87 | qa = ChessQuestionAnswering 88 | -------------------------------------------------------------------------------- /configs/games/reversi.py: -------------------------------------------------------------------------------- 1 | from playground.games import ReversiQuestionAnswering 2 | 3 | _base_ = ['configs/base.py'] 4 | 5 | game_name = 'reversi' 6 | game_description = dict( 7 | e2e=( 8 | 'Reversi (also known as Othello) is a strategy board game played on ' 9 | 'an 8x8 grid. Players take turns placing black and white pieces on ' 10 | 'the board. The goal is to have more pieces of your color on the ' 11 | 'board than your opponent by the end of the game. A valid move must ' 12 | 'sandwich one or more of the opponent\'s pieces between the newly ' 13 | 'placed piece and another of your pieces in a horizontal, vertical, ' 14 | 'or diagonal line, flipping the sandwiched pieces to your color. The ' 15 | 'game starts with four pieces in the center: two black (at D4 and E5) ' 16 | 'and two white (at D5 and E4). The black player (you) goes first, ' 17 | 'followed by the white player (AI). The grid is labeled with rows A ' 18 | 'to H (top to bottom) and columns 1 to 8 (left to right). You are ' 19 | 'playing as black, aiming to maximize your pieces by placing them ' 20 | 'strategically. Based on the board state screenshots, please first ' 21 | 'observe the current situation, then carefully think and explain your ' 22 | 'strategy briefly, and finally output your movement for this status. ' 23 | 'Please strictly follow the following format:\n' 24 | 'Observation: \n' 25 | 'Strategy: \n' 26 | 'Movement: \n' 27 | 'where the observation should briefly summarize the current ' 28 | 'situation, the strategy is a brief explanation of how you plan to ' 29 | 'maximize your pieces, and the position can be any combination of ' 30 | 'rows A to H and columns 1 to 8, for example, A1, D4, or H8.' 31 | ), 32 | perceive=( 33 | 'Reversi (also known as Othello) is a strategy board game played on ' 34 | 'an 8x8 grid, where two players take turns placing black and white ' 35 | 'pieces on the board. The goal is to have more pieces of your color ' 36 | 'on the board than your opponent by the end of the game. A piece is ' 37 | 'placed on an empty square and must sandwich one or more of the ' 38 | 'opponent\'s pieces between the newly placed piece and another of the ' 39 | 'player\'s pieces in a horizontal, vertical, or diagonal line. The ' 40 | 'opponent\'s pieces in between are then flipped to the player\'s ' 41 | 'color. Given a screenshot of the Reversi board, please represent the ' 42 | 'current state of the game using an 8x8 matrix. In this matrix, empty ' 43 | 'cells should be represented by 0, black pieces by 1, and white ' 44 | 'pieces by 2. Please strictly follow the format:\n' 45 | 'Game State: \n' 46 | 'where is an 8x8 matrix. For example,\n' 47 | 'Game State: [[0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0], ' 48 | '[0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 2, 1, 0, 0, 0], ' 49 | '[0, 0, 0, 1, 2, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0], ' 50 | '[0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0]]\n' 51 | 'represents a Reversi board with the initial four pieces placed.' 52 | ), 53 | qa=( 54 | 'Reversi (also known as Othello) is played on an 8x8 grid where two ' 55 | 'players take turns placing black and white pieces on the board. ' 56 | 'Please answer the following question based on the provided ' 57 | 'screenshot of the current game state:\n' 58 | '{question}\n' 59 | 'Answer: \n' 60 | 'where should be one of A, B, C, or D.' 61 | ), 62 | rule=( 63 | 'Reversi (also known as Othello) is played on an 8x8 grid. Players ' 64 | 'take turns placing black and white pieces on the board. The grid ' 65 | 'starts with two black pieces and two white pieces in the center ' 66 | '(D4, D5, E4, E5). A valid move consists of placing a piece in such a ' 67 | 'way that it sandwiches one or more of the opponent\'s pieces between ' 68 | 'the newly placed piece and another of the player\'s pieces in a ' 69 | 'horizontal, vertical, or diagonal line. After placing the piece, all ' 70 | 'of the opponent\'s pieces in between are flipped to the player\'s ' 71 | 'color. The black player (you) goes first, followed by the white ' 72 | 'player (AI). The grid is labeled with rows A to H and columns 1 to ' 73 | '8. Based on the current game state, please find a valid position ' 74 | 'where you can place your next black piece and flip at least one of ' 75 | 'the opponent\'s white pieces.\n' 76 | 'Please strictly follow the format:\n' 77 | 'Movement: \n' 78 | 'where the position can be any combination of rows A to H and columns ' 79 | '1 to 8, for example, A1, D4, or H8.' 80 | ) 81 | ) 82 | 83 | player_first = True 84 | qa = ReversiQuestionAnswering 85 | -------------------------------------------------------------------------------- /playground/games/sudoku/sudoku_ui.py: -------------------------------------------------------------------------------- 1 | import string 2 | 3 | from PyQt5.QtCore import Qt 4 | from PyQt5.QtWidgets import QLabel, QPushButton, QVBoxLayout, QWidget 5 | 6 | 7 | class SudokuUI: 8 | 9 | def __init__(self, parent): 10 | self.centralwidget = QWidget(parent) 11 | self.centralwidget.setObjectName('centralwidget') 12 | 13 | layout = QVBoxLayout(self.centralwidget) 14 | layout.setSpacing(0) 15 | self.secondPageWidget = QWidget(self.centralwidget) 16 | layout.addWidget(self.secondPageWidget) 17 | 18 | self.centralwidget.setStyleSheet('background-color: white') 19 | self.secondPageWidget.setStyleSheet( 20 | 'margin: 0; padding: 0px; background-color: white;') 21 | 22 | self.setup_second_page() 23 | 24 | def setup_second_page(self): 25 | self.setup_borders() 26 | self.setup_grid_buttons() 27 | self.setup_row_labels() 28 | self.setup_col_labels() 29 | 30 | def setup_borders(self): 31 | self.border_lines = [QWidget(self.secondPageWidget) for _ in range(4)] 32 | self.border_lines[0].setGeometry(50, 100, 424, 3) 33 | self.border_lines[1].setGeometry(50, 100, 3, 424) 34 | self.border_lines[2].setGeometry(470, 100, 3, 424) 35 | self.border_lines[3].setGeometry(50, 520, 424, 3) 36 | for line in self.border_lines: 37 | line.setStyleSheet('background-color: black') 38 | 39 | self.main_vertical_lines = [ 40 | QWidget(self.secondPageWidget) for _ in range(2) 41 | ] 42 | self.main_vertical_lines[0].setGeometry(190, 100, 3, 422) 43 | self.main_vertical_lines[1].setGeometry(330, 100, 3, 422) 44 | for line in self.main_vertical_lines: 45 | line.setStyleSheet('background-color: black') 46 | 47 | self.main_horizontal_lines = [ 48 | QWidget(self.secondPageWidget) for _ in range(2) 49 | ] 50 | self.main_horizontal_lines[0].setGeometry(50, 240, 422, 3) 51 | self.main_horizontal_lines[1].setGeometry(50, 380, 422, 3) 52 | for line in self.main_horizontal_lines: 53 | line.setStyleSheet('background-color: black') 54 | 55 | self.internal_vertical_lines = [ 56 | QWidget(self.secondPageWidget) for _ in range(6) 57 | ] 58 | row_gap = 98 59 | for i, line in enumerate(self.internal_vertical_lines): 60 | if i > 0 and i % 2 == 0: 61 | row_gap += 48 62 | line.setGeometry(row_gap, 100, 1, 422) 63 | line.setStyleSheet('background-color: black') 64 | row_gap += 46 65 | 66 | self.internal_horizontal_lines = [ 67 | QWidget(self.secondPageWidget) for _ in range(6) 68 | ] 69 | col_gap = 148 70 | for i, line in enumerate(self.internal_horizontal_lines): 71 | if i > 0 and i % 2 == 0: 72 | col_gap += 48 73 | line.setGeometry(50, col_gap, 422, 1) 74 | line.setStyleSheet('background-color: black') 75 | col_gap += 46 76 | 77 | def setup_grid_buttons(self): 78 | self.puzzle_buttons = [[ 79 | QPushButton('', self.secondPageWidget) for _ in range(9) 80 | ] for _ in range(9)] 81 | gap_col = 0 82 | for j in range(9): 83 | gap_row = 0 84 | if j % 3 == 0: 85 | gap_col += 2 86 | for i in range(9): 87 | if i % 3 == 0: 88 | gap_row += 2 89 | btn = self.puzzle_buttons[j][i] 90 | btn.setGeometry(51 + i * 45 + gap_row, 101 + j * 45 + gap_col, 91 | 45, 45) 92 | btn.setStyleSheet( 93 | 'background-color: white; font-family: sans-serif; font-size: 25px; border: 1px solid black;' # noqa 94 | ) 95 | gap_row += 1 96 | gap_col += 1 97 | 98 | def setup_row_labels(self): 99 | """Setup row labels A-I on the left side of the grid.""" 100 | self.row_labels = [] 101 | for i in range(9): 102 | label = QLabel(string.ascii_uppercase[i], self.secondPageWidget) 103 | label.setGeometry(30, 101 + i * 45 + (45 - 25) // 2, 20, 25) 104 | label.setStyleSheet( 105 | 'color: black; background-color: white; font-family: sans-serif; font-size: 20px;' # noqa 106 | ) 107 | label.setAlignment(Qt.AlignCenter) 108 | self.row_labels.append(label) 109 | 110 | def setup_col_labels(self): 111 | """Setup column labels 1-9 on the top of the grid.""" 112 | self.col_labels = [] 113 | for i in range(9): 114 | label = QLabel(str(i + 1), self.secondPageWidget) 115 | label.setGeometry(51 + i * 45 + (45 - 25) // 2, 75, 25, 20) 116 | label.setStyleSheet( 117 | 'color: black; background-color: white; font-family: sans-serif; font-size: 20px;' # noqa 118 | ) 119 | label.setAlignment(Qt.AlignCenter) 120 | self.col_labels.append(label) 121 | -------------------------------------------------------------------------------- /playground/games/chess/common/consts.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from functools import reduce 3 | from itertools import chain, combinations 4 | from random import getrandbits 5 | 6 | _64BITS = 0xFFFFFFFFFFFFFFFF 7 | 8 | WHITE = 0 9 | BLACK = 1 10 | 11 | PIECE_CONVERSION = { 12 | 1: 'P', # 白兵 13 | 2: 'N', # 白马 14 | 3: 'B', # 白象 15 | 4: 'R', # 白车 16 | 5: 'Q', # 白后 17 | 6: 'K', # 白王 18 | 9: 'p', # 黑兵 19 | 10: 'n', # 黑马 20 | 11: 'b', # 黑象 21 | 12: 'r', # 黑车 22 | 13: 'q', # 黑后 23 | 14: 'k' # 黑王 24 | } 25 | 26 | PIECE_MAP = { 27 | 'P': 'wp', # 白兵 28 | 'N': 'wn', # 白马 29 | 'B': 'wb', # 白象 30 | 'R': 'wr', # 白车 31 | 'Q': 'wq', # 白后 32 | 'K': 'wk', # 白王 33 | 'p': 'bp', # 黑兵 34 | 'n': 'bn', # 黑马 35 | 'b': 'bb', # 黑象 36 | 'r': 'br', # 黑车 37 | 'q': 'bq', # 黑后 38 | 'k': 'bk' # 黑王 39 | } 40 | 41 | RANKS = (RANK_1, RANK_2, RANK_3, RANK_4, RANK_5, RANK_6, RANK_7, 42 | RANK_8) = range(8) 43 | 44 | FILES = (A_FILE, B_FILE, C_FILE, D_FILE, E_FILE, F_FILE, G_FILE, 45 | H_FILE) = range(8) 46 | 47 | RANK_1_BB = 0xFF 48 | RANK_2_BB = RANK_1_BB << 8 49 | RANK_3_BB = RANK_2_BB << 8 50 | RANK_4_BB = RANK_3_BB << 8 51 | RANK_5_BB = RANK_4_BB << 8 52 | RANK_6_BB = RANK_5_BB << 8 53 | RANK_7_BB = RANK_6_BB << 8 54 | RANK_8_BB = RANK_7_BB << 8 55 | 56 | A_FILE_BB = 0x101010101010101 57 | B_FILE_BB = A_FILE_BB << 1 58 | C_FILE_BB = B_FILE_BB << 1 59 | D_FILE_BB = C_FILE_BB << 1 60 | E_FILE_BB = D_FILE_BB << 1 61 | F_FILE_BB = E_FILE_BB << 1 62 | G_FILE_BB = F_FILE_BB << 1 63 | H_FILE_BB = G_FILE_BB << 1 64 | 65 | COLOURS = (WHITE, BLACK) = range(2) 66 | 67 | ALL_PIECES = 0 68 | PIECE_TYPES = (PAWN, KNIGHT, BISHOP, ROOK, QUEEN, KING) = range(1, 7) 69 | 70 | # First two bits used for piece type, third bit used for colour 71 | NO_PIECE = 0 72 | PIECES = (W_PAWN, W_KNIGHT, W_BISHOP, W_ROOK, W_QUEEN, W_KING, B_PAWN, 73 | B_KNIGHT, B_BISHOP, B_ROOK, B_QUEEN, 74 | B_KING) = list(chain(range(1, 7), range(9, 15))) 75 | 76 | # Shift amounts for each direction (other directions obtained by negation) 77 | EAST = 1 78 | NORTHWEST = 7 79 | NORTH = 8 80 | NORTHEAST = 9 81 | 82 | MIDGAME = 0 83 | ENDGAME = 1 84 | 85 | MATERIAL = [None for _ in range(15)] 86 | MATERIAL[NO_PIECE] = (0, 0) 87 | MATERIAL[PAWN] = (128, 213) 88 | MATERIAL[KNIGHT] = (782, 865) 89 | MATERIAL[BISHOP] = (830, 918) 90 | MATERIAL[ROOK] = (1289, 1378) 91 | MATERIAL[QUEEN] = (2529, 2687) 92 | MATERIAL[KING] = (20000, 20000) 93 | 94 | # Penalty for doubled pawns 95 | DOUBLED = (11, 56) 96 | 97 | # Penalty for isolated pawns 98 | ISOLATED = (5, 15) 99 | 100 | # Penalty for backward pawns 101 | BACKWARD = (9, 24) 102 | 103 | MATE = 100000 104 | DRAW = 0 105 | FUTILITY_MARGIN = 400 106 | 107 | # Not using math.inf, as 'INFINITY + 1' is sometimes needed 108 | INFINITY = 1000000 109 | 110 | # Castling sides 111 | KINGSIDE = 1 112 | QUEENSIDE = 1 << 2 113 | 114 | # Castling types 115 | NO_CASTLING = 0 116 | CASTLING_RIGHTS = (W_KINGSIDE, B_KINGSIDE, W_QUEENSIDE, B_QUEENSIDE) = list( 117 | (1 << i for i in range(4))) 118 | 119 | # Move generation types 120 | ALL = 0 121 | QUIETS = 1 122 | CAPTURES = 2 123 | EVASIONS = 3 124 | 125 | # Move types 126 | NORMAL = 0 127 | PROMOTION = 1 << 14 128 | EN_PASSANT = 2 << 14 129 | CASTLING = 3 << 14 130 | 131 | # Promotion types 132 | KNIGHT_PROMOTION = (KNIGHT - 2) << 12 133 | BISHOP_PROMOTION = (BISHOP - 2) << 12 134 | ROOK_PROMOTION = (ROOK - 2) << 12 135 | QUEEN_PROMOTION = (QUEEN - 2) << 12 136 | 137 | # Transposition table entry types 138 | LOWER = 0 139 | UPPER = 1 140 | EXACT = 2 141 | 142 | MoveDetailed = namedtuple('MoveDetailed', 'move captured') 143 | 144 | StateInfo = namedtuple('StateInfo', 145 | 'zobrist en_passant castling_rights halfmove_clock') 146 | 147 | TTEntry = namedtuple('TTEntry', 'zobrist move depth score type') 148 | ZobristTuple = namedtuple('Zobrist', 'board en_passant castling colour') 149 | 150 | PawnEntry = namedtuple('PawnEntry', 'key score_mg score_eg') 151 | 152 | MaterialEntry = namedtuple('MaterialEntry', 'key material_score imbalance') 153 | 154 | # Random 64-bit integer for each combination of square and piece 155 | ZOBRIST_BOARD = [[None for _ in range(64)] for _ in range(16)] 156 | for piece in PIECES: 157 | for sq in range(64): 158 | ZOBRIST_BOARD[piece][sq] = getrandbits(64) 159 | 160 | # Random 64-bit integer for each en-passant file 161 | ZOBRIST_ENPASSANT = [None for _ in range(8)] 162 | for file_num in range(8): 163 | ZOBRIST_ENPASSANT[file_num] = getrandbits(64) 164 | 165 | # Random 64-bit integer for each combination of castling rights 166 | ZOBRIST_CASTLING = [0 for _ in range(16)] 167 | # Individual castling rights 168 | for cr in CASTLING_RIGHTS: 169 | ZOBRIST_CASTLING[cr] = getrandbits(64) 170 | # Combinations of castling rights 171 | for length in range(2, 5): 172 | combos = combinations(CASTLING_RIGHTS, length) 173 | for combo in combos: 174 | index = reduce(lambda x, y: x | y, combo) 175 | for cr in combo: 176 | ZOBRIST_CASTLING[index] ^= cr 177 | 178 | ZOBRIST_COLOUR = getrandbits(64) 179 | -------------------------------------------------------------------------------- /playground/games/reversi/AI.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | 4 | class ReversiAI: 5 | 6 | def valid_move(self, board, x, y, player): 7 | if board[y][x] != 0: 8 | return False 9 | for dx in [-1, 0, 1]: 10 | for dy in [-1, 0, 1]: 11 | if dx == 0 and dy == 0: 12 | continue 13 | nx, ny = x + dx, y + dy 14 | if 0 <= nx < 8 and 0 <= ny < 8 and board[ny][ 15 | nx] == self.opponent(player): 16 | while 0 <= nx < 8 and 0 <= ny < 8: 17 | if board[ny][nx] == player: 18 | return True 19 | if board[ny][nx] == 0: 20 | break 21 | nx += dx 22 | ny += dy 23 | return False 24 | 25 | def make_move(self, board, x, y, player): 26 | board[y][x] = player 27 | for dx in [-1, 0, 1]: 28 | for dy in [-1, 0, 1]: 29 | if dx == 0 and dy == 0: 30 | continue 31 | nx, ny = x + dx, y + dy 32 | if 0 <= nx < 8 and 0 <= ny < 8 and board[ny][ 33 | nx] == self.opponent(player): 34 | nx += dx 35 | ny += dy 36 | while 0 <= nx < 8 and 0 <= ny < 8: 37 | if board[ny][nx] == player: 38 | while True: 39 | nx -= dx 40 | ny -= dy 41 | if nx == x and ny == y: 42 | break 43 | board[ny][nx] = player 44 | break 45 | if board[ny][nx] == 0: 46 | break 47 | nx += dx 48 | ny += dy 49 | 50 | def opponent(self, player): 51 | return 1 if player == 2 else 2 52 | 53 | def score(self, board): 54 | w, b = 0, 0 55 | for row in board: 56 | for cell in row: 57 | if cell == 2: 58 | w += 1 59 | elif cell == 1: 60 | b += 1 61 | return w, b 62 | 63 | def best_move(self, board, depth, player): 64 | moves = [(x, y) for x in range(8) for y in range(8) 65 | if self.valid_move(board, x, y, player)] 66 | if not moves: 67 | return None 68 | 69 | best = None 70 | if player == 2: 71 | max_val = -float('inf') 72 | for x, y in moves: 73 | new_board = copy.deepcopy(board) 74 | self.make_move(new_board, x, y, player) 75 | val = self.alpha_beta(new_board, depth - 1, -float('inf'), 76 | float('inf'), self.opponent(player)) 77 | if val > max_val: 78 | max_val = val 79 | best = (x, y) 80 | else: 81 | min_val = float('inf') 82 | for x, y in moves: 83 | new_board = copy.deepcopy(board) 84 | self.make_move(new_board, x, y, player) 85 | val = self.alpha_beta(new_board, depth - 1, -float('inf'), 86 | float('inf'), self.opponent(player)) 87 | if val < min_val: 88 | min_val = val 89 | best = (x, y) 90 | return best 91 | 92 | def alpha_beta(self, board, depth, alpha, beta, player): 93 | if depth == 0: 94 | w, b = self.score(board) 95 | return w - b if player == 2 else b - w 96 | 97 | moves = [(x, y) for x in range(8) for y in range(8) 98 | if self.valid_move(board, x, y, player)] 99 | if not moves: 100 | return self.alpha_beta(board, depth - 1, alpha, beta, 101 | self.opponent(player)) 102 | 103 | if player == 2: 104 | max_val = -float('inf') 105 | for x, y in moves: 106 | new_board = copy.deepcopy(board) 107 | self.make_move(new_board, x, y, player) 108 | val = self.alpha_beta(new_board, depth - 1, alpha, beta, 109 | self.opponent(player)) 110 | max_val = max(max_val, val) 111 | alpha = max(alpha, val) 112 | if beta <= alpha: 113 | break 114 | return max_val 115 | else: 116 | min_val = float('inf') 117 | for x, y in moves: 118 | new_board = copy.deepcopy(board) 119 | self.make_move(new_board, x, y, player) 120 | val = self.alpha_beta(new_board, depth - 1, alpha, beta, 121 | self.opponent(player)) 122 | min_val = min(min_val, val) 123 | beta = min(beta, val) 124 | if beta <= alpha: 125 | break 126 | return min_val 127 | -------------------------------------------------------------------------------- /playground/games/chess/common/flood_fill.py: -------------------------------------------------------------------------------- 1 | # Returns rook attacks in the north direction 2 | def ratks_n(sq, occ): 3 | sq = 1 << sq 4 | flood = sq 5 | sq = (sq << 8) & ~occ 6 | flood |= sq 7 | sq = (sq << 8) & ~occ 8 | flood |= sq 9 | sq = (sq << 8) & ~occ 10 | flood |= sq 11 | sq = (sq << 8) & ~occ 12 | flood |= sq 13 | sq = (sq << 8) & ~occ 14 | flood |= sq 15 | sq = (sq << 8) & ~occ 16 | flood |= sq 17 | sq = (sq << 8) & ~occ 18 | flood |= sq 19 | flood <<= 8 20 | return flood & 0xFFFFFFFFFFFFFFFF 21 | 22 | 23 | # Returns rook attacks in the east direction 24 | def ratks_e(sq, occ): 25 | occ |= 0x101010101010101 26 | sq = 1 << sq 27 | flood = sq 28 | sq = (sq << 1) & ~occ 29 | flood |= sq 30 | sq = (sq << 1) & ~occ 31 | flood |= sq 32 | sq = (sq << 1) & ~occ 33 | flood |= sq 34 | sq = (sq << 1) & ~occ 35 | flood |= sq 36 | sq = (sq << 1) & ~occ 37 | flood |= sq 38 | sq = (sq << 1) & ~occ 39 | flood |= sq 40 | sq = (sq << 1) & ~occ 41 | flood |= sq 42 | flood <<= 1 43 | flood &= ~0x101010101010101 44 | return flood & 0xFFFFFFFFFFFFFFFF 45 | 46 | 47 | # Returns rook attacks in the south direction 48 | def ratks_s(sq, occ): 49 | sq = 1 << sq 50 | flood = sq 51 | sq = (sq >> 8) & ~occ 52 | flood |= sq 53 | sq = (sq >> 8) & ~occ 54 | flood |= sq 55 | sq = (sq >> 8) & ~occ 56 | flood |= sq 57 | sq = (sq >> 8) & ~occ 58 | flood |= sq 59 | sq = (sq >> 8) & ~occ 60 | flood |= sq 61 | sq = (sq >> 8) & ~occ 62 | flood |= sq 63 | sq = (sq >> 8) & ~occ 64 | flood |= sq 65 | flood >>= 8 66 | return flood & 0xFFFFFFFFFFFFFFFF 67 | 68 | 69 | # Returns rook attacks in the west direction 70 | def ratks_w(sq, occ): 71 | occ |= 0x8080808080808080 72 | sq = 1 << sq 73 | flood = sq 74 | sq = (sq >> 1) & ~occ 75 | flood |= sq 76 | sq = (sq >> 1) & ~occ 77 | flood |= sq 78 | sq = (sq >> 1) & ~occ 79 | flood |= sq 80 | sq = (sq >> 1) & ~occ 81 | flood |= sq 82 | sq = (sq >> 1) & ~occ 83 | flood |= sq 84 | sq = (sq >> 1) & ~occ 85 | flood |= sq 86 | sq = (sq >> 1) & ~occ 87 | flood |= sq 88 | flood >>= 1 89 | flood &= ~0x8080808080808080 90 | return flood & 0xFFFFFFFFFFFFFFFF 91 | 92 | 93 | # Returns bishop attacks in the north-east direction 94 | def batks_ne(sq, occ): 95 | occ |= 0x101010101010101 96 | sq = 1 << sq 97 | flood = sq 98 | sq = (sq << 9) & ~occ 99 | flood |= sq 100 | sq = (sq << 9) & ~occ 101 | flood |= sq 102 | sq = (sq << 9) & ~occ 103 | flood |= sq 104 | sq = (sq << 9) & ~occ 105 | flood |= sq 106 | sq = (sq << 9) & ~occ 107 | flood |= sq 108 | sq = (sq << 9) & ~occ 109 | flood |= sq 110 | sq = (sq << 9) & ~occ 111 | flood |= sq 112 | flood <<= 9 113 | flood &= ~0x101010101010101 114 | return flood & 0xFFFFFFFFFFFFFFFF 115 | 116 | 117 | # Returns bishop attacks in the south-east direction 118 | def batks_se(sq, occ): 119 | occ |= 0x101010101010101 120 | sq = 1 << sq 121 | flood = sq 122 | sq = (sq >> 7) & ~occ 123 | flood |= sq 124 | sq = (sq >> 7) & ~occ 125 | flood |= sq 126 | sq = (sq >> 7) & ~occ 127 | flood |= sq 128 | sq = (sq >> 7) & ~occ 129 | flood |= sq 130 | sq = (sq >> 7) & ~occ 131 | flood |= sq 132 | sq = (sq >> 7) & ~occ 133 | flood |= sq 134 | sq = (sq >> 7) & ~occ 135 | flood |= sq 136 | flood >>= 7 137 | flood &= ~0x101010101010101 138 | return flood & 0xFFFFFFFFFFFFFFFF 139 | 140 | 141 | # Returns bishop attacks in the south-west direction 142 | def batks_sw(sq, occ): 143 | occ |= 0x8080808080808080 144 | sq = 1 << sq 145 | flood = sq 146 | sq = (sq >> 9) & ~occ 147 | flood |= sq 148 | sq = (sq >> 9) & ~occ 149 | flood |= sq 150 | sq = (sq >> 9) & ~occ 151 | flood |= sq 152 | sq = (sq >> 9) & ~occ 153 | flood |= sq 154 | sq = (sq >> 9) & ~occ 155 | flood |= sq 156 | sq = (sq >> 9) & ~occ 157 | flood |= sq 158 | sq = (sq >> 9) & ~occ 159 | flood |= sq 160 | flood >>= 9 161 | flood &= ~0x8080808080808080 162 | return flood & 0xFFFFFFFFFFFFFFFF 163 | 164 | 165 | # Returns bishop attacks in the north-west direction 166 | def batks_nw(sq, occ): 167 | occ |= 0x8080808080808080 168 | sq = 1 << sq 169 | flood = sq 170 | sq = (sq << 7) & ~occ 171 | flood |= sq 172 | sq = (sq << 7) & ~occ 173 | flood |= sq 174 | sq = (sq << 7) & ~occ 175 | flood |= sq 176 | sq = (sq << 7) & ~occ 177 | flood |= sq 178 | sq = (sq << 7) & ~occ 179 | flood |= sq 180 | sq = (sq << 7) & ~occ 181 | flood |= sq 182 | sq = (sq << 7) & ~occ 183 | flood |= sq 184 | flood <<= 7 185 | flood &= ~0x8080808080808080 186 | return flood & 0xFFFFFFFFFFFFFFFF 187 | 188 | 189 | def rook_attacks(sq, occ): 190 | return ratks_n(sq, occ) | ratks_e(sq, occ) | ratks_s(sq, occ) | ratks_w( 191 | sq, occ) 192 | 193 | 194 | def bishop_attacks(sq, occ): 195 | return batks_ne(sq, occ) | batks_se(sq, occ) | batks_sw( 196 | sq, occ) | batks_nw(sq, occ) 197 | -------------------------------------------------------------------------------- /playground/benchmark.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import os.path as osp 4 | import sys 5 | 6 | from pjtools.configurator import AutoConfigurator 7 | from PyQt5.QtWidgets import QApplication 8 | 9 | from playground.registry import GAME_REGISTRY 10 | from playground.utils import set_random_seed 11 | 12 | 13 | class Generator: 14 | 15 | def __init__(self, base_cfg): 16 | cfg = AutoConfigurator.fromfile(base_cfg) 17 | self.benchmark_setting = cfg.benchmark_setting 18 | self.seed = set_random_seed() 19 | self.sample_size = self.benchmark_setting.sample_size 20 | 21 | def generate_benchmark(self): 22 | for task in self.benchmark_setting.offline_task: 23 | for game in self.benchmark_setting.games: 24 | save_path = osp.join(self.benchmark_setting.benchmark_path, 25 | task, game) 26 | if not osp.exists(save_path): 27 | os.makedirs(save_path) 28 | if osp.exists(osp.join(save_path, 'annotation.json')): 29 | print( 30 | f'Benchmark data for {task} in {game} has been found.') 31 | else: 32 | self.render(task, game, save_path) 33 | 34 | def render(self, task, game, save_path): 35 | game_cfg = AutoConfigurator.fromfile(f'configs/games/{game}.py') 36 | app = QApplication(sys.argv) # noqa 37 | if task == 'perceive': 38 | self.render_perceive(game_cfg, save_path) 39 | elif task == 'rule': 40 | self.render_rule(game_cfg, save_path) 41 | elif task == 'qa': 42 | self.render_qa(game_cfg, save_path) 43 | else: 44 | raise ValueError(f'Invalid task: {task}') 45 | 46 | def render_perceive(self, game_cfg, save_path): 47 | game_class = GAME_REGISTRY.get(game_cfg.game_name) 48 | annotations = [] 49 | for i in range(self.sample_size): 50 | game = game_class(game_cfg) 51 | gt = game.get_random_state() 52 | screenshot = game.get_screenshot() 53 | screenshot.save(osp.join(save_path, f'{i:07d}.jpg')) 54 | annotation = { 55 | 'file': f'{i:07d}.jpg', 56 | 'gt': gt, 57 | } 58 | annotations.append(annotation) 59 | with open(osp.join(save_path, 'annotation.json'), 60 | 'w', 61 | encoding='utf-8') as json_file: 62 | json.dump( 63 | { 64 | 'task': 'perceive', 65 | 'game': game_cfg.game_name, 66 | 'annotations': annotations, 67 | }, json_file) 68 | 69 | def render_qa(self, game_cfg, save_path): 70 | game_class = GAME_REGISTRY.get(game_cfg.game_name) 71 | annotations = [] 72 | for i in range(self.sample_size): 73 | game = game_class(game_cfg) 74 | random_state = game.get_random_state() 75 | QA = game_cfg.qa(game_cfg.game_description['qa']) 76 | qa_pairs = QA.get_qa_pairs(random_state) 77 | example_qa = '\n'.join(f'Question: {q}\nAnswer: {a}' 78 | for q, a in qa_pairs[:QA.shot]) 79 | question, answer = qa_pairs[QA.shot] 80 | screenshot = game.get_screenshot() 81 | screenshot.save(osp.join(save_path, f'{i:07d}.jpg')) 82 | annotation = { 83 | 'file': f'{i:07d}.jpg', 84 | 'gt': { 85 | 'question': question, 86 | 'answer': answer, 87 | 'example_qa': example_qa 88 | }, 89 | } 90 | annotations.append(annotation) 91 | with open(osp.join(save_path, 'annotation.json'), 92 | 'w', 93 | encoding='utf-8') as json_file: 94 | json.dump( 95 | { 96 | 'task': 'qa', 97 | 'game': game_cfg.game_name, 98 | 'annotations': annotations, 99 | }, json_file) 100 | 101 | def render_rule(self, game_cfg, save_path): 102 | game_class = GAME_REGISTRY.get(game_cfg.game_name) 103 | annotations = [] 104 | for i in range(self.sample_size): 105 | game = game_class(game_cfg) 106 | rule_state, valid_movements = game.get_rule_state() 107 | screenshot = game.get_screenshot() 108 | screenshot.save(osp.join(save_path, f'{i:07d}.jpg')) 109 | annotation = { 110 | 'file': f'{i:07d}.jpg', 111 | 'gt': { 112 | 'rule_state': rule_state, 113 | 'valid_movements': valid_movements 114 | }, 115 | } 116 | annotations.append(annotation) 117 | with open(osp.join(save_path, 'annotation.json'), 118 | 'w', 119 | encoding='utf-8') as json_file: 120 | json.dump( 121 | { 122 | 'task': 'rule', 123 | 'game': game_cfg.game_name, 124 | 'annotations': annotations, 125 | }, json_file) 126 | -------------------------------------------------------------------------------- /playground/games/chess/common/attack_tables.py: -------------------------------------------------------------------------------- 1 | import playground.games.chess.common.flood_fill as ff 2 | from playground.games.chess.common.consts import (_64BITS, BISHOP, BLACK, KING, 3 | KNIGHT, QUEEN, ROOK, WHITE) 4 | 5 | # The set of squares for possible rook attack blocker pieces 6 | rook_masks = [ 7 | 0x101010101017e, 0x202020202027c, 0x404040404047a, 0x8080808080876, 8 | 0x1010101010106e, 0x2020202020205e, 0x4040404040403e, 0x8080808080807e, 9 | 0x1010101017e00, 0x2020202027c00, 0x4040404047a00, 0x8080808087600, 10 | 0x10101010106e00, 0x20202020205e00, 0x40404040403e00, 0x80808080807e00, 11 | 0x10101017e0100, 0x20202027c0200, 0x40404047a0400, 0x8080808760800, 12 | 0x101010106e1000, 0x202020205e2000, 0x404040403e4000, 0x808080807e8000, 13 | 0x101017e010100, 0x202027c020200, 0x404047a040400, 0x8080876080800, 14 | 0x1010106e101000, 0x2020205e202000, 0x4040403e404000, 0x8080807e808000, 15 | 0x1017e01010100, 0x2027c02020200, 0x4047a04040400, 0x8087608080800, 16 | 0x10106e10101000, 0x20205e20202000, 0x40403e40404000, 0x80807e80808000, 17 | 0x17e0101010100, 0x27c0202020200, 0x47a0404040400, 0x8760808080800, 18 | 0x106e1010101000, 0x205e2020202000, 0x403e4040404000, 0x807e8080808000, 19 | 0x7e010101010100, 0x7c020202020200, 0x7a040404040400, 0x76080808080800, 20 | 0x6e101010101000, 0x5e202020202000, 0x3e404040404000, 0x7e808080808000, 21 | 0x7e01010101010100, 0x7c02020202020200, 0x7a04040404040400, 22 | 0x7608080808080800, 0x6e10101010101000, 0x5e20202020202000, 23 | 0x3e40404040404000, 0x7e80808080808000 24 | ] 25 | 26 | # The set of squares for possible bishop attack blocker pieces 27 | bishop_masks = [ 28 | 0x40201008040200, 0x402010080400, 0x4020100a00, 0x40221400, 0x2442800, 29 | 0x204085000, 0x20408102000, 0x2040810204000, 0x20100804020000, 30 | 0x40201008040000, 0x4020100a0000, 0x4022140000, 0x244280000, 0x20408500000, 31 | 0x2040810200000, 0x4081020400000, 0x10080402000200, 0x20100804000400, 32 | 0x4020100a000a00, 0x402214001400, 0x24428002800, 0x2040850005000, 33 | 0x4081020002000, 0x8102040004000, 0x8040200020400, 0x10080400040800, 34 | 0x20100a000a1000, 0x40221400142200, 0x2442800284400, 0x4085000500800, 35 | 0x8102000201000, 0x10204000402000, 0x4020002040800, 0x8040004081000, 36 | 0x100a000a102000, 0x22140014224000, 0x44280028440200, 0x8500050080400, 37 | 0x10200020100800, 0x20400040201000, 0x2000204081000, 0x4000408102000, 38 | 0xa000a10204000, 0x14001422400000, 0x28002844020000, 0x50005008040200, 39 | 0x20002010080400, 0x40004020100800, 0x20408102000, 0x40810204000, 40 | 0xa1020400000, 0x142240000000, 0x284402000000, 0x500804020000, 41 | 0x201008040200, 0x402010080400, 0x2040810204000, 0x4081020400000, 42 | 0xa102040000000, 0x14224000000000, 0x28440200000000, 0x50080402000000, 43 | 0x20100804020000, 0x40201008040200 44 | ] 45 | 46 | # Initialise table for rook attacks, indexed by square 47 | ratk_table = [{} for _ in range(64)] 48 | for sq in range(64): 49 | occ = 0 50 | # Produce attacks with occupancies of all subsets of rook mask 51 | while True: 52 | ratk_table[sq][occ] = ff.rook_attacks(sq, occ) 53 | occ = (occ - rook_masks[sq]) & rook_masks[sq] # Carry-Rippler 54 | if not occ: 55 | break 56 | 57 | # Initialise table for bishop attacks, indexed by square 58 | batk_table = [{} for _ in range(64)] 59 | for sq in range(64): 60 | occ = 0 61 | # Produce attacks with occupancies of all subsets of bishop mask 62 | while True: 63 | batk_table[sq][occ] = ff.bishop_attacks(sq, occ) 64 | occ = (occ - bishop_masks[sq]) & bishop_masks[sq] # Carry-Rippler 65 | if not occ: 66 | break 67 | 68 | # Initialise table for non-pawn attacks, indexed by piece type and square 69 | pseudo_attacks = [[0 for _ in range(64)] for _ in range(7)] 70 | for sq in range(64): 71 | bb = 1 << sq 72 | clip_a_file = bb & 0xFEFEFEFEFEFEFEFE 73 | clip_h_file = bb & 0x7F7F7F7F7F7F7F7F 74 | clip_ab_files = bb & 0xFCFCFCFCFCFCFCFC 75 | clip_gh_files = bb & 0X3F3F3F3F3F3F3F3F 76 | 77 | pseudo_attacks[KNIGHT][sq] |= clip_ab_files << 6 78 | pseudo_attacks[KNIGHT][sq] |= clip_gh_files << 10 79 | pseudo_attacks[KNIGHT][sq] |= clip_a_file << 15 80 | pseudo_attacks[KNIGHT][sq] |= clip_h_file << 17 81 | pseudo_attacks[KNIGHT][sq] |= clip_gh_files >> 6 82 | pseudo_attacks[KNIGHT][sq] |= clip_ab_files >> 10 83 | pseudo_attacks[KNIGHT][sq] |= clip_h_file >> 15 84 | pseudo_attacks[KNIGHT][sq] |= clip_a_file >> 17 85 | pseudo_attacks[KNIGHT][sq] &= _64BITS 86 | 87 | pseudo_attacks[BISHOP][sq] |= batk_table[sq][0] 88 | pseudo_attacks[ROOK][sq] |= ratk_table[sq][0] 89 | pseudo_attacks[QUEEN][sq] |= batk_table[sq][0] | ratk_table[sq][0] 90 | 91 | pseudo_attacks[KING][sq] |= clip_h_file << 1 92 | pseudo_attacks[KING][sq] |= clip_a_file << 7 93 | pseudo_attacks[KING][sq] |= bb << 8 94 | pseudo_attacks[KING][sq] |= clip_h_file << 9 95 | pseudo_attacks[KING][sq] |= clip_a_file >> 1 96 | pseudo_attacks[KING][sq] |= clip_h_file >> 7 97 | pseudo_attacks[KING][sq] |= bb >> 8 98 | pseudo_attacks[KING][sq] |= clip_a_file >> 9 99 | pseudo_attacks[KING][sq] &= _64BITS 100 | 101 | # Initialise table for pawn attacks, indexed by colour and square 102 | pawn_attacks = [[0 for _ in range(64)] for _ in range(2)] 103 | for sq in range(64): 104 | bb = 1 << sq 105 | clip_a_file = bb & 0xFEFEFEFEFEFEFEFE 106 | clip_h_file = bb & 0x7F7F7F7F7F7F7F7F 107 | 108 | pawn_attacks[WHITE][sq] |= clip_a_file << 7 109 | pawn_attacks[WHITE][sq] |= clip_h_file << 9 110 | 111 | pawn_attacks[BLACK][sq] |= clip_h_file >> 7 112 | pawn_attacks[BLACK][sq] |= clip_a_file >> 9 113 | -------------------------------------------------------------------------------- /playground/experiment/recipe.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import json 3 | import os 4 | import os.path as osp 5 | 6 | import torch 7 | from pjtools.configurator import AutoConfigurator 8 | 9 | from playground.evaluator import Evaluator 10 | from playground.registry import AGENT_REGISTRY 11 | from playground.state_code import GameStatusEncoder 12 | 13 | 14 | class Recipe: 15 | 16 | def __init__(self, args): 17 | self.base_cfg = AutoConfigurator.fromfile('configs/base.py') 18 | self.recipe = AutoConfigurator.fromfile(args.exp_recipe) 19 | self.agent_cfg = AutoConfigurator.fromfile(args.agent_cfg) 20 | self.agent_cfg_path = args.agent_cfg 21 | self.benchmark_setting = self.base_cfg.benchmark_setting 22 | 23 | self.save_path = osp.join(self.recipe.save_path, self.recipe.name) 24 | os.makedirs(self.save_path, exist_ok=True) 25 | self.log_file = osp.join(self.save_path, 'evaluation.log') 26 | 27 | self.agent = AGENT_REGISTRY.get(self.agent_cfg.lmm_agent.agent)( 28 | self.agent_cfg) 29 | 30 | self.init_experiment_record() 31 | 32 | def init_experiment_record(self): 33 | self.record_path = osp.join( 34 | self.save_path, 35 | self.agent_cfg.lmm_agent.name + self.recipe.name + '.json') 36 | if osp.exists(self.record_path): 37 | with open(self.record_path, 'r') as f: 38 | self.record = json.load(f) 39 | else: 40 | self.record = {} 41 | 42 | self.update_record_with_new_tasks_and_games() 43 | self.save_record() 44 | 45 | def update_record_with_new_tasks_and_games(self): 46 | tasks = self.recipe.tasks 47 | games = self.recipe.games 48 | sample_size = self.benchmark_setting.sample_size 49 | e2e_round = self.benchmark_setting.e2e_round 50 | 51 | for task in tasks: 52 | if task not in self.record: 53 | self.record[task] = {} 54 | repetition_round = sample_size if task != 'e2e' else e2e_round 55 | for game in games: 56 | if game not in self.record[task]: 57 | self.record[task][game] = [None] * repetition_round 58 | 59 | def save_record(self): 60 | with open(self.record_path, 'w') as f: 61 | json.dump(self.record, f, indent=4, cls=GameStatusEncoder) 62 | 63 | def run_experiments(self): 64 | tasks = self.recipe.tasks 65 | games = self.recipe.games 66 | 67 | for task in tasks: 68 | for game in games: 69 | if task not in self.record: 70 | self.record[task] = {} 71 | if game not in self.record[task]: 72 | self.record[task][game] = [None 73 | ] * self.recipe.repetition_round 74 | self.save_record() 75 | 76 | if task != 'e2e': 77 | with open(osp.join(self.benchmark_setting.benchmark_path, 78 | task, game, 'annotation.json'), 79 | 'r', 80 | encoding='utf-8') as json_file: 81 | annotation = json.load(json_file) 82 | assert annotation['game'] == game 83 | assert annotation['task'] == task 84 | assert len(annotation['annotations'] 85 | ) == self.benchmark_setting.sample_size 86 | 87 | completed_rounds = self.record[task][game] 88 | 89 | game_cfg = AutoConfigurator.fromfile( 90 | f'configs/games/{game}.py') 91 | 92 | evaluator = Evaluator(game_cfg, self.agent, task, 93 | self.log_file, self.save_path) 94 | 95 | while None in completed_rounds: 96 | next_round = completed_rounds.index(None) 97 | 98 | print(f'Running experiment for task: {task}, ' 99 | f'game: {game}, round: {next_round + 1}') 100 | 101 | try: 102 | if task != 'e2e': 103 | batch = { 104 | 'task': 105 | task, 106 | 'screenshot_path': 107 | osp.join(self.benchmark_setting.benchmark_path, 108 | task, game, f'{next_round:07d}.jpg'), 109 | 'gt': 110 | annotation['annotations'][next_round]['gt'], 111 | 'game_cfg': 112 | game_cfg 113 | } 114 | else: 115 | batch = {'task': task, 'game_cfg': game_cfg} 116 | result, simulator = evaluator.run(batch) 117 | simulator.cleanup() 118 | self.record[task][game][next_round] = result 119 | self.save_record() 120 | completed_rounds = self.record[task][game] 121 | except Exception as e: 122 | print( 123 | f'Error occurred during task {task}, game {game}, ' 124 | f'round {next_round + 1}: {e}') 125 | continue 126 | 127 | print(f'Task: {task}, game: {game} has been completed.') 128 | 129 | torch.cuda.synchronize() 130 | evaluator.cleanup() 131 | del evaluator 132 | torch.cuda.empty_cache() 133 | gc.collect() 134 | 135 | def cleanup(self): 136 | """Clean up resources at the end of the experiment.""" 137 | if hasattr(self.agent, 'model'): 138 | del self.agent.model 139 | del self.agent 140 | torch.cuda.empty_cache() 141 | gc.collect() 142 | -------------------------------------------------------------------------------- /playground/games/chess/chess_ui.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from PyQt5.QtCore import (QEventLoop, QPropertyAnimation, Qt, QThread, 4 | pyqtSignal) 5 | from PyQt5.QtGui import QFont, QPixmap, QResizeEvent 6 | from PyQt5.QtWidgets import QGridLayout, QLabel, QSizePolicy, QWidget 7 | 8 | import playground.games.chess.common.common as common 9 | from playground.games.chess.common.consts import PIECE_CONVERSION, PIECE_MAP 10 | from playground.games.chess.position import Position 11 | 12 | SQR_SIZE = 100 13 | 14 | 15 | class ChessUI(QWidget): 16 | 17 | def __init__(self, parent=None, user_is_white=True): 18 | super().__init__(parent) 19 | self.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) 20 | self.setContentsMargins(0, 0, 0, 0) 21 | self.setFixedSize(800, 800) 22 | self.layout = QGridLayout() 23 | self.layout.setContentsMargins(0, 0, 0, 0) 24 | self.layout.setSpacing(0) 25 | self.setLayout(self.layout) 26 | self.sqr_size = SQR_SIZE 27 | self.user_is_white = user_is_white 28 | self.position = Position(common.starting_fen) 29 | self.search_thread = SearchThread(self) 30 | self.pieces = {} 31 | self.selected_piece = None 32 | self.selected_square = None 33 | self.draw_board_with_labels() 34 | self.reset_board() 35 | 36 | def resizeEvent(self, event: QResizeEvent): 37 | side = min(self.width(), self.height()) 38 | self.setFixedSize(side, side) 39 | super().resizeEvent(event) 40 | 41 | def draw_board_with_labels(self): 42 | font = QFont() 43 | font.setPointSize(14) 44 | font.setBold(True) 45 | 46 | for row in range(8): # 0..7 47 | for col in range(8): # 0..7 48 | square = QWidget(self) 49 | square.setSizePolicy(QSizePolicy.Expanding, 50 | QSizePolicy.Expanding) 51 | if (row + col) % 2 == 0: 52 | square.setStyleSheet('background-color: #F0D9B5;') # 浅 53 | else: 54 | square.setStyleSheet('background-color: #B58863;') # 深 55 | 56 | self.layout.addWidget(square, row + 1, col + 1) 57 | 58 | for i in range(10): 59 | self.layout.setRowStretch(i, 1) 60 | self.layout.setColumnStretch(i, 1) 61 | 62 | file_labels = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'] 63 | for col in range(8): 64 | lbl = QLabel(file_labels[col], self) 65 | lbl.setAlignment(Qt.AlignCenter) 66 | lbl.setFont(font) 67 | self.layout.addWidget(lbl, 0, col + 1) # top row=0, col+1 68 | 69 | for col in range(8): 70 | lbl = QLabel(file_labels[col], self) 71 | lbl.setAlignment(Qt.AlignCenter) 72 | lbl.setFont(font) 73 | self.layout.addWidget(lbl, 9, col + 1) 74 | 75 | rank_labels = ['8', '7', '6', '5', '4', '3', '2', '1'] 76 | for row in range(8): 77 | lbl = QLabel(rank_labels[row], self) 78 | lbl.setAlignment(Qt.AlignCenter) 79 | lbl.setFont(font) 80 | self.layout.addWidget(lbl, row + 1, 0) 81 | 82 | for row in range(8): 83 | lbl = QLabel(rank_labels[row], self) 84 | lbl.setAlignment(Qt.AlignCenter) 85 | lbl.setFont(font) 86 | self.layout.addWidget(lbl, row + 1, 9) 87 | 88 | def place_piece(self, sqr_name, piece): 89 | if isinstance(piece, str): 90 | piece_symbol = piece 91 | else: 92 | piece_symbol = PIECE_CONVERSION.get(piece) 93 | 94 | if not piece_symbol: 95 | print(f'Unknown piece: {piece}') 96 | return 97 | 98 | piece_image = PIECE_MAP.get(piece_symbol) 99 | if not piece_image: 100 | print(f'Unknown piece: {piece_symbol}') 101 | return 102 | 103 | piece_label = PieceLabel(self, piece_image) 104 | pixmap_path = f'./playground/games/chess/assets/pieces/{piece_image}.png' # noqa 105 | if not os.path.exists(pixmap_path): 106 | print(f'Image not found: {pixmap_path}') 107 | else: 108 | piece_label.setPixmap(QPixmap(pixmap_path)) 109 | 110 | col, row = common.square_to_coords[sqr_name] 111 | self.layout.addWidget(piece_label, row + 1, col + 1) 112 | self.pieces[sqr_name] = piece_label 113 | 114 | def move_piece(self, src_sqr, dst_sqr): 115 | piece = self.pieces.get(src_sqr) 116 | if piece: 117 | dst_col, dst_row = common.square_to_coords[dst_sqr] 118 | dst_square = self.layout.itemAtPosition(dst_row + 1, 119 | dst_col + 1).widget() 120 | 121 | animation = QPropertyAnimation(piece, b'pos') 122 | animation.setEndValue(dst_square.pos()) 123 | animation.setDuration(500) 124 | animation.start() 125 | 126 | loop = QEventLoop() 127 | animation.finished.connect(loop.quit) 128 | loop.exec_() 129 | 130 | self.pieces[dst_sqr] = self.pieces.pop(src_sqr) 131 | piece.setParent(None) 132 | self.layout.removeWidget(piece) 133 | 134 | def reset_board(self): 135 | initial_positions = { 136 | 'a8': 'r', 137 | 'b8': 'n', 138 | 'c8': 'b', 139 | 'd8': 'q', 140 | 'e8': 'k', 141 | 'f8': 'b', 142 | 'g8': 'n', 143 | 'h8': 'r', 144 | 'a7': 'p', 145 | 'b7': 'p', 146 | 'c7': 'p', 147 | 'd7': 'p', 148 | 'e7': 'p', 149 | 'f7': 'p', 150 | 'g7': 'p', 151 | 'h7': 'p', 152 | 'a1': 'R', 153 | 'b1': 'N', 154 | 'c1': 'B', 155 | 'd1': 'Q', 156 | 'e1': 'K', 157 | 'f1': 'B', 158 | 'g1': 'N', 159 | 'h1': 'R', 160 | 'a2': 'P', 161 | 'b2': 'P', 162 | 'c2': 'P', 163 | 'd2': 'P', 164 | 'e2': 'P', 165 | 'f2': 'P', 166 | 'g2': 'P', 167 | 'h2': 'P' 168 | } 169 | 170 | self.clear() 171 | for sqr, piece in initial_positions.items(): 172 | self.place_piece(sqr, piece) 173 | 174 | def refresh_from_state(self): 175 | self.clear() 176 | for sqr_index in range(64): 177 | piece = self.position.piece_at(sqr_index) 178 | sqr_name = common.squares_san[sqr_index] 179 | if piece: 180 | self.place_piece(sqr_name, piece.symbol()) 181 | 182 | def clear(self): 183 | all_pieces = self.findChildren(QLabel) 184 | for piece in all_pieces: 185 | if isinstance(piece, PieceLabel): 186 | piece.setParent(None) 187 | 188 | 189 | class PieceLabel(QLabel): 190 | 191 | def __init__(self, parent, piece): 192 | super().__init__(parent) 193 | self.piece = piece 194 | self.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) 195 | self.setMinimumSize(1, 1) 196 | self.setScaledContents(True) 197 | self.setMouseTracking(True) 198 | self.show() 199 | 200 | 201 | class SearchThread(QThread): 202 | move_signal = pyqtSignal(int) 203 | 204 | def __init__(self, board): 205 | super().__init__() 206 | self.board = board 207 | 208 | def run(self): 209 | move = self.board.search.iter_search(time_limit=1) 210 | self.board.parent().computer_move(move) 211 | -------------------------------------------------------------------------------- /playground/games/gomoku/AI.py: -------------------------------------------------------------------------------- 1 | class AI: 2 | 3 | def __init__(self, chessboard): 4 | self.chessboard = chessboard 5 | self.size = len(chessboard) 6 | self.count = 0 7 | 8 | def ai(self, color, deep, pre_evaluate): 9 | if deep >= 2: 10 | temp = self.evaluateBoard(2, self.chessboard) - self.evaluateBoard( 11 | 1, self.chessboard) 12 | return temp 13 | if color == 2: 14 | values = -100000000 15 | else: 16 | values = 100000000 17 | for i in range(15): 18 | for j in range(15): 19 | if self.chessboard[i][j][2] == 0: 20 | if self.judge_empty(i, j): 21 | continue 22 | self.chessboard[i][j][2] = color 23 | evaluate = self.ai(3 - color, deep + 1, values) 24 | if color == 2: 25 | if evaluate > pre_evaluate: 26 | self.chessboard[i][j][2] = 0 27 | self.count += 1 28 | return 100000000 29 | else: 30 | if evaluate < pre_evaluate: 31 | self.chessboard[i][j][2] = 0 32 | self.count += 1 33 | return -100000000 34 | if color == 2: 35 | if evaluate >= values: 36 | values = evaluate 37 | else: 38 | if evaluate <= values: 39 | values = evaluate 40 | self.chessboard[i][j][2] = 0 41 | return values 42 | 43 | def judge_empty(self, m, n): 44 | directions = [(-1, 0), (1, 0), (-1, 1), (1, -1), (0, 1), (0, -1), 45 | (1, 1), (-1, -1)] 46 | j = 0 47 | count = 1 48 | while j < len(directions): 49 | a = 0 50 | while a <= 1: 51 | x, y = m, n 52 | a += 1 53 | for i in range(2): 54 | if x + directions[j][0] < 0 or x + directions[j][ 55 | 0] > self.size - 1 or y + directions[j][ 56 | 1] < 0 or y + directions[j][1] > self.size - 1: 57 | break 58 | x += directions[j][0] 59 | y += directions[j][1] 60 | if self.chessboard[x][y][2] == 0: 61 | count += 1 62 | else: 63 | break 64 | j += 1 65 | if count == 17: 66 | return 1 67 | return 0 68 | 69 | def judge(self, m, n): 70 | directions = [(-1, 0), (1, 0), (-1, 1), (1, -1), (0, 1), (0, -1), 71 | (1, 1), (-1, -1)] 72 | j = 0 73 | while j < len(directions): 74 | count = 1 75 | a = 0 76 | while a <= 1: 77 | x, y = m, n 78 | a += 1 79 | for i in range(4): 80 | if x + directions[j][0] < 0 or x + directions[j][ 81 | 0] > self.size - 1 or y + directions[j][ 82 | 1] < 0 or y + directions[j][1] > self.size - 1: 83 | break 84 | x += directions[j][0] 85 | y += directions[j][1] 86 | if self.chessboard[x][y][2] == self.chessboard[m][n][2]: 87 | count += 1 88 | else: 89 | break 90 | j += 1 91 | if count >= 5: 92 | if self.chessboard[m][n][2] == 2: 93 | return 1 94 | return 0 95 | 96 | def evaluateBoard(self, color, chessboard): 97 | values = 0 98 | directions = [(-1, 0), (1, 0), (-1, 1), (1, -1), (0, 1), (0, -1), 99 | (1, 1), (-1, -1)] 100 | directions_2 = [(1, 0), (1, -1), (0, 1), (1, 1)] 101 | for row in range(self.size): 102 | for col in range(self.size): 103 | if chessboard[row][col][2] != color: 104 | continue 105 | j = 0 106 | while j < len(directions): 107 | count = 1 108 | a = 0 109 | record = [] 110 | while a <= 1: 111 | x, y = row, col 112 | a += 1 113 | for i in range(4): 114 | if x + directions[j][0] < 0 or x + directions[j][ 115 | 0] > self.size - 1 or y + directions[j][ 116 | 1] < 0 or y + directions[j][ 117 | 1] > self.size - 1: 118 | record.append(3 - color) 119 | break 120 | x += directions[j][0] 121 | y += directions[j][1] 122 | if chessboard[x][y][2] == chessboard[row][col][2]: 123 | count += 1 124 | else: 125 | record.append(chessboard[x][y][2]) 126 | break 127 | j += 1 128 | if count >= 5: 129 | values += 200000 130 | elif count == 4: 131 | if record[0] == record[1] == 0: 132 | values += 70000 133 | elif (record[0] == 0 and record[1] == (3 - color)) or ( 134 | record[0] == (3 - color) and record[1] == 0): 135 | values += 1000 136 | elif count == 3: 137 | if record[0] == record[1] == 0: 138 | values += 1000 139 | elif (record[0] == 0 and record[1] == (3 - color)) or ( 140 | record[0] == (3 - color) and record[1] == 0): 141 | values += 150 142 | elif count == 2: 143 | if record[0] == record[1] == 0: 144 | values += 1000 145 | elif (record[0] == 0 and record[1] == (3 - color)) or ( 146 | record[0] == (3 - color) and record[1] == 0): 147 | values += 150 148 | k = 0 149 | while k < len(directions_2): 150 | x, y = row, col 151 | record = [] 152 | record.append(chessboard[x][y][2]) 153 | for i in range(4): 154 | if i == 1 and len(record) == 2: 155 | if record[0] != record[1] and record[ 156 | 0] and record[1]: 157 | values += 10 158 | if x + directions_2[k][0] < 0 or x + directions_2[ 159 | k][0] > self.size - 1 or y + directions_2[ 160 | k][1] < 0 or y + directions_2[k][ 161 | 1] > self.size - 1: 162 | break 163 | x += directions_2[k][0] 164 | y += directions_2[k][1] 165 | record.append(chessboard[x][y][2]) 166 | if len(record) == 5: 167 | count = record.count(0) 168 | if (count == 1 and record[1] == 0 169 | and record.count(color) 170 | == 4) or (count == 1 and record[3] == 0 171 | and record.count(color) == 4): 172 | values += 3000 173 | if count == 1 and record[2] == 0 and record.count( 174 | color) == 4: 175 | values += 2600 176 | k += 1 177 | return values 178 | -------------------------------------------------------------------------------- /playground/games/chess/common/movegen.py: -------------------------------------------------------------------------------- 1 | from gmpy2 import bit_scan1 2 | 3 | from playground.games.chess.common.attack_tables import (batk_table, bishop_masks, 4 | pseudo_attacks, 5 | ratk_table, rook_masks) 6 | from playground.games.chess.common.common import pawn_shift 7 | from playground.games.chess.common.consts import ( 8 | A_FILE_BB, ALL, BISHOP_PROMOTION, CAPTURES, H_FILE_BB, KING, KNIGHT, 9 | KNIGHT_PROMOTION, NORTH, NORTHEAST, NORTHWEST, PROMOTION, QUEEN_PROMOTION, 10 | QUIETS, RANK_2_BB, RANK_3_BB, RANK_6_BB, RANK_7_BB, ROOK_PROMOTION, WHITE) 11 | 12 | 13 | def get_king_moves(sq, colour, move_type, occ, player_occ, move_list): 14 | moves = pseudo_attacks[KING][sq] 15 | 16 | if move_type == QUIETS: 17 | moves &= ~occ 18 | elif move_type == CAPTURES: 19 | moves &= player_occ[colour ^ 1] 20 | elif move_type == ALL: 21 | moves &= ~player_occ[colour] 22 | 23 | while moves: 24 | index = bit_scan1(moves) 25 | moves &= moves - 1 26 | move_list.append((sq << 6) + index) 27 | 28 | 29 | def get_knight_moves(sq, colour, move_type, occ, player_occ, move_list): 30 | moves = pseudo_attacks[KNIGHT][sq] 31 | 32 | if move_type == QUIETS: 33 | moves &= ~occ 34 | elif move_type == CAPTURES: 35 | moves &= player_occ[colour ^ 1] 36 | elif move_type == ALL: 37 | moves &= ~player_occ[colour] 38 | 39 | while moves: 40 | index = bit_scan1(moves) 41 | moves &= moves - 1 42 | move_list.append((sq << 6) + index) 43 | 44 | 45 | def get_pawn_moves(bitboard, colour, move_type, occ, player_occ, move_list): 46 | empty = ~occ 47 | enemy_occ = player_occ[colour ^ 1] 48 | 49 | if colour == WHITE: 50 | third_rank = RANK_3_BB 51 | seventh_rank = RANK_7_BB 52 | one_step_shift = NORTH 53 | two_step_shift = one_step_shift + NORTH 54 | left_atk_shift = 7 55 | right_atk_shift = 9 56 | left_file = A_FILE_BB 57 | right_file = H_FILE_BB 58 | else: 59 | third_rank = RANK_6_BB 60 | seventh_rank = RANK_2_BB 61 | one_step_shift = -NORTH 62 | two_step_shift = one_step_shift - NORTH 63 | left_atk_shift = -7 64 | right_atk_shift = -9 65 | left_file = H_FILE_BB 66 | right_file = A_FILE_BB 67 | 68 | pawns_not_promoting = bitboard & ~seventh_rank 69 | promoting_pawns = bitboard & seventh_rank 70 | 71 | # Captures and queen promotions 72 | if move_type == CAPTURES or move_type == ALL: 73 | left_atk = pawn_shift[colour]( 74 | (pawns_not_promoting & ~left_file), NORTHWEST) & enemy_occ 75 | right_atk = pawn_shift[colour]( 76 | (pawns_not_promoting & ~right_file), NORTHEAST) & enemy_occ 77 | 78 | promoted_push = pawn_shift[colour](promoting_pawns, NORTH) & empty 79 | promoted_left_atk = pawn_shift[colour]( 80 | (promoting_pawns & ~left_file), NORTHWEST) & enemy_occ 81 | promoted_right_atk = pawn_shift[colour]( 82 | (promoting_pawns & ~right_file), NORTHEAST) & enemy_occ 83 | 84 | while left_atk: 85 | index = bit_scan1(left_atk) 86 | left_atk &= left_atk - 1 87 | move_list.append(((index - left_atk_shift) << 6) + index) 88 | 89 | while right_atk: 90 | index = bit_scan1(right_atk) 91 | right_atk &= right_atk - 1 92 | move_list.append(((index - right_atk_shift) << 6) + index) 93 | 94 | while promoted_push: 95 | index = bit_scan1(promoted_push) 96 | promoted_push &= promoted_push - 1 97 | from_to = ((index - one_step_shift) << 6) + index 98 | move_list.append(PROMOTION + QUEEN_PROMOTION + from_to) 99 | 100 | while promoted_left_atk: 101 | index = bit_scan1(promoted_left_atk) 102 | promoted_left_atk &= promoted_left_atk - 1 103 | from_to = ((index - left_atk_shift) << 6) + index 104 | move_list.append(PROMOTION + QUEEN_PROMOTION + from_to) 105 | 106 | while promoted_right_atk: 107 | index = bit_scan1(promoted_right_atk) 108 | promoted_right_atk &= promoted_right_atk - 1 109 | from_to = ((index - right_atk_shift) << 6) + index 110 | move_list.append(PROMOTION + QUEEN_PROMOTION + from_to) 111 | 112 | # Non-captures and underpromotions 113 | if move_type == QUIETS or move_type == ALL: 114 | one_step = pawn_shift[colour](pawns_not_promoting, NORTH) & empty 115 | two_steps = pawn_shift[colour]((one_step & third_rank), NORTH) & empty 116 | 117 | promoted_push = pawn_shift[colour](promoting_pawns, NORTH) & empty 118 | promoted_left_atk = pawn_shift[colour]( 119 | (promoting_pawns & ~left_file), NORTHWEST) & enemy_occ 120 | promoted_right_atk = pawn_shift[colour]( 121 | (promoting_pawns & ~right_file), NORTHEAST) & enemy_occ 122 | 123 | while one_step: 124 | index = bit_scan1(one_step) 125 | one_step &= one_step - 1 126 | move_list.append(((index - one_step_shift) << 6) + index) 127 | 128 | while two_steps: 129 | index = bit_scan1(two_steps) 130 | two_steps &= two_steps - 1 131 | move_list.append(((index - two_step_shift) << 6) + index) 132 | 133 | while promoted_push: 134 | index = bit_scan1(promoted_push) 135 | promoted_push &= promoted_push - 1 136 | generate_underpromotions(index - one_step_shift, index, move_list) 137 | 138 | while promoted_left_atk: 139 | index = bit_scan1(promoted_left_atk) 140 | promoted_left_atk &= promoted_left_atk - 1 141 | generate_underpromotions(index - left_atk_shift, index, move_list) 142 | 143 | while promoted_right_atk: 144 | index = bit_scan1(promoted_right_atk) 145 | promoted_right_atk &= promoted_right_atk - 1 146 | generate_underpromotions(index - right_atk_shift, index, move_list) 147 | 148 | 149 | def get_rook_moves(sq, colour, move_type, occ, player_occ, move_list): 150 | moves = ratk_table[sq][occ & rook_masks[sq]] 151 | 152 | if move_type == QUIETS: 153 | moves &= ~occ 154 | elif move_type == CAPTURES: 155 | moves &= player_occ[colour ^ 1] 156 | elif move_type == ALL: 157 | moves &= ~player_occ[colour] 158 | 159 | while moves: 160 | index = bit_scan1(moves) 161 | moves &= moves - 1 162 | move_list.append((sq << 6) + index) 163 | 164 | 165 | def get_bishop_moves(sq, colour, move_type, occ, player_occ, move_list): 166 | moves = batk_table[sq][occ & bishop_masks[sq]] 167 | 168 | if move_type == QUIETS: 169 | moves &= ~occ 170 | elif move_type == CAPTURES: 171 | moves &= player_occ[colour ^ 1] 172 | elif move_type == ALL: 173 | moves &= ~player_occ[colour] 174 | 175 | while moves: 176 | index = bit_scan1(moves) 177 | moves &= moves - 1 178 | move_list.append((sq << 6) + index) 179 | 180 | 181 | def get_queen_moves(sq, colour, move_type, occ, player_occ, move_list): 182 | moves = batk_table[sq][occ & bishop_masks[sq]] | ratk_table[sq][ 183 | occ & rook_masks[sq]] 184 | 185 | if move_type == QUIETS: 186 | moves &= ~occ 187 | elif move_type == CAPTURES: 188 | moves &= player_occ[colour ^ 1] 189 | elif move_type == ALL: 190 | moves &= ~player_occ[colour] 191 | 192 | while moves: 193 | index = bit_scan1(moves) 194 | moves &= moves - 1 195 | move_list.append((sq << 6) + index) 196 | 197 | 198 | def generate_promotions(src_sq, dst_sq, move_list): 199 | from_to = (src_sq << 6) + dst_sq 200 | move_list.append(PROMOTION + KNIGHT_PROMOTION + from_to) 201 | move_list.append(PROMOTION + BISHOP_PROMOTION + from_to) 202 | move_list.append(PROMOTION + ROOK_PROMOTION + from_to) 203 | move_list.append(PROMOTION + QUEEN_PROMOTION + from_to) 204 | 205 | 206 | def generate_underpromotions(src_sq, dst_sq, move_list): 207 | from_to = (src_sq << 6) + dst_sq 208 | move_list.append(PROMOTION + KNIGHT_PROMOTION + from_to) 209 | move_list.append(PROMOTION + BISHOP_PROMOTION + from_to) 210 | move_list.append(PROMOTION + ROOK_PROMOTION + from_to) 211 | -------------------------------------------------------------------------------- /playground/games/tictactoe/tictactoe.py: -------------------------------------------------------------------------------- 1 | import random 2 | import re 3 | from random import sample 4 | 5 | from PyQt5.QtGui import QFont, QPainter, QPixmap 6 | from PyQt5.QtWidgets import QMainWindow 7 | 8 | from playground.games import BaseGame, BaseGameLogic 9 | from playground.games.tictactoe.AI import Minimax 10 | from playground.games.tictactoe.tictactoe_ui import Ui_MainWindow 11 | from playground.registry import GAME_REGISTRY 12 | from playground.state_code import GameStatus 13 | 14 | 15 | class TicTacToeLogic(BaseGameLogic): 16 | """Logic for Tic Tac Toe game.""" 17 | 18 | def __init__(self, game_cfg): 19 | self.game_cfg = game_cfg 20 | self.board = [i + 1 for i in range(9)] 21 | self.bot = None 22 | self.opponent = None 23 | self.winner = None 24 | self.is_finish = False 25 | self.status = GameStatus.IN_PROGRESS 26 | self.moves_history = [] 27 | self._initialize_players() 28 | 29 | def _initialize_players(self): 30 | players = sample(['X', 'O'], 2) 31 | self.bot = players[0] 32 | self.opponent = players[1] if self.game_cfg.player_first else players[ 33 | 0] # noqa 34 | self.bot = players[0] if self.game_cfg.player_first else players[1] 35 | 36 | def make_move(self, index, player): 37 | if self.board[index] not in ['X', 'O'] and not self.is_finish: 38 | self.board[index] = player 39 | self._check_winner() 40 | return True 41 | return False 42 | 43 | def _check_winner(self): 44 | win_positions = [(0, 1, 2), (3, 4, 5), (6, 7, 8), (0, 3, 6), (1, 4, 7), 45 | (2, 5, 8), (0, 4, 8), (2, 4, 6)] 46 | for pos in win_positions: 47 | if self.board[pos[0]] == self.board[pos[1]] == self.board[pos[2]]: 48 | self.winner = self.board[pos[0]] 49 | self.is_finish = True 50 | self.status = GameStatus.WIN if self.winner == self.opponent else GameStatus.LOSE # noqa 51 | return 52 | if all(cell in ['X', 'O'] for cell in self.board) and not self.winner: 53 | self.is_finish = True 54 | self.status = GameStatus.TIE 55 | 56 | def input_move(self, move): 57 | if self.status != GameStatus.IN_PROGRESS: 58 | return self.status 59 | col_map = {'1': 0, '2': 1, '3': 2} 60 | row_map = {'A': 0, 'B': 1, 'C': 2} 61 | match = re.match(r'([A-Ca-c])([1-3])|([1-3])([A-Ca-c])', move) 62 | if match: 63 | row = match.group(1).upper() if match.group(1) else match.group( 64 | 4).upper() 65 | col = match.group(2) if match.group(2) else match.group(3) 66 | index = row_map[row] * 3 + col_map[col] 67 | if self.make_move(index, self.opponent): 68 | self.moves_history.append(move) 69 | return self.status 70 | self.status = GameStatus.INVALID_MOVE 71 | return self.status 72 | 73 | def get_game_status(self): 74 | return self.status 75 | 76 | def reset_board(self): 77 | self.board = [i + 1 for i in range(9)] 78 | self.winner = None 79 | self.is_finish = False 80 | self.status = GameStatus.IN_PROGRESS 81 | self.moves_history = [] 82 | 83 | def get_random_state(self): 84 | self.reset_board() 85 | positions = [1, 0, -1] 86 | random_state = sample(positions * 3, 9) 87 | if not any(value == -1 for value in random_state): 88 | rand_index = random.randint(0, 8) 89 | random_state[rand_index] = -1 90 | for i, value in enumerate(random_state): 91 | if value == 1: 92 | self.board[i] = 'X' 93 | elif value == 0: 94 | self.board[i] = 'O' 95 | return [random_state[i:i + 3] for i in range(0, 9, 3)] 96 | 97 | def get_rule_state(self): 98 | self.reset_board() 99 | while True: 100 | positions = [-1] * 9 101 | x_count = random.randint(1, 5) 102 | o_count = x_count if random.choice([True, False]) else x_count - 1 103 | if o_count < 0: 104 | o_count = 0 105 | if x_count + o_count >= 9: 106 | continue 107 | positions[:x_count] = [1] * x_count 108 | positions[x_count:x_count + o_count] = [0] * o_count 109 | random.shuffle(positions) 110 | for i, val in enumerate(positions): 111 | if val == 1: 112 | self.board[i] = 'X' 113 | elif val == 0: 114 | self.board[i] = 'O' 115 | self._check_winner() 116 | if self.is_finish: 117 | self.reset_board() 118 | continue 119 | board_state = [positions[i:i + 3] for i in range(0, 9, 3)] 120 | valid_movements = [] 121 | row_map = {0: 'A', 1: 'B', 2: 'C'} 122 | col_map = {0: '1', 1: '2', 2: '3'} 123 | for i, val in enumerate(positions): 124 | if val == -1: 125 | r, c = divmod(i, 3) 126 | move_str = row_map[r] + col_map[c] 127 | valid_movements.append(move_str) 128 | return board_state, valid_movements 129 | 130 | def calculate_score(self): 131 | """Calculate score based on steps taken and game outcome.""" 132 | player_steps = len(self.moves_history) 133 | base_score = player_steps * 10 134 | bonus_score = 0 135 | if self.status == GameStatus.WIN: 136 | bonus_score = 50 137 | elif self.status == GameStatus.TIE: 138 | bonus_score = 20 139 | total_score = base_score + bonus_score 140 | return total_score 141 | 142 | def parse_e2e(self, lmm_output): 143 | """Parse e2e output to a move.""" 144 | match = re.search(r'Movement:\s*([A-Ca-c][1-3]|[1-3][A-Ca-c])', 145 | lmm_output, re.IGNORECASE) 146 | if match: 147 | move = match.group(1).upper() 148 | if move[0].isdigit(): 149 | move = move[1] + move[0] 150 | return move 151 | return GameStatus.INVALID_MOVE 152 | 153 | 154 | class TicTacToeRenderer(QMainWindow): 155 | """Renderer for Tic Tac Toe UI.""" 156 | 157 | def __init__(self, logic): 158 | super().__init__() 159 | self.ui = Ui_MainWindow() 160 | self.ui.setupUi(self) 161 | self.logic = logic 162 | self.select_font = QFont() 163 | self.select_font.setPointSize(35) 164 | self._update_ui() 165 | 166 | def _update_ui(self): 167 | color_map = {'X': 'red', 'O': 'blue'} 168 | color = color_map.get(self.logic.opponent, 'black') 169 | self.ui.label_2.setText( 170 | f'You are playing as {self.logic.opponent}' # noqa 171 | ) 172 | for i, cell in enumerate(self.logic.board): 173 | button = self.ui.buttons[i] 174 | if cell in ['X', 'O']: 175 | button.setFont(self.select_font) 176 | button.setText(cell) 177 | button.setStyleSheet('color:blue' if cell == 178 | 'O' else 'color:red') 179 | else: 180 | button.setText('') 181 | button.setStyleSheet('') 182 | 183 | def get_screenshot(self): 184 | board_width = 500 185 | board_height = 600 186 | screenshot = QPixmap(board_width, board_height) 187 | painter = QPainter(screenshot) 188 | self.render(painter) 189 | painter.end() 190 | return screenshot 191 | 192 | 193 | @GAME_REGISTRY.register('tictactoe') 194 | class TicTacToe(BaseGame): 195 | AI_component = True 196 | 197 | def __init__(self, game_cfg): 198 | super().__init__(game_cfg) 199 | self.logic = TicTacToeLogic(game_cfg) 200 | self.renderer = None 201 | self.minimax = Minimax( 202 | self.logic.bot, 203 | self.logic.opponent) if game_cfg.player_first else None 204 | if not game_cfg.player_first: 205 | self.ai_move() 206 | 207 | def get_screenshot(self): 208 | if self.renderer is None: 209 | self.renderer = TicTacToeRenderer(self.logic) 210 | self.renderer._update_ui() 211 | return self.renderer.get_screenshot() 212 | 213 | def input_move(self, move): 214 | return self.logic.input_move(move) 215 | 216 | def get_game_status(self): 217 | return self.logic.get_game_status() 218 | 219 | def get_random_state(self): 220 | return self.logic.get_random_state() 221 | 222 | def get_rule_state(self): 223 | return self.logic.get_rule_state() 224 | 225 | def ai_move(self): 226 | if not self.AI_component or self.logic.is_finish: 227 | return None 228 | game_match = self.minimax.generate_2d(self.logic.board) 229 | move_index = self.minimax.find_best_move(game_match) 230 | if self.logic.make_move(move_index, self.logic.bot): 231 | return f'{chr(65 + move_index // 3)}{move_index % 3 + 1}' 232 | return None 233 | 234 | def calculate_score(self): 235 | return self.logic.calculate_score() 236 | 237 | def parse_e2e(self, lmm_output): 238 | return self.logic.parse_e2e(lmm_output) 239 | -------------------------------------------------------------------------------- /playground/games/chess/chess.py: -------------------------------------------------------------------------------- 1 | import random 2 | import re 3 | 4 | import chess 5 | import chess.engine 6 | from PyQt5.QtWidgets import QMainWindow 7 | 8 | from playground.games import BaseGame, BaseGameLogic 9 | from playground.games.chess.chess_ui import ChessUI 10 | from playground.registry import GAME_REGISTRY 11 | from playground.state_code import GameStatus 12 | 13 | 14 | class ChessLogic(BaseGameLogic): 15 | """Pure logic for Chess game.""" 16 | 17 | def __init__(self, game_cfg): 18 | self.game_cfg = game_cfg 19 | self.user_is_white = game_cfg.user_is_white 20 | self.board = chess.Board() 21 | self.status = GameStatus.IN_PROGRESS 22 | self.turn = 'white' if self.user_is_white else 'black' 23 | self.moves_history = [] 24 | 25 | def make_move(self, move, is_ai=False): 26 | """Make a move on the board.""" 27 | expected_turn = 'white' if self.user_is_white else 'black' 28 | if is_ai: 29 | expected_turn = 'black' if self.user_is_white else 'white' 30 | if self.turn != expected_turn: 31 | return False 32 | try: 33 | if len(move) > 2 and move[0] in 'NBRQK': 34 | adjusted_move = move[0] + move[1:].lower() 35 | else: 36 | adjusted_move = move.lower() 37 | chess_move = self.board.parse_san(adjusted_move) 38 | if chess_move not in self.board.legal_moves: 39 | return False 40 | self.board.push(chess_move) 41 | self.moves_history.append(chess_move) 42 | self.turn = 'black' if self.turn == 'white' else 'white' 43 | self._update_game_status() 44 | return True 45 | except ValueError: 46 | return False 47 | 48 | def _update_game_status(self): 49 | """Update game status based on current board state.""" 50 | if self.board.is_checkmate(): 51 | self.status = GameStatus.LOSE if self.turn == ( 52 | 'white' if self.user_is_white else 'black') else GameStatus.WIN 53 | elif (self.board.is_insufficient_material() 54 | or self.board.is_stalemate() or self.board.is_repetition(count=3) 55 | or self.board.halfmove_clock >= 100): 56 | self.status = GameStatus.TIE 57 | else: 58 | self.status = GameStatus.IN_PROGRESS 59 | 60 | def input_move(self, move): 61 | """Process player move in SAN format (e.g., 'e4').""" 62 | if self.status != GameStatus.IN_PROGRESS: 63 | return self.status 64 | if self.make_move(move, is_ai=False): 65 | return self.status 66 | return GameStatus.INVALID_MOVE 67 | 68 | def get_game_status(self): 69 | return self.status 70 | 71 | def reset_board(self): 72 | """Reset the board to initial state.""" 73 | self.board = chess.Board() 74 | self.status = GameStatus.IN_PROGRESS 75 | self.turn = 'white' if self.user_is_white else 'black' 76 | self.moves_history = [] 77 | 78 | def get_random_state(self): 79 | """Generate a random game state.""" 80 | self.reset_board() 81 | num_moves = random.randint(5, 55) 82 | for _ in range(num_moves): 83 | legal_moves = list(self.board.legal_moves) 84 | if not legal_moves: 85 | break 86 | move = random.choice(legal_moves) 87 | self.board.push(move) 88 | self.moves_history.append(move) 89 | 90 | piece_to_numeric = { 91 | chess.PAWN: 1, 92 | chess.KNIGHT: 2, 93 | chess.BISHOP: 3, 94 | chess.ROOK: 4, 95 | chess.QUEEN: 5, 96 | chess.KING: 6, 97 | None: 0 98 | } 99 | board_matrix = [[0 for _ in range(8)] for _ in range(8)] 100 | for i in range(64): 101 | piece = self.board.piece_at(i) 102 | if piece: 103 | value = piece_to_numeric[piece.piece_type] 104 | if piece.color == chess.BLACK: 105 | value = -value 106 | board_matrix[7 - (i // 8)][i % 8] = value 107 | 108 | self._update_game_status() 109 | return board_matrix 110 | 111 | def get_rule_state(self): 112 | """Generate a rule state with valid movements.""" 113 | self.reset_board() 114 | num_moves = random.randint(5, 55) 115 | for _ in range(num_moves): 116 | legal_moves = list(self.board.legal_moves) 117 | if not legal_moves: 118 | break 119 | move = random.choice(legal_moves) 120 | self.board.push(move) 121 | self.moves_history.append(move) 122 | 123 | while self.board.turn != (chess.WHITE 124 | if self.user_is_white else chess.BLACK): 125 | legal_moves = list(self.board.legal_moves) 126 | if not legal_moves: 127 | break 128 | move = random.choice(legal_moves) 129 | self.board.push(move) 130 | self.moves_history.append(move) 131 | 132 | fen = self.board.fen() 133 | valid_movements = [self.board.san(m) for m in self.board.legal_moves] 134 | self._update_game_status() 135 | return fen, valid_movements 136 | 137 | def calculate_score(self): 138 | """Calculate score based on steps and captured enemy pieces.""" 139 | white_steps = len( 140 | [m for i, m in enumerate(self.moves_history) if i % 2 == 0]) 141 | step_score = white_steps * 10 142 | 143 | piece_values = { 144 | chess.PAWN: 1, 145 | chess.KNIGHT: 3, 146 | chess.BISHOP: 3, 147 | chess.ROOK: 5, 148 | chess.QUEEN: 9, 149 | chess.KING: 0 150 | } 151 | 152 | initial_black_piece_value = (8 * piece_values[chess.PAWN] + 153 | 2 * piece_values[chess.KNIGHT] + 154 | 2 * piece_values[chess.BISHOP] + 155 | 2 * piece_values[chess.ROOK] + 156 | 1 * piece_values[chess.QUEEN]) 157 | 158 | black_piece_value = 0 159 | for square in chess.SQUARES: 160 | piece = self.board.piece_at(square) 161 | if piece and piece.color == chess.BLACK: 162 | black_piece_value += piece_values.get(piece.piece_type, 0) 163 | 164 | captured_black_value = initial_black_piece_value - black_piece_value 165 | piece_bonus = captured_black_value * 5 166 | 167 | outcome_bonus = 0 168 | if self.status == GameStatus.WIN: 169 | outcome_bonus = 1000 170 | elif self.status == GameStatus.TIE: 171 | outcome_bonus = 500 172 | 173 | total_score = step_score + piece_bonus + outcome_bonus 174 | return total_score 175 | 176 | def parse_e2e(self, lmm_output): 177 | """Parse e2e output to a move in SAN format.""" 178 | match = re.search( 179 | r'Movement:\s*([a-hA-H][1-8][a-hA-H][1-8]|[a-hA-H][1-8]|O-O|O-O-O|(?:N|B|R|Q|K)?[a-hA-H]?[1-8]?x?[a-hA-H][1-8](?:=[QRNB])?|(?:N|B|R|Q|K)[a-hA-H][1-8])', # noqa 180 | lmm_output, 181 | re.IGNORECASE) 182 | if match: 183 | return match.group(1) 184 | return GameStatus.INVALID_MOVE 185 | 186 | 187 | class ChessRenderer(QMainWindow): 188 | """Renderer for Chess UI.""" 189 | 190 | def __init__(self, logic): 191 | super().__init__() 192 | self.logic = logic 193 | self.ui = ChessUI(self, user_is_white=self.logic.user_is_white) 194 | self.setCentralWidget(self.ui) 195 | self.ui.position = self.logic.board 196 | self.ui.reset_board() 197 | 198 | def get_screenshot(self): 199 | """Generate screenshot of the current board.""" 200 | self.ui.position = self.logic.board 201 | self.ui.refresh_from_state() 202 | screenshot = self.ui.grab() 203 | return screenshot 204 | 205 | 206 | @GAME_REGISTRY.register('chess') 207 | class Chess(BaseGame): 208 | AI_component = True 209 | 210 | def __init__(self, game_cfg): 211 | super().__init__(game_cfg) 212 | self.logic = ChessLogic(game_cfg) 213 | self.renderer = None 214 | self.engine = chess.engine.SimpleEngine.popen_uci( 215 | '/usr/games/stockfish') 216 | 217 | def __del__(self): 218 | """Cleanup engine resources.""" 219 | if hasattr(self, 'engine'): 220 | self.engine.quit() 221 | 222 | def get_screenshot(self): 223 | if self.renderer is None: 224 | self.renderer = ChessRenderer(self.logic) 225 | return self.renderer.get_screenshot() 226 | 227 | def input_move(self, move): 228 | return self.logic.input_move(move) 229 | 230 | def get_game_status(self): 231 | return self.logic.get_game_status() 232 | 233 | def get_random_state(self): 234 | return self.logic.get_random_state() 235 | 236 | def get_rule_state(self): 237 | return self.logic.get_rule_state() 238 | 239 | def ai_move(self): 240 | """Calculate and apply AI move using Stockfish.""" 241 | if not self.AI_component or self.logic.status != GameStatus.IN_PROGRESS: # noqa 242 | return None 243 | 244 | result = self.engine.play(self.logic.board, 245 | chess.engine.Limit(time=1.0)) 246 | chess_move = result.move 247 | if chess_move in self.logic.board.legal_moves: 248 | san_move = self.logic.board.san(chess_move) 249 | if self.logic.make_move(san_move, is_ai=True): 250 | return san_move 251 | return None 252 | 253 | def calculate_score(self): 254 | return self.logic.calculate_score() 255 | 256 | def parse_e2e(self, lmm_output): 257 | return self.logic.parse_e2e(lmm_output) 258 | -------------------------------------------------------------------------------- /playground/games/sudoku/sudoku.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import re 3 | import time 4 | 5 | from PyQt5.QtCore import QTimer 6 | from PyQt5.QtGui import QPainter, QPixmap 7 | from PyQt5.QtWidgets import QMainWindow 8 | 9 | from playground.games import BaseGame, BaseGameLogic 10 | from playground.games.sudoku import sudoku_generator 11 | from playground.games.sudoku.sudoku_ui import SudokuUI 12 | from playground.registry import GAME_REGISTRY 13 | from playground.state_code import GameStatus 14 | 15 | 16 | class SudokuLogic(BaseGameLogic): 17 | """Pure logic for Sudoku game.""" 18 | 19 | def __init__(self, game_cfg): 20 | self.game_cfg = game_cfg 21 | self.b_size = 9 22 | self.solution = [[0 for _ in range(self.b_size)] 23 | for _ in range(self.b_size)] 24 | self.puzzle = [] 25 | self.assigned = [[False for _ in range(self.b_size)] 26 | for _ in range(self.b_size)] 27 | self.status = GameStatus.IN_PROGRESS 28 | self.moves_history = [] 29 | self.timer_start = int(time.time()) 30 | self.pause_time = 0 31 | self.start_game() 32 | 33 | def start_game(self): 34 | """Initialize the Sudoku puzzle.""" 35 | self.status = GameStatus.IN_PROGRESS 36 | solution = [[0 for _ in range(self.b_size)] 37 | for _ in range(self.b_size)] 38 | sudoku_generator.fillGrid(solution) 39 | self.solution = copy.deepcopy(solution) 40 | self.puzzle = sudoku_generator.generate_puzzle(solution, 5) 41 | self.assigned = [[self.puzzle[y][x] != 0 for x in range(self.b_size)] 42 | for y in range(self.b_size)] 43 | self.moves_history = [] 44 | 45 | def input_move(self, move): 46 | """Process move in format 'A1 5' (row A, col 1, number 5).""" 47 | if self.status != GameStatus.IN_PROGRESS: 48 | return self.status 49 | 50 | match = re.match(r'([A-Ia-i])([1-9])\s([1-9])', move) 51 | if not match: 52 | return GameStatus.INVALID_MOVE 53 | 54 | row = ord(match.group(1).upper()) - ord('A') 55 | col = int(match.group(2)) - 1 56 | number = int(match.group(3)) 57 | 58 | if not (0 <= row < self.b_size and 0 <= col < self.b_size): 59 | return GameStatus.INVALID_MOVE 60 | 61 | if self.assigned[row][col]: 62 | return GameStatus.INVALID_MOVE 63 | 64 | if self.puzzle[row][col] != 0: 65 | return GameStatus.INVALID_MOVE 66 | 67 | if not self._is_valid_move(row, col, number): 68 | return GameStatus.INVALID_MOVE 69 | 70 | self.moves_history.append(move) 71 | self.puzzle[row][col] = number 72 | self._check_win() 73 | return self.status 74 | 75 | def _is_valid_move(self, row, col, number): 76 | """Check if placing number at (row, col) is valid.""" 77 | if number in self.puzzle[row]: 78 | return False 79 | if number in [self.puzzle[r][col] for r in range(self.b_size)]: 80 | return False 81 | start_row, start_col = (row // 3) * 3, (col // 3) * 3 82 | for y in range(start_row, start_row + 3): 83 | for x in range(start_col, start_col + 3): 84 | if self.puzzle[y][x] == number: 85 | return False 86 | return True 87 | 88 | def _check_win(self): 89 | """Check if the puzzle is solved correctly.""" 90 | for y in range(self.b_size): 91 | for x in range(self.b_size): 92 | if self.puzzle[y][ 93 | x] == 0 or self.puzzle[y][x] != self.solution[y][x]: 94 | return 95 | self.status = GameStatus.WIN 96 | 97 | def get_game_status(self): 98 | return self.status 99 | 100 | def get_random_state(self): 101 | return copy.deepcopy(self.puzzle) 102 | 103 | def get_rule_state(self): 104 | valid_movements = [] 105 | for y in range(self.b_size): 106 | for x in range(self.b_size): 107 | if self.puzzle[y][x] == 0: 108 | candidates = set(range(1, 10)) 109 | for col in range(self.b_size): 110 | if self.puzzle[y][col] != 0: 111 | candidates.discard(self.puzzle[y][col]) 112 | for row in range(self.b_size): 113 | if self.puzzle[row][x] != 0: 114 | candidates.discard(self.puzzle[row][x]) 115 | start_row, start_col = (y // 3) * 3, (x // 3) * 3 116 | for row in range(start_row, start_row + 3): 117 | for col in range(start_col, start_col + 3): 118 | if self.puzzle[row][col] != 0: 119 | candidates.discard(self.puzzle[row][col]) 120 | for num in sorted(candidates): 121 | valid_movements.append( 122 | f"{chr(y + ord('A'))}{x + 1} {num}") 123 | return self.puzzle, valid_movements 124 | 125 | def calculate_score(self): 126 | """Calculate score based on filled numbers""" 127 | base_score = sum( 128 | 1 for y in range(self.b_size) for x in range(self.b_size) 129 | if self.puzzle[y][x] != 0 and not self.assigned[y][x]) * 2 130 | 131 | correct_count = sum( 132 | 1 for y in range(self.b_size) for x in range(self.b_size) 133 | if self.puzzle[y][x] != 0 and not self.assigned[y][x] 134 | and self.puzzle[y][x] == self.solution[y][x]) 135 | correct_score = correct_count * 10 136 | 137 | outcome_bonus = 1000 if self.status == GameStatus.WIN else 0 138 | 139 | total_score = base_score + correct_score + outcome_bonus 140 | return total_score 141 | 142 | def parse_e2e(self, lmm_output): 143 | """Parse e2e output to a move.""" 144 | match = re.search(r'Movement:\s*([A-Ia-i][1-9]\s[1-9])', lmm_output, 145 | re.IGNORECASE) 146 | if match: 147 | move = match.group(1).upper() 148 | return move 149 | return GameStatus.INVALID_MOVE 150 | 151 | 152 | class SudokuRenderer(QMainWindow): 153 | """Renderer for Sudoku UI.""" 154 | 155 | def __init__(self, logic): 156 | super().__init__() 157 | self.logic = logic 158 | self.ui = SudokuUI(self) 159 | self.setCentralWidget(self.ui.centralwidget) 160 | self._update_ui_from_logic() 161 | self.adjust_window_size() 162 | self.timer = QTimer(self) 163 | self.timer.timeout.connect(self.update_time) 164 | self.timer.start(1000) 165 | self.show() 166 | 167 | def adjust_window_size(self): 168 | self.setFixedSize(550, 700) 169 | 170 | def _update_ui_from_logic(self): 171 | """Sync UI with logic state.""" 172 | for y in range(self.logic.b_size): 173 | for x in range(self.logic.b_size): 174 | btn = self.ui.puzzle_buttons[y][x] 175 | if self.logic.puzzle[y][x] != 0: 176 | btn.setText(str(self.logic.puzzle[y][x])) 177 | if self.logic.assigned[y][x]: 178 | btn.setStyleSheet( 179 | 'color: black; background-color: white; font-family: sans-serif; font-size: 25px; border: 1px solid black;' # noqa 180 | ) 181 | else: 182 | btn.setStyleSheet( 183 | 'color: blue; background-color: white; font-family: sans-serif; font-size: 25px; border: 1px solid black;' # noqa 184 | ) 185 | btn.setDisabled(True) 186 | else: 187 | btn.setText('') 188 | btn.setStyleSheet( 189 | 'color: black; background-color: white; font-family: sans-serif; font-size: 25px; border: 1px solid black;' # noqa 190 | ) 191 | btn.setEnabled(True) 192 | 193 | def update_time(self): 194 | """Update the timer display.""" 195 | elapsed_time = int(time.time() - self.logic.timer_start - 196 | self.logic.pause_time) 197 | self.ui.show_time.setText(self.time_int_to_string(elapsed_time)) 198 | 199 | def time_int_to_string(self, i): 200 | return time.strftime('%H:%M:%S', time.gmtime(i)) 201 | 202 | def get_screenshot(self): 203 | """Generate screenshot of the entire window.""" 204 | self._update_ui_from_logic() 205 | screenshot = QPixmap(self.width(), self.height()) 206 | painter = QPainter(screenshot) 207 | painter.setRenderHint(QPainter.Antialiasing) 208 | self.render(painter) 209 | painter.end() 210 | return screenshot 211 | 212 | 213 | @GAME_REGISTRY.register('sudoku') 214 | class Sudoku(BaseGame): 215 | AI_component = False 216 | 217 | def __init__(self, game_cfg): 218 | super().__init__(game_cfg) 219 | self.logic = SudokuLogic(game_cfg) 220 | self.renderer = None 221 | 222 | def get_screenshot(self): 223 | if self.renderer is None: 224 | self.renderer = SudokuRenderer(self.logic) 225 | return self.renderer.get_screenshot() 226 | 227 | def input_move(self, move): 228 | return self.logic.input_move(move) 229 | 230 | def get_game_status(self): 231 | return self.logic.get_game_status() 232 | 233 | def get_random_state(self): 234 | return self.logic.get_random_state() 235 | 236 | def get_rule_state(self): 237 | return self.logic.get_rule_state() 238 | 239 | def ai_move(self): 240 | return None 241 | 242 | def calculate_score(self): 243 | return self.logic.calculate_score() 244 | 245 | def parse_e2e(self, lmm_output): 246 | return self.logic.parse_e2e(lmm_output) 247 | -------------------------------------------------------------------------------- /playground/games/reversi/reversi.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import random 3 | import re 4 | 5 | from PyQt5.QtGui import QColor, QPainter, QPixmap 6 | from PyQt5.QtWidgets import QMainWindow 7 | 8 | from playground.games import BaseGame, BaseGameLogic 9 | from playground.games.reversi.AI import ReversiAI 10 | from playground.games.reversi.reversi_ui import Ui_MainWindow 11 | from playground.registry import GAME_REGISTRY 12 | from playground.state_code import GameStatus 13 | 14 | 15 | class ReversiLogic(BaseGameLogic): 16 | """Pure logic for Reversi game.""" 17 | 18 | def __init__(self, game_cfg): 19 | self.game_cfg = game_cfg 20 | self.board = [[0 for _ in range(8)] for _ in range(8)] 21 | self.board[3][3], self.board[3][4], self.board[4][3], self.board[4][ 22 | 4] = 2, 1, 1, 2 23 | self.current_player = 1 24 | self.last_move = None 25 | self.player_steps = 0 26 | self.status = GameStatus.IN_PROGRESS 27 | self.no_move_count = 0 28 | self.ai = ReversiAI() 29 | 30 | def make_move(self, x, y): 31 | """Make a move on the board and flip pieces.""" 32 | if not self.valid_move(x, y): 33 | return False 34 | self.ai.make_move(self.board, x, y, self.current_player) 35 | self.last_move = (x, y) 36 | if self.current_player == 1: 37 | self.player_steps += 1 38 | self._check_game_over() 39 | return True 40 | 41 | def valid_move(self, x, y): 42 | """Check if a move is valid for the current player.""" 43 | return self.ai.valid_move(self.board, x, y, self.current_player) 44 | 45 | def switch_player(self): 46 | """Switch the current player.""" 47 | self.current_player = self.ai.opponent(self.current_player) 48 | 49 | def _check_game_over(self): 50 | """Check if the game is over (no valid moves for both players).""" 51 | if not any(self.valid_move(x, y) for x in range(8) for y in range(8)): 52 | self.no_move_count += 1 53 | self.switch_player() 54 | else: 55 | self.no_move_count = 0 56 | 57 | if self.no_move_count >= 2: 58 | white_score, black_score = self.ai.score(self.board) 59 | if white_score > black_score: 60 | self.status = GameStatus.LOSE 61 | elif black_score > white_score: 62 | self.status = GameStatus.WIN 63 | else: 64 | self.status = GameStatus.TIE 65 | 66 | def input_move(self, move): 67 | """Process player move from input string (e.g., 'D4').""" 68 | if self.status != GameStatus.IN_PROGRESS: 69 | return self.status 70 | col_map = { 71 | '1': 0, 72 | '2': 1, 73 | '3': 2, 74 | '4': 3, 75 | '5': 4, 76 | '6': 5, 77 | '7': 6, 78 | '8': 7 79 | } 80 | row_map = { 81 | 'A': 0, 82 | 'B': 1, 83 | 'C': 2, 84 | 'D': 3, 85 | 'E': 4, 86 | 'F': 5, 87 | 'G': 6, 88 | 'H': 7 89 | } 90 | move = re.sub(r'\s+', '', move).upper() 91 | match = re.match(r'([A-H])([1-8])|([1-8])([A-H])', move) 92 | if not match: 93 | return GameStatus.INVALID_MOVE 94 | if match.group(1) and match.group(2): 95 | row, col = row_map[match.group(1)], col_map[match.group(2)] 96 | else: 97 | col, row = col_map[match.group(3)], row_map[match.group(4)] 98 | if not (0 <= row < 8 and 0 <= col < 8) or not self.valid_move( 99 | col, row): 100 | return GameStatus.INVALID_MOVE 101 | self.make_move(col, row) 102 | self.switch_player() 103 | return self.status 104 | 105 | def get_game_status(self): 106 | return self.status 107 | 108 | def reset_board(self): 109 | """Reset the game board.""" 110 | self.board = [[0 for _ in range(8)] for _ in range(8)] 111 | self.board[3][3], self.board[3][4], self.board[4][3], self.board[4][ 112 | 4] = 2, 1, 1, 2 113 | self.current_player = 1 114 | self.last_move = None 115 | self.player_steps = 0 116 | self.status = GameStatus.IN_PROGRESS 117 | self.no_move_count = 0 118 | 119 | def get_random_state(self): 120 | """Generate a random game state.""" 121 | self.reset_board() 122 | total_cells = 8 * 8 123 | stone_ranges = { 124 | 'sparse': (10, 25), 125 | 'mild': (26, 40), 126 | 'dense': (41, 56) 127 | } 128 | range_choice = random.choice(list(stone_ranges.keys())) 129 | min_stones, max_stones = stone_ranges[range_choice] 130 | min_stones = max(0, min(min_stones, total_cells)) 131 | max_stones = max(0, min(max_stones, total_cells)) 132 | 133 | total_stones = random.randint(min_stones, max_stones) 134 | black_stones = random.randint(total_stones * 30 // 100, 135 | total_stones * 70 // 100) 136 | white_stones = total_stones - black_stones 137 | empty_cells = total_cells - total_stones 138 | 139 | pieces = [1] * black_stones + [2] * white_stones + [0] * empty_cells 140 | random.shuffle(pieces) 141 | 142 | for i in range(8): 143 | for j in range(8): 144 | self.board[i][j] = pieces[i * 8 + j] 145 | 146 | self._check_game_over() 147 | if self.status != GameStatus.IN_PROGRESS: 148 | return self.get_random_state() 149 | return self.board 150 | 151 | def get_rule_state(self): 152 | """Generate a rule state with valid movements.""" 153 | valid_state_found = False 154 | valid_moves = [] 155 | 156 | row_labels = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H'] 157 | col_labels = ['1', '2', '3', '4', '5', '6', '7', '8'] 158 | 159 | while not valid_state_found: 160 | board_state = self.get_random_state() 161 | valid_moves = [] 162 | 163 | for x in range(8): 164 | for y in range(8): 165 | if self.ai.valid_move(board_state, x, y, 1): 166 | move_str = f'{row_labels[y]}{col_labels[x]}' 167 | valid_moves.append(move_str) 168 | 169 | if valid_moves: 170 | valid_state_found = True 171 | return board_state, valid_moves 172 | 173 | def calculate_score(self): 174 | """Calculate score based on player's steps and remaining pieces.""" 175 | step_score = self.player_steps * 10 176 | black_count = sum(1 for i in range(8) for j in range(8) 177 | if self.board[i][j] == 1) 178 | piece_bonus = black_count * 20 179 | outcome_bonus = 0 180 | if self.status == GameStatus.WIN: 181 | outcome_bonus = 1000 182 | elif self.status == GameStatus.TIE: 183 | outcome_bonus = 500 184 | total_score = step_score + piece_bonus + outcome_bonus 185 | return total_score 186 | 187 | def parse_e2e(self, lmm_output): 188 | """Parse e2e output to a move.""" 189 | match = re.search(r'Movement:\s*([A-Ha-h][1-8]|[1-8][A-Ha-h])', 190 | lmm_output, re.IGNORECASE) 191 | if match: 192 | move = match.group(1).upper() 193 | if move[0].isdigit(): 194 | move = move[1] + move[0] 195 | return move 196 | return GameStatus.INVALID_MOVE 197 | 198 | 199 | class ReversiRenderer(QMainWindow): 200 | 201 | def __init__(self, logic): 202 | super().__init__() 203 | self.ui = Ui_MainWindow() 204 | self.ui.setupUi(self) 205 | self.logic = logic 206 | 207 | def get_screenshot(self): 208 | board_width = 500 209 | board_height = 600 210 | screenshot = QPixmap(board_width, board_height) 211 | screenshot.fill(QColor(255, 255, 255)) 212 | painter = QPainter(screenshot) 213 | self.render(painter) 214 | self.ui.draw_board(painter, self.logic.board) 215 | self.ui.draw_labels(painter) 216 | painter.end() 217 | return screenshot 218 | 219 | 220 | @GAME_REGISTRY.register('reversi') 221 | class Reversi(BaseGame): 222 | AI_component = True 223 | 224 | def __init__(self, game_cfg): 225 | super().__init__(game_cfg) 226 | self.logic = ReversiLogic(game_cfg) 227 | self.renderer = None 228 | 229 | def get_screenshot(self): 230 | if self.renderer is None: 231 | self.renderer = ReversiRenderer(self.logic) 232 | return self.renderer.get_screenshot() 233 | 234 | def input_move(self, move): 235 | return self.logic.input_move(move) 236 | 237 | def get_game_status(self): 238 | return self.logic.get_game_status() 239 | 240 | def get_random_state(self): 241 | return self.logic.get_random_state() 242 | 243 | def get_rule_state(self): 244 | return self.logic.get_rule_state() 245 | 246 | def ai_move(self): 247 | if not self.AI_component or self.logic.status != GameStatus.IN_PROGRESS: # noqa 248 | return None 249 | 250 | best_move = self.logic.ai.best_move(copy.deepcopy(self.logic.board), 3, 251 | self.logic.current_player) 252 | if best_move: 253 | x, y = best_move 254 | if self.logic.make_move(x, y): 255 | self.logic.switch_player() 256 | row_labels = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H'] 257 | col_labels = ['1', '2', '3', '4', '5', '6', '7', '8'] 258 | return f'{row_labels[y]}{col_labels[x]}' 259 | return None 260 | 261 | def calculate_score(self): 262 | return self.logic.calculate_score() 263 | 264 | def parse_e2e(self, lmm_output): 265 | return self.logic.parse_e2e(lmm_output) 266 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |

👾Are Large Vision Language Models Good Game Players?

4 | 5 | 10 | 🎮The University of Adelaide 🕹️Zhejiang University 11 | 12 |
13 | 14 |
15 | 16 | License: MIT 17 |
18 | 19 |
20 | 21 | ![](assets/LVLM-Playground.jpg) 22 | 23 | LVLM-Playground is a benchmark to evaluate Large Vision Language Models (LVLMs) on game-playing tasks, assessing their perception, reasoning, and decision-making across six classic games. This repository provides tools to run experiments, analyze performance, and visualize results. For further details, please refer to our paper [here](https://openreview.net/pdf?id=c4OGMNyzPT). 24 | 25 | ## News 26 | 27 | \[**2025.Mar**\] LVLM-Playground is released! 🚀
28 | \[**2025.Feb**\] LVLM-Playground has been accepted to ICLR 2025! 🎉 Check the paper [here](https://openreview.net/pdf?id=c4OGMNyzPT). 29 | 30 | ## Installation 31 | 32 | 1. **Clone the Repository**: 33 | 34 | ```bash 35 | git clone https://github.com/xinke-wang/LVLM-Playground.git 36 | cd LVLM-Playground 37 | ``` 38 | 39 | 2. **Setup a Conda Environment** 40 | 41 | ```bash 42 | conda create -n playground python=3.11 -y 43 | conda activate playground 44 | pip install -r requirements.txt 45 | ``` 46 | 47 | 3. **Install Stockfish for Chess** 48 | 49 | To run experiments on Chess, you need to install Stockfish. 50 | 51 | - With `sudo` privileges, install [Stockfish](https://stockfishchess.org/) via your package manager: 52 | 53 | ```bash 54 | sudo apt-get install stockfish 55 | ``` 56 | 57 | - Alternatively, you can download the latest Stocfish binary from [stockfishcess.org](https://stockfishchess.org/download/). 58 | 59 | - Extract the binary and place the `stockfish` executable in your system PATH or the project directory. 60 | 61 | ## Data Preparation 62 | 63 | 1. **Download Pre-generated Benchmark Data**: 64 | 65 | To facilitate reproducibility of the experiments in our paper, we provide pre-generated benchmark data. You can download the data by running the following command: 66 | 67 | ```bash 68 | wget https://universityofadelaide.box.com/shared/static/9xx4brpiipqmmyomau2v522frtijx930.zip -O benchmark.zip 69 | unzip benchmark.zip -d . 70 | ``` 71 | 72 | After unzipping, you should have the following directory structure: 73 | 74 | ``` 75 | LVLM-Playground 76 | ├── benchmark 77 | │ ├── perceive 78 | │ │ ├── chess 79 | │ │ │ ├── 0000000.jpg 80 | │ │ │ ├── 0000001.jpg 81 | │ │ │ ├── ... 82 | │ │ │ └── annotation.json 83 | │ │ ├── gomoku 84 | │ │ ├── minesweeper 85 | │ │ ├── reversi 86 | │ │ ├── sudoku 87 | │ │ └── tictactoe 88 | │ ├── qa 89 | │ └── rule 90 | ``` 91 | 92 | 2. **Generating a Custom Benchmark (Optional)**: 93 | 94 | Alternatively, you can generate a new benchmark dataset by running the following command: 95 | 96 | ```bash 97 | python generate_benchmark.py 98 | ``` 99 | 100 | You can modify the `configs/base.py` file to customize the benchmark generation process. 101 | 102 | ```python 103 | # configs/base.py 104 | benchmark_setting = dict( 105 | games=['tictactoe', 'gomoku', 'minesweeper', 'reversi', 'sudoku', 'chess'], 106 | sample_size=2000, 107 | e2e_round=100, 108 | offline_task=['perceive', 'qa', 'rule'], 109 | benchmark_path='benchmark' 110 | ) 111 | ``` 112 | 113 | - Adjust `games` to include or exclude specific games. 114 | - Modify `sample_size` to control the number of samples per game. 115 | - Change `benchmark_path` to specify the output directory. 116 | 117 | ## Running Experiments 118 | 119 | Once the data is ready, run experiments using: 120 | 121 | ```bash 122 | python run.py --exp-recipe configs/recipe/base.py --agent-cfg configs/agents/internvl/internvl2-1b.py 123 | ``` 124 | 125 | `--exp-recipe` specifies the experiment settings, and `--agent-cfg` specifies the agent configuration. If you are using the commercial model (e.g., OpenAI, Google, Anthropic) as agent, ensure you have the necessary API keys set as environment variables (e.g., `OPENAI+API_KEY`, `GOOGLE_API_KEY`). The framework can **automatically resume** the experiment from unexpected termination, as long as you set the same experiment name in the experiment recipe config (`configs/recipe/base.py`). 126 | 127 | We provide several pre-defined agent configurations in the `configs/agents` directory, includes three widely used commercial APIs [Gemini](configs/agents/google), [Claude](configs/agents/anhthropic), and [ChatGPT](configs/agents/openai), as well open-source models supported by [LMDeploy](https://github.com/InternLM/lmdeploy). You can find the pre-set configurations in `configs/agents`, and modify them to customize the LVLM settings. 128 | 129 | You can customize the experiment settings by modifying the configuration file `configs/recipe/base.py`. 130 | 131 | ```python 132 | # configs/recipe/base.py 133 | name = 'standard' 134 | save_path = 'experiments' 135 | tasks = ['perceive', 'qa', 'rule', 'e2e'] 136 | games = ['tictactoe', 'reversi', 'gomoku', 'minesweeper', 'sudoku', 'chess'] 137 | ``` 138 | 139 | - Adjust `tasks` to include or exclude specific tasks. 140 | - Modify `games` to specify the games to evaluate. 141 | - Change `save_path` to specify the output directory. 142 | - Set `name` to identify the experiment. 143 | 144 | ## Evaluating Results 145 | 146 | Once you have run the experiments, the results will be saved in the `experiments` directory with a name specified in the experiment recipe. You can evaluate the results using: 147 | 148 | ```bash 149 | python evaluate.py --exp-path experiments/standard/gpt4o.json 150 | ``` 151 | 152 | Evaluation results will be saved by default in the `evaluation_results/` directory. 153 | 154 | ## Visualizing Results 155 | 156 | To visualize the evaluation results, generate a radar chart comparing LVLMs across tasks: 157 | 158 | ```bash 159 | python plot_radar.py 160 | ``` 161 | 162 | This will automatically create a radar chart (`radar_chart.pdf`) in the current directory, illustrating performance differences. 163 | 164 | To compare with the results in our paper, you can download the evaluation files from [here](https://universityofadelaide.box.com/s/tn398x5zyj5eq0e05atfcja40w1xdf83) and place them in the `evaluation_results` directory. This includes the evaluation results for [GPT-4o-240806](https://openai.com/index/gpt-4o-system-card/), [Gemini-1.5pro](https://blog.google/technology/ai/google-gemini-next-generation-model-february-2024/), [Claude-3.5-sonnet](https://www.anthropic.com/news/claude-3-5-sonnet), [Qwen2-vl-7b](https://huggingface.co/Qwen/Qwen2-VL-7B), [DeepSeek-vl-7b](https://huggingface.co/deepseek-ai/deepseek-vl-7b-base), [Phi3-vl](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct), [LLaVA-1.6-7b](https://huggingface.co/liuhaotian/llava-v1.6-vicuna-7b), and [InternVL2-8b](https://huggingface.co/OpenGVLab/InternVL2-8B). 165 | 166 | ![](assets/radar_chart.jpg) 167 | 168 | ## Evaluating Customized Models 169 | 170 | To evaluate a customized LVLM, follow these steps: 171 | 172 | 1. **Implement Your Model**: 173 | 174 | Implement your model by inheriting the `BaseAgent` class in `agents/base_agent.py`, and registering it with `AGENT_REGISTRY`. You may use the following template: 175 | 176 | ```python 177 | from playground.agents import BaseAgent 178 | from playground.registry import AGENT_REGISTRY 179 | 180 | @AGENT_REGISTRY.register('custom_single') 181 | class CustomAgentSingleStep(BaseAgent): 182 | 183 | def __init__(self, agent_cfg): 184 | super().__init__(agent_cfg) 185 | # Initialize your model, API, or configuration here 186 | pass 187 | 188 | def get_decision(self, screenshot_path: str, prompt: str): 189 | # Implement logic to process the screenshot and prompt, return a decision 190 | pass 191 | ``` 192 | 193 | 2. **Configure Your Model in** `configs/agents`: 194 | 195 | Create or modify a configuration file (e.g., configs/agents/custom_agent.py) to define your model’s settings. Example: 196 | 197 | ```python 198 | lmm_agent = dict( 199 | agent='custom_single', 200 | model='your_model_name', 201 | max_tokens=512, 202 | image_size=512, 203 | backend_config=None, 204 | general_config=None, 205 | name='custom_agent' 206 | ) 207 | ``` 208 | 209 | Ensure the `agent` field matches the registered name in the `AGENT_REGISTRY`. After defining and configuring your model, follow the standard steps to run experiments (`python run.py`), evaluate results (`python evaluate.py`), and visualize performance (`python plot_radar.py`). 210 | 211 | ## Acknowledgements 212 | 213 | We acknowledge the authors of the following repositories for providing the game UIs and search-based AI implementations: 214 | 215 | - [Python-Chess](https://github.com/niklasf/python-chess) 216 | - [Moonsweeper](https://www.pythonguis.com/examples/python-minesweeper/) 217 | - [Sudoku-in-Python](https://github.com/humzah286/Sudoku-in-python) 218 | - [Gobang](https://github.com/sgsx11/Gobang) 219 | - [GUI-tic-tac-toe-AI](https://github.com/Erfan-ram/GUI-tic-tac-toe-Ai) 220 | - [reversi-minimax-algorithm](https://github.com/abkhan04/reversi-minimax-algorithm) 221 | 222 | 223 | ## Contact 224 | 225 | If you have any questions or suggestions, please feel free to open an issue or contact us via email Xinyu Wang [xinyu.wang02@adelaide.edu.au](mailto:xinyu.wang02@adelaide.edu.au). 226 | 227 | ## Citation 228 | 229 | If you find this repository useful for your research, please consider citing our paper: 230 | 231 | ```bibtex 232 | @inproceedings{wang2025large, 233 | title={Are Large Vision Language Models Good Game Players?}, 234 | author={Wang, Xinyu and Zhuang, Bohan and Wu, Qi}, 235 | booktitle={International Conference on Learning Representations}, 236 | year={2025} 237 | } 238 | ``` 239 | -------------------------------------------------------------------------------- /playground/games/minesweeper/minesweeper.py: -------------------------------------------------------------------------------- 1 | import random 2 | import re 3 | import time 4 | 5 | from PyQt5.QtGui import QIcon, QPainter, QPixmap 6 | from PyQt5.QtWidgets import QMainWindow 7 | 8 | from playground.games import BaseGame, BaseGameLogic 9 | from playground.games.minesweeper.game_cfg import LEVELS, STATUS_ICONS 10 | from playground.games.minesweeper.minesweeper_ui import MinesweeperUI 11 | from playground.registry import GAME_REGISTRY 12 | from playground.state_code import GameStatus 13 | 14 | 15 | class MinesweeperLogic(BaseGameLogic): 16 | """Pure logic for Minesweeper game.""" 17 | 18 | def __init__(self, game_cfg): 19 | self.game_cfg = game_cfg 20 | self.level = game_cfg.level 21 | self.b_size, self.n_mines = LEVELS[self.level] 22 | self.board = None 23 | self.status = GameStatus.IN_PROGRESS 24 | self.moves_history = [] 25 | self.timer_start = 0 26 | self.reset_board() 27 | 28 | def reset_board(self): 29 | """Reset the game board to initial state without revealing.""" 30 | self.board = [[-1 for _ in range(self.b_size)] 31 | for _ in range(self.b_size)] 32 | self.status = GameStatus.IN_PROGRESS 33 | self.moves_history = [] 34 | self.timer_start = int(time.time()) 35 | 36 | mine_positions = set() 37 | while len(mine_positions) < self.n_mines: 38 | x, y = random.randint(0, self.b_size - 1), random.randint( 39 | 0, self.b_size - 1) 40 | if (x, y) not in mine_positions: 41 | mine_positions.add((x, y)) 42 | self.board[y][x] = 9 43 | 44 | for y in range(self.b_size): 45 | for x in range(self.b_size): 46 | if self.board[y][x] != 9: 47 | self.board[y][x] = -1 48 | 49 | def _get_adjacency_n(self, x, y, mine_positions): 50 | """Calculate number of adjacent mines.""" 51 | count = 0 52 | for xi in range(max(0, x - 1), min(self.b_size, x + 2)): 53 | for yi in range(max(0, y - 1), min(self.b_size, y + 2)): 54 | if (xi, yi) in mine_positions: 55 | count += 1 56 | return count 57 | 58 | def _expand_reveal(self, x, y): 59 | """Reveal surrounding cells if no adjacent mines.""" 60 | if not (0 <= x < self.b_size 61 | and 0 <= y < self.b_size) or self.board[y][x] >= 0: 62 | return 63 | mine_positions = {(xi, yi) 64 | for yi in range(self.b_size) 65 | for xi in range(self.b_size) 66 | if self.board[yi][xi] in [9, 10]} 67 | self.board[y][x] = self._get_adjacency_n(x, y, mine_positions) 68 | if self.board[y][x] == 0: 69 | for xi in range(max(0, x - 1), min(self.b_size, x + 2)): 70 | for yi in range(max(0, y - 1), min(self.b_size, y + 2)): 71 | self._expand_reveal(xi, yi) 72 | 73 | def input_move(self, move): 74 | """Process move in format 'A1'.""" 75 | if self.status != GameStatus.IN_PROGRESS: 76 | return self.status 77 | pattern = re.compile(r'([a-zA-Z])([0-9]+)|([0-9]+)([a-zA-Z])') 78 | match = pattern.match(move) 79 | if not match: 80 | return GameStatus.INVALID_MOVE 81 | 82 | row, col = (match.group(1), 83 | match.group(2)) if match.group(1) else (match.group(4), 84 | match.group(3)) 85 | row = row.lower() 86 | if row.isalpha(): 87 | y = ord(row) - ord('a') 88 | x = int(col) - 1 89 | else: 90 | y = int(row) - 1 91 | x = ord(col.lower()) - ord('a') 92 | 93 | if not (0 <= x < self.b_size and 0 <= y < self.b_size): 94 | return GameStatus.INVALID_MOVE 95 | 96 | if 0 <= self.board[y][x] <= 8: 97 | return GameStatus.INVALID_MOVE 98 | 99 | self.moves_history.append(move) 100 | if self.board[y][x] == 9: 101 | self.board[y][x] = 10 102 | self.status = GameStatus.LOSE 103 | return self.status 104 | 105 | self._expand_reveal(x, y) 106 | self._check_win() 107 | return self.status 108 | 109 | def _check_win(self): 110 | """Check if only mines (10) remain unrevealed.""" 111 | unrevealed_count = sum(row.count(-1) for row in self.board) 112 | if unrevealed_count == self.n_mines: 113 | self.status = GameStatus.WIN 114 | 115 | def get_game_status(self): 116 | return self.status 117 | 118 | def get_random_state(self): 119 | """Generate a random game state with ~50% cells revealed.""" 120 | self.reset_board() 121 | game_state = [[-1 for _ in range(self.b_size)] 122 | for _ in range(self.b_size)] 123 | total_cells = self.b_size * self.b_size 124 | cells_to_reveal = total_cells // 2 + 1 125 | positions = [(x, y) for x in range(self.b_size) 126 | for y in range(self.b_size)] 127 | revealed_positions = random.sample(positions, cells_to_reveal) 128 | 129 | mine_positions = {(x, y) 130 | for y in range(self.b_size) 131 | for x in range(self.b_size) if self.board[y][x] == 9} 132 | for y in range(self.b_size): 133 | for x in range(self.b_size): 134 | if (x, y) in revealed_positions: 135 | if self.board[y][x] == 9: 136 | game_state[y][x] = 9 137 | else: 138 | game_state[y][x] = self._get_adjacency_n( 139 | x, y, mine_positions) 140 | else: 141 | game_state[y][x] = -1 142 | return game_state 143 | 144 | def get_rule_state(self): 145 | """Generate a rule state with valid movements.""" 146 | game_state = self.get_random_state() 147 | valid_movements = [] 148 | for y in range(self.b_size): 149 | for x in range(self.b_size): 150 | if game_state[y][x] == -1: 151 | pos_str = f"{chr(y + ord('A'))}{x + 1}" 152 | valid_movements.append(pos_str) 153 | return game_state, valid_movements 154 | 155 | def calculate_score(self): 156 | """Calculate score based on steps, revealed cells, and game outcome.""" 157 | step_score = len(self.moves_history) * 10 158 | reveal_bonus = sum(1 for y in range(self.b_size) 159 | for x in range(self.b_size) 160 | if 0 <= self.board[y][x] <= 8) * 2 161 | outcome_bonus = 1000 if self.status == GameStatus.WIN else 0 162 | total_score = step_score + reveal_bonus + outcome_bonus 163 | return total_score 164 | 165 | def parse_e2e(self, lmm_output): 166 | """Parse e2e output to a move.""" 167 | match = re.search(r'Movement:\s*([A-Ha-h][1-8]|[1-8][A-Ha-h])', 168 | lmm_output, re.IGNORECASE) 169 | if match: 170 | move = match.group(1).upper() 171 | if move[0].isdigit(): 172 | move = move[1] + move[0] 173 | return move 174 | return GameStatus.INVALID_MOVE 175 | 176 | 177 | class MinesweeperRenderer(QMainWindow): 178 | """Renderer for Minesweeper UI.""" 179 | 180 | def __init__(self, logic): 181 | super().__init__() 182 | self.logic = logic 183 | self.ui = MinesweeperUI(self, self.logic.b_size) 184 | self.setCentralWidget(self.ui.centralwidget) 185 | self._update_ui_from_logic() 186 | self.adjust_window_size() 187 | self.show() 188 | 189 | def adjust_window_size(self): 190 | window_width = 50 + self.logic.b_size * 20 191 | window_height = 100 + self.logic.b_size * 20 192 | self.setFixedSize(window_width, window_height) 193 | 194 | def _update_ui_from_logic(self): 195 | """Sync UI with logic state.""" 196 | self.ui.minesLabel.setText(f'{self.logic.n_mines:03d}') 197 | elapsed = int( 198 | time.time() 199 | ) - self.logic.timer_start if self.logic.status == GameStatus.IN_PROGRESS else 0 # noqa 200 | self.ui.clockLabel.setText(f'{elapsed:03d}') 201 | self.ui.statusButton.setIcon(QIcon(STATUS_ICONS[self.logic.status])) 202 | for y in range(self.logic.b_size): 203 | for x in range(self.logic.b_size): 204 | widget = self.ui.gameGrid.itemAtPosition(y + 1, x + 1).widget() 205 | widget.is_mine = (self.logic.board[y][x] in [9, 10]) 206 | widget.is_revealed = (self.logic.board[y][x] >= 0 207 | and self.logic.board[y][x] != 9 208 | ) or self.logic.board[y][x] == 10 209 | widget.adjacent_n = self.logic.board[y][ 210 | x] if widget.is_revealed and not widget.is_mine else 0 211 | widget.update() 212 | 213 | def get_screenshot(self): 214 | """Generate screenshot of the entire window.""" 215 | self._update_ui_from_logic() 216 | window_width = self.width() 217 | window_height = self.height() 218 | screenshot = QPixmap(window_width, window_height) 219 | painter = QPainter(screenshot) 220 | painter.setRenderHint(QPainter.Antialiasing) 221 | self.render(painter) 222 | painter.end() 223 | return screenshot 224 | 225 | 226 | @GAME_REGISTRY.register('minesweeper') 227 | class MineSweeper(BaseGame): 228 | AI_component = False 229 | 230 | def __init__(self, game_cfg): 231 | super().__init__(game_cfg) 232 | self.logic = MinesweeperLogic(game_cfg) 233 | self.renderer = None 234 | 235 | def get_screenshot(self): 236 | if self.renderer is None: 237 | self.renderer = MinesweeperRenderer(self.logic) 238 | return self.renderer.get_screenshot() 239 | 240 | def input_move(self, move): 241 | return self.logic.input_move(move) 242 | 243 | def get_game_status(self): 244 | return self.logic.get_game_status() 245 | 246 | def get_random_state(self): 247 | return self.logic.get_random_state() 248 | 249 | def get_rule_state(self): 250 | return self.logic.get_rule_state() 251 | 252 | def ai_move(self): 253 | return None 254 | 255 | def calculate_score(self): 256 | return self.logic.calculate_score() 257 | 258 | def parse_e2e(self, lmm_output): 259 | return self.logic.parse_e2e(lmm_output) 260 | --------------------------------------------------------------------------------