├── .gitignore ├── README.md ├── README2.md ├── Result.txt ├── __init__.py ├── config.py ├── dqn.py ├── ensemble.py ├── envi.py ├── game.py ├── models └── .gitkeep ├── net.py ├── outs ├── .gitkeep └── plot.py ├── precompiled └── .gitkeep ├── rule_based ├── __init__.py ├── rule_play.py └── utils │ ├── __init__.py │ ├── card.py │ ├── decomposer.py │ ├── evaluator.py │ ├── rule_based_model.py │ └── utils.py ├── server ├── CFR.py ├── __init__.py ├── app.py ├── client.py ├── config.py ├── core.py ├── init.py ├── mcts │ ├── __init__.py │ ├── backup.py │ ├── card.py │ ├── default_policy.py │ ├── evaluator.py │ ├── get_bestchild.py │ ├── get_moves.py │ ├── interface.py │ ├── tree.py │ └── tree_policy.py └── rule_utils │ ├── __init__.py │ ├── card.py │ ├── decomposer.py │ ├── evaluator.py │ ├── rule_based_model.py │ └── utils.py ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Prerequisites 2 | *.d 3 | 4 | # Compiled Object files 5 | *.slo 6 | *.lo 7 | *.o 8 | *.obj 9 | 10 | # Precompiled Headers 11 | *.gch 12 | *.pch 13 | 14 | # Compiled Dynamic libraries 15 | *.so 16 | *.dylib 17 | *.dll 18 | 19 | # Fortran module files 20 | *.mod 21 | *.smod 22 | 23 | # Compiled Static libraries 24 | *.lai 25 | *.la 26 | *.a 27 | *.lib 28 | 29 | # Executables 30 | *.exe 31 | *.out 32 | *.app 33 | # Byte-compiled / optimized / DLL files 34 | __pycache__/ 35 | *.py[cod] 36 | *$py.class 37 | 38 | # C extensions 39 | *.so 40 | 41 | # Distribution / packaging 42 | .Python 43 | build/ 44 | develop-eggs/ 45 | dist/ 46 | downloads/ 47 | eggs/ 48 | .eggs/ 49 | lib/ 50 | lib64/ 51 | parts/ 52 | sdist/ 53 | var/ 54 | wheels/ 55 | pip-wheel-metadata/ 56 | share/python-wheels/ 57 | *.egg-info/ 58 | .installed.cfg 59 | *.egg 60 | MANIFEST 61 | 62 | # PyInstaller 63 | # Usually these files are written by a python script from a template 64 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 65 | *.manifest 66 | *.spec 67 | 68 | # Installer logs 69 | pip-log.txt 70 | pip-delete-this-directory.txt 71 | 72 | # Unit test / coverage reports 73 | htmlcov/ 74 | .tox/ 75 | .nox/ 76 | .coverage 77 | .coverage.* 78 | .cache 79 | nosetests.xml 80 | coverage.xml 81 | *.cover 82 | .hypothesis/ 83 | .pytest_cache/ 84 | 85 | # Translations 86 | *.mo 87 | *.pot 88 | 89 | # Django stuff: 90 | *.log 91 | local_settings.py 92 | db.sqlite3 93 | db.sqlite3-journal 94 | 95 | # Flask stuff: 96 | instance/ 97 | .webassets-cache 98 | 99 | # Scrapy stuff: 100 | .scrapy 101 | 102 | # Sphinx documentation 103 | docs/_build/ 104 | 105 | # PyBuilder 106 | target/ 107 | 108 | # Jupyter Notebook 109 | .ipynb_checkpoints 110 | 111 | # IPython 112 | profile_default/ 113 | ipython_config.py 114 | 115 | # pyenv 116 | .python-version 117 | 118 | # pipenv 119 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 120 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 121 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 122 | # install all needed dependencies. 123 | #Pipfile.lock 124 | 125 | # celery beat schedule file 126 | celerybeat-schedule 127 | 128 | # SageMath parsed files 129 | *.sage.py 130 | 131 | # Environments 132 | .env 133 | .venv 134 | env/ 135 | venv/ 136 | ENV/ 137 | env.bak/ 138 | venv.bak/ 139 | 140 | # Spyder project settings 141 | .spyderproject 142 | .spyproject 143 | 144 | # Rope project settings 145 | .ropeproject 146 | 147 | # mkdocs documentation 148 | /site 149 | 150 | # mypy 151 | .mypy_cache/ 152 | .dmypy.json 153 | dmypy.json 154 | 155 | # Pyre type checker 156 | .pyre/ 157 | 158 | 159 | precompiled/* 160 | !precompiled/.gitkeep 161 | .idea/* 162 | 163 | outs/* 164 | !outs/.gitkeep 165 | models/* 166 | !models/.gitkeep 167 | outs/images/* 168 | !outs/images/.gitkeep 169 | __pycache__/* 170 | backup/* 171 | !outs/plot.py 172 | tmp.py 173 | *.db -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 模型测试 2 | 3 | DHCP是rule based model,用来训练网络的模型,且训练的都是地主。 4 | 5 | 每个模型进行了1000轮测试,胜率如下所示。 6 | 7 | 8 | | 地主开始训练时刻 | 训练第n次结果 | 对抗random胜率 | 对抗RHCP农民胜率 | 9 | | --- | --- | --- | --- | 10 | |0802_1836|78500_44|85|37.9| 11 | |0803_0959|8000|83|37.7| 12 | |0804_0912|4500_57|91.5|53.6| 13 | |0804_1045|3500_53|91.1|54.4| 14 | |0804_1423|3700_54|95.9|56.2| 15 | |0804_2022|lord_scratch4000|94.5|54.3| 16 | |0805_1019|2900_54|93.4|55.8| 17 | |0805_1049|lord_4000|—|58.1| 18 | |0806_1906|zero_lord_3000|—|52| 19 | |0806_1906|zero_lord_4000|—|50.5| 20 | |0806_1905|zero_lord_7000|—|42.8| 21 | |0806_1905|zero_lord_13000|—|28.1| 22 | |0807_1340(调整γ和状态)|lord_2900_54|—|53.1| 23 | |0807_1340(调整γ和状态)|lord_4000|—|55.7| 24 | |0808_0852|3300_53|-|57.5| 25 | |0808_0852|3500_59|-|58.7| 26 | |0808_0852|4700_60|-|55.2| 27 | 28 | 29 | 30 | | 农民开始训练时刻 | 训练第n次结果 | 对抗DHCP地主胜率 | 31 | | --- | --- | --- | 32 | |0806_1906|zero_up+down_3000(农民对抗地主)|17.0+21.3| 33 | |0806_1906|zero_up+down_4000(农民对抗地主)|18.6+19.3| 34 | |0806_1905|zero_up+down_7000(农民对抗地主)|15.6+18.7| 35 | |0806_1905|zero_up+down_13000(农民对抗地主)|13.1+16.7| 36 | |0807_1344(调整γ和状态)|zero_up+down_3000(农民对抗地主)|20.8+21.8(规则地主:57.4)| 37 | |0807_1344(调整γ和状态)|zero_up+down_4000(农民对抗地主)|17.2+25.8(规则地主:57.0)| 38 | |0807_1344(调整γ和状态)|zero_up+down_6000(农民对抗地主)|21.9+23.5(规则地主:54.6)| 39 | |0808_0918|5800_59|19.6+20.2| 40 | |0808_0854|4000|21.3+24.8| 41 | |0808_0854|6000|22.8+21.6| 42 | 43 | 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /README2.md: -------------------------------------------------------------------------------- 1 | # 模型测试 2 | 3 | DHCP是rule based model,用来训练网络的模型,且训练的都是地主。 4 | 5 | 每个模型进行了1000轮测试,胜率如下所示。 6 | 7 | | 开始训练时刻 | 训练第n次结果 | 对抗random胜率 | 对抗RHCP胜率 | 8 | | --- | --- | --- | --- | 9 | |08.02 18:36|78500|85|37.9| 10 | |08.03 09:59|8000|83|37.7| 11 | |08.04 09:12|4500|91.5|53.6| 12 | |08.04 14:23|3700|95.9|56.2| 13 | |08.05 10:19|2900|93.4|55.8| 14 | |08.05 10:49|4000|—|58.1| 15 | |08.07 13:40|4000|—|55.7| 16 | |08.07 13:44|6000(训练的农民)|—|21.9+23.5(对手:规则地主胜率54.6)| 17 | 18 | 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /Result.txt: -------------------------------------------------------------------------------- 1 | Level Count for Landlord: 2 | Counter({'huangjin_5': 86827, 'zuanshi_5': 78508, 'huangjin_4': 67529, 'huangjin_3': 50552, 'zuanshi_4': 42770, 'huangjin_2': 40143, 'huangjin_1': 31695, 'bojin_5': 25383, 'bojin_4': 21045, 'zuanshi_3': 19730, 'bojin_3': 19223, 'bojin_2': 17196, 'bojin_1': 16585, 'zuanshi_2': 9488, 'zuanshi_1': 5307, 'xingyao_5': 4793, 'xingyao_4': 374, 'xingyao_3': 223, 'xingyao_2': 80, 'xingyao_1': 6}) 3 | 4 | 5 | Level Count for Landlord_down: 6 | Counter({'huangjin_5': 36890, 'zuanshi_5': 30046, 'huangjin_4': 27827, 'huangjin_3': 20641, 'zuanshi_4': 16375, 'huangjin_2': 16164, 'huangjin_1': 13002, 'bojin_5': 10364, 'bojin_4': 8340, 'zuanshi_3': 7651, 'bojin_3': 7599, 'bojin_2': 6997, 'bojin_1': 6702, 'zuanshi_2': 3693, 'xingyao_5': 2137, 'zuanshi_1': 2061, 'xingyao_4': 287, 'xingyao_3': 156, 'xingyao_2': 83, 'xingyao_1': 5}) 7 | 8 | 9 | Level Count for Landlord_up: 10 | Counter({'huangjin_5': 41167, 'zuanshi_5': 33893, 'huangjin_4': 31687, 'huangjin_3': 24033, 'huangjin_2': 18769, 'zuanshi_4': 18429, 'huangjin_1': 14620, 'bojin_5': 11706, 'bojin_4': 9540, 'bojin_3': 8553, 'zuanshi_3': 8416, 'bojin_2': 7642, 'bojin_1': 7350, 'zuanshi_2': 4384, 'xingyao_5': 2397, 'zuanshi_1': 2344, 'xingyao_4': 307, 'xingyao_3': 198, 'xingyao_2': 83, 'xingyao_1': 5}) 11 | 12 | 13 | Percent of Level, Landlord: 14 | {'xingyao_5': 0.5138844215717808, 'xingyao_4': 0.38636363636363635, 'xingyao_3': 0.38648180242634317, 'xingyao_2': 0.3252032520325203, 'xingyao_1': 0.375, 'zuanshi_5': 0.5511383181112975, 'zuanshi_4': 0.5513445226493413, 'zuanshi_3': 0.551163505321675, 'zuanshi_2': 0.5401651010532309, 'zuanshi_1': 0.5464373970345964, 'bojin_5': 0.534908224980507, 'bojin_4': 0.540655105973025, 'bojin_3': 0.5434063604240282, 'bojin_2': 0.5401602010365949, 'bojin_1': 0.5413389039396808, 'huangjin_5': 0.5265944542830111, 'huangjin_4': 0.5315444377100667, 'huangjin_3': 0.5308634196542961, 'huangjin_2': 0.5346981725185146, 'huangjin_1': 0.5343324847851375} 15 | 16 | 17 | Percent of Level, landlord_down 18 | {'xingyao_5': 0.22911975983703228, 'xingyao_4': 0.2964876033057851, 'xingyao_3': 0.2703639514731369, 'xingyao_2': 0.33739837398373984, 'xingyao_1': 0.3125, 'zuanshi_5': 0.21092757306226176, 'zuanshi_4': 0.21108876685487404, 'zuanshi_3': 0.21373299438500434, 'zuanshi_2': 0.2102476515798463, 'zuanshi_1': 0.21221169686985172, 'bojin_5': 0.2184055802583609, 'bojin_4': 0.21425818882466283, 'bojin_3': 0.21481272084805653, 'bojin_2': 0.2197895398146694, 'bojin_1': 0.21875510004243234, 'huangjin_5': 0.22373304868877514, 'huangjin_4': 0.2190360744000063, 'huangjin_3': 0.21675802826959023, 'huangjin_2': 0.21530182748148544, 'huangjin_1': 0.21919517170457037} 19 | 20 | 21 | Percent of Level, landlord_up 22 | {'xingyao_5': 0.25699581859118686, 'xingyao_4': 0.31714876033057854, 'xingyao_3': 0.3431542461005199, 'xingyao_2': 0.33739837398373984, 'xingyao_1': 0.3125, 'zuanshi_5': 0.2379341088264407, 'zuanshi_4': 0.23756671049578468, 'zuanshi_3': 0.23510350029332067, 'zuanshi_2': 0.24958724736692287, 'zuanshi_1': 0.2413509060955519, 'bojin_5': 0.24668619476113207, 'bojin_4': 0.24508670520231213, 'bojin_3': 0.24178091872791518, 'bojin_2': 0.24005025914873568, 'bojin_1': 0.23990599601788687, 'huangjin_5': 0.24967249702821379, 'huangjin_4': 0.24941948788992704, 'huangjin_3': 0.25237855207611365, 'huangjin_2': 0.25, 'huangjin_1': 0.24647234351029215} 23 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/charleschen003/doudizhu-rl/56993b04d227e4718969209ab542142d406d3241/__init__.py -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import logging 4 | from datetime import datetime 5 | from pytz import timezone 6 | 7 | # Hyper Parameters for DQN 8 | GAMMA = 0.95 # discount factor for target q 9 | EPSILON_HIGH = 0.5 # starting value of epsilon 10 | EPSILON_LOW = 0.01 # final value of epsilon 11 | REPLAY_SIZE = 20000 # experience replay buffer size 12 | BATCH_SIZE = 256 # size of minibatch 13 | DECAY = int((8000 * (2 / 3)) / 5) # epsilon decay config 1000 for 8000 14 | UPDATE_TARGET_EVERY = 20 # target-net参数更新频率 15 | 16 | CARDS = range(3, 18) 17 | STR = [str(i) for i in range(3, 11)] + ['J', 'Q', 'K', 'A', '2', '小', '大'] 18 | DICT = dict(zip(CARDS, STR)) 19 | 20 | DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 21 | WORK_DIR, _ = os.path.split(os.path.abspath(__file__)) 22 | 23 | MODEL_DIR = os.path.join(WORK_DIR, 'models') 24 | IMG_DIR = os.path.join(WORK_DIR, 'outs', 'images') 25 | LOG_DIR = os.path.join(WORK_DIR, 'outs', 'logs') 26 | WIN_DIR = os.path.join(WORK_DIR, 'outs', 'win_rates') 27 | ENV_DIR = os.path.join(WORK_DIR, 'precompiled') 28 | 29 | 30 | def name_dir(name, max_split=2): 31 | return os.path.join(*name.split('_', max_split)) 32 | 33 | 34 | def get_logger(): 35 | now_utc = datetime.now(timezone('Asia/Shanghai')) 36 | begin = now_utc.strftime("%m%d_%H%M") 37 | path = os.path.join(LOG_DIR, name_dir(begin)) 38 | dirname = os.path.dirname(path) 39 | if not os.path.exists(dirname): 40 | os.makedirs(dirname) 41 | path = '{}.log'.format(path) 42 | 43 | logger = logging.getLogger('DDZ_RL') 44 | logger.setLevel(logging.INFO) 45 | log_format = '[%(asctime)s][%(name)s][%(levelname)s]: %(message)s' 46 | logging.basicConfig(filename=path, filemode='w', format=log_format) 47 | return begin, logger, path 48 | -------------------------------------------------------------------------------- /dqn.py: -------------------------------------------------------------------------------- 1 | import config as conf 2 | import torch 3 | import random 4 | import numpy as np 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | from collections import deque 8 | 9 | 10 | class DQNFirst: 11 | def __init__(self, net_cls): 12 | super(DQNFirst, self).__init__() 13 | self.epsilon = conf.EPSILON_HIGH 14 | self.replay_buffer = deque(maxlen=conf.REPLAY_SIZE) 15 | 16 | self.policy_net = net_cls().to(conf.DEVICE) 17 | self.target_net = net_cls().to(conf.DEVICE) 18 | self.target_net.load_state_dict(self.policy_net.state_dict()) 19 | self.optimizer = optim.Adam(self.policy_net.parameters(), 1e-4) 20 | 21 | def perceive(self, state, action, reward, next_state, next_action, done): 22 | self.replay_buffer.append(( 23 | state, action, reward, next_state, next_action, done)) 24 | if len(self.replay_buffer) < conf.BATCH_SIZE: 25 | return None 26 | 27 | # training 28 | samples = random.sample(self.replay_buffer, conf.BATCH_SIZE) 29 | s0, a0, r1, s1, a1, done = zip(*samples) 30 | s0 = torch.stack(s0) 31 | a0 = torch.stack(a0) 32 | r1 = torch.tensor(r1, dtype=torch.float).view(conf.BATCH_SIZE, -1)\ 33 | .to(conf.DEVICE) 34 | s1 = torch.stack(s1) 35 | a1 = torch.stack(a1) 36 | done = torch.tensor(done, dtype=torch.float).view(conf.BATCH_SIZE, -1)\ 37 | .to(conf.DEVICE) 38 | 39 | s1_reward = self.target_net(s1, a1).detach() 40 | y_true = r1 + (1 - done) * conf.GAMMA * s1_reward 41 | y_pred = self.policy_net(s0, a0) 42 | 43 | loss = nn.MSELoss()(y_true, y_pred) 44 | res = loss.item() 45 | self.optimizer.zero_grad() 46 | loss.backward() 47 | self.optimizer.step() 48 | return res 49 | 50 | def e_greedy_action(self, face, actions): 51 | """ 52 | :param face: 当前状态 2 * 15 * 4 53 | :param actions: 所有动作 batch_size * 15 * 4 54 | :return: action: 选择的动作 15 * 4 55 | """ 56 | q_value = self.policy_net(face, actions).detach() 57 | if random.random() <= self.epsilon: 58 | idx = np.random.randint(0, actions.shape[0]) 59 | else: 60 | idx = torch.argmax(q_value).item() 61 | return actions[idx] 62 | 63 | def greedy_action(self, face, actions): 64 | """ 65 | :param face: 当前状态 2 * 15 * 4 66 | :param actions: 所有动作 batch_size * 15 * 4 67 | :return: action: 选择的动作 15 * 4 68 | """ 69 | q_value = self.policy_net(face, actions).detach() 70 | idx = torch.argmax(q_value).item() 71 | return actions[idx] 72 | 73 | def update_epsilon(self, episode): 74 | self.epsilon = conf.EPSILON_LOW + \ 75 | (conf.EPSILON_HIGH - conf.EPSILON_LOW) * \ 76 | np.exp(-1.0 * episode / conf.DECAY) 77 | 78 | def update_target(self, episode): 79 | if episode % conf.UPDATE_TARGET_EVERY == 0: 80 | self.target_net.load_state_dict(self.policy_net.state_dict()) 81 | -------------------------------------------------------------------------------- /ensemble.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import json 4 | import config as conf 5 | import torch 6 | import requests 7 | 8 | BEGIN, logger, LOG_PATH = conf.get_logger() 9 | from server.core import Predictor 10 | 11 | 12 | class Game: 13 | predictor = Predictor() 14 | 15 | def __init__(self, env_cls, nets_dict, dqns_dict, reward_dict=None, 16 | train_dict=None, preload=None, seed=None, debug=False): 17 | if reward_dict is None: 18 | reward_dict = {'lord': 100, 'down': 50, 'up': 50} 19 | if train_dict is None: 20 | train_dict = {'lord': True, 'down': True, 'up': True} 21 | if preload is None: 22 | preload = {} 23 | assert not (nets_dict.keys() ^ dqns_dict.keys()), 'Net and DQN must match' 24 | 25 | self.lord_wins, self.down_wins, self.up_wins = [], [], [] 26 | self.lord_total_loss = self.down_total_loss = self.up_total_loss = 0 27 | self.lord_loss_count = self.down_loss_count = self.up_loss_count = 0 28 | self.up_total_wins = self.lord_total_wins = self.down_total_wins = 0 29 | self.up_recent_wins = self.lord_recent_wins = self.down_recent_wins = 0 30 | self.lord_max_wins = self.farmer_max_wins = 0 31 | 32 | self.env = env_cls(debug=debug, seed=seed) 33 | self.lord = self.down = self.up = None 34 | self.lord_train = self.down_train = self.up_train = False 35 | for role in ['lord', 'down', 'up']: 36 | if nets_dict.get(role): 37 | setattr(self, role, dqns_dict[role](nets_dict[role])) 38 | setattr(self, '{}_train'.format(role), train_dict[role]) 39 | if preload.get(role): 40 | getattr(self, role).target_net.load(preload.get(role)) 41 | getattr(self, role).policy_net.load(preload.get(role)) 42 | 43 | self.lord_s0 = self.down_s0 = self.up_s0 = None 44 | self.lord_a0 = self.down_a0 = self.up_a0 = None 45 | self.reward_dict = reward_dict 46 | self.preload = preload 47 | self.train_dict = train_dict 48 | 49 | def accumulate_loss(self, name, loss): 50 | assert name in {'up', 'down', 'lord'} 51 | if loss: 52 | if name == 'lord': 53 | self.lord_loss_count += 1 54 | self.lord_total_loss += loss 55 | elif name == 'down': 56 | self.down_loss_count += 1 57 | self.down_total_loss += loss 58 | else: 59 | self.up_loss_count += 1 60 | self.up_total_loss += loss 61 | 62 | def save_win_rates(self, episode): 63 | self.lord_wins.append(self.lord_recent_wins) 64 | self.up_wins.append(self.up_recent_wins) 65 | self.down_wins.append(self.down_recent_wins) 66 | # 是否高于最高胜率 67 | if self.lord and self.up is None and self.down is None: 68 | if self.lord_recent_wins > self.lord_max_wins: 69 | self.lord_max_wins = self.lord_recent_wins 70 | self.lord.policy_net.save( 71 | '{}_lord_{}_{}'.format(BEGIN, episode, self.lord_max_wins)) 72 | if self.lord and not self.lord_train: 73 | if self.up_recent_wins + self.down_recent_wins > self.farmer_max_wins: 74 | self.farmer_max_wins = self.up_recent_wins + self.down_recent_wins 75 | self.up.policy_net.save( 76 | '{}_up_{}_{}'.format(BEGIN, episode, self.farmer_max_wins)) 77 | self.down.policy_net.save( 78 | '{}_down_{}_{}'.format(BEGIN, episode, self.farmer_max_wins)) 79 | # 存一次胜率目录 80 | data = {'lord': self.lord_wins, 'down': self.down_wins, 'up': self.up_wins} 81 | path = os.path.join(conf.WIN_DIR, conf.name_dir(BEGIN)) 82 | dirname = os.path.dirname(path) 83 | if not os.path.exists(dirname): 84 | os.makedirs(dirname) 85 | path = '{}.json'.format(path) 86 | with open(path, 'w') as f: 87 | json.dump(data, f) 88 | 89 | def reset_recent(self): 90 | self.lord_recent_wins = self.up_recent_wins = self.down_recent_wins = 0 91 | self.lord_total_loss = self.down_total_loss = self.up_total_loss = 0 92 | self.lord_loss_count = self.down_loss_count = self.up_loss_count = 0 93 | 94 | def step(self, ai): 95 | assert ai in {'lord', 'down', 'up'} 96 | agent = getattr(self, ai) 97 | continue_train = getattr(self, '{}_train'.format(ai)) 98 | if agent: # 不是使用规则 99 | s0 = self.env.face 100 | if continue_train: # 需要继续训练 101 | setattr(self, '{}_s0'.format(ai), s0) # 更新状态s0 102 | action_f = agent.e_greedy_action 103 | else: 104 | action_f = agent.greedy_action 105 | a0 = action_f(s0, self.env.valid_actions()) 106 | if continue_train: 107 | setattr(self, '{}_a0'.format(ai), a0) # 更新动作a0 108 | _, done, _ = self.env.step_manual(a0) 109 | else: 110 | _, done, _ = self.env.step_auto() 111 | return done 112 | 113 | def feedback(self, ai, done, punish=False): 114 | assert ai in {'lord', 'up', 'down'} 115 | agent = getattr(self, ai) 116 | if agent and getattr(self, '{}_train'.format(ai)): # 是需要继续训练的模型 117 | if done: 118 | reward = self.reward_dict[ai] 119 | if punish: 120 | reward = -reward 121 | else: 122 | reward = 0 123 | s0 = getattr(self, '{}_s0'.format(ai)) 124 | a0 = getattr(self, '{}_a0'.format(ai)) 125 | s1 = self.env.face 126 | if done: 127 | a1 = torch.zeros((15, 4), dtype=torch.float).to(conf.DEVICE) 128 | else: 129 | a1 = agent.greedy_action(s1, self.env.valid_actions()) 130 | loss = agent.perceive(s0, a0, reward, s1, a1, done) 131 | self.accumulate_loss(ai, loss) 132 | 133 | def lord_turn(self): 134 | done = self.step('lord') 135 | if not done: # 本局未结束 136 | if self.down_a0 is not None: # 如果下家曾经出过牌 137 | self.feedback('down', done) 138 | else: # 本局结束,地主胜利 139 | if self.down_a0 is not None: # 如果下家曾经出过牌(不是一次性走完) 140 | self.feedback('down', done, punish=True) # 下家负反馈 141 | self.feedback('up', done, punish=True) # 上家负反馈 142 | # 自己得到正反馈 143 | self.feedback('lord', done) 144 | self.lord_total_wins += 1 145 | self.lord_recent_wins += 1 146 | return done 147 | 148 | def down_turn(self): 149 | done = self.step('down') 150 | if not done: # 本局未结束 151 | if self.up_a0 is not None: 152 | self.feedback('up', done) 153 | else: # 本局结束,农民胜利 154 | self.feedback('up', done) 155 | self.feedback('lord', done, punish=True) 156 | self.feedback('down', done) 157 | self.down_recent_wins += 1 158 | self.down_total_wins += 1 159 | return done 160 | 161 | def up_turn(self): 162 | done = self.step('up') 163 | if not done: # 本局未结束,地主得到0反馈 164 | self.feedback('lord', done) 165 | else: # 本局结束,农民胜利 166 | self.feedback('lord', done, punish=True) # 地主得到负反馈 167 | self.feedback('down', done) # 下家得到正反馈 168 | self.feedback('up', done) # 自己得到正反馈 169 | self.up_total_wins += 1 170 | self.up_recent_wins += 1 171 | return done 172 | 173 | def play(self): 174 | self.env.reset() 175 | self.env.prepare() 176 | while True: # 177 | done = self.lord_turn() 178 | if done: 179 | break 180 | done = self.down_turn() 181 | if done: 182 | break 183 | done = self.up_turn() 184 | if done: 185 | break 186 | 187 | def train(self, episodes, log_every=100, model_every=1000): 188 | if not ((self.lord and self.lord_train) 189 | or (self.up and self.up_train) 190 | or (self.down and self.down_train)): 191 | print('No agent need train.') 192 | return 193 | print('Logged at {}'.format(LOG_PATH)) 194 | messages = '' 195 | for role in ['up', 'lord', 'down']: 196 | m = '{}: {} based model.'.format( 197 | role, 'AI' if getattr(self, role) else 'Rule') 198 | if getattr(self, role): 199 | preload = self.preload.get(role) 200 | if preload: 201 | m += ' With pretrained model {}.'.format(preload) 202 | else: 203 | m += ' Without pretrained model.' 204 | if self.train_dict.get(role): 205 | m += ' Continue training.' 206 | messages += '\n{}'.format(m) 207 | logger.info(messages + '\n------------------------------------') 208 | print(messages) 209 | start_time = time.time() 210 | for episode in range(1, episodes + 1): 211 | self.play() 212 | 213 | if episode % log_every == 0: 214 | end_time = time.time() 215 | message = ( 216 | 'Reach at round {}, recent {} rounds takes {:.2f}seconds\n' 217 | '\tUp recent/total win: {:.2%}/{:.2%} [Mean loss: {:.2f}]\n' 218 | '\tLord recent/total win: {:.2%}/{:.2%} [Mean loss: {:.2f}]\n' 219 | '\tDown recent/total win: {:.2%}/{:.2%} [Mean loss: {:.2f}]\n' 220 | ).format(episode, log_every, end_time - start_time, 221 | self.up_recent_wins / log_every, self.up_total_wins / episode, 222 | self.up_total_loss / (self.up_loss_count + 1e-3), 223 | self.lord_recent_wins / log_every, self.lord_total_wins / episode, 224 | self.lord_total_loss / (self.lord_loss_count + 1e-3), 225 | self.down_recent_wins / log_every, self.down_total_wins / episode, 226 | self.down_total_loss / (self.down_loss_count + 1e-3)) 227 | logger.info(message) 228 | self.save_win_rates(episode) 229 | self.reset_recent() 230 | start_time = time.time() 231 | if episode % model_every == 0: 232 | for role in ['lord', 'down', 'up']: 233 | ai = getattr(self, role) 234 | if ai: 235 | ai.policy_net.save( 236 | '{}_{}_{}'.format(BEGIN, role, episode)) 237 | 238 | for role in ['lord', 'down', 'up']: 239 | ai = getattr(self, role) 240 | if ai: 241 | ai.update_epsilon(episode) 242 | ai.update_target(episode) 243 | 244 | @staticmethod 245 | def compete(env_cls, nets_dict, dqns_dict, model_dict, total=1000, 246 | print_every=100, debug=True): 247 | import collections 248 | assert not (nets_dict.keys() ^ dqns_dict.keys()), 'Net and DQN must match' 249 | assert not (nets_dict.keys() ^ model_dict.keys()), 'Net and Model must match' 250 | wins = collections.Counter() 251 | total_wins = collections.Counter() 252 | ai = {'up': None, 'lord': None, 'down': None} 253 | for role in ['up', 'lord', 'down']: 254 | if nets_dict.get(role) is not None: 255 | print('AI based {}.'.format(role)) 256 | ai[role] = dqns_dict[role](nets_dict[role]) 257 | ai[role].policy_net.load(model_dict[role]) 258 | else: 259 | print('Rule based {}.'.format(role)) 260 | 261 | env = env_cls(debug=debug) 262 | start_time = time.time() 263 | for episode in range(1, total + 1): 264 | if debug: 265 | print('\n-------------------------------------------') 266 | env.reset() 267 | env.prepare() 268 | done = False 269 | while not done: 270 | for role in ['lord', 'down', 'up']: 271 | if ai[role]: 272 | action = ai[role].greedy_action(env.face, env.valid_actions()) 273 | _, done, _ = env.step_manual(action) 274 | else: 275 | _, done, _ = env.step_auto() 276 | if done: # 地主结束本局,地主赢 277 | wins[role] += 1 278 | total_wins[role] += 1 279 | break 280 | 281 | if episode % print_every == 0: 282 | end_time = time.time() 283 | message = ('Reach at {}, Last {} rounds takes {:.2f}seconds\n' 284 | '\tUp recent/total win rate: {:.2%}/{:.2%}\n' 285 | '\tLord recent/total win rate: {:.2%}/{:.2%}\n' 286 | '\tDown recent/total win rate: {:.2%}/{:.2%}\n') 287 | args = (episode, print_every, end_time - start_time, 288 | wins['up'] / print_every, total_wins['up'] / episode, 289 | wins['lord'] / print_every, total_wins['lord'] / episode, 290 | wins['down'] / print_every, total_wins['down'] / episode) 291 | print(message.format(*args)) 292 | wins = collections.Counter() 293 | start_time = time.time() 294 | return total_wins 295 | 296 | @classmethod 297 | def ensemble_compete(cls, env_cls, nets_dict, dqns_dict, model_dict, total=1000, 298 | print_every=100, debug=True): 299 | import collections 300 | assert not (nets_dict.keys() ^ dqns_dict.keys()), 'Net and DQN must match' 301 | assert not (nets_dict.keys() ^ model_dict.keys()), 'Net and Model must match' 302 | wins = collections.Counter() 303 | 304 | total_wins = collections.Counter() 305 | ai = {'up': None, 'lord': None, 'down': None} 306 | for role in ['up', 'lord', 'down']: 307 | if nets_dict.get(role) is not None: 308 | print('AI based {}.'.format(role)) 309 | ai[role] = dqns_dict[role](nets_dict[role]) 310 | ai[role].policy_net.load(model_dict[role]) 311 | else: 312 | print('Rule based {}.'.format(role)) 313 | 314 | env = env_cls(debug=debug) 315 | start_time = time.time() 316 | for episode in range(1, total + 1): 317 | if debug: 318 | print('\n-------------------------------------------') 319 | env.reset() 320 | env.prepare() 321 | 322 | last_taken = {0: [], 1: [], 2: []} 323 | history = {0: [], 1: [], 2: []} 324 | left = {0: 17, 1: 20, 2: 17} 325 | hand_cards = {0: [], 1: [], 2: []} 326 | role2id = {'lord': 1, 'down': 2, 'up': 0} 327 | done = False 328 | while not done: 329 | for role in ['lord', 'down', 'up']: 330 | role_id = role2id[role] 331 | cur_cards = [int(i) for i in env.get_curr_handcards()] 332 | hand_cards[role_id] = cur_cards 333 | if role == 'lord': # ensemble 334 | payload = { 335 | 'role_id': 1, # 0代表地主上家,1代表地主,2代表地主下家 336 | 'last_taken': last_taken, 337 | 'cur_cards': cur_cards, # 无需保持顺序 338 | 'history': history, 339 | 'left': left, 340 | 'hand_cards': hand_cards, 341 | 'debug': True, # 是否返回debug 342 | } 343 | # print(payload) 344 | while True: 345 | try: 346 | res = requests.post('http://117.78.4.26:5000', json=payload) 347 | res = json.loads(res.content) 348 | except json.decoder.JSONDecodeError: 349 | print(payload) 350 | print('Server break, retry 3s later') 351 | exit(0) 352 | time.sleep(3) 353 | else: 354 | break 355 | msg = res['msg'] 356 | # print(msg) 357 | # input() 358 | action = res['data'] 359 | arr = cls.predictor.mock_env.cards2arr(action) 360 | onehot = cls.predictor.mock_env.batch_arr2onehot([arr])[0] 361 | _, done, _ = env.step_manual(onehot) 362 | elif ai[role]: 363 | action = ai[role].greedy_action(env.face, env.valid_actions()) 364 | _, done, _ = env.step_manual(action) 365 | else: 366 | action, done, _ = env.step_auto() 367 | action = [int(i) for i in action] 368 | # print(action) 369 | last_taken[role_id] = action 370 | history[role_id].extend(action) 371 | left[role_id] -= len(action) 372 | hand_cards[role_id] = list(set(hand_cards[role_id]) - set(action)) 373 | if not done and not hand_cards[role_id]: 374 | print(set(hand_cards[role_id]), set(action)) 375 | if done: # 地主结束本局,地主赢 376 | wins[role] += 1 377 | total_wins[role] += 1 378 | break 379 | 380 | if episode % print_every == 0: 381 | end_time = time.time() 382 | message = ('Reach at {}, Last {} rounds takes {:.2f}seconds\n' 383 | '\tUp recent/total win rate: {:.2%}/{:.2%}\n' 384 | '\tLord recent/total win rate: {:.2%}/{:.2%}\n' 385 | '\tDown recent/total win rate: {:.2%}/{:.2%}\n') 386 | args = (episode, print_every, end_time - start_time, 387 | wins['up'] / print_every, total_wins['up'] / episode, 388 | wins['lord'] / print_every, total_wins['lord'] / episode, 389 | wins['down'] / print_every, total_wins['down'] / episode) 390 | print(message.format(*args)) 391 | wins = collections.Counter() 392 | start_time = time.time() 393 | return total_wins 394 | -------------------------------------------------------------------------------- /envi.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import config as conf 3 | import torch 4 | import random 5 | import collections 6 | import numpy as np 7 | from config import DEVICE 8 | from collections import Counter 9 | 10 | sys.path.insert(0, conf.ENV_DIR) 11 | import r 12 | import env 13 | from env import Env as CEnv 14 | 15 | 16 | class Env(CEnv): 17 | def __init__(self, debug=False, seed=None): 18 | if seed: 19 | super(Env, self).__init__(seed=seed) 20 | else: 21 | super(Env, self).__init__() 22 | self.taken = np.zeros((15,)) 23 | self.left = np.array([17, 20, 17], dtype=np.int) 24 | # 0表示上家,1表示地主,2表示下家 25 | self.history = collections.defaultdict(lambda: np.zeros((15,))) 26 | self.recent_handout = collections.defaultdict(lambda: np.zeros((15,))) 27 | self.old_cards = dict() 28 | self.debug = debug 29 | 30 | def reset(self): 31 | super(Env, self).reset() 32 | self.taken = np.zeros((15,)) 33 | self.left = np.array([17, 20, 17]) 34 | self.history = collections.defaultdict(lambda: np.zeros((15,))) 35 | self.recent_handout = collections.defaultdict(lambda: np.zeros((15,))) 36 | self.old_cards = dict() 37 | 38 | def _update(self, role, cards): 39 | self.left[role] -= len(cards) 40 | for card, count in Counter(cards - 3).items(): 41 | self.taken[card] += count 42 | self.history[role][card] += count 43 | self.recent_handout[role] = self.cards2arr(cards) 44 | if self.debug: 45 | char = '$' 46 | handcards = self.cards2str(self.old_cards[role]) 47 | if role == 1: 48 | char = '#' 49 | name = '地主' 50 | print('\n# 地主手牌: {}'.format(handcards), end='') 51 | input() 52 | elif role == 0: 53 | name = '上家' 54 | print('\n$ 上家手牌: {}'.format(handcards), end='') 55 | input() 56 | else: 57 | name = '下家' 58 | print('\n$ 下家手牌: {}'.format(handcards), end='') 59 | input() 60 | print('{} {}出牌: {},分别剩余: {}'.format( 61 | char, name, self.cards2str(cards), self.left)) 62 | 63 | def step_manual(self, onehot_cards): 64 | role = self.get_role_ID() - 1 65 | self.old_cards[role] = self.get_curr_handcards() 66 | arr_cards = self.onehot2arr(onehot_cards) 67 | cards = self.arr2cards(arr_cards) 68 | 69 | self._update(role, cards) 70 | return super(Env, self).step_manual(cards) 71 | 72 | def step_auto(self): 73 | role = self.get_role_ID() - 1 74 | self.old_cards[role] = self.get_curr_handcards() 75 | cards, r, _ = super(Env, self).step_auto() 76 | self._update(role, cards) 77 | return cards, r, _ 78 | 79 | def step_random(self): 80 | role = self.get_role_ID() - 1 81 | self.old_cards[role] = self.get_curr_handcards() 82 | actions = self.valid_actions(tensor=False) 83 | cards = self.arr2cards(random.choice(actions)) 84 | self._update(role, cards) 85 | return super(Env, self).step_manual(cards) 86 | 87 | @property 88 | def face(self): 89 | """ 90 | :return: 4 * 15 * 4 的数组,作为当前状态 91 | """ 92 | handcards = self.cards2arr(self.get_curr_handcards()) 93 | known = self.batch_arr2onehot([handcards, self.taken]) 94 | prob = self.get_state_prob().reshape(2, 15, 4) 95 | face = np.concatenate((known, prob)) 96 | return torch.tensor(face, dtype=torch.float).to(DEVICE) 97 | 98 | def valid_actions(self, tensor=True): 99 | """ 100 | :return: batch_size * 15 * 4 的可行动作集合 101 | """ 102 | handcards = self.cards2arr(self.get_curr_handcards()) 103 | last_two = self.get_last_two_cards() 104 | if last_two[0]: 105 | last = last_two[0] 106 | elif last_two[1]: 107 | last = last_two[1] 108 | else: 109 | last = [] 110 | last = self.cards2arr(last) 111 | actions = r.get_moves(handcards, last) 112 | if tensor: 113 | return torch.tensor(self.batch_arr2onehot(actions), 114 | dtype=torch.float).to(DEVICE) 115 | else: 116 | return actions 117 | 118 | @classmethod 119 | def arr2cards(cls, arr): 120 | """ 121 | :param arr: 15 * 4 122 | :return: ['A','A','A', '3', '3'] 用 [3,3,14,14,14]表示 123 | [3,4,5,6,7,8,9,10, J, Q, K, A, 2,BJ,CJ] 124 | [3,4,5,6,7,8,9,10,11,12,13,14,15,16,17] 125 | """ 126 | res = [] 127 | for idx in range(15): 128 | for _ in range(arr[idx]): 129 | res.append(idx + 3) 130 | return np.array(res, dtype=np.int) 131 | 132 | @classmethod 133 | def cards2arr(cls, cards): 134 | arr = np.zeros((15,), dtype=np.int) 135 | for card in cards: 136 | arr[card - 3] += 1 137 | return arr 138 | 139 | @classmethod 140 | def batch_arr2onehot(cls, batch_arr): 141 | res = np.zeros((len(batch_arr), 15, 4), dtype=np.int) 142 | for idx, arr in enumerate(batch_arr): 143 | for card_idx, count in enumerate(arr): 144 | if count > 0: 145 | res[idx][card_idx][:int(count)] = 1 146 | return res 147 | 148 | @classmethod 149 | def onehot2arr(cls, onehot_cards): 150 | """ 151 | :param onehot_cards: 15 * 4 152 | :return: (15,) 153 | """ 154 | res = np.zeros((15,), dtype=np.int) 155 | for idx, onehot in enumerate(onehot_cards): 156 | res[idx] = sum(onehot) 157 | return res 158 | 159 | def cards2str(self, cards): 160 | res = [conf.DICT[i] for i in cards] 161 | return res 162 | 163 | 164 | class EnvComplicated(Env): 165 | @property 166 | def face(self): 167 | """ 168 | :return: 7 * 15 * 4 的数组,作为当前状态 169 | """ 170 | handcards = self.cards2arr(self.get_curr_handcards()) 171 | role = self.get_role_ID() - 1 172 | h0 = self.history[(role - 1 + 3) % 3] 173 | h1 = self.history[(role + 0 + 3) % 3] 174 | h2 = self.history[(role + 1 + 3) % 3] 175 | known = self.batch_arr2onehot([handcards, self.taken, h0, h1, h2]) 176 | prob = self.get_state_prob().reshape(2, 15, 4) 177 | face = np.concatenate((known, prob)) 178 | return torch.tensor(face, dtype=torch.float).to(DEVICE) 179 | 180 | 181 | class EnvCooperation(Env): 182 | @property 183 | def face(self): 184 | """ 185 | :return: 9 * 15 * 4 的数组,作为当前状态 186 | """ 187 | handcards = self.cards2arr(self.get_curr_handcards()) 188 | role = self.get_role_ID() - 1 189 | h0 = self.history[(role - 1 + 3) % 3] 190 | h1 = self.history[(role + 0 + 3) % 3] 191 | h2 = self.history[(role + 1 + 3) % 3] 192 | b1 = self.recent_handout[(role - 1 + 3) % 3] 193 | b2 = self.recent_handout[(role - 2 + 3) % 3] 194 | known = self.batch_arr2onehot([handcards, self.taken, 195 | h0, h1, h2, b1, b2]) 196 | prob = self.get_state_prob().reshape(2, 15, 4) 197 | face = np.concatenate((known, prob)) 198 | return torch.tensor(face, dtype=torch.float).to(DEVICE) 199 | 200 | 201 | class EnvCooperationSimplify(Env): 202 | @property 203 | def face(self): 204 | """ 205 | :return: 6 * 15 * 4 的数组,作为当前状态 206 | """ 207 | handcards = self.cards2arr(self.get_curr_handcards()) 208 | role = self.get_role_ID() - 1 209 | h0 = self.history[(role - 1 + 3) % 3] 210 | h1 = self.history[(role + 0 + 3) % 3] 211 | h2 = self.history[(role + 1 + 3) % 3] 212 | b1 = self.recent_handout[(role - 1 + 3) % 3] 213 | b2 = self.recent_handout[(role - 2 + 3) % 3] 214 | known = self.batch_arr2onehot([handcards, self.taken, b1, b2]) 215 | prob = self.get_state_prob().reshape(2, 15, 4) 216 | face = np.concatenate((known, prob)) 217 | return torch.tensor(face, dtype=torch.float).to(DEVICE) 218 | -------------------------------------------------------------------------------- /game.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import json 4 | import config as conf 5 | import torch 6 | 7 | BEGIN, logger, LOG_PATH = conf.get_logger() 8 | 9 | 10 | class Game: 11 | def __init__(self, env_cls, nets_dict, dqns_dict, reward_dict=None, 12 | train_dict=None, preload=None, seed=None, debug=False): 13 | if reward_dict is None: 14 | reward_dict = {'lord': 100, 'down': 50, 'up': 50} 15 | if train_dict is None: 16 | train_dict = {'lord': True, 'down': True, 'up': True} 17 | if preload is None: 18 | preload = {} 19 | assert not (nets_dict.keys() ^ dqns_dict.keys()), 'Net and DQN must match' 20 | 21 | self.lord_wins, self.down_wins, self.up_wins = [], [], [] 22 | self.lord_total_loss = self.down_total_loss = self.up_total_loss = 0 23 | self.lord_loss_count = self.down_loss_count = self.up_loss_count = 0 24 | self.up_total_wins = self.lord_total_wins = self.down_total_wins = 0 25 | self.up_recent_wins = self.lord_recent_wins = self.down_recent_wins = 0 26 | self.lord_max_wins = self.farmer_max_wins = 0 27 | 28 | self.env = env_cls(debug=debug, seed=seed) 29 | self.lord = self.down = self.up = None 30 | self.lord_train = self.down_train = self.up_train = False 31 | for role in ['lord', 'down', 'up']: 32 | if nets_dict.get(role): 33 | setattr(self, role, dqns_dict[role](nets_dict[role])) 34 | setattr(self, '{}_train'.format(role), train_dict[role]) 35 | if preload.get(role): 36 | getattr(self, role).target_net.load(preload.get(role)) 37 | getattr(self, role).policy_net.load(preload.get(role)) 38 | 39 | self.lord_s0 = self.down_s0 = self.up_s0 = None 40 | self.lord_a0 = self.down_a0 = self.up_a0 = None 41 | self.reward_dict = reward_dict 42 | self.preload = preload 43 | self.train_dict = train_dict 44 | 45 | def accumulate_loss(self, name, loss): 46 | assert name in {'up', 'down', 'lord'} 47 | if loss: 48 | if name == 'lord': 49 | self.lord_loss_count += 1 50 | self.lord_total_loss += loss 51 | elif name == 'down': 52 | self.down_loss_count += 1 53 | self.down_total_loss += loss 54 | else: 55 | self.up_loss_count += 1 56 | self.up_total_loss += loss 57 | 58 | def save_win_rates(self, episode): 59 | self.lord_wins.append(self.lord_recent_wins) 60 | self.up_wins.append(self.up_recent_wins) 61 | self.down_wins.append(self.down_recent_wins) 62 | # 是否高于最高胜率 63 | if self.lord and self.up is None and self.down is None: 64 | if self.lord_recent_wins > self.lord_max_wins: 65 | self.lord_max_wins = self.lord_recent_wins 66 | self.lord.policy_net.save( 67 | '{}_lord_{}_{}'.format(BEGIN, episode, self.lord_max_wins)) 68 | if self.lord and not self.lord_train: 69 | if self.up_recent_wins + self.down_recent_wins > self.farmer_max_wins: 70 | self.farmer_max_wins = self.up_recent_wins + self.down_recent_wins 71 | self.up.policy_net.save( 72 | '{}_up_{}_{}'.format(BEGIN, episode, self.farmer_max_wins)) 73 | self.down.policy_net.save( 74 | '{}_down_{}_{}'.format(BEGIN, episode, self.farmer_max_wins)) 75 | # 存一次胜率目录 76 | data = {'lord': self.lord_wins, 'down': self.down_wins, 'up': self.up_wins} 77 | path = os.path.join(conf.WIN_DIR, conf.name_dir(BEGIN)) 78 | dirname = os.path.dirname(path) 79 | if not os.path.exists(dirname): 80 | os.makedirs(dirname) 81 | path = '{}.json'.format(path) 82 | with open(path, 'w') as f: 83 | json.dump(data, f) 84 | 85 | def reset_recent(self): 86 | self.lord_recent_wins = self.up_recent_wins = self.down_recent_wins = 0 87 | self.lord_total_loss = self.down_total_loss = self.up_total_loss = 0 88 | self.lord_loss_count = self.down_loss_count = self.up_loss_count = 0 89 | 90 | def step(self, ai): 91 | assert ai in {'lord', 'down', 'up'} 92 | agent = getattr(self, ai) 93 | continue_train = getattr(self, '{}_train'.format(ai)) 94 | if agent: # 不是使用规则 95 | s0 = self.env.face 96 | if continue_train: # 需要继续训练 97 | setattr(self, '{}_s0'.format(ai), s0) # 更新状态s0 98 | action_f = agent.e_greedy_action 99 | else: 100 | action_f = agent.greedy_action 101 | a0 = action_f(s0, self.env.valid_actions()) 102 | if continue_train: 103 | setattr(self, '{}_a0'.format(ai), a0) # 更新动作a0 104 | _, done, _ = self.env.step_manual(a0) 105 | else: 106 | _, done, _ = self.env.step_auto() 107 | return done 108 | 109 | def feedback(self, ai, done, punish=False): 110 | assert ai in {'lord', 'up', 'down'} 111 | agent = getattr(self, ai) 112 | if agent and getattr(self, '{}_train'.format(ai)): # 是需要继续训练的模型 113 | if done: 114 | reward = self.reward_dict[ai] 115 | if punish: 116 | reward = -reward 117 | else: 118 | reward = 0 119 | s0 = getattr(self, '{}_s0'.format(ai)) 120 | a0 = getattr(self, '{}_a0'.format(ai)) 121 | s1 = self.env.face 122 | if done: 123 | a1 = torch.zeros((15, 4), dtype=torch.float).to(conf.DEVICE) 124 | else: 125 | a1 = agent.greedy_action(s1, self.env.valid_actions()) 126 | loss = agent.perceive(s0, a0, reward, s1, a1, done) 127 | self.accumulate_loss(ai, loss) 128 | 129 | def lord_turn(self): 130 | done = self.step('lord') 131 | if not done: # 本局未结束 132 | if self.down_a0 is not None: # 如果下家曾经出过牌 133 | self.feedback('down', done) 134 | else: # 本局结束,地主胜利 135 | if self.down_a0 is not None: # 如果下家曾经出过牌(不是一次性走完) 136 | self.feedback('down', done, punish=True) # 下家负反馈 137 | self.feedback('up', done, punish=True) # 上家负反馈 138 | # 自己得到正反馈 139 | self.feedback('lord', done) 140 | self.lord_total_wins += 1 141 | self.lord_recent_wins += 1 142 | return done 143 | 144 | def down_turn(self): 145 | done = self.step('down') 146 | if not done: # 本局未结束 147 | if self.up_a0 is not None: 148 | self.feedback('up', done) 149 | else: # 本局结束,农民胜利 150 | self.feedback('up', done) 151 | self.feedback('lord', done, punish=True) 152 | self.feedback('down', done) 153 | self.down_recent_wins += 1 154 | self.down_total_wins += 1 155 | return done 156 | 157 | def up_turn(self): 158 | done = self.step('up') 159 | if not done: # 本局未结束,地主得到0反馈 160 | self.feedback('lord', done) 161 | else: # 本局结束,农民胜利 162 | self.feedback('lord', done, punish=True) # 地主得到负反馈 163 | self.feedback('down', done) # 下家得到正反馈 164 | self.feedback('up', done) # 自己得到正反馈 165 | self.up_total_wins += 1 166 | self.up_recent_wins += 1 167 | return done 168 | 169 | def play(self): 170 | self.env.reset() 171 | self.env.prepare() 172 | while True: # 173 | done = self.lord_turn() 174 | if done: 175 | break 176 | done = self.down_turn() 177 | if done: 178 | break 179 | done = self.up_turn() 180 | if done: 181 | break 182 | 183 | def train(self, episodes, log_every=100, model_every=1000): 184 | if not ((self.lord and self.lord_train) 185 | or (self.up and self.up_train) 186 | or (self.down and self.down_train)): 187 | print('No agent need train.') 188 | return 189 | print('Logged at {}'.format(LOG_PATH)) 190 | messages = '' 191 | for role in ['up', 'lord', 'down']: 192 | m = '{}: {} based model.'.format( 193 | role, 'AI' if getattr(self, role) else 'Rule') 194 | if getattr(self, role): 195 | preload = self.preload.get(role) 196 | if preload: 197 | m += ' With pretrained model {}.'.format(preload) 198 | else: 199 | m += ' Without pretrained model.' 200 | if self.train_dict.get(role): 201 | m += ' Continue training.' 202 | messages += '\n{}'.format(m) 203 | logger.info(messages + '\n------------------------------------') 204 | print(messages) 205 | start_time = time.time() 206 | for episode in range(1, episodes + 1): 207 | self.play() 208 | 209 | if episode % log_every == 0: 210 | end_time = time.time() 211 | message = ( 212 | 'Reach at round {}, recent {} rounds takes {:.2f}seconds\n' 213 | '\tUp recent/total win: {:.2%}/{:.2%} [Mean loss: {:.2f}]\n' 214 | '\tLord recent/total win: {:.2%}/{:.2%} [Mean loss: {:.2f}]\n' 215 | '\tDown recent/total win: {:.2%}/{:.2%} [Mean loss: {:.2f}]\n' 216 | ).format(episode, log_every, end_time - start_time, 217 | self.up_recent_wins / log_every, self.up_total_wins / episode, 218 | self.up_total_loss / (self.up_loss_count + 1e-3), 219 | self.lord_recent_wins / log_every, self.lord_total_wins / episode, 220 | self.lord_total_loss / (self.lord_loss_count + 1e-3), 221 | self.down_recent_wins / log_every, self.down_total_wins / episode, 222 | self.down_total_loss / (self.down_loss_count + 1e-3)) 223 | logger.info(message) 224 | self.save_win_rates(episode) 225 | self.reset_recent() 226 | start_time = time.time() 227 | if episode % model_every == 0: 228 | for role in ['lord', 'down', 'up']: 229 | ai = getattr(self, role) 230 | if ai: 231 | ai.policy_net.save( 232 | '{}_{}_{}'.format(BEGIN, role, episode)) 233 | 234 | for role in ['lord', 'down', 'up']: 235 | ai = getattr(self, role) 236 | if ai: 237 | ai.update_epsilon(episode) 238 | ai.update_target(episode) 239 | 240 | @staticmethod 241 | def compete(env_cls, nets_dict, dqns_dict, model_dict, total=1000, 242 | print_every=100, debug=True): 243 | import collections 244 | assert not (nets_dict.keys() ^ dqns_dict.keys()), 'Net and DQN must match' 245 | assert not (nets_dict.keys() ^ model_dict.keys()), 'Net and Model must match' 246 | wins = collections.Counter() 247 | total_wins = collections.Counter() 248 | ai = {'up': None, 'lord': None, 'down': None} 249 | for role in ['up', 'lord', 'down']: 250 | if nets_dict.get(role) is not None: 251 | print('AI based {}.'.format(role)) 252 | ai[role] = dqns_dict[role](nets_dict[role]) 253 | ai[role].policy_net.load(model_dict[role]) 254 | else: 255 | print('Rule based {}.'.format(role)) 256 | 257 | env = env_cls(debug=debug) 258 | start_time = time.time() 259 | for episode in range(1, total + 1): 260 | if debug: 261 | print('\n-------------------------------------------') 262 | env.reset() 263 | env.prepare() 264 | done = False 265 | while not done: 266 | for role in ['lord', 'down', 'up']: 267 | if ai[role]: 268 | action = ai[role].greedy_action(env.face, env.valid_actions()) 269 | _, done, _ = env.step_manual(action) 270 | else: 271 | _, done, _ = env.step_auto() 272 | if done: # 地主结束本局,地主赢 273 | wins[role] += 1 274 | total_wins[role] += 1 275 | break 276 | 277 | if episode % print_every == 0: 278 | end_time = time.time() 279 | message = ('Reach at {}, Last {} rounds takes {:.2f}seconds\n' 280 | '\tUp recent/total win rate: {:.2%}/{:.2%}\n' 281 | '\tLord recent/total win rate: {:.2%}/{:.2%}\n' 282 | '\tDown recent/total win rate: {:.2%}/{:.2%}\n') 283 | args = (episode, print_every, end_time - start_time, 284 | wins['up'] / print_every, total_wins['up'] / episode, 285 | wins['lord'] / print_every, total_wins['lord'] / episode, 286 | wins['down'] / print_every, total_wins['down'] / episode) 287 | print(message.format(*args)) 288 | wins = collections.Counter() 289 | start_time = time.time() 290 | return total_wins 291 | -------------------------------------------------------------------------------- /models/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/charleschen003/doudizhu-rl/56993b04d227e4718969209ab542142d406d3241/models/.gitkeep -------------------------------------------------------------------------------- /net.py: -------------------------------------------------------------------------------- 1 | import os 2 | from abc import ABC 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import config as conf 8 | 9 | 10 | class Net(nn.Module, ABC): 11 | def save(self, name, max_split=2): 12 | path = os.path.join(conf.MODEL_DIR, conf.name_dir(name, max_split)) 13 | dirname = os.path.dirname(path) 14 | if not os.path.exists(dirname): 15 | os.makedirs(dirname) 16 | path = '{}.pt'.format(path) 17 | torch.save(self.state_dict(), path) 18 | 19 | def load(self, name=None, abspath=None, max_split=2): 20 | if abspath: 21 | path = abspath 22 | else: 23 | path = os.path.join(conf.MODEL_DIR, conf.name_dir(name, max_split)) 24 | path = '{}.pt'.format(path) 25 | map_location = 'cpu' if conf.DEVICE.type == 'cpu' else None 26 | static_dict = torch.load(path, map_location) 27 | self.load_state_dict(static_dict) 28 | self.eval() 29 | print("Loaded model from {}.".format(path)) 30 | 31 | 32 | class NetFirst(Net): 33 | def __init__(self): 34 | # input shape: 5 * 15 * 4 35 | super(Net, self).__init__() 36 | self.conv1 = nn.Conv2d(5, 256, (1, 1), (1, 4)) # 256 * 15 * 1 37 | self.conv2 = nn.Conv2d(5, 256, (1, 2), (1, 4)) 38 | self.conv3 = nn.Conv2d(5, 256, (1, 3), (1, 4)) 39 | self.conv4 = nn.Conv2d(5, 256, (1, 4), (1, 4)) 40 | self.convs = (self.conv1, self.conv2, self.conv3, self.conv4) # 256 * 15 * 4 41 | self.pool = nn.MaxPool2d((1, 4)) # 256 * 15 * 1 42 | self.drop = nn.Dropout(0.5) 43 | self.fc1 = nn.Linear(256 * 15, 256) 44 | self.fc2 = nn.Linear(256, 1) 45 | 46 | def forward(self, face, actions): 47 | """ 48 | :param face: 当前状态 4 * 15 * 4 49 | :param actions: 所有动作 batch_size * 15 * 4 50 | :return: 51 | """ 52 | if face.dim() == 3: 53 | face = face.unsqueeze(0).repeat((actions.shape[0], 1, 1, 1)) 54 | actions = actions.unsqueeze(1) 55 | state_action = torch.cat((face, actions), dim=1) 56 | 57 | x = torch.cat([f(state_action) for f in self.convs], -1) 58 | x = self.pool(x) 59 | x = x.view(actions.shape[0], -1) 60 | x = self.drop(x) 61 | x = F.relu(self.fc1(x)) 62 | x = self.fc2(x) 63 | return x 64 | 65 | 66 | class NetComplicated(Net): 67 | def __init__(self): 68 | # input shape: 5 * 15 * 4 69 | super(Net, self).__init__() 70 | self.conv1 = nn.Conv2d(5, 256, (1, 1), (1, 4)) # 256 * 15 * 1 71 | self.conv2 = nn.Conv2d(5, 256, (1, 2), (1, 4)) 72 | self.conv3 = nn.Conv2d(5, 256, (1, 3), (1, 4)) 73 | self.conv4 = nn.Conv2d(5, 256, (1, 4), (1, 4)) 74 | self.convs = (self.conv1, self.conv2, self.conv3, self.conv4) # 256 * 15 * 4 75 | self.conv_shunzi = nn.Conv2d(5, 256, (15, 1), 1) # 256 * 1 * 4 76 | self.pool = nn.MaxPool2d((1, 4)) # 256 * 15 * 1 77 | self.drop = nn.Dropout(0.5) 78 | self.fc1 = nn.Linear(256 * (15 + 4), 256) 79 | self.fc2 = nn.Linear(256, 1) 80 | 81 | def forward(self, face, actions): 82 | """ 83 | :param face: 当前状态 face_deep(根据env固定) * 15 * 4 84 | :param actions: 所有动作 batch_size * 15 * 4 85 | :return: 86 | """ 87 | if face.dim() == 3: 88 | face = face.unsqueeze(0).repeat((actions.shape[0], 1, 1, 1)) 89 | actions = actions.unsqueeze(1) 90 | state_action = torch.cat((face, actions), dim=1) 91 | 92 | x = torch.cat([f(state_action) for f in self.convs], -1) 93 | x = self.pool(x) 94 | x = x.view(actions.shape[0], -1) 95 | 96 | x_shunzi = self.conv_shunzi(state_action).view(actions.shape[0], -1) 97 | x = torch.cat([x, x_shunzi], -1) 98 | 99 | x = self.drop(x) 100 | x = F.relu(self.fc1(x)) 101 | x = self.fc2(x) 102 | return x 103 | 104 | 105 | class NetMoreComplicated(NetComplicated): 106 | def __init__(self): 107 | # input shape: 8 * 15 * 4 108 | super(Net, self).__init__() 109 | self.conv1 = nn.Conv2d(8, 256, (1, 1), (1, 4)) # 256 * 15 * 1 110 | self.conv2 = nn.Conv2d(8, 256, (1, 2), (1, 4)) 111 | self.conv3 = nn.Conv2d(8, 256, (1, 3), (1, 4)) 112 | self.conv4 = nn.Conv2d(8, 256, (1, 4), (1, 4)) 113 | self.convs = (self.conv1, self.conv2, self.conv3, self.conv4) # 256 * 15 * 4 114 | self.conv_shunzi = nn.Conv2d(8, 256, (15, 1), 1) # 256 * 1 * 4 115 | self.pool = nn.MaxPool2d((1, 4)) # 256 * 15 * 1 116 | self.drop = nn.Dropout(0.5) 117 | self.fc1 = nn.Linear(256 * (15 + 4), 256) 118 | self.fc2 = nn.Linear(256, 1) 119 | 120 | 121 | class NetCooperation(NetComplicated): 122 | def __init__(self): 123 | # input shape: 10 * 15 * 4 124 | super(Net, self).__init__() 125 | self.conv1 = nn.Conv2d(10, 256, (1, 1), (1, 4)) # 256 * 15 * 1 126 | self.conv2 = nn.Conv2d(10, 256, (1, 2), (1, 4)) 127 | self.conv3 = nn.Conv2d(10, 256, (1, 3), (1, 4)) 128 | self.conv4 = nn.Conv2d(10, 256, (1, 4), (1, 4)) 129 | self.convs = (self.conv1, self.conv2, self.conv3, self.conv4) # 256 * 15 * 4 130 | self.conv_shunzi = nn.Conv2d(10, 256, (15, 1), 1) # 256 * 1 * 4 131 | self.pool = nn.MaxPool2d((1, 4)) # 256 * 15 * 1 132 | self.drop = nn.Dropout(0.5) 133 | self.fc1 = nn.Linear(256 * (15 + 4), 256) 134 | self.fc2 = nn.Linear(256, 1) 135 | 136 | 137 | class NetCooperationSimplify(NetComplicated): 138 | def __init__(self): 139 | # input shape: 7 * 15 * 4 140 | super(Net, self).__init__() 141 | self.conv1 = nn.Conv2d(7, 256, (1, 1), (1, 4)) # 256 * 15 * 1 142 | self.conv2 = nn.Conv2d(7, 256, (1, 2), (1, 4)) 143 | self.conv3 = nn.Conv2d(7, 256, (1, 3), (1, 4)) 144 | self.conv4 = nn.Conv2d(7, 256, (1, 4), (1, 4)) 145 | self.convs = (self.conv1, self.conv2, self.conv3, self.conv4) # 256 * 15 * 4 146 | self.conv_shunzi = nn.Conv2d(7, 256, (15, 1), 1) # 256 * 1 * 4 147 | self.pool = nn.MaxPool2d((1, 4)) # 256 * 15 * 1 148 | self.drop = nn.Dropout(0.5) 149 | self.fc1 = nn.Linear(256 * (15 + 4), 256) 150 | self.fc2 = nn.Linear(256, 1) 151 | 152 | 153 | class NetFinal(Net): 154 | # input shape: 7 * 15 * 4 155 | def __init__(self): 156 | super(NetFinal, self).__init__() 157 | # 深1卷积 158 | for l in [1, 5, 6, 7, 8, 9, 10, 11, 12]: 159 | setattr(self, 'conv1_{}'.format(l), nn.Conv2d(7, 64, (l, 1), (1, 4))) 160 | # 深2卷积 161 | for l in [1, 3, 4, 5, 6, 7, 8, 9, 10]: 162 | setattr(self, 'conv2_{}'.format(l), nn.Conv2d(7, 64, (l, 2), (1, 4))) 163 | # 深3卷积 164 | for l in [1, 2, 3, 4, 5, 6]: 165 | setattr(self, 'conv3_{}'.format(l), nn.Conv2d(7, 64, (l, 3), (1, 4))) 166 | # 深4卷积 167 | for l in [1, 2, 3, 4, 5]: 168 | setattr(self, 'conv4_{}'.format(l), nn.Conv2d(7, 64, (l, 4), (1, 4))) 169 | 170 | def forward(self, face, actions): 171 | """ 172 | :param face: 当前状态 face_deep(根据env固定) * 15 * 4 173 | :param actions: 所有动作 batch_size * 15 * 4 174 | :return: 175 | """ 176 | if face.dim() == 3: 177 | face = face.unsqueeze(0).repeat((actions.shape[0], 1, 1, 1)) 178 | actions = actions.unsqueeze(1) 179 | state_action = torch.cat((face, actions), dim=1) 180 | 181 | return 182 | -------------------------------------------------------------------------------- /outs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/charleschen003/doudizhu-rl/56993b04d227e4718969209ab542142d406d3241/outs/.gitkeep -------------------------------------------------------------------------------- /outs/plot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | import seaborn as sns 5 | import matplotlib.pyplot as plt 6 | from scipy.ndimage.filters import gaussian_filter1d 7 | import config as conf 8 | 9 | sns.set(color_codes=True) 10 | plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 11 | plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号 12 | plt.rcParams["figure.figsize"] = [9.6, 4.8] # 设置大小,默认为[6.4, 4.8] 13 | 14 | 15 | def plot(fn): 16 | path = os.path.join(conf.WIN_DIR, conf.name_dir(fn)) 17 | path = '{}.json'.format(path) 18 | with open(path) as f: 19 | data = json.load(f) 20 | title = '未知' 21 | if isinstance(data, list): 22 | title = '地主胜率走势' 23 | y = np.array(data) / 100 24 | plt.plot(y, alpha=0.3) 25 | sm = gaussian_filter1d(y, sigma=3) 26 | plt.plot(sm) 27 | elif isinstance(data, dict): 28 | title = '胜率走势' 29 | for i, (k, v) in enumerate(data.items()): 30 | y = np.array(v) / 100 31 | plt.plot(y, alpha=0.3,color='C{}'.format(i)) 32 | sm = gaussian_filter1d(y, sigma=5) 33 | plt.plot(sm, label=k, color='C{}'.format(i)) 34 | plt.title(title) 35 | plt.xlabel('训练总百次数') 36 | plt.ylabel('过去100次AI地主胜率') 37 | plt.legend() 38 | 39 | path = os.path.join(conf.IMG_DIR, conf.name_dir(fn)) 40 | dirname = os.path.dirname(path) 41 | if not os.path.exists(dirname): 42 | os.makedirs(dirname) 43 | path = '{}.svg'.format(path) 44 | plt.savefig(path, format='svg') 45 | print('Saved at {}'.format(path)) 46 | 47 | 48 | if __name__ == '__main__': 49 | plot('0808_0918') 50 | -------------------------------------------------------------------------------- /precompiled/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/charleschen003/doudizhu-rl/56993b04d227e4718969209ab542142d406d3241/precompiled/.gitkeep -------------------------------------------------------------------------------- /rule_based/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/charleschen003/doudizhu-rl/56993b04d227e4718969209ab542142d406d3241/rule_based/__init__.py -------------------------------------------------------------------------------- /rule_based/rule_play.py: -------------------------------------------------------------------------------- 1 | from envi import Env 2 | from rule_based.utils.rule_based_model import RuleBasedModel 3 | 4 | 5 | def rule_play(): 6 | env = Env() 7 | rule = RuleBasedModel() 8 | total_lord_win, total_farmer_win = 0, 0 9 | for episode in range(1, 3000 + 1): 10 | # print(episode) 11 | env.reset() 12 | env.prepare() 13 | r = 0 14 | while r == 0: # r == -1 地主赢, r == 1,农民赢 15 | # lord first 16 | r, _, _ = env.step_manual(rule.choose(env)) 17 | if r == -1: # 地主赢 18 | total_lord_win += 1 19 | else: 20 | h = env.get_curr_handcards() 21 | a, r, _ = env.step_auto() # 下家 22 | print("Auto1", h, "//", a) 23 | if r == 0: 24 | h = env.get_curr_handcards() 25 | a, r, _ = env.step_auto() # 上家 26 | print("Auto2", h, "//", a) 27 | if r == 1: # 地主输 28 | total_farmer_win += 1 29 | print('\nLord win rate: {} / {} = {:.2%}\n\n' 30 | .format(total_lord_win, episode, total_lord_win / episode)) 31 | 32 | 33 | if __name__ == '__main__': 34 | rule_play() 35 | -------------------------------------------------------------------------------- /rule_based/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/charleschen003/doudizhu-rl/56993b04d227e4718969209ab542142d406d3241/rule_based/utils/__init__.py -------------------------------------------------------------------------------- /rule_based/utils/card.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | from enum import Enum 3 | import numpy as np 4 | import itertools 5 | import functools 6 | import math 7 | 8 | 9 | # Category = Enum('Category', 'EMPTY SINGLE DOUBLE TRIPLE QUADRIC THREE_ONE THREE_TWO SINGLE_LINE DOUBLE_LINE \ 10 | # TRIPLE_LINE THREE_ONE_LINE THREE_TWO_LINE BIGBANG FOUR_TWO', start=0) 11 | 12 | 13 | class Category: 14 | EMPTY = 0 15 | SINGLE = 1 16 | DOUBLE = 2 17 | TRIPLE = 3 18 | QUADRIC = 4 19 | THREE_ONE = 5 20 | THREE_TWO = 6 21 | SINGLE_LINE = 7 22 | DOUBLE_LINE = 8 23 | TRIPLE_LINE = 9 24 | THREE_ONE_LINE = 10 25 | THREE_TWO_LINE = 11 26 | BIGBANG = 12 27 | FOUR_TAKE_ONE = 13 28 | FOUR_TAKE_TWO = 14 29 | 30 | 31 | Category2Range = [] 32 | 33 | 34 | def get_action_space(): 35 | actions = [[]] 36 | # actions = [] 37 | Category2Range.append([0, 1]) 38 | # max_cards = 20 39 | # single 40 | temp = len(actions) 41 | for card in Card.cards: # 15 42 | actions.append([card]) 43 | Category2Range.append([temp, len(actions)]) 44 | temp = len(actions) 45 | # print(len(actions)) 46 | # pair 47 | for card in Card.cards: # 13 48 | if card != '*' and card != '$': 49 | actions.append([card] * 2) 50 | # print(len(actions)) 51 | Category2Range.append([temp, len(actions)]) 52 | temp = len(actions) 53 | # triple 54 | for card in Card.cards: # 13 55 | if card != '*' and card != '$': 56 | actions.append([card] * 3) 57 | # print(len(actions)) 58 | Category2Range.append([temp, len(actions)]) 59 | temp = len(actions) 60 | # bomb 61 | for card in Card.cards: # 13 62 | if card != '*' and card != '$': 63 | actions.append([card] * 4) 64 | Category2Range.append([temp, len(actions)]) 65 | temp = len(actions) 66 | # print(len(actions)) 67 | # 3 + 1 68 | for main in Card.cards: 69 | if main != '*' and main != '$': 70 | for extra in Card.cards: 71 | if extra != main: 72 | actions.append([main] * 3 + [extra]) 73 | # print(len(actions)) 74 | Category2Range.append([temp, len(actions)]) 75 | temp = len(actions) 76 | # 3 + 2 77 | for main in Card.cards: 78 | if main != '*' and main != '$': 79 | for extra in Card.cards: 80 | if extra != main and extra != '*' and extra != '$': 81 | actions.append([main] * 3 + [extra] * 2) 82 | # print(len(actions)) 83 | Category2Range.append([temp, len(actions)]) 84 | temp = len(actions) 85 | # single sequence 86 | for start_v in range(Card.to_value('3'), Card.to_value('2')): 87 | for end_v in range(start_v + 5, Card.to_value('*')): 88 | seq = range(start_v, end_v) 89 | actions.append(sorted(Card.to_cards(seq), key=lambda c: Card.cards.index(c))) 90 | # print(len(actions)) 91 | Category2Range.append([temp, len(actions)]) 92 | temp = len(actions) 93 | # double sequence 94 | for start_v in range(Card.to_value('3'), Card.to_value('2')): 95 | for end_v in range(start_v + 3, int(min(start_v + 20 / 2 + 1, Card.to_value('*')))): 96 | seq = range(start_v, end_v) 97 | actions.append(sorted(Card.to_cards(seq) * 2, key=lambda c: Card.cards.index(c))) 98 | # print(len(actions)) 99 | Category2Range.append([temp, len(actions)]) 100 | temp = len(actions) 101 | # triple sequence 102 | for start_v in range(Card.to_value('3'), Card.to_value('2')): 103 | for end_v in range(start_v + 2, int(min(start_v + 20 // 3 + 1, Card.to_value('*')))): 104 | seq = range(start_v, end_v) 105 | actions.append(sorted(Card.to_cards(seq) * 3, key=lambda c: Card.cards.index(c))) 106 | # print(len(actions)) 107 | Category2Range.append([temp, len(actions)]) 108 | temp = len(actions) 109 | # 3 + 1 sequence 110 | for start_v in range(Card.to_value('3'), Card.to_value('2')): 111 | for end_v in range(start_v + 2, int(min(start_v + 20 / 4 + 1, Card.to_value('*')))): 112 | seq = range(start_v, end_v) 113 | main = Card.to_cards(seq) 114 | remains = [card for card in Card.cards if card not in main] 115 | for extra in list(itertools.combinations(remains, end_v - start_v)): 116 | if not ('*' in list(extra) and '$' in list(extra) and len(extra) == 2): 117 | actions.append(sorted(main * 3, key=lambda c: Card.cards.index(c)) + list(extra)) 118 | # print(len(actions)) 119 | Category2Range.append([temp, len(actions)]) 120 | temp = len(actions) 121 | # 3 + 2 sequence 122 | for start_v in range(Card.to_value('3'), Card.to_value('2')): 123 | for end_v in range(start_v + 2, int(min(start_v + 20 / 5 + 1, Card.to_value('*')))): 124 | seq = range(start_v, end_v) 125 | main = Card.to_cards(seq) 126 | remains = [card for card in Card.cards if card not in main and card not in ['*', '$']] 127 | for extra in list(itertools.combinations(remains, end_v - start_v)): 128 | actions.append(sorted(main * 3, key=lambda c: Card.cards.index(c)) + list(extra) * 2) 129 | # print(len(actions)) 130 | Category2Range.append([temp, len(actions)]) 131 | temp = len(actions) 132 | # bigbang 133 | actions.append(['*', '$']) 134 | # print(len(actions)) 135 | Category2Range.append([temp, len(actions)]) 136 | temp = len(actions) 137 | # 4 + 1 + 1 138 | for main in Card.cards: 139 | if main != '*' and main != '$': 140 | remains = [card for card in Card.cards if card != main] 141 | for extra in list(itertools.combinations(remains, 2)): 142 | if not ('*' in list(extra) and '$' in list(extra) and len(extra) == 2): 143 | actions.append([main] * 4 + list(extra)) 144 | # print(len(actions)) 145 | Category2Range.append([temp, len(actions)]) 146 | temp = len(actions) 147 | # 4 + 2 + 2 148 | for main in Card.cards: 149 | if main != '*' and main != '$': 150 | remains = [card for card in Card.cards if card != main and card != '*' and card != '$'] 151 | for extra in list(itertools.combinations(remains, 2)): 152 | actions.append([main] * 4 + list(extra) * 2) 153 | # print(len(actions)) 154 | Category2Range.append([temp, len(actions)]) 155 | temp = len(actions) 156 | # temp = len(actions) 157 | # for a in actions: 158 | # a.sort(key=lambda c: Card.cards.index(c)) 159 | return actions 160 | 161 | 162 | class Card: 163 | cards = ['3', '4', '5', '6', '7', '8', '9', '10', 'J', 'Q', 'K', 'A', '2', '*', '$'] 164 | np_cards = np.array(cards) 165 | # full_cards = [x for pair in zip(cards, cards, cards, cards) for x in pair if x not in ['*', '$']] 166 | # full_cards += ['*', '$'] 167 | cards_to_onehot_idx = dict((x, i * 4) for (i, x) in enumerate(cards)) 168 | cards_to_onehot_idx['*'] = 52 169 | cards_to_onehot_idx['$'] = 53 170 | cards_to_value = dict(zip(cards, range(len(cards)))) 171 | value_to_cards = dict((v, c) for (c, v) in cards_to_value.items()) 172 | 173 | def __init__(self): 174 | pass 175 | 176 | @staticmethod 177 | def char2onehot(cards): 178 | counts = Counter(cards) 179 | onehot = np.zeros(54) 180 | for x in cards: 181 | if x in ['*', '$']: 182 | onehot[Card.cards_to_onehot_idx[x]] = 1 183 | else: 184 | subvec = np.zeros(4) 185 | subvec[:counts[x]] = 1 186 | onehot[Card.cards_to_onehot_idx[x]:Card.cards_to_onehot_idx[x] + 4] = subvec 187 | return onehot 188 | 189 | @staticmethod 190 | def char2onehot60(cards): 191 | counts = Counter(cards) 192 | onehot = np.zeros(60, dtype=np.int32) 193 | for x in cards: 194 | subvec = np.zeros(4) 195 | subvec[:counts[x]] = 1 196 | onehot[Card.cards.index(x) * 4:Card.cards.index(x) * 4 + 4] = subvec 197 | return onehot 198 | 199 | @staticmethod 200 | def val2onehot(cards): 201 | chars = [Card.cards[i - 3] for i in cards] 202 | return Card.char2onehot(chars) 203 | 204 | @staticmethod 205 | def val2onehot60(cards): 206 | counts = Counter(cards) 207 | onehot = np.zeros(60) 208 | for x in cards: 209 | idx = (x - 3) * 4 210 | subvec = np.zeros(4) 211 | subvec[:counts[x]] = 1 212 | onehot[idx:idx + 4] = subvec 213 | return onehot 214 | 215 | # convert char to 0-56 color cards 216 | @staticmethod 217 | def char2color(cards): 218 | result = np.zeros([len(cards)]) 219 | mask = np.zeros([57]) 220 | for i in range(len(cards)): 221 | ind = Card.cards.index(cards[i]) * 4 222 | while mask[ind] == 1: 223 | ind += 1 224 | mask[ind] = 1 225 | result[i] = ind 226 | 227 | return result 228 | 229 | @staticmethod 230 | def onehot2color(cards): 231 | result = [] 232 | for i in range(len(cards)): 233 | if cards[i] == 0: 234 | continue 235 | if i == 53: 236 | result.append(56) 237 | else: 238 | result.append(i) 239 | return np.array(result) 240 | 241 | @staticmethod 242 | def onehot2char(cards): 243 | result = [] 244 | for i in range(len(cards)): 245 | if cards[i] == 0: 246 | continue 247 | if i == 53: 248 | result.append(Card.cards[14]) 249 | else: 250 | result.append(Card.cards[i // 4]) 251 | return result 252 | 253 | @staticmethod 254 | def onehot2val(cards): 255 | result = [] 256 | for i in range(len(cards)): 257 | if cards[i] == 0: 258 | continue 259 | if i == 53: 260 | result.append(17) 261 | else: 262 | result.append(i // 4 + 3) 263 | return result 264 | 265 | @staticmethod 266 | def char2value_3_17(cards): 267 | result = [] 268 | if type(cards) is list or type(cards) is range: 269 | for c in cards: 270 | result.append(Card.cards_to_value[c] + 3) 271 | return np.array(result) 272 | else: 273 | return Card.cards_to_value[cards] + 3 274 | 275 | @staticmethod 276 | def to_value(card): 277 | if type(card) is list or type(card) is range: 278 | val = 0 279 | for c in card: 280 | val += Card.cards_to_value[c] 281 | return val 282 | else: 283 | return Card.cards_to_value[card] 284 | 285 | @staticmethod 286 | def to_cards(values): 287 | if type(values) is list or type(values) is range: 288 | cards = [] 289 | for v in values: 290 | cards.append(Card.value_to_cards[v]) 291 | return cards 292 | else: 293 | return Card.value_to_cards[values] 294 | 295 | @staticmethod 296 | def to_cards_from_3_17(values): 297 | return Card.np_cards[values - 3].tolist() 298 | 299 | 300 | class CardGroup: 301 | def __init__(self, cards, t, val, len=1): 302 | self.type = t 303 | self.cards = cards 304 | self.value = val 305 | self.len = len 306 | 307 | def bigger_than(self, g): 308 | if self.type == Category.EMPTY: 309 | return g.type != Category.EMPTY 310 | if g.type == Category.EMPTY: 311 | return True 312 | if g.type == Category.BIGBANG: 313 | return False 314 | if self.type == Category.BIGBANG: 315 | return True 316 | if g.type == Category.QUADRIC: 317 | if self.type == Category.QUADRIC and self.value > g.value: 318 | return True 319 | else: 320 | return False 321 | if self.type == Category.QUADRIC or \ 322 | (self.type == g.type and self.len == g.len and self.value > g.value): 323 | return True 324 | else: 325 | return False 326 | 327 | @staticmethod 328 | def isvalid(cards): 329 | return CardGroup.folks(cards) == 1 330 | 331 | @staticmethod 332 | def to_cardgroup(cards): 333 | candidates = CardGroup.analyze(cards) 334 | for c in candidates: 335 | if len(c.cards) == len(cards): 336 | return c 337 | print("cards error!") 338 | print(cards) 339 | raise Exception("Invalid Cards!") 340 | 341 | @staticmethod 342 | def folks(cards): 343 | cand = CardGroup.analyze(cards) 344 | cnt = 10000 345 | # if not cards: 346 | # return 0 347 | # for c in cand: 348 | # remain = list(cards) 349 | # for card in c.cards: 350 | # remain.remove(card) 351 | # if CardGroup.folks(remain) + 1 < cnt: 352 | # cnt = CardGroup.folks(remain) + 1 353 | # return cnt 354 | spec = False 355 | for c in cand: 356 | if c.type == Category.TRIPLE_LINE or c.type == Category.THREE_ONE or \ 357 | c.type == Category.THREE_TWO or c.type == Category.FOUR_TAKE_ONE or \ 358 | c.type == Category.FOUR_TAKE_TWO or c.type == Category.THREE_ONE_LINE or \ 359 | c.type == Category.THREE_TWO_LINE or c.type == Category.SINGLE_LINE or \ 360 | c.type == Category.DOUBLE_LINE: 361 | spec = True 362 | remain = list(cards) 363 | for card in c.cards: 364 | remain.remove(card) 365 | if CardGroup.folks(remain) + 1 < cnt: 366 | cnt = CardGroup.folks(remain) + 1 367 | if not spec: 368 | cnt = len(cand) 369 | return cnt 370 | 371 | @staticmethod 372 | def analyze(cards): 373 | cards = list(cards) 374 | if len(cards) == 0: 375 | return [CardGroup([], Category.EMPTY, 0)] 376 | candidates = [] 377 | 378 | # TODO: this does not rule out Nuke kicker 379 | counts = Counter(cards) 380 | if '*' in cards and '$' in cards: 381 | candidates.append((CardGroup(['*', '$'], Category.BIGBANG, 100))) 382 | # cards.remove('*') 383 | # cards.remove('$') 384 | 385 | quadrics = [] 386 | # quadric 387 | for c in counts: 388 | if counts[c] == 4: 389 | quadrics.append(c) 390 | candidates.append(CardGroup([c] * 4, Category.QUADRIC, Card.to_value(c))) 391 | cards = list(filter(lambda a: a != c, cards)) 392 | 393 | counts = Counter(cards) 394 | singles = [c for c in counts if counts[c] == 1] 395 | doubles = [c for c in counts if counts[c] == 2] 396 | triples = [c for c in counts if counts[c] == 3] 397 | 398 | singles.sort(key=lambda k: Card.cards_to_value[k]) 399 | doubles.sort(key=lambda k: Card.cards_to_value[k]) 400 | triples.sort(key=lambda k: Card.cards_to_value[k]) 401 | 402 | # continuous sequence 403 | if len(singles) > 0: 404 | cnt = 1 405 | cand = [singles[0]] 406 | for i in range(1, len(singles)): 407 | if Card.to_value(singles[i]) >= Card.to_value('2'): 408 | break 409 | if Card.to_value(singles[i]) == Card.to_value(cand[-1]) + 1: 410 | cand.append(singles[i]) 411 | cnt += 1 412 | else: 413 | if cnt >= 5: 414 | candidates.append(CardGroup(cand, Category.SINGLE_LINE, Card.to_value(cand[0]), cnt)) 415 | # for c in cand: 416 | # cards.remove(c) 417 | cand = [singles[i]] 418 | cnt = 1 419 | if cnt >= 5: 420 | candidates.append(CardGroup(cand, Category.SINGLE_LINE, Card.to_value(cand[0]), cnt)) 421 | # for c in cand: 422 | # cards.remove(c) 423 | 424 | if len(doubles) > 0: 425 | cnt = 1 426 | cand = [doubles[0]] * 2 427 | for i in range(1, len(doubles)): 428 | if Card.to_value(doubles[i]) >= Card.to_value('2'): 429 | break 430 | if Card.to_value(doubles[i]) == Card.to_value(cand[-1]) + 1: 431 | cand += [doubles[i]] * 2 432 | cnt += 1 433 | else: 434 | if cnt >= 3: 435 | candidates.append(CardGroup(cand, Category.DOUBLE_LINE, Card.to_value(cand[0]), cnt)) 436 | # for c in cand: 437 | # if c in cards: 438 | # cards.remove(c) 439 | cand = [doubles[i]] * 2 440 | cnt = 1 441 | if cnt >= 3: 442 | candidates.append(CardGroup(cand, Category.DOUBLE_LINE, Card.to_value(cand[0]), cnt)) 443 | # for c in cand: 444 | # if c in cards: 445 | # cards.remove(c) 446 | 447 | if len(triples) > 0: 448 | cnt = 1 449 | cand = [triples[0]] * 3 450 | for i in range(1, len(triples)): 451 | if Card.to_value(triples[i]) >= Card.to_value('2'): 452 | break 453 | if Card.to_value(triples[i]) == Card.to_value(cand[-1]) + 1: 454 | cand += [triples[i]] * 3 455 | cnt += 1 456 | else: 457 | if cnt >= 2: 458 | candidates.append(CardGroup(cand, Category.TRIPLE_LINE, Card.to_value(cand[0]), cnt)) 459 | # for c in cand: 460 | # if c in cards: 461 | # cards.remove(c) 462 | cand = [triples[i]] * 3 463 | cnt = 1 464 | if cnt >= 2: 465 | candidates.append(CardGroup(cand, Category.TRIPLE_LINE, Card.to_value(cand[0]), cnt)) 466 | # for c in cand: 467 | # if c in cards: 468 | # cards.remove(c) 469 | 470 | for t in triples: 471 | candidates.append(CardGroup([t] * 3, Category.TRIPLE, Card.to_value(t))) 472 | 473 | counts = Counter(cards) 474 | singles = [c for c in counts if counts[c] == 1] 475 | doubles = [c for c in counts if counts[c] == 2] 476 | 477 | # single 478 | for s in singles: 479 | candidates.append(CardGroup([s], Category.SINGLE, Card.to_value(s))) 480 | 481 | # double 482 | for d in doubles: 483 | candidates.append(CardGroup([d] * 2, Category.DOUBLE, Card.to_value(d))) 484 | 485 | # 3 + 1, 3 + 2 486 | for c in triples: 487 | triple = [c] * 3 488 | for s in singles: 489 | if s not in triple: 490 | candidates.append(CardGroup(triple + [s], Category.THREE_ONE, 491 | Card.to_value(c))) 492 | for d in doubles: 493 | if d not in triple: 494 | candidates.append(CardGroup(triple + [d] * 2, Category.THREE_TWO, 495 | Card.to_value(c))) 496 | 497 | # 4 + 2 498 | for c in quadrics: 499 | for extra in list(itertools.combinations(singles, 2)): 500 | candidates.append(CardGroup([c] * 4 + list(extra), Category.FOUR_TAKE_ONE, 501 | Card.to_value(c))) 502 | for extra in list(itertools.combinations(doubles, 2)): 503 | candidates.append(CardGroup([c] * 4 + list(extra) * 2, Category.FOUR_TAKE_TWO, 504 | Card.to_value(c))) 505 | # 3 * n + n, 3 * n + 2 * n 506 | triple_seq = [c.cards for c in candidates if c.type == Category.TRIPLE_LINE] 507 | for cand in triple_seq: 508 | cnt = int(len(cand) / 3) 509 | for extra in list(itertools.combinations(singles, cnt)): 510 | candidates.append( 511 | CardGroup(cand + list(extra), Category.THREE_ONE_LINE, 512 | Card.to_value(cand[0]), cnt)) 513 | for extra in list(itertools.combinations(doubles, cnt)): 514 | candidates.append( 515 | CardGroup(cand + list(extra) * 2, Category.THREE_TWO_LINE, 516 | Card.to_value(cand[0]), cnt)) 517 | 518 | importance = [Category.EMPTY, Category.SINGLE, Category.DOUBLE, Category.DOUBLE_LINE, Category.SINGLE_LINE, 519 | Category.THREE_ONE, 520 | Category.THREE_TWO, Category.THREE_ONE_LINE, Category.THREE_TWO_LINE, 521 | Category.TRIPLE_LINE, Category.TRIPLE, Category.FOUR_TAKE_ONE, Category.FOUR_TAKE_TWO, 522 | Category.QUADRIC, Category.BIGBANG] 523 | candidates.sort(key=functools.cmp_to_key(lambda x, y: importance.index(x.type) - importance.index(y.type) 524 | if importance.index(x.type) != importance.index(y.type) else x.value - y.value)) 525 | # for c in candidates: 526 | # print c.cards 527 | return candidates 528 | 529 | 530 | action_space = get_action_space() 531 | action_space_onehot60 = np.array([Card.char2onehot60(a) for a in action_space]) 532 | action_space_category = [action_space[r[0]:r[1]] for r in Category2Range] 533 | 534 | augment_action_space = action_space + action_space_category[Category.SINGLE][:13] * 3 + action_space_category[ 535 | Category.DOUBLE] 536 | 537 | extra_actions = [] 538 | for j in range(3): 539 | for i in range(13): 540 | tmp = np.zeros([60]) 541 | tmp[i * 4 + j + 1] = 1 542 | extra_actions.append(tmp) 543 | 544 | for i in range(13): 545 | tmp = np.zeros([60]) 546 | tmp[i * 4 + 2:i * 4 + 4] = 1 547 | extra_actions.append(tmp) 548 | 549 | augment_action_space_onehot60 = np.concatenate([action_space_onehot60, np.stack(extra_actions)], 0) 550 | 551 | 552 | def clamp_action_idx(idx): 553 | len_action = len(action_space) 554 | if idx < len_action: 555 | return idx 556 | if idx >= len_action + 13 * 3: 557 | idx = idx - len_action - 13 * 3 + 16 558 | else: 559 | idx = (idx - len_action) % 13 + 1 560 | return idx 561 | 562 | 563 | if __name__ == '__main__': 564 | pass 565 | # print(Card.val2onehot60([3, 3, 16, 17])) 566 | # print(Category2Range) 567 | print(len(action_space_category)) 568 | print(CardGroup.to_cardgroup(['6', '6', 'Q', 'Q', 'Q']).value) 569 | 570 | # print(len(action_space)) 571 | # for a in action_space: 572 | # assert len(a) <= 20 573 | # if len(a) > 0: 574 | # CardGroup.to_cardgroup(a) 575 | # print(a) 576 | # print(action_space_category[Category.SINGLE_LINE.value]) 577 | # print(action_space_category[Category.DOUBLE_LINE.value]) 578 | # print(action_space_category[Category.THREE_ONE.value]) 579 | # CardGroup.to_cardgroup(['6', '6', 'Q', 'Q', 'Q']) 580 | # actions = get_action_space() 581 | # for i in range(1, len(actions)): 582 | # CardGroup.to_cardgroup(actions[i]) 583 | # print(CardGroup.folks(['3', '4', '3', '4', '3', '4', '*', '$'])) 584 | # CardGroup.to_cardgroup(['3', '4', '3', '4', '3', '4', '*', '$']) 585 | # print actions[561] 586 | # print CardGroup.folks(actions[561]) 587 | # CardGroup.to_cardgroup(actions[i]) 588 | # print Card.to_onehot(['3', '4', '4', '$']) 589 | # print len(actions) 590 | # print Card.to_cards(1) 591 | # CardGroup.analyze(['3', '3', '3', '4', '4', '4', '10', 'J', 'Q', 'A', 'A', '2', '2', '*', '$']) 592 | -------------------------------------------------------------------------------- /rule_based/utils/decomposer.py: -------------------------------------------------------------------------------- 1 | # https://github.com/qq456cvb/doudizhu-C 2 | import sys 3 | from rule_based.utils.card import Card, action_space, CardGroup, augment_action_space_onehot60, \ 4 | augment_action_space, clamp_action_idx 5 | from rule_based.utils.utils import get_mask_onehot60 6 | import numpy as np 7 | from config import ENV_DIR 8 | 9 | sys.path.insert(0, ENV_DIR) 10 | from env import get_combinations_nosplit, get_combinations_recursive 11 | 12 | 13 | class Decomposer: 14 | def __init__(self, num_actions=(100, 21)): 15 | self.num_actions = num_actions 16 | 17 | def get_combinations(self, curr_cards_char, last_cards_char): 18 | if len(curr_cards_char) > 10: 19 | card_mask = Card.char2onehot60(curr_cards_char).astype(np.uint8) 20 | mask = augment_action_space_onehot60 21 | a = np.expand_dims(1 - card_mask, 0) * mask 22 | invalid_row_idx = set(np.where(a > 0)[0]) 23 | if len(last_cards_char) == 0: 24 | invalid_row_idx.add(0) 25 | 26 | valid_row_idx = [i for i in range(len(augment_action_space)) if i not in invalid_row_idx] 27 | 28 | mask = mask[valid_row_idx, :] 29 | idx_mapping = dict(zip(range(mask.shape[0]), valid_row_idx)) 30 | 31 | # augment mask 32 | # TODO: known issue: 555444666 will not decompose into 5554 and 66644 33 | combs = get_combinations_nosplit(mask, card_mask) 34 | combs = [([] if len(last_cards_char) == 0 else [0]) + [clamp_action_idx(idx_mapping[idx]) for idx in comb] 35 | for 36 | comb in combs] 37 | 38 | if len(last_cards_char) > 0: 39 | idx_must_be_contained = set( 40 | [idx for idx in valid_row_idx if CardGroup.to_cardgroup(augment_action_space[idx]). \ 41 | bigger_than(CardGroup.to_cardgroup(last_cards_char))]) 42 | combs = [comb for comb in combs if not idx_must_be_contained.isdisjoint(comb)] 43 | fine_mask = np.zeros([len(combs), self.num_actions[1]], dtype=np.bool) 44 | for i in range(len(combs)): 45 | for j in range(len(combs[i])): 46 | if combs[i][j] in idx_must_be_contained: 47 | fine_mask[i][j] = True 48 | else: 49 | fine_mask = None 50 | else: 51 | mask = get_mask_onehot60(curr_cards_char, action_space, None).reshape(len(action_space), 15, 4).sum( 52 | -1).astype( 53 | np.uint8) 54 | valid = mask.sum(-1) > 0 55 | cards_target = Card.char2onehot60(curr_cards_char).reshape(-1, 4).sum(-1).astype(np.uint8) 56 | # do not feed empty to C++, which will cause infinite loop 57 | combs = get_combinations_recursive(mask[valid, :], cards_target) 58 | idx_mapping = dict(zip(range(valid.shape[0]), np.where(valid)[0])) 59 | 60 | combs = [([] if len(last_cards_char) == 0 else [0]) + [idx_mapping[idx] for idx in comb] for comb in combs] 61 | 62 | if len(last_cards_char) > 0: 63 | valid[0] = True 64 | idx_must_be_contained = set( 65 | [idx for idx in range(len(action_space)) if 66 | valid[idx] and CardGroup.to_cardgroup(action_space[idx]). \ 67 | bigger_than(CardGroup.to_cardgroup(last_cards_char))]) 68 | combs = [comb for comb in combs if not idx_must_be_contained.isdisjoint(comb)] 69 | fine_mask = np.zeros([len(combs), self.num_actions[1]], dtype=np.bool) 70 | for i in range(len(combs)): 71 | for j in range(len(combs[i])): 72 | if combs[i][j] in idx_must_be_contained: 73 | fine_mask[i][j] = True 74 | else: 75 | fine_mask = None 76 | return combs, fine_mask 77 | -------------------------------------------------------------------------------- /rule_based/utils/evaluator.py: -------------------------------------------------------------------------------- 1 | # https://www.jianshu.com/p/9fb001daedcf 2 | from rule_based.utils.card import action_space_category 3 | 4 | char2val = { 5 | "3": 3, "4": 4, "5": 5, "6": 6, 6 | "7": 7, "8": 8, "9": 9, "10": 10, 7 | "J": 11, "Q": 12, "K": 13, "A": 14, 8 | "2": 15, "*": 16, "$": 17 9 | } 10 | cards_value = [] 11 | for c in range(len(action_space_category)): 12 | for a in action_space_category[c]: 13 | v = None 14 | if c == 0: 15 | v = 0 16 | elif c <= 3: # 1单牌, 2对子, 3三条 17 | v = char2val[a[0]] - 10 # maxCard - 10 18 | if c == 2 and v > 0: 19 | v *= 1.5 # positive + 50% 20 | if c == 3 and v > 0: 21 | v *= 2 # positive + 100% 22 | elif c == 4: # 4炸弹 23 | v = 9 # 固定9分 24 | elif c <= 6: # 5三带一, 6三带二 25 | v = char2val[a[0]] - 10 # maxCard - 10 26 | if v > 0: 27 | v *= 1.5 # 带牌比三条加得少 28 | elif c <= 9: # 7顺子, 8连对, 9飞机 29 | v = max(0, (char2val[a[-1]] - 10) / 2) # max(0, (maxCard - 10) / 2) 30 | elif c == 10: # 10飞机带小 31 | main_len = len(a) // 4 * 3 32 | v = max(0, (char2val[a[-1]] - 10) / 2) # max(0, (maxCard - 10) / 2) 33 | for i in range(main_len, len(a)): 34 | if char2val[a[i]] > 10: 35 | v += char2val[a[i]] - 10 # 带牌为正加上 36 | elif c == 11: # 11飞机带大 37 | main_len = len(a) // 5 * 3 38 | v = max(0, (char2val[a[-1]] - 10) / 2) # max(0, (maxCard - 10) / 2) 39 | for i in range(main_len, main_len + main_len // 3): 40 | if char2val[a[i]] > 10: 41 | v += 1.5 * (char2val[a[i]] - 10) # 带牌为正加上 42 | elif c == 12: # 12火箭 43 | v = 12 44 | elif c <= 14: # 13四带二只, 14四带二对 45 | v = char2val[a[0]] - 10 # maxCard - 10 46 | assert v is not None 47 | cards_value.append(v) 48 | -------------------------------------------------------------------------------- /rule_based/utils/rule_based_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from rule_based.utils.card import action_space 4 | from rule_based.utils.decomposer import Decomposer 5 | from rule_based.utils.evaluator import cards_value 6 | 7 | card_list = [ 8 | "3", "4", "5", "6", 9 | "7", "8", "9", "10", 10 | "J", "Q", "K", "A", 11 | "2", "*", "$" 12 | ] 13 | 14 | 15 | class RuleBasedModel: 16 | @staticmethod 17 | def get_hand(env): 18 | return env.cards2arr(env.get_curr_handcards()) 19 | 20 | @staticmethod 21 | def get_last(env): 22 | last_two = env.get_last_two_cards() 23 | if last_two[0]: 24 | last = last_two[0] 25 | elif last_two[1]: 26 | last = last_two[1] 27 | else: 28 | last = [] 29 | return env.cards2arr(last) 30 | 31 | @staticmethod 32 | def get_id(env): 33 | return env.get_role_ID() - 1 34 | 35 | @staticmethod 36 | def arr2onehot(arr): 37 | res = np.zeros((15, 4), dtype=np.int) 38 | for card_idx, count in enumerate(arr): 39 | if count > 0: 40 | res[card_idx][:int(count)] = 1 41 | return res 42 | 43 | def choose(self, env): 44 | # 获得手牌 45 | hand_card = self.get_hand(env) 46 | # 拆牌器和引擎用了不同的编码 1 -> A, B -> *, R -> $ 47 | trans_hand_card = [card_list[i] for i in range(15) for _ in range(hand_card[i])] 48 | # 获得上家出牌 49 | last_move = [card_list[i] for i in range(15) for _ in range(self.get_last(env)[i])] 50 | # 拆牌 51 | D = Decomposer() 52 | combs, fine_mask = D.get_combinations(trans_hand_card, last_move) 53 | # 根据对手剩余最少牌数决定每多一手牌的惩罚 54 | left_cards = env.left 55 | min_oppo_crads = min(left_cards[1], left_cards[2]) if self.get_id(env) == 0 else left_cards[0] 56 | round_penalty = 15 - 12 * min_oppo_crads / 20 57 | # 寻找最优出牌 58 | best_move = None 59 | best_comb = None 60 | max_value = -np.inf 61 | for i in range(len(combs)): 62 | # 手牌总分 63 | total_value = sum([cards_value[x] for x in combs[i]]) 64 | small_num = 0 65 | for j in range(0, len(combs[i])): 66 | if j > 0 and action_space[j][0] not in ["2", "R", "B"]: 67 | small_num += 1 68 | total_value -= small_num * round_penalty 69 | for j in range(0, len(combs[i])): 70 | # Pass 得分 71 | if combs[i][j] == 0 and min_oppo_crads > 4: 72 | if total_value > max_value: 73 | max_value = total_value 74 | best_comb = combs[i] 75 | best_move = 0 76 | # 出牌得分 77 | elif combs[i][j] > 0 and (fine_mask is None or fine_mask[i, j] == True): 78 | # 特判只有一手 79 | if len(combs[i]) == 1 or len(combs[i]) == 2 and combs[i][0] == 0: 80 | max_value = np.inf 81 | best_comb = combs[i] 82 | best_move = combs[i][-1] 83 | move_value = total_value - cards_value[combs[i][j]] + round_penalty 84 | if move_value > max_value: 85 | max_value = move_value 86 | best_comb = combs[i] 87 | best_move = combs[i][j] 88 | if best_move is None: 89 | best_comb = [0] 90 | best_move = 0 91 | # 最优出牌 92 | best_cards = action_space[best_move] 93 | move = [best_cards.count(x) for x in card_list] 94 | # 输出选择的牌组 95 | # print("\nbest comb: ") 96 | # for m in best_comb: 97 | # print(action_space[m], cards_value[m]) 98 | # 输出 [手牌] // [出牌] 99 | print("RuleBasedModel", env.arr2cards(hand_card), end=' // ') 100 | print(env.arr2cards(move)) 101 | return self.arr2onehot(move) 102 | -------------------------------------------------------------------------------- /server/CFR.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sun Aug 11 17:53:29 2019 4 | 5 | @author: 刘文景 6 | """ 7 | from envi import r 8 | import numpy as np 9 | import time 10 | import random 11 | 12 | 13 | def hash_card(card): # 例:输入[0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 输出:'020000000000000' 被hash_actions调用 14 | # print("hash_card输入:",card) 15 | card_str = "" 16 | for i in card: 17 | card_str += str(i) 18 | return card_str 19 | 20 | 21 | def hash_actions(actions): # 例:输入[[1,2],[3,4],[5,6]] 输出:'12 34 56 ' 调用hash_card函数 22 | # print("hash_actions输入:",actions) 23 | actions_str = "" 24 | for i in actions: 25 | actions_str += hash_card(i) + " " 26 | # print("hash_actions输出:",actions_str) 27 | return actions_str 28 | 29 | 30 | def after_move_cards(cards, player, action): # 输入当前所有人手牌状态 执行出牌动作动作的玩家 出的牌 返回出牌后的手牌状态 31 | """ 32 | 例: 33 | 输入: 34 | cards=[[1,2,3],[4,5,6],[7,8,9]] 35 | after_move_cards(cards,0,[1,0,0]) 36 | 输出: 37 | [[0, 2, 3], [4, 5, 6], [7, 8, 9]] 38 | """ 39 | c = cards[:] # !!!!!注意这里是深复制 40 | # print("after_move_cards:输入:cards:",c,"player:",player,"action:",action) 41 | t = np.array(c[player]) 42 | t -= np.array(action) 43 | c[player] = list(t) 44 | # print("after_move_cards:输出:",c) 45 | return c 46 | 47 | 48 | def get_moves_new(hand, last_move): # 对于俊get_moves函数的一点小修正 49 | # print("get_moves_new: hand:",hand,"last_move:",last_move) 50 | if hand == [0] * 15: # 即如果手牌为空 那么应对任何牌返回的动作集皆是空集 51 | # print("结果:空") 52 | return [] 53 | else: 54 | # print("结果:",get_moves(hand,last_move)) 55 | return r.get_moves(hand, last_move) 56 | 57 | 58 | class GameStateBase: # 游戏状态父类 59 | 60 | def __init__(self, parent, to_move, actions): 61 | self.parent = parent # parent:记录父结点 62 | self.to_move = to_move # to_move:该回合要行动的玩家 63 | self.actions = actions # actions:该回合可能的行动 64 | 65 | def is_chance(self): 66 | return self.to_move == "CHANCE" # 若该回合的to_move为特殊标记“CHANCE”,则说明该结点是机会结点 67 | 68 | def visualization(self): # 用来观看所生成的所有局面 69 | if self.is_terminal: 70 | pass 71 | # print("胜利玩家:", (self.to_move - 1) % 3, "初始发牌状态:", self.initial_cards) 72 | # print("出牌过程:", self.information_set[1:], "\n") 73 | else: 74 | for child in self.children: 75 | self.children[child].visualization() 76 | 77 | 78 | class ChanceGameState(GameStateBase): # 机会结点(初始发牌) 继承GameStateBase 79 | def __init__(self, actions, first_to_move, last_move=[0] * 15, last_valid_action_pid="CHANCE"): 80 | """ 81 | actions传入的是发牌之后所有有可能的手牌状态 82 | 格式[ [第一种可能的每人手牌状态] , [第二种可能的每人手牌状态] , …… , [最后一种可能的每人手牌状态]] 83 | 其中每一种可能的每人手牌状态的格式是: [ [位置1的初始手牌] , [位置2的初始手牌] , [位置3的初始手牌] ] 84 | 其中[位置1的初始手牌]的格式是一个 15维list 85 | first_to_move传入的是chance node之后第一个行动的玩家是谁 86 | last_move为chance node之前上一个有效出牌 默认为没有(即下一个玩家first_to_move自由出牌) 87 | 88 | """ 89 | # print("chancenode actions:",actions) 90 | super().__init__(parent=None, to_move="CHANCE", actions=actions) # to_move 为特殊名“CHANCE” actions为所有发牌的可能 91 | 92 | ################### 构造孩子结点部分 ###################################### 93 | self.children = { # 孩子结点是一个字典,key是可能的每人手牌情况,value是对应的孩子结点(PlayerMoveGameState类的实例) 94 | hash_actions(cards): PlayerMoveGameState( 95 | # 参数分别是 parent , to_move, actions_history , initial_cards , cards , actions , last_valid_action, last_valid_action_pid 96 | self, first_to_move, [], cards, cards, get_moves_new(cards[first_to_move], last_move), last_move, 97 | last_valid_action_pid 98 | ) for cards in self.actions 99 | } 100 | ###################################################################################################### 101 | 102 | self.information_set = actions # 初始信息集(CHANCE NODE的信息集并没有实际意义) 103 | 104 | self.is_terminal = False 105 | 106 | self.chance_prob = 1. / len(self.children) # 设置发牌产生的每种结果的可能性 107 | 108 | def sample_one(self): 109 | return random.choice(list(self.children.values())) 110 | 111 | def sample_one(self): 112 | return random.choice(list(self.children.values())) 113 | 114 | 115 | class PlayerMoveGameState(GameStateBase): # 玩家行动结点 继承GameStateBase 116 | 117 | def __init__(self, parent, to_move, actions_history, initial_cards, cards, actions, last_valid_action, 118 | last_valid_action_pid): 119 | """ 120 | parent 为该结点的父母结点 即本状态的上一个状态 121 | to_move 为本回合要行动的角色 122 | actions_history 为出牌历史 123 | initial_cards 为初始发牌状态 124 | cards 为当前回合所有人的手牌状态 125 | actions 为当前回合出牌角色可以做出的所有行动(可以出的所有牌组) 126 | last_valid_action 为上一次有效出牌(即忽视“要不起”) 127 | last_valid_action_pid 为上一次有效出牌的角色 128 | """ 129 | 130 | super().__init__(parent=parent, to_move=to_move, actions=actions) 131 | 132 | self.actions_history = actions_history 133 | self.initial_cards = initial_cards 134 | self.cards = cards 135 | self.last_valid_action_pid = last_valid_action_pid 136 | self.last_valid_action = last_valid_action 137 | 138 | ################### 判断是否该结点是否是终结状态 并且计算效用值 ############################### 139 | actor = 0 140 | self.is_terminal = False 141 | for c in self.cards: 142 | if np.array(c).sum() == 0: # 有某个玩家(actor)已经出完了牌 143 | self.is_terminal = True # 该结点是终结结点 144 | if actor == 0: # 地主胜利 145 | self.utility = 1 146 | else: # 农民胜利 147 | self.utility = -1 148 | actor += 1 149 | 150 | ############################################################################################################## 151 | 152 | def next_react_actions(self, action): # 当该结点选择动作action时,下一个结点所有的可能动作 153 | if action != [0] * 15: # 若本回合没有选择过牌 则下一家应该应对本回合所出的牌 154 | return get_moves_new(self.cards[(self.to_move + 1) % 3], action) 155 | else: # 若本回合选择过牌 156 | if (self.to_move + 1) % 3 == self.last_valid_action_pid: # 假如下一家是上一个产生有效发牌的人 则下一家是自由出牌 157 | return get_moves_new(self.cards[(self.to_move + 1) % 3], [0] * 15) 158 | else: # 假如下一家不是上一个产生有效发牌的人 则下一家需要应对上一个有效发牌 159 | return get_moves_new(self.cards[(self.to_move + 1) % 3], self.last_valid_action) 160 | 161 | def next_react_last_valid_action(self, action): # 当该结点选择动作action时,下一个结点所记录的上一个有效动作 162 | if action != [0] * 15: # 若本回合没有选择过牌 则下一家应该应对本回合的牌 163 | return action 164 | else: # 若本回合选择过牌 165 | if (self.to_move + 1) % 3 == self.last_valid_action_pid: # 假如下一家是上一个产生有效发牌的人 则下一家是自由出牌 166 | return [0] * 15 167 | else: # 假如下一家不是上一个产生有效发牌的人 则下一家需要应对上一个有效发牌 168 | return self.last_valid_action 169 | 170 | def next_react_last_valid_action_pid(self, action): # 当该结点选择动作action时,下一个结点所记录的上一个有效动作的执行者 171 | if action != [0] * 15: # 若本回合没有选择过牌 则下一家应该应对本回合的牌 172 | return self.to_move 173 | else: # 若本回合选择过牌 174 | return self.last_valid_action_pid 175 | 176 | ################### 构造孩子结点部分 ###################################### 177 | 178 | if self.is_terminal == False: # 不是终端结点 才可构造接下来的孩子结点 179 | self.children = { 180 | hash_card(a): PlayerMoveGameState( 181 | self, # 参数parent 182 | (self.to_move + 1) % 3, # 参数to_move 下一个要行动的玩家 183 | self.actions_history + [a], # 参数actions_history 原来的历史+本次的行动 184 | initial_cards, # 参数initial_cards 代表的是最开始发牌(指的是chance node)的时候 的每个人的手牌状态 185 | after_move_cards(self.cards, to_move, a), # 参数cards 代表的是该行动结束之后 下一回合的手牌状态 186 | next_react_actions(self, a), # 参数actions 下一个玩家在行动a之后可能的行动 187 | next_react_last_valid_action(self, a), # 参数 last_valid_action 188 | next_react_last_valid_action_pid(self, a), # 参数last_valid_action_pid 189 | ) for a in self.actions 190 | } 191 | else: 192 | self.children = {} 193 | ########################################################################### 194 | 195 | ################### 构造信息集部分 ###################################### 196 | # 构造信息集(针对的是当前玩家to_move) 该结点(状态)所处在的信息集 197 | # 构造信息集的第一项 ini_card 指的是本回合行动的玩家最初的手牌 198 | if self.to_move == 0: # 当前回合是玩家0行动(所以构造的是针对玩家0的信息集) 199 | ini_card = self.initial_cards[0] # 初始发牌时第0位玩家的手牌 200 | elif self.to_move == 1: 201 | ini_card = self.initial_cards[1] 202 | else: 203 | ini_card = self.initial_cards[2] 204 | # 信息集格式:[最初时该玩家的手牌ini_card, 行动历史1 即actions_history的第一项, 行动历史2 即actions_history的第二项,……] 205 | self.information_set = [ini_card] 206 | for history in self.actions_history: 207 | self.information_set += [history] 208 | 209 | # print("信息集:",self.information_set) 210 | ########################################################################### 211 | 212 | 213 | def init_sigma(node, output=None): # 初始化策略:输入一个结点(一般是根节点) 然后输出从该结点开始直到最深 所有信息集的初始策略(随机策略) 214 | output = dict() # 创建空字典 字典内的元素还是字典 215 | 216 | def init_sigma_recursive(node): 217 | output[hash_actions(node.information_set)] = {hash_card(action): 1. / len(node.actions) for action in 218 | node.actions} # 构造该结点的针对当前信息集的策略 219 | for k in node.children: 220 | init_sigma_recursive(node.children[k]) # 按深度遍历 构造每个结点的信息集的策略 221 | 222 | if not node.is_chance(): # 如果该结点不是chance node 正常遍历 223 | init_sigma_recursive(node) 224 | else: # 如果该结点是chance node 则只遍历该结点的孩子结点(策略只针对非chande node的node) 225 | for action in node.actions: 226 | output.update(init_sigma(node.children[hash_actions(action)])) 227 | return output # 格式是字典 key是信息集 value是该信息集下的策略(策略也是一个字典 key是某种行动 value是该行动的概率) 228 | 229 | 230 | def init_empty_node_maps(node, output=None): # 初始化结点并指向0值 输出output格式:字典 key是信息集 value也是一个字典(key是动作 value是0 待更新) 231 | output = dict() 232 | 233 | def init_empty_node_maps_recursive(node): 234 | output[hash_actions(node.information_set)] = {hash_card(action): 0. for action in node.actions} 235 | for k in node.children: 236 | init_empty_node_maps_recursive(node.children[k]) 237 | 238 | if not node.is_chance(): # 如果该结点不是chance node 正常遍历 239 | init_empty_node_maps_recursive(node) 240 | else: # 如果该结点是chance node 则只遍历该结点的孩子结点 241 | for action in node.actions: 242 | output.update(init_empty_node_maps(node.children[hash_actions(action)])) 243 | return output 244 | 245 | 246 | class CounterfactualRegretMinimizationBase: 247 | 248 | def __init__(self, root, chance_sampling=False): 249 | self.root = root 250 | self.sigma = init_sigma(root) # 格式:字典 key是信息集 value也是一个字典(key是动作 value是对应动作的选择概率(初始为随机选择)) 251 | self.cumulative_regrets = init_empty_node_maps( 252 | root) # 一开始都是0 格式:字典 key是信息集(hash后的) value也是一个字典(key是动作(hash后的) value是0 待更新) 253 | # self.cumulative_sigma = init_empty_node_maps(root) # 一开始都是0 格式同上 254 | # self.nash_equilibrium = init_empty_node_maps(root) # 一开始都是0 格式同上 # 在 __value_of_the_game_state_recursive会用到 255 | self.chance_sampling = chance_sampling 256 | 257 | def _update_sigma(self, information_set): # 利用cfr算法更新策略:information_set是信息集(不是Chance Node) 258 | # print("调用_update_sigma") 259 | i = hash_actions(information_set) # 信息集hash化 260 | # print("_update_sigma中cumulative_regrets[i].values:",self.cumulative_regrets[i].values()) 261 | rgrt_sum = sum( 262 | filter(lambda x: x > 0, self.cumulative_regrets[i].values())) # 返回self.cumulative_regrets[i].values()中正数之和 263 | for a in self.cumulative_regrets[i]: # a为对应该信息集下的某种动作 264 | before_change = self.sigma[i][a] 265 | # print("原来的某策略:",self.sigma[i][a]) 266 | self.sigma[i][a] = max(self.cumulative_regrets[i][a], 0.) / rgrt_sum if rgrt_sum > 0 else 1. / len( 267 | self.cumulative_regrets[i].keys()) 268 | after_change = self.sigma[i][a] 269 | if abs(after_change - before_change) > 1e-3: 270 | pass 271 | # print("策略修正 ", after_change - before_change) 272 | # print("_update_sigma后的某策略:",self.sigma[i][a]) 273 | 274 | def _cumulate_cfr_regret(self, information_set, action, regret): 275 | # print("调用_cumulate_cfr_regret,regret:",regret) 276 | i = hash_actions(information_set) 277 | act = hash_card(action) 278 | self.cumulative_regrets[i][act] += regret 279 | 280 | def _cfr_utility_recursive(self, state, reach_a, reach_b, reach_c): # 迭代调用返回该结点的虚拟效用(counterfactual utility) 281 | # reach_i就相当于第i个玩家 到当前结点为止做出过的所有决定的概率的积 282 | children_states_utilities = {} 283 | 284 | if state.is_terminal: # 如果当前结点是终结结点 直接返回效用值 285 | return state.utility 286 | 287 | # 如果当前节点是发牌结点(chance),要考虑是否sampling 若采样某一种发牌情况,则计算该情况cfr效用;若不采样,则计算所有发牌情况的平均效用 288 | if state.to_move == "CHANCE": 289 | if self.chance_sampling: 290 | # if node is a chance node, lets sample one child node and proceed normally 291 | return self._cfr_utility_recursive(state.sample_one(), reach_a, reach_b, reach_c) # samole_one 函数暂时缺失 292 | else: 293 | chance_outcomes = {state.children[hash_actions(action)] for action in 294 | state.actions} # 格式:集合 该state在不同动作的情况下产生的所有子结点的集合 295 | return state.chance_prob * sum( 296 | [self._cfr_utility_recursive(outcome, reach_a, reach_b, reach_c) for outcome in chance_outcomes]) 297 | 298 | # 如果是游戏中间状态结点 计算该结点cfr效用 (sum up all utilities for playing actions in our game state) 299 | value = 0. 300 | for action in state.actions: 301 | sigma_info = self.sigma[hash_actions(state.information_set)] 302 | act = hash_card(action) 303 | child_reach_a = reach_a * ( 304 | sigma_info[act] if state.to_move == 0 else 1) # 如果该回合是第一个玩家做决定 那么在a玩家的做决定策略上乘上做该次决定的概率 305 | child_reach_b = reach_b * ( 306 | sigma_info[act] if state.to_move == 1 else 1) # 如果该回合是第二个玩家做决定 那么在b玩家的做决定策略上乘上做该次决定的概率 307 | child_reach_c = reach_c * ( 308 | sigma_info[act] if state.to_move == 2 else 1) # 如果该回合是第三个玩家做决定 那么在b玩家的做决定策略上乘上做该次决定的概率 309 | 310 | # 将被选择的孩子结点视作根节点 计算该孩子结点的效用(value as if child state implied by chosen action was a game tree root) 311 | child_state_utility = self._cfr_utility_recursive(state.children[act], child_reach_a, child_reach_b, 312 | child_reach_c) 313 | # 将上函数得到的结果乘以选择该行动的概率(即对针对该行动的策略)增加到该结点的value值上 314 | value += sigma_info[act] * child_state_utility 315 | # values for chosen actions (child nodes) are kept here 316 | children_states_utilities[act] = child_state_utility 317 | 318 | # cfr_reach是相对于该结点的玩家来说 其余的玩家做出的选择的概率乘积 319 | # reach是相对于该结点的玩家做出的选择的概率乘积 320 | if state.to_move == 0: 321 | (cfr_reach, reach) = (reach_b * reach_c, reach_a) 322 | elif state.to_move == 1: 323 | (cfr_reach, reach) = (reach_a * reach_c, reach_b) 324 | else: # 该节点行动的是第三个玩家 325 | (cfr_reach, reach) = (reach_a * reach_b, reach_c) 326 | 327 | for action in state.actions: 328 | """ 329 | # 对不同的玩家身份,得到的效用是不同的(地主赢就代表了农民输) 所以应该做一些正负的转换 330 | # 但是对此存疑 因为该模式取于二人零和博弈模型 所以效用的设定 存在了疑问 331 | """ 332 | 333 | act = hash_card(action) 334 | if state.to_move == 0: # 当前行动的是地主 335 | action_cfr_regret = cfr_reach * (children_states_utilities[act] - value) 336 | else: # 当前行动的是农民 (效用和地主的效用是相反的) 337 | action_cfr_regret = -1 * cfr_reach * (children_states_utilities[act] - value) 338 | 339 | # 计算self.cumulative_regrets[state.inf_set()][action] += action_cfr_regret 340 | self._cumulate_cfr_regret(state.information_set, action, action_cfr_regret) 341 | 342 | if self.chance_sampling: 343 | # update sigma according to cumulative regrets - we can do it here because we are using chance sampling 344 | # and so we only visit single game_state from an information set (chance is sampled once) 345 | self._update_sigma(state.information_set) 346 | return value 347 | 348 | 349 | class VanillaCFR(CounterfactualRegretMinimizationBase): 350 | 351 | def __init__(self, root): 352 | super().__init__(root=root, chance_sampling=False) 353 | 354 | def run(self, iterations=1): 355 | for _ in range(0, iterations): 356 | # print("第", _ + 1, "轮开始", end=" ") 357 | time_start = time.time() 358 | 359 | self._cfr_utility_recursive(self.root, 1, 1, 1) 360 | # since we do not update sigmas in each information set while traversing, we need to traverse the tree to perform to update it now 361 | self.__update_sigma_recursively(self.root) 362 | 363 | time_end = time.time() 364 | # print('本轮结束 总用时:', time_end - time_start) 365 | 366 | def __update_sigma_recursively(self, node): 367 | # stop traversal at terminal node 368 | if node.is_terminal: 369 | return 370 | # 忽略chance node 371 | if not node.is_chance(): # 如果该结点不是CHANCE node 372 | self._update_sigma(node.information_set) 373 | # go to subtrees 374 | for k in node.children: 375 | self.__update_sigma_recursively(node.children[k]) 376 | 377 | 378 | class ChanceSamplingCFR(CounterfactualRegretMinimizationBase): 379 | 380 | def __init__(self, root): 381 | super().__init__(root=root, chance_sampling=True) 382 | 383 | def run(self, iterations=1): 384 | for _ in range(0, iterations): 385 | # print("第", _ + 1, "轮开始") 386 | time_start = time.time() 387 | self._cfr_utility_recursive(self.root, 1, 1, 1) 388 | time_end = time.time() 389 | # print('本轮结束 总用时:', time_end - time_start) 390 | 391 | 392 | # 余冠一 393 | def deal(card_n, remainder, cards=[[0] * 15] * 3): # 为了生成所有发牌情况 394 | """ 395 | card_ni 代表第i位玩家的手牌数 396 | remainder 代表当前能发出的所有牌(相对于残局 即意味着 全部的牌减去已经打出的牌) 397 | """ 398 | if sum(remainder) != sum(card_n): # 如果剩余可发的牌和所有人的手牌数不等 则报错: 399 | raise RuntimeError('deal手牌设置出错') 400 | 401 | if sum(remainder) == 0: 402 | return [cards] 403 | 404 | output = [] 405 | for i in range(0, 15): # 代表remainder向量的第i位 406 | if remainder[i] > 0: 407 | r = remainder[:] 408 | r[i] = 0 409 | for num1 in range(min(remainder[i], card_n[0]) + 1): 410 | for num2 in range(min(remainder[i] - num1, card_n[1]) + 1): 411 | num3 = remainder[i] - num1 - num2 412 | if num3 <= card_n[2]: 413 | n = card_n[:] 414 | n[0] -= num1 415 | n[1] -= num2 416 | n[2] -= num3 417 | c = [dc[:] for dc in cards] 418 | c[0][i] += num1 419 | c[1][i] += num2 420 | c[2][i] += num3 421 | output.extend(deal(n, r, c)) 422 | break 423 | return output 424 | 425 | 426 | def initiate_game(person, card, first_to_move, last_move=[0] * 15, last_valid_action_pid="CHANCE"): 427 | """ 428 | 根据输入信息 输出训练后得出的策略 429 | """ 430 | cards_dealings = deal(person, card) 431 | testgame = ChanceGameState(cards_dealings, first_to_move, last_move, last_valid_action_pid) 432 | hahaha = VanillaCFR(testgame) 433 | hahaha.run(8) 434 | return hahaha.sigma 435 | 436 | 437 | def choose(information_set, sigma): 438 | """ 439 | information_set 格式形如 '010000000000000 100000000000000 ' 440 | """ 441 | probability = np.array(list(sigma[information_set].values())) 442 | return np.random.choice(list(sigma[information_set].keys()), p=probability.ravel()) 443 | 444 | 445 | def card_change(yq_card): 446 | wj_card = [0] * 15 447 | for i in yq_card: 448 | wj_card[i - 3] += 1 449 | return wj_card 450 | 451 | 452 | """测试用例 453 | payload1 = { 454 | 'role_id': 1, # 0代表地主上家,1代表地主,2代表地主下家 455 | 'last_taken': { # 更改处 456 | 0: [], 457 | 1: [], 458 | 2: [], 459 | }, 460 | 'cur_cards': [3,3], # 无需保持顺序 461 | 'history': { # 各家走过的牌的历史The environment 462 | 0: [4,4,5,5,5,5,6,6,6,6,7,7,7,7,8,8,8,8,9,9,9,9,10,10,10,10,11,11,11,11], 463 | 1: [12,12,12,12,13,13,13,13], 464 | 2: [14,14,14,14,15,15,15,15,16,17], 465 | }, 466 | 'left': { # 各家剩余的牌 467 | 0: 2, 468 | 1: 2, 469 | 2: 2, 470 | }, 471 | 'debug': False, # 是否返回debug 472 | } 473 | 474 | payload2 = { 475 | 'role_id': 1, # 0代表地主上家,1代表地主,2代表地主下家 476 | 'last_taken': { # 更改处 477 | 0: [], 478 | 1: [], 479 | 2: [], 480 | }, 481 | 'cur_cards': [7,8,12], # 无需保持顺序 482 | 'history': { # 各家走过的牌的历史The environment 483 | 0: [14,15,15,15,12,12,12,7,7,7,8,8,8], 484 | 1: [3,3,3,3,4,4,4,4,5,5,5,5,6,6,6,6,9,9,9,9,10,10,10,10], 485 | 2: [11,11,11,11,13,13,13,13,16,17], 486 | }, 487 | 'left': { # 各家剩余的牌 488 | 0: 2, 489 | 1: 3, 490 | 2: 2, 491 | }, 492 | 'debug': False, # 是否返回debug 493 | } 494 | 495 | 496 | payload3 = { 497 | 'role_id': 1, 498 | 'last_taken': {0: [14, 14], 1: [], 2: [8, 8]}, 499 | 'cur_cards': [8], 500 | 'history': { 501 | 0: [13, 16, 3, 7, 7, 7, 12, 14, 4, 4, 10, 10, 6, 6, 14, 14], 502 | 1: [4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 4, 6, 15, 15, 15, 11, 13, 9, 9], 503 | 2: [9, 10, 11, 12, 13, 9, 15, 17, 3, 3, 3, 11, 8, 8] 504 | }, 505 | 'left': {0: 1, 1: 1, 2: 3}, 506 | 'debug': True} 507 | """ 508 | 509 | 510 | def yq_outcard(string): 511 | # 输入 110000000000000' 输出 [3,3] 512 | output = [] 513 | for i in range(0, 15): 514 | number = int(string[i]) 515 | for j in range(number): 516 | output += [i + 3] 517 | return output 518 | 519 | 520 | def final_card(payload): 521 | first_to_move = (payload['role_id'] - 1) % 3 522 | id = payload['role_id'] 523 | if payload['last_taken'][(id - 1) % 3] == []: # 上家为空 524 | if payload['last_taken'][(id - 2) % 3] == []: # 上上家为空 525 | last_move = [0] * 15 526 | last_valid_action_pid = first_to_move 527 | else: 528 | last_move = card_change(payload['last_taken'][(id - 2) % 3]) 529 | last_valid_action_pid = (first_to_move - 2) % 3 530 | else: 531 | last_move = card_change(payload['last_taken'][(id - 1) % 3]) 532 | last_valid_action_pid = (first_to_move - 1) % 3 533 | person = [payload['left'][1], payload['left'][2], payload['left'][0]] 534 | card = list(np.array([4] * 13 + [1, 1]) - np.array( 535 | card_change(payload['history'][0] + payload['history'][1] + payload['history'][2]))) 536 | sigma = initiate_game(person, card, first_to_move, last_move, last_valid_action_pid) 537 | information_set = hash_card(card_change(payload['cur_cards'])) + " " 538 | return yq_outcard(choose(information_set, sigma)) 539 | 540 | # print(final_card(payload3)) 541 | # print(final_card(payload1)) 542 | # print(final_card(payload2)) 543 | -------------------------------------------------------------------------------- /server/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/charleschen003/doudizhu-rl/56993b04d227e4718969209ab542142d406d3241/server/__init__.py -------------------------------------------------------------------------------- /server/app.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import logging 5 | from flask import Flask, request 6 | from flask_sqlalchemy import SQLAlchemy 7 | 8 | cur_dir, _ = os.path.split(os.path.abspath(__file__)) 9 | par_dir = os.path.abspath(os.path.join(cur_dir, '..')) 10 | sys.path.insert(0, par_dir) 11 | 12 | import server.config as conf 13 | from server.core import Predictor 14 | 15 | ai = Predictor() 16 | logging.basicConfig(filename=os.path.join(cur_dir, 'debug.log'), 17 | level=logging.DEBUG) 18 | 19 | app = Flask(__name__) 20 | app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = True 21 | app.config['SQLALCHEMY_DATABASE_URI'] = conf.db_url 22 | 23 | db = SQLAlchemy(app) 24 | 25 | 26 | class Record(db.Model): 27 | id = db.Column(db.Integer, primary_key=True) 28 | human_up = db.Column(db.Boolean, nullable=False) 29 | human_lord = db.Column(db.Boolean, nullable=False) 30 | human_down = db.Column(db.Boolean, nullable=False) 31 | win_up = db.Column(db.Boolean, nullable=False) 32 | win_lord = db.Column(db.Boolean, nullable=False) 33 | win_down = db.Column(db.Boolean, nullable=False) 34 | 35 | 36 | def get_res(payload): 37 | for key in ['history', 'last_taken', 'left', 'hand_card']: 38 | if key in payload: 39 | payload[key][0] = payload[key].pop('0') 40 | payload[key][1] = payload[key].pop('1') 41 | payload[key][2] = payload[key].pop('2') 42 | debug = payload.pop('debug', False) 43 | res = ai.act(payload) 44 | app.logger.debug(res['msg']) 45 | if debug is False: 46 | res['msg'] = 'success' 47 | return res 48 | 49 | 50 | @app.route('/', methods=['GET', 'POST']) 51 | def home(): 52 | if request.method == 'POST': 53 | payload = request.get_json() 54 | return get_res(payload) 55 | else: 56 | return 'It works' 57 | 58 | 59 | @app.route('/record', methods=['GET', 'POST']) 60 | def record(): 61 | if request.method == 'POST': 62 | payload = request.get_json() 63 | human = [payload['is_human'][str(i)] for i in range(3)] 64 | win = [payload['record'][str(i)] for i in range(3)] 65 | values = human + win 66 | columns = ['human_up', 'human_lord', 'human_down', 67 | 'win_up', 'win_lord', 'win_down'] 68 | r = Record(**dict(zip(columns, values))) 69 | db.session.add(r) 70 | db.session.commit() 71 | return {} 72 | else: 73 | return 'It works, record' 74 | 75 | 76 | if __name__ == '__main__': 77 | app.run(host='0.0.0.0', port=5000, debug=True) 78 | -------------------------------------------------------------------------------- /server/client.py: -------------------------------------------------------------------------------- 1 | import json 2 | import requests 3 | 4 | url = 'http://127.0.0.1:5000/' 5 | server_url = 'http://117.78.4.26:5000' 6 | payload1 = { 7 | 'role_id': 1, # 0代表地主上家,1代表地主,2代表地主下家 8 | 'last_taken': { # 更改处 9 | 0: [], 10 | 1: [3, 4, 5, 6, 7, 8, 9], 11 | 2: [7, 8, 9, 10, 11, 12, 13], 12 | }, 13 | 'cur_cards': [15, 15, 14, 13, 13, 12, 11, 10, 9, 6, 6, 6, 4], # 无需保持顺序 14 | 'history': { # 各家走过的牌的历史The environment 15 | 0: [], 16 | 1: [3, 4, 5, 6, 7, 8, 9], 17 | 2: [7, 8, 9, 10, 11, 12, 13], 18 | }, 19 | 'left': { # 各家剩余的牌 20 | 0: 17, 21 | 1: 13, 22 | 2: 10, 23 | }, 24 | 'debug': False, # 是否返回debug 25 | } 26 | payload2 = { 27 | 'role_id': 1, # 0代表地主上家,1代表地主,2代表地主下家 28 | 'last_taken': { # 更改处 29 | 0: [], 30 | 1: [15, 15], 31 | 2: [], 32 | }, 33 | 'cur_cards': [16, 14, 13, 12, 12, 11, 11, 10, 9, 7, 6, 6, 4, 4, 4, 4], # 无需保持顺序 34 | 'history': { # 各家走过的牌的历史The environment 35 | 0: [8, 8], 36 | 1: [5, 5, 15, 15], 37 | 2: [7, 7], 38 | }, 39 | 'left': { # 各家剩余的牌 40 | 0: 15, 41 | 1: 16, 42 | 2: 15, 43 | }, 44 | 'debug': True, # 是否返回debug 45 | } 46 | 47 | # res = requests.post(server_url, json=payload1) 48 | # print(json.loads(res.content)) 49 | # 50 | # res = requests.post(server_url, json=payload2) 51 | # print(json.loads(res.content)) 52 | record = { 53 | 'is_human': { 54 | '0': 1, 55 | '1': 0, 56 | '2': 0, 57 | }, 58 | 'record': { 59 | '0': 1, 60 | '1': 0, 61 | '2': 0, 62 | } 63 | } 64 | 65 | payload3 = { 66 | 'role_id': 1, # 0代表地主上家,1代表地主,2代表地主下家 67 | 'last_taken': { # 更改处 68 | 0: [], 69 | 1: [9, 9, 9, 6], 70 | 2: [], 71 | }, 72 | 'cur_cards': [17, 16, 15, 14, 14, 12, 10], # 无需保持顺序 73 | 'history': { # 各家走过的牌的历史The environment 74 | 0: [], 75 | 1: [5, 5, 5, 4, 4, 3, 3, 3, 3, 9, 9, 9, 6], 76 | 2: [11, 11, 11, 8, 8], 77 | }, 78 | 'left': { # 各家剩余的牌 79 | 0: 17, 80 | 1: 7, 81 | 2: 12, 82 | }, 83 | 'hand_card': { 84 | 0: [15, 14, 13, 13, 12, 10, 10, 9, 8, 8, 7, 7, 7, 6, 6, 6, 4], 85 | 1: [17, 16, 15, 14, 14, 12, 10], 86 | 2: [15, 15, 14, 13, 13, 12, 12, 11, 10, 7, 5, 4], 87 | }, 88 | 'debug': False, # 是否返回debug 89 | } 90 | import time 91 | 92 | start = time.time() 93 | for i in range(10): 94 | print(i) 95 | res = requests.post('http://40.115.138.207:5000/', json=payload3) 96 | end = time.time() 97 | print(end - start) 98 | -------------------------------------------------------------------------------- /server/config.py: -------------------------------------------------------------------------------- 1 | from dqn import DQNFirst 2 | from net import NetCooperationSimplify 3 | 4 | # db_url = 'sqlite:///tmp.db' 5 | db_url = 'mysql+pymysql://lyq:lyqhhh@localhost/ddz' 6 | 7 | net_dict = { 8 | 'lord': NetCooperationSimplify, 9 | 'up': NetCooperationSimplify, 10 | 'down': NetCooperationSimplify, 11 | } 12 | 13 | model_dict = { 14 | 'lord': '0808_0852_lord_3500_59', # 原:0805_1409_lord_4000 15 | 'up': '0808_0854_up_4000', 16 | 'down': '0808_0854_down_4000', 17 | } 18 | 19 | dqn_dict = { 20 | 'lord': DQNFirst, 21 | 'up': DQNFirst, 22 | 'down': DQNFirst, 23 | } 24 | -------------------------------------------------------------------------------- /server/core.py: -------------------------------------------------------------------------------- 1 | import time 2 | import json 3 | import torch 4 | import numpy as np 5 | import requests 6 | import server.config as conf 7 | from envi import r, Env 8 | from server.CFR import final_card 9 | from server.rule_utils.rule_based_model import choose 10 | 11 | DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 12 | 13 | 14 | class Predictor: 15 | def __init__(self): 16 | self.mock_env = Env(seed=0) 17 | self.lord = self.up = self.down = None 18 | for role in ['lord', 'up', 'down']: 19 | if conf.net_dict[role]: 20 | ai = conf.dqn_dict[role](conf.net_dict[role]) 21 | ai.policy_net.load(conf.model_dict[role]) 22 | setattr(self, role, ai) 23 | self.id2ai = {0: self.up, 1: self.lord, 2: self.down} 24 | self.id2name = {0: '地主上', 1: '地主', 2: '地主下'} 25 | 26 | def get_prob(self, role_id, cur_cards, history, left): 27 | size1, size2 = left[(role_id + 1 + 3) % 3], left[(role_id + 2 + 3) % 3] 28 | taken = np.hstack(list(history.values())).astype(np.int) 29 | cards = np.array(cur_cards, dtype=np.int) 30 | known = self.mock_env.cards2arr(np.hstack([taken, cards])) 31 | known = self.mock_env.batch_arr2onehot([known]).flatten() 32 | prob = self.mock_env.get_state_prob_manual(known, size1, size2) 33 | return prob 34 | 35 | def parse_history(self, role_id, history, last_taken): 36 | h0 = history[(role_id - 1 + 3) % 3] 37 | h1 = history[(role_id + 0 + 3) % 3] 38 | h2 = history[(role_id + 1 + 3) % 3] 39 | b1 = last_taken[(role_id - 1 + 3) % 3] 40 | b2 = last_taken[(role_id - 2 + 3) % 3] 41 | taken = h0 + h1 + h2 42 | return list(map(self.mock_env.cards2arr, [taken, h0, h1, h2, b1, b2])) 43 | 44 | def face(self, role_id, cur_cards, history, left, last_taken, **kwargs): 45 | """ 46 | :return: 6 * 15 * 4 的数组,作为当前状态 47 | """ 48 | # 已知数据 49 | handcards = self.mock_env.cards2arr(cur_cards) 50 | taken, h0, h1, h2, b1, b2 = self.parse_history(role_id, history, last_taken) 51 | known = self.mock_env.batch_arr2onehot([handcards, taken, b1, b2]) 52 | prob = self.get_prob(role_id, cur_cards, history, left).reshape(2, 15, 4) 53 | state = np.concatenate((known, prob)) 54 | return torch.tensor(state, dtype=torch.float).to(DEVICE) 55 | 56 | def valid_actions(self, role_id, cur_cards, last_taken, **kwargs): 57 | """ 58 | :return: batch_size * 15 * 4 的可行动作集合 59 | """ 60 | last = last_taken[(role_id - 1 + 3) % 3] 61 | if not last: 62 | last = last_taken[(role_id - 2 + 3) % 3] 63 | last_back = last 64 | cur_cards, last = list(map(self.mock_env.cards2arr, [cur_cards, last])) 65 | actions = r.get_moves(cur_cards, last) 66 | return last_back, torch.tensor(self.mock_env.batch_arr2onehot(actions), 67 | dtype=torch.float).to(DEVICE) 68 | 69 | def choose(self, role_id, state, actions): 70 | action = self.id2ai[role_id].greedy_action(state, actions) 71 | action = self.mock_env.onehot2arr(action) 72 | return [int(i) for i in self.mock_env.arr2cards(action)] 73 | 74 | def act(self, payload, **kwargs): # TODO 判断使用哪个model 75 | if not payload['cur_cards']: 76 | return {'msg': '无手牌', 'status': False, 'data': []} 77 | start_time = time.time() 78 | total_left = sum(payload['left'].values()) 79 | self_left = len(payload['cur_cards']) 80 | if self_left >= 12: 81 | name = 'Rule' 82 | action = choose(payload) 83 | action = [int(i) for i in self.mock_env.arr2cards(action)] 84 | last_taken = payload['last_taken'] 85 | last = last_taken[(payload['role_id'] - 1 + 3) % 3] 86 | if not last: 87 | last = last_taken[(payload['role_id'] - 2 + 3) % 3] 88 | elif total_left <= 6: 89 | name = 'CFR' 90 | action = final_card(payload) 91 | last_taken = payload['last_taken'] 92 | last = last_taken[(payload['role_id'] - 1 + 3) % 3] 93 | if not last: 94 | last = last_taken[(payload['role_id'] - 2 + 3) % 3] 95 | # elif self_left <= 8: 96 | # name = 'MCTS' 97 | # res = requests.post('http://40.115.138.207:5000/', json=payload) 98 | # action = json.loads(res.content)['data'] 99 | # last_taken = payload['last_taken'] 100 | # last = last_taken[(payload['role_id'] - 1 + 3) % 3] 101 | # if not last: 102 | # last = last_taken[(payload['role_id'] - 2 + 3) % 3] 103 | else: 104 | name = 'RL1' 105 | state = self.face(**payload) 106 | last, actions = self.valid_actions(**payload) 107 | action = self.choose(payload['role_id'], state, actions) 108 | end_time = time.time() 109 | msg = (('\n\t【{0}】使用模型{1},响应耗时{2:.2f}ms\n' 110 | '\t【{0}】桌上的牌:{3}\n' 111 | '\t【{0}】上家出牌:{4}\n' 112 | '\t【{0}】当前手牌:{5}\n' 113 | '\t【{0}】本次出牌:{6}') 114 | .format(self.id2name[payload['role_id']], name, 115 | 1000 * (end_time - start_time), payload['history'], 116 | last, payload['cur_cards'], action)) 117 | res = {'msg': msg, 'status': True, 'data': action} 118 | return res 119 | -------------------------------------------------------------------------------- /server/init.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | cur_dir, _ = os.path.split(os.path.abspath(__file__)) 5 | par_dir = os.path.abspath(os.path.join(cur_dir, '..')) 6 | sys.path.insert(0, par_dir) 7 | 8 | from server.app import db 9 | 10 | db.create_all() 11 | -------------------------------------------------------------------------------- /server/mcts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/charleschen003/doudizhu-rl/56993b04d227e4718969209ab542142d406d3241/server/mcts/__init__.py -------------------------------------------------------------------------------- /server/mcts/backup.py: -------------------------------------------------------------------------------- 1 | def backup(node, reward): 2 | while node is not None: 3 | node.visit += 1 4 | node.reward += reward 5 | node = node.parent 6 | -------------------------------------------------------------------------------- /server/mcts/card.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | from enum import Enum 3 | import numpy as np 4 | import itertools 5 | import functools 6 | import math 7 | 8 | 9 | # Category = Enum('Category', 'EMPTY SINGLE DOUBLE TRIPLE QUADRIC THREE_ONE THREE_TWO SINGLE_LINE DOUBLE_LINE \ 10 | # TRIPLE_LINE THREE_ONE_LINE THREE_TWO_LINE BIGBANG FOUR_TWO', start=0) 11 | 12 | 13 | class Category: 14 | EMPTY = 0 15 | SINGLE = 1 16 | DOUBLE = 2 17 | TRIPLE = 3 18 | QUADRIC = 4 19 | THREE_ONE = 5 20 | THREE_TWO = 6 21 | SINGLE_LINE = 7 22 | DOUBLE_LINE = 8 23 | TRIPLE_LINE = 9 24 | THREE_ONE_LINE = 10 25 | THREE_TWO_LINE = 11 26 | BIGBANG = 12 27 | FOUR_TAKE_ONE = 13 28 | FOUR_TAKE_TWO = 14 29 | 30 | 31 | Category2Range = [] 32 | 33 | 34 | def get_action_space(): 35 | actions = [[]] 36 | # actions = [] 37 | Category2Range.append([0, 1]) 38 | # max_cards = 20 39 | # single 40 | temp = len(actions) 41 | for card in Card.cards: # 15 42 | actions.append([card]) 43 | Category2Range.append([temp, len(actions)]) 44 | temp = len(actions) 45 | # print(len(actions)) 46 | # pair 47 | for card in Card.cards: # 13 48 | if card != '*' and card != '$': 49 | actions.append([card] * 2) 50 | # print(len(actions)) 51 | Category2Range.append([temp, len(actions)]) 52 | temp = len(actions) 53 | # triple 54 | for card in Card.cards: # 13 55 | if card != '*' and card != '$': 56 | actions.append([card] * 3) 57 | # print(len(actions)) 58 | Category2Range.append([temp, len(actions)]) 59 | temp = len(actions) 60 | # bomb 61 | for card in Card.cards: # 13 62 | if card != '*' and card != '$': 63 | actions.append([card] * 4) 64 | Category2Range.append([temp, len(actions)]) 65 | temp = len(actions) 66 | # print(len(actions)) 67 | # 3 + 1 68 | for main in Card.cards: 69 | if main != '*' and main != '$': 70 | for extra in Card.cards: 71 | if extra != main: 72 | actions.append([main] * 3 + [extra]) 73 | # print(len(actions)) 74 | Category2Range.append([temp, len(actions)]) 75 | temp = len(actions) 76 | # 3 + 2 77 | for main in Card.cards: 78 | if main != '*' and main != '$': 79 | for extra in Card.cards: 80 | if extra != main and extra != '*' and extra != '$': 81 | actions.append([main] * 3 + [extra] * 2) 82 | # print(len(actions)) 83 | Category2Range.append([temp, len(actions)]) 84 | temp = len(actions) 85 | # single sequence 86 | for start_v in range(Card.to_value('3'), Card.to_value('2')): 87 | for end_v in range(start_v + 5, Card.to_value('*')): 88 | seq = range(start_v, end_v) 89 | actions.append(sorted(Card.to_cards(seq), key=lambda c: Card.cards.index(c))) 90 | # print(len(actions)) 91 | Category2Range.append([temp, len(actions)]) 92 | temp = len(actions) 93 | # double sequence 94 | for start_v in range(Card.to_value('3'), Card.to_value('2')): 95 | for end_v in range(start_v + 3, int(min(start_v + 20 / 2 + 1, Card.to_value('*')))): 96 | seq = range(start_v, end_v) 97 | actions.append(sorted(Card.to_cards(seq) * 2, key=lambda c: Card.cards.index(c))) 98 | # print(len(actions)) 99 | Category2Range.append([temp, len(actions)]) 100 | temp = len(actions) 101 | # triple sequence 102 | for start_v in range(Card.to_value('3'), Card.to_value('2')): 103 | for end_v in range(start_v + 2, int(min(start_v + 20 // 3 + 1, Card.to_value('*')))): 104 | seq = range(start_v, end_v) 105 | actions.append(sorted(Card.to_cards(seq) * 3, key=lambda c: Card.cards.index(c))) 106 | # print(len(actions)) 107 | Category2Range.append([temp, len(actions)]) 108 | temp = len(actions) 109 | # 3 + 1 sequence 110 | for start_v in range(Card.to_value('3'), Card.to_value('2')): 111 | for end_v in range(start_v + 2, int(min(start_v + 20 / 4 + 1, Card.to_value('*')))): 112 | seq = range(start_v, end_v) 113 | main = Card.to_cards(seq) 114 | remains = [card for card in Card.cards if card not in main] 115 | for extra in list(itertools.combinations(remains, end_v - start_v)): 116 | if not ('*' in list(extra) and '$' in list(extra) and len(extra) == 2): 117 | actions.append(sorted(main * 3, key=lambda c: Card.cards.index(c)) + list(extra)) 118 | # print(len(actions)) 119 | Category2Range.append([temp, len(actions)]) 120 | temp = len(actions) 121 | # 3 + 2 sequence 122 | for start_v in range(Card.to_value('3'), Card.to_value('2')): 123 | for end_v in range(start_v + 2, int(min(start_v + 20 / 5 + 1, Card.to_value('*')))): 124 | seq = range(start_v, end_v) 125 | main = Card.to_cards(seq) 126 | remains = [card for card in Card.cards if card not in main and card not in ['*', '$']] 127 | for extra in list(itertools.combinations(remains, end_v - start_v)): 128 | actions.append(sorted(main * 3, key=lambda c: Card.cards.index(c)) + list(extra) * 2) 129 | # print(len(actions)) 130 | Category2Range.append([temp, len(actions)]) 131 | temp = len(actions) 132 | # bigbang 133 | actions.append(['*', '$']) 134 | # print(len(actions)) 135 | Category2Range.append([temp, len(actions)]) 136 | temp = len(actions) 137 | # 4 + 1 + 1 138 | for main in Card.cards: 139 | if main != '*' and main != '$': 140 | remains = [card for card in Card.cards if card != main] 141 | for extra in list(itertools.combinations(remains, 2)): 142 | if not ('*' in list(extra) and '$' in list(extra) and len(extra) == 2): 143 | actions.append([main] * 4 + list(extra)) 144 | # print(len(actions)) 145 | Category2Range.append([temp, len(actions)]) 146 | temp = len(actions) 147 | # 4 + 2 + 2 148 | for main in Card.cards: 149 | if main != '*' and main != '$': 150 | remains = [card for card in Card.cards if card != main and card != '*' and card != '$'] 151 | for extra in list(itertools.combinations(remains, 2)): 152 | actions.append([main] * 4 + list(extra) * 2) 153 | # print(len(actions)) 154 | Category2Range.append([temp, len(actions)]) 155 | temp = len(actions) 156 | # temp = len(actions) 157 | # for a in actions: 158 | # a.sort(key=lambda c: Card.cards.index(c)) 159 | return actions 160 | 161 | 162 | class Card: 163 | cards = ['3', '4', '5', '6', '7', '8', '9', '10', 'J', 'Q', 'K', 'A', '2', '*', '$'] 164 | np_cards = np.array(cards) 165 | # full_cards = [x for pair in zip(cards, cards, cards, cards) for x in pair if x not in ['*', '$']] 166 | # full_cards += ['*', '$'] 167 | cards_to_onehot_idx = dict((x, i * 4) for (i, x) in enumerate(cards)) 168 | cards_to_onehot_idx['*'] = 52 169 | cards_to_onehot_idx['$'] = 53 170 | cards_to_value = dict(zip(cards, range(len(cards)))) 171 | value_to_cards = dict((v, c) for (c, v) in cards_to_value.items()) 172 | 173 | def __init__(self): 174 | pass 175 | 176 | @staticmethod 177 | def char2onehot(cards): 178 | counts = Counter(cards) 179 | onehot = np.zeros(54) 180 | for x in cards: 181 | if x in ['*', '$']: 182 | onehot[Card.cards_to_onehot_idx[x]] = 1 183 | else: 184 | subvec = np.zeros(4) 185 | subvec[:counts[x]] = 1 186 | onehot[Card.cards_to_onehot_idx[x]:Card.cards_to_onehot_idx[x] + 4] = subvec 187 | return onehot 188 | 189 | @staticmethod 190 | def char2onehot60(cards): 191 | counts = Counter(cards) 192 | onehot = np.zeros(60, dtype=np.int32) 193 | for x in cards: 194 | subvec = np.zeros(4) 195 | subvec[:counts[x]] = 1 196 | onehot[Card.cards.index(x) * 4:Card.cards.index(x) * 4 + 4] = subvec 197 | return onehot 198 | 199 | @staticmethod 200 | def val2onehot(cards): 201 | chars = [Card.cards[i - 3] for i in cards] 202 | return Card.char2onehot(chars) 203 | 204 | @staticmethod 205 | def val2onehot60(cards): 206 | counts = Counter(cards) 207 | onehot = np.zeros(60) 208 | for x in cards: 209 | idx = (x - 3) * 4 210 | subvec = np.zeros(4) 211 | subvec[:counts[x]] = 1 212 | onehot[idx:idx + 4] = subvec 213 | return onehot 214 | 215 | # convert char to 0-56 color cards 216 | @staticmethod 217 | def char2color(cards): 218 | result = np.zeros([len(cards)]) 219 | mask = np.zeros([57]) 220 | for i in range(len(cards)): 221 | ind = Card.cards.index(cards[i]) * 4 222 | while mask[ind] == 1: 223 | ind += 1 224 | mask[ind] = 1 225 | result[i] = ind 226 | 227 | return result 228 | 229 | @staticmethod 230 | def onehot2color(cards): 231 | result = [] 232 | for i in range(len(cards)): 233 | if cards[i] == 0: 234 | continue 235 | if i == 53: 236 | result.append(56) 237 | else: 238 | result.append(i) 239 | return np.array(result) 240 | 241 | @staticmethod 242 | def onehot2char(cards): 243 | result = [] 244 | for i in range(len(cards)): 245 | if cards[i] == 0: 246 | continue 247 | if i == 53: 248 | result.append(Card.cards[14]) 249 | else: 250 | result.append(Card.cards[i // 4]) 251 | return result 252 | 253 | @staticmethod 254 | def onehot2val(cards): 255 | result = [] 256 | for i in range(len(cards)): 257 | if cards[i] == 0: 258 | continue 259 | if i == 53: 260 | result.append(17) 261 | else: 262 | result.append(i // 4 + 3) 263 | return result 264 | 265 | @staticmethod 266 | def char2value_3_17(cards): 267 | result = [] 268 | if type(cards) is list or type(cards) is range: 269 | for c in cards: 270 | result.append(Card.cards_to_value[c] + 3) 271 | return np.array(result) 272 | else: 273 | return Card.cards_to_value[cards] + 3 274 | 275 | @staticmethod 276 | def to_value(card): 277 | if type(card) is list or type(card) is range: 278 | val = 0 279 | for c in card: 280 | val += Card.cards_to_value[c] 281 | return val 282 | else: 283 | return Card.cards_to_value[card] 284 | 285 | @staticmethod 286 | def to_cards(values): 287 | if type(values) is list or type(values) is range: 288 | cards = [] 289 | for v in values: 290 | cards.append(Card.value_to_cards[v]) 291 | return cards 292 | else: 293 | return Card.value_to_cards[values] 294 | 295 | @staticmethod 296 | def to_cards_from_3_17(values): 297 | return Card.np_cards[values - 3].tolist() 298 | 299 | 300 | class CardGroup: 301 | def __init__(self, cards, t, val, len=1): 302 | self.type = t 303 | self.cards = cards 304 | self.value = val 305 | self.len = len 306 | 307 | def bigger_than(self, g): 308 | if self.type == Category.EMPTY: 309 | return g.type != Category.EMPTY 310 | if g.type == Category.EMPTY: 311 | return True 312 | if g.type == Category.BIGBANG: 313 | return False 314 | if self.type == Category.BIGBANG: 315 | return True 316 | if g.type == Category.QUADRIC: 317 | if self.type == Category.QUADRIC and self.value > g.value: 318 | return True 319 | else: 320 | return False 321 | if self.type == Category.QUADRIC or \ 322 | (self.type == g.type and self.len == g.len and self.value > g.value): 323 | return True 324 | else: 325 | return False 326 | 327 | @staticmethod 328 | def isvalid(cards): 329 | return CardGroup.folks(cards) == 1 330 | 331 | @staticmethod 332 | def to_cardgroup(cards): 333 | candidates = CardGroup.analyze(cards) 334 | for c in candidates: 335 | if len(c.cards) == len(cards): 336 | return c 337 | print("cards error!") 338 | print(cards) 339 | raise Exception("Invalid Cards!") 340 | 341 | @staticmethod 342 | def folks(cards): 343 | cand = CardGroup.analyze(cards) 344 | cnt = 10000 345 | # if not cards: 346 | # return 0 347 | # for c in cand: 348 | # remain = list(cards) 349 | # for card in c.cards: 350 | # remain.remove(card) 351 | # if CardGroup.folks(remain) + 1 < cnt: 352 | # cnt = CardGroup.folks(remain) + 1 353 | # return cnt 354 | spec = False 355 | for c in cand: 356 | if c.type == Category.TRIPLE_LINE or c.type == Category.THREE_ONE or \ 357 | c.type == Category.THREE_TWO or c.type == Category.FOUR_TAKE_ONE or \ 358 | c.type == Category.FOUR_TAKE_TWO or c.type == Category.THREE_ONE_LINE or \ 359 | c.type == Category.THREE_TWO_LINE or c.type == Category.SINGLE_LINE or \ 360 | c.type == Category.DOUBLE_LINE: 361 | spec = True 362 | remain = list(cards) 363 | for card in c.cards: 364 | remain.remove(card) 365 | if CardGroup.folks(remain) + 1 < cnt: 366 | cnt = CardGroup.folks(remain) + 1 367 | if not spec: 368 | cnt = len(cand) 369 | return cnt 370 | 371 | @staticmethod 372 | def analyze(cards): 373 | cards = list(cards) 374 | if len(cards) == 0: 375 | return [CardGroup([], Category.EMPTY, 0)] 376 | candidates = [] 377 | 378 | # TODO: this does not rule out Nuke kicker 379 | counts = Counter(cards) 380 | if '*' in cards and '$' in cards: 381 | candidates.append((CardGroup(['*', '$'], Category.BIGBANG, 100))) 382 | # cards.remove('*') 383 | # cards.remove('$') 384 | 385 | quadrics = [] 386 | # quadric 387 | for c in counts: 388 | if counts[c] == 4: 389 | quadrics.append(c) 390 | candidates.append(CardGroup([c] * 4, Category.QUADRIC, Card.to_value(c))) 391 | cards = list(filter(lambda a: a != c, cards)) 392 | 393 | counts = Counter(cards) 394 | singles = [c for c in counts if counts[c] == 1] 395 | doubles = [c for c in counts if counts[c] == 2] 396 | triples = [c for c in counts if counts[c] == 3] 397 | 398 | singles.sort(key=lambda k: Card.cards_to_value[k]) 399 | doubles.sort(key=lambda k: Card.cards_to_value[k]) 400 | triples.sort(key=lambda k: Card.cards_to_value[k]) 401 | 402 | # continuous sequence 403 | if len(singles) > 0: 404 | cnt = 1 405 | cand = [singles[0]] 406 | for i in range(1, len(singles)): 407 | if Card.to_value(singles[i]) >= Card.to_value('2'): 408 | break 409 | if Card.to_value(singles[i]) == Card.to_value(cand[-1]) + 1: 410 | cand.append(singles[i]) 411 | cnt += 1 412 | else: 413 | if cnt >= 5: 414 | candidates.append(CardGroup(cand, Category.SINGLE_LINE, Card.to_value(cand[0]), cnt)) 415 | # for c in cand: 416 | # cards.remove(c) 417 | cand = [singles[i]] 418 | cnt = 1 419 | if cnt >= 5: 420 | candidates.append(CardGroup(cand, Category.SINGLE_LINE, Card.to_value(cand[0]), cnt)) 421 | # for c in cand: 422 | # cards.remove(c) 423 | 424 | if len(doubles) > 0: 425 | cnt = 1 426 | cand = [doubles[0]] * 2 427 | for i in range(1, len(doubles)): 428 | if Card.to_value(doubles[i]) >= Card.to_value('2'): 429 | break 430 | if Card.to_value(doubles[i]) == Card.to_value(cand[-1]) + 1: 431 | cand += [doubles[i]] * 2 432 | cnt += 1 433 | else: 434 | if cnt >= 3: 435 | candidates.append(CardGroup(cand, Category.DOUBLE_LINE, Card.to_value(cand[0]), cnt)) 436 | # for c in cand: 437 | # if c in cards: 438 | # cards.remove(c) 439 | cand = [doubles[i]] * 2 440 | cnt = 1 441 | if cnt >= 3: 442 | candidates.append(CardGroup(cand, Category.DOUBLE_LINE, Card.to_value(cand[0]), cnt)) 443 | # for c in cand: 444 | # if c in cards: 445 | # cards.remove(c) 446 | 447 | if len(triples) > 0: 448 | cnt = 1 449 | cand = [triples[0]] * 3 450 | for i in range(1, len(triples)): 451 | if Card.to_value(triples[i]) >= Card.to_value('2'): 452 | break 453 | if Card.to_value(triples[i]) == Card.to_value(cand[-1]) + 1: 454 | cand += [triples[i]] * 3 455 | cnt += 1 456 | else: 457 | if cnt >= 2: 458 | candidates.append(CardGroup(cand, Category.TRIPLE_LINE, Card.to_value(cand[0]), cnt)) 459 | # for c in cand: 460 | # if c in cards: 461 | # cards.remove(c) 462 | cand = [triples[i]] * 3 463 | cnt = 1 464 | if cnt >= 2: 465 | candidates.append(CardGroup(cand, Category.TRIPLE_LINE, Card.to_value(cand[0]), cnt)) 466 | # for c in cand: 467 | # if c in cards: 468 | # cards.remove(c) 469 | 470 | for t in triples: 471 | candidates.append(CardGroup([t] * 3, Category.TRIPLE, Card.to_value(t))) 472 | 473 | counts = Counter(cards) 474 | singles = [c for c in counts if counts[c] == 1] 475 | doubles = [c for c in counts if counts[c] == 2] 476 | 477 | # single 478 | for s in singles: 479 | candidates.append(CardGroup([s], Category.SINGLE, Card.to_value(s))) 480 | 481 | # double 482 | for d in doubles: 483 | candidates.append(CardGroup([d] * 2, Category.DOUBLE, Card.to_value(d))) 484 | 485 | # 3 + 1, 3 + 2 486 | for c in triples: 487 | triple = [c] * 3 488 | for s in singles: 489 | if s not in triple: 490 | candidates.append(CardGroup(triple + [s], Category.THREE_ONE, 491 | Card.to_value(c))) 492 | for d in doubles: 493 | if d not in triple: 494 | candidates.append(CardGroup(triple + [d] * 2, Category.THREE_TWO, 495 | Card.to_value(c))) 496 | 497 | # 4 + 2 498 | for c in quadrics: 499 | for extra in list(itertools.combinations(singles, 2)): 500 | candidates.append(CardGroup([c] * 4 + list(extra), Category.FOUR_TAKE_ONE, 501 | Card.to_value(c))) 502 | for extra in list(itertools.combinations(doubles, 2)): 503 | candidates.append(CardGroup([c] * 4 + list(extra) * 2, Category.FOUR_TAKE_TWO, 504 | Card.to_value(c))) 505 | # 3 * n + n, 3 * n + 2 * n 506 | triple_seq = [c.cards for c in candidates if c.type == Category.TRIPLE_LINE] 507 | for cand in triple_seq: 508 | cnt = int(len(cand) / 3) 509 | for extra in list(itertools.combinations(singles, cnt)): 510 | candidates.append( 511 | CardGroup(cand + list(extra), Category.THREE_ONE_LINE, 512 | Card.to_value(cand[0]), cnt)) 513 | for extra in list(itertools.combinations(doubles, cnt)): 514 | candidates.append( 515 | CardGroup(cand + list(extra) * 2, Category.THREE_TWO_LINE, 516 | Card.to_value(cand[0]), cnt)) 517 | 518 | importance = [Category.EMPTY, Category.SINGLE, Category.DOUBLE, Category.DOUBLE_LINE, Category.SINGLE_LINE, 519 | Category.THREE_ONE, 520 | Category.THREE_TWO, Category.THREE_ONE_LINE, Category.THREE_TWO_LINE, 521 | Category.TRIPLE_LINE, Category.TRIPLE, Category.FOUR_TAKE_ONE, Category.FOUR_TAKE_TWO, 522 | Category.QUADRIC, Category.BIGBANG] 523 | candidates.sort(key=functools.cmp_to_key(lambda x, y: importance.index(x.type) - importance.index(y.type) 524 | if importance.index(x.type) != importance.index(y.type) else x.value - y.value)) 525 | # for c in candidates: 526 | # print c.cards 527 | return candidates 528 | 529 | 530 | action_space = get_action_space() 531 | action_space_onehot60 = np.array([Card.char2onehot60(a) for a in action_space]) 532 | action_space_category = [action_space[r[0]:r[1]] for r in Category2Range] 533 | 534 | augment_action_space = action_space + action_space_category[Category.SINGLE][:13] * 3 + action_space_category[ 535 | Category.DOUBLE] 536 | 537 | extra_actions = [] 538 | for j in range(3): 539 | for i in range(13): 540 | tmp = np.zeros([60]) 541 | tmp[i * 4 + j + 1] = 1 542 | extra_actions.append(tmp) 543 | 544 | for i in range(13): 545 | tmp = np.zeros([60]) 546 | tmp[i * 4 + 2:i * 4 + 4] = 1 547 | extra_actions.append(tmp) 548 | 549 | augment_action_space_onehot60 = np.concatenate([action_space_onehot60, np.stack(extra_actions)], 0) 550 | 551 | 552 | def clamp_action_idx(idx): 553 | len_action = len(action_space) 554 | if idx < len_action: 555 | return idx 556 | if idx >= len_action + 13 * 3: 557 | idx = idx - len_action - 13 * 3 + 16 558 | else: 559 | idx = (idx - len_action) % 13 + 1 560 | return idx 561 | 562 | 563 | if __name__ == '__main__': 564 | pass 565 | # print(Card.val2onehot60([3, 3, 16, 17])) 566 | # print(Category2Range) 567 | print(len(action_space_category)) 568 | print(CardGroup.to_cardgroup(['6', '6', 'Q', 'Q', 'Q']).value) 569 | 570 | # print(len(action_space)) 571 | # for a in action_space: 572 | # assert len(a) <= 20 573 | # if len(a) > 0: 574 | # CardGroup.to_cardgroup(a) 575 | # print(a) 576 | # print(action_space_category[Category.SINGLE_LINE.value]) 577 | # print(action_space_category[Category.DOUBLE_LINE.value]) 578 | # print(action_space_category[Category.THREE_ONE.value]) 579 | # CardGroup.to_cardgroup(['6', '6', 'Q', 'Q', 'Q']) 580 | # actions = get_action_space() 581 | # for i in range(1, len(actions)): 582 | # CardGroup.to_cardgroup(actions[i]) 583 | # print(CardGroup.folks(['3', '4', '3', '4', '3', '4', '*', '$'])) 584 | # CardGroup.to_cardgroup(['3', '4', '3', '4', '3', '4', '*', '$']) 585 | # print actions[561] 586 | # print CardGroup.folks(actions[561]) 587 | # CardGroup.to_cardgroup(actions[i]) 588 | # print Card.to_onehot(['3', '4', '4', '$']) 589 | # print len(actions) 590 | # print Card.to_cards(1) 591 | # CardGroup.analyze(['3', '3', '3', '4', '4', '4', '10', 'J', 'Q', 'A', 'A', '2', '2', '*', '$']) -------------------------------------------------------------------------------- /server/mcts/default_policy.py: -------------------------------------------------------------------------------- 1 | from copy import copy 2 | 3 | 4 | def default_policy(node, my_id): 5 | current_state = node.get_state() 6 | # 随机出牌直到游戏结束 7 | while current_state.winner == -1: 8 | current_state = current_state.get_next_state_with_random_choice(None) 9 | final_sate_reward = current_state.compute_reward(my_id) 10 | return final_sate_reward 11 | -------------------------------------------------------------------------------- /server/mcts/evaluator.py: -------------------------------------------------------------------------------- 1 | # https://www.jianshu.com/p/9fb001daedcf 2 | from mcts.card import action_space_category 3 | 4 | char2val = { 5 | "3": 3, "4": 4, "5": 5, "6": 6, 6 | "7": 7, "8": 8, "9": 9, "10": 10, 7 | "J": 11, "Q": 12, "K": 13, "A": 14, 8 | "2": 15, "*": 16, "$": 17 9 | } 10 | char2index = { 11 | "3": 1, "4": 2, "5": 3, "6": 4, 12 | "7": 5, "8": 6, "9": 7, "10": 8, 13 | "J": 9, "Q": 10, "K": 11, "A": 12, 14 | "2": 13, "*": 14, "$": 15 15 | } 16 | 17 | cards_value = {} 18 | for c in range(len(action_space_category)): 19 | for a in action_space_category[c]: 20 | cards = [0] * 15 21 | for i in a: 22 | cards[char2index[i]-1] += 1 23 | v = None 24 | if c == 0: 25 | v = 0 26 | elif c <= 3: # 1单牌, 2对子, 3三条 27 | v = char2val[a[0]] - 10 # maxCard - 10 28 | if c == 2 and v > 0: 29 | v *= 1.5 # positive + 50% 30 | if c == 3 and v > 0: 31 | v *= 2 # positive + 100% 32 | elif c == 4: # 4炸弹 33 | v = 9 # 固定9分 34 | elif c <= 6: # 5三带一, 6三带二 35 | v = char2val[a[0]] - 10 # maxCard - 10 36 | if v > 0: 37 | v *= 1.5 # 带牌比三条加得少 38 | elif c <= 9: # 7顺子, 8连对, 9飞机 39 | v = max(0, (char2val[a[-1]] - 10) / 2) # max(0, (maxCard - 10) / 2) 40 | elif c == 10: # 10飞机带小 41 | main_len = len(a) // 4 * 3 42 | v = max(0, (char2val[a[-1]] - 10) / 2) # max(0, (maxCard - 10) / 2) 43 | for i in range(main_len, len(a)): 44 | if char2val[a[i]] > 10: 45 | v += char2val[a[i]] - 10 # 带牌为正加上 46 | elif c == 11: # 11飞机带大 47 | main_len = len(a) // 5 * 3 48 | v = max(0, (char2val[a[-1]] - 10) / 2) # max(0, (maxCard - 10) / 2) 49 | for i in range(main_len, main_len + main_len // 3): 50 | if char2val[a[i]] > 10: 51 | v += 1.5 * (char2val[a[i]] - 10) # 带牌为正加上 52 | elif c == 12: # 12火箭 53 | v = 12 54 | elif c <= 14: # 13四带二只, 14四带二对 55 | v = char2val[a[0]] - 10 # maxCard - 10 56 | assert v is not None 57 | cards_value[tuple(cards)] = v 58 | 59 | # print(action_space_category) 60 | # print(cards_value) -------------------------------------------------------------------------------- /server/mcts/get_bestchild.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def UCB1(node, c=0.7): 5 | visit = np.array([n.visit for n in node.children]) 6 | reward = np.array([n.reward for n in node.children]) 7 | values = reward/visit + c * np.sqrt(2*np.log(node.visit) / visit) 8 | index = np.where(values == np.max(values)) 9 | return np.array(node.children)[index] 10 | 11 | 12 | def UCB2(node, c=0.7): 13 | visit = np.array([n.visit for n in node.children]) 14 | reward = np.array([n.reward for n in node.children]) 15 | values = reward / visit + c * np.sqrt(2 * np.log(node.visit) / visit) 16 | index = np.where(values == np.min(values)) 17 | return np.array(node.children)[index] 18 | 19 | 20 | def get_bestchild(node, my_id): 21 | if node.state.my_id == my_id: 22 | nodes = UCB1(node) 23 | if len(nodes) == 1: 24 | return nodes[0] 25 | else: 26 | return np.random.choice(nodes) 27 | else: 28 | nodes = UCB2(node) 29 | if len(nodes) == 1: 30 | return nodes[0] 31 | else: 32 | return np.random.choice(nodes) 33 | 34 | def get_bestchild_(node): 35 | visit = np.array([n.visit for n in node.children]) 36 | reward = np.array([n.reward for n in node.children]) 37 | values = reward / visit 38 | index = np.where(values == np.max(values)) 39 | nodes = np.array(node.children)[index] 40 | if len(nodes) == 1: 41 | return nodes[0] 42 | else: 43 | return np.random.choice(nodes) -------------------------------------------------------------------------------- /server/mcts/get_moves.py: -------------------------------------------------------------------------------- 1 | from envi import r 2 | from mcts.evaluator import cards_value 3 | import numpy as np 4 | 5 | # test r 6 | # all_cards = [4]*13 + [1,1] 7 | # print(all_cards) 8 | # 9 | # handcards = [0,0,0,0,0,1,1,3,3,3,3,3,1,1,1] 10 | # no_cards = [0]*15 11 | # result = r.get_moves(handcards,no_cards) 12 | # print(result) 13 | 14 | 15 | # 根据当前手牌和上家的牌返回所有可能出的牌 16 | # input: 17 | # handcards : dict e.g. 4455566 = {'3': 0, '4': 2, '5': 3, '6': 2, '7': 0, '8': 0, '9': 0, '10': 0, '11': 0, '12': 0, '13': 0, '1': 0, '2': 0, '14': 0, '15': 0} 18 | # lastcards: list e.g. 33 = [3,3] 19 | # ouput: 20 | # list of dict 21 | 22 | sidaihuojian = [] 23 | t = [0]*13+[1,1] 24 | for i in range(13): 25 | tt = t[:] 26 | tt[i] = 4 27 | sidaihuojian.append(tt) 28 | 29 | sandaihuojian = [] 30 | for i in range(11): 31 | ttt = t[:] 32 | ttt[i] = 3 33 | ttt[i+1] = 3 34 | sandaihuojian.append(ttt) 35 | 36 | 37 | # print(sidaihuojian) 38 | def get_moves(handcards, lastcards): 39 | if not lastcards: 40 | lastcards = [] 41 | index = [str(i) for i in range(3,14)] + ['1','2','14','15'] 42 | rhandcards = list(handcards.values()) 43 | tem = dict(zip(index, [0]*15)) 44 | for l in lastcards: 45 | # print('last cards',l) 46 | tem[str(l)] += 1 47 | rlastcards = list(tem.values()) 48 | 49 | moves = [] 50 | rmoves = r.get_moves(rhandcards, rlastcards) 51 | 52 | length = len(rmoves) 53 | if length > 10: 54 | # print('-----pruning-----') 55 | values = [] 56 | handnum = sum(rhandcards) 57 | rrmoves = [] 58 | for m in rmoves: 59 | if m in sidaihuojian or m in sandaihuojian: 60 | continue 61 | rrmoves.append(m) 62 | values.append(cards_value[tuple(m)]- 0.1 * (handnum - sum(m))) 63 | sorted_index = sorted(range(len(values)), key=lambda i: values[i]) 64 | for k in range(int(length/3+1)): 65 | moves.append(dict(zip(index, rrmoves[sorted_index[k]]))) 66 | moves.append(dict(zip(index, rrmoves[sorted_index[-k-1]]))) 67 | 68 | else: 69 | for m in rmoves: 70 | moves.append(dict(zip(index, m))) 71 | 72 | return moves 73 | 74 | 75 | # test function: 76 | # index = [str(i) for i in range(3, 14)] + ['1', '2', '14', '15'] 77 | # aa = [0, 2, 3, 2] + [0]*11 78 | # a = dict(zip(index, aa)) 79 | # # hand_card = {'3': 2, '4': 1, '5': 1, '6': 1, '7': 1, '8': 1, '9': 1, '10': 1, '11': 1, '12': 2, '13': 0, '1': 0, '2': 1, '14': 0, '15': 0} 80 | # # hand_card = {'3': 1, '4': 1, '5': 1, '6': 3, '7': 3, '8': 3, '9': 0, '10': 0, '11': 2, '12': 0, '13': 0, '1': 0, '2': 1, '14': 0, '15': 0} 81 | # hand_card = {'3': 2, '4': 1, '5': 1, '6': 1, '7': 1, '8': 1, '9': 1, '10': 1, '11': 1, '12': 2, '13': 0, '1': 0, '2': 2, '14': 0, '15': 0} 82 | # b = [] 83 | # moves = get_moves(hand_card, b) 84 | # print(moves) 85 | 86 | -------------------------------------------------------------------------------- /server/mcts/interface.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import sys 3 | import os 4 | sys.path.insert(0, os.path.join('..')) 5 | 6 | from mcts.tree_policy import tree_policy 7 | from mcts.default_policy import default_policy 8 | from mcts.backup import backup 9 | from mcts.tree import Node, State 10 | from mcts.get_moves import get_moves 11 | from mcts.get_bestchild import get_bestchild_ 12 | import time 13 | 14 | 15 | def mcts(payload): 16 | root = Node(None, None) 17 | my_id = (payload['role_id'] + 2) % 3 18 | 19 | next_id = (payload['role_id'] + 1) % 3 20 | next_next_id = (payload['role_id'] + 2) % 3 21 | my_card_ = payload['hand_card'][payload['role_id']] 22 | my_card_.sort() 23 | my_card = card_list_to_dict(card_to_list(change_card_form_reversal(my_card_))) 24 | next_card_ = payload['hand_card'][next_id] 25 | next_card_.sort() 26 | next_card = card_list_to_dict(card_to_list(change_card_form_reversal(next_card_))) 27 | next_next_card_ = payload['hand_card'][next_next_id] 28 | next_next_card_.sort() 29 | next_next_card = card_list_to_dict(card_to_list(change_card_form_reversal(next_next_card_))) 30 | last_move_, last_p_ = get_last_move(payload['role_id'], next_id, next_next_id, payload['last_taken']) 31 | last_move = change_card_form_reversal(last_move_) 32 | last_p = (last_p_ + 2) % 3 33 | moves_num = len(get_moves(my_card, last_move)) 34 | state = State(my_id, my_card, next_card, next_next_card, last_move, -1, moves_num, None, last_p) 35 | root.set_state(state) 36 | 37 | computation_budget = 1000 38 | for i in range(computation_budget): 39 | expand_node = tree_policy(root, my_id) 40 | reward = default_policy(expand_node, my_id) 41 | backup(expand_node, reward) 42 | best_next_node = get_bestchild_(root) 43 | move = best_next_node.get_state().action 44 | 45 | return move 46 | 47 | 48 | def change_card_form_reversal(before): 49 | # e.g.[3, 3, 3, 4, 4, 4, 14, 15, 16, 17] -> [3, 3, 3, 4, 4, 4, 1, 2, 14 ,15] 50 | card = before.copy() 51 | for i, j in enumerate(before): 52 | if j == 14: 53 | card[i] = 1 54 | if j == 15: 55 | card[i] = 2 56 | if j == 16: 57 | card[i] = 14 58 | if j == 17: 59 | card[i] = 15 60 | return card 61 | 62 | 63 | def card_to_list(before): 64 | # e.g. [3, 3, 3, 4, 4, 4, 1, 2] -> [3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0] 65 | # index = [str(i) for i in range(3, 14)] + ['1', '2', '14', '15'] 66 | tem = [0] * 15 67 | for card in before: 68 | tem[card - 1] += 1 69 | tem = tem[2:-2] + tem[:2] + tem[-2:] 70 | return tem 71 | 72 | 73 | def card_list_to_dict(card_list): 74 | # e.g. [3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0] -> ['3':3, '4':3, '5':0, '6':0, '7':0, '8':0, '9':0, '10':0, '11':0, '12':0, '13':0, '1':1, '2':1, '14':0, '15':0] 75 | card_name = [str(i) for i in range(3, 14)] + ['1', '2', '14', '15'] 76 | card_dict = dict(zip(card_name, card_list)) 77 | return card_dict 78 | 79 | 80 | def get_last_move(role_id, next_id, next_next_id, last_taken): 81 | my_taken = last_taken[role_id] 82 | next_taken = last_taken[next_id] 83 | next_next_taken = last_taken[next_next_id] 84 | if len(next_next_taken) != 0: 85 | return next_next_taken, next_next_id 86 | if len(next_taken) != 0: 87 | return next_taken, next_id 88 | return my_taken, role_id 89 | -------------------------------------------------------------------------------- /server/mcts/tree.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from mcts.get_moves import get_moves 3 | from copy import copy 4 | 5 | 6 | class Node(object): 7 | def __init__(self, parent, state): 8 | self.parent = parent 9 | self.children = [] 10 | # 胜利 11 | self.reward = 0 12 | # 总局数 13 | self.visit = 0 14 | self.state = state 15 | 16 | def set_state(self, state): 17 | self.state = state 18 | 19 | def get_state(self): 20 | return self.state 21 | 22 | def get_children(self): 23 | return self.children 24 | 25 | def add_child(self, sub_node): 26 | self.children.append(sub_node) 27 | 28 | def is_all_expand(self): 29 | if len(self.children) < self.state.moves_num: 30 | return False 31 | return True 32 | 33 | def expand(self): 34 | if self.state.try_flag == 0: 35 | valid_moves = get_moves(self.state.my_card, self.state.last_move) 36 | for move in valid_moves: 37 | self.state.init_untried_actions(move) 38 | self.state.try_flag = 1 39 | 40 | moves_num = len(self.state.untried_actions) 41 | i = np.random.choice(moves_num) 42 | untried_move = self.state.untried_actions[i].copy() 43 | while self.state.is_buchu(untried_move) and self.state.last_pid == self.state.my_id: 44 | i = np.random.choice(moves_num) 45 | untried_move = self.state.untried_actions[i].copy() 46 | 47 | new_state = self.get_state().get_next_state_with_random_choice(untried_move) 48 | del self.state.untried_actions[i] 49 | sub_node = Node(self, new_state) 50 | self.add_child(sub_node) 51 | return sub_node 52 | 53 | 54 | class State(object): 55 | def __init__(self, my_id, my_card, next_card, next_next_card, last_move, winner, moves_num, action, last_p): 56 | self.my_id = my_id 57 | self.my_card = my_card 58 | self.next_card = next_card 59 | self.next_next_card = next_next_card 60 | self.last_move = last_move 61 | self.winner = winner 62 | self.moves_num = moves_num 63 | self.action = action 64 | self.last_pid = last_p 65 | self.untried_actions = [] 66 | self.try_flag = 0 67 | 68 | def init_untried_actions(self, move): 69 | self.untried_actions.append(move) 70 | 71 | def compute_reward(self, my_id): 72 | if my_id == 0: 73 | if self.winner == my_id: 74 | return 1 75 | else: 76 | return 0 77 | else: 78 | if self.winner != 0: 79 | return 1 80 | else: 81 | return 0 82 | 83 | def get_next_state_with_random_choice(self, untried_move): 84 | 85 | # 下家变自家,下下家变下家,自家变下下家 86 | valid_moves = get_moves(self.my_card, self.last_move) 87 | moves_num = len(valid_moves) 88 | i = np.random.choice(moves_num) 89 | tmp = valid_moves[i].copy() 90 | if untried_move is not None: 91 | tmp = untried_move 92 | while self.is_buchu(tmp) and self.last_pid == self.my_id: 93 | i = np.random.choice(moves_num) 94 | tmp = valid_moves[i].copy() 95 | move = [] 96 | next_next_card = self.my_card.copy() 97 | for k in [str(i) for i in range(3, 14)] + ['1', '2', '14', '15']: 98 | move.extend([int(k)] * tmp.get(k, 0)) 99 | next_next_card[k] -= tmp.get(k, 0) 100 | 101 | my_id = (self.my_id + 1) % 3 102 | my_card = self.next_card.copy() 103 | next_card = self.next_next_card.copy() 104 | # 判断出完牌游戏是否结束 105 | winner = self.my_id 106 | for lis in next_next_card.values(): 107 | if lis != 0: 108 | winner = -1 109 | break 110 | last_move = move.copy() 111 | last_p = self.my_id 112 | # 如果选择不出, 下家的last_move等于自家的last_move 113 | if len(move) == 0: 114 | last_p = self.last_pid 115 | last_move = self.last_move.copy() 116 | if len(move) == 0 and self.last_pid == my_id: 117 | last_move = [] 118 | valid_moves_ = get_moves(my_card, last_move) 119 | moves_num_ = len(valid_moves_) 120 | next_state = State(my_id, my_card, next_card, next_next_card, last_move, winner, moves_num_, move, last_p) 121 | return next_state 122 | 123 | @staticmethod 124 | def is_buchu(move): 125 | for k in [str(i) for i in range(3, 14)] + ['1', '2', '14', '15']: 126 | if move.get(k) != 0: 127 | return False 128 | return True 129 | -------------------------------------------------------------------------------- /server/mcts/tree_policy.py: -------------------------------------------------------------------------------- 1 | from mcts.get_bestchild import get_bestchild 2 | 3 | 4 | def tree_policy(node, my_id): 5 | while node.state.winner == -1: 6 | if node.is_all_expand(): 7 | node = get_bestchild(node, my_id) 8 | else: 9 | sub_node = node.expand() 10 | return sub_node 11 | return node 12 | 13 | 14 | -------------------------------------------------------------------------------- /server/rule_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/charleschen003/doudizhu-rl/56993b04d227e4718969209ab542142d406d3241/server/rule_utils/__init__.py -------------------------------------------------------------------------------- /server/rule_utils/card.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | from enum import Enum 3 | import numpy as np 4 | import itertools 5 | import functools 6 | import math 7 | 8 | # Category = Enum('Category', 'EMPTY SINGLE DOUBLE TRIPLE QUADRIC THREE_ONE THREE_TWO SINGLE_LINE DOUBLE_LINE \ 9 | # TRIPLE_LINE THREE_ONE_LINE THREE_TWO_LINE BIGBANG FOUR_TWO', start=0) 10 | 11 | 12 | class Category: 13 | EMPTY = 0 14 | SINGLE = 1 15 | DOUBLE = 2 16 | TRIPLE = 3 17 | QUADRIC = 4 18 | THREE_ONE = 5 19 | THREE_TWO = 6 20 | SINGLE_LINE = 7 21 | DOUBLE_LINE = 8 22 | TRIPLE_LINE = 9 23 | THREE_ONE_LINE = 10 24 | THREE_TWO_LINE = 11 25 | BIGBANG = 12 26 | FOUR_TAKE_ONE = 13 27 | FOUR_TAKE_TWO = 14 28 | 29 | 30 | Category2Range = [] 31 | 32 | 33 | def get_action_space(): 34 | actions = [[]] 35 | # actions = [] 36 | Category2Range.append([0, 1]) 37 | # max_cards = 20 38 | # single 39 | temp = len(actions) 40 | for card in Card.cards: # 15 41 | actions.append([card]) 42 | Category2Range.append([temp, len(actions)]) 43 | temp = len(actions) 44 | # print(len(actions)) 45 | # pair 46 | for card in Card.cards: # 13 47 | if card != '*' and card != '$': 48 | actions.append([card] * 2) 49 | # print(len(actions)) 50 | Category2Range.append([temp, len(actions)]) 51 | temp = len(actions) 52 | # triple 53 | for card in Card.cards: # 13 54 | if card != '*' and card != '$': 55 | actions.append([card] * 3) 56 | # print(len(actions)) 57 | Category2Range.append([temp, len(actions)]) 58 | temp = len(actions) 59 | # bomb 60 | for card in Card.cards: # 13 61 | if card != '*' and card != '$': 62 | actions.append([card] * 4) 63 | Category2Range.append([temp, len(actions)]) 64 | temp = len(actions) 65 | # print(len(actions)) 66 | # 3 + 1 67 | for main in Card.cards: 68 | if main != '*' and main != '$': 69 | for extra in Card.cards: 70 | if extra != main: 71 | actions.append([main] * 3 + [extra]) 72 | # print(len(actions)) 73 | Category2Range.append([temp, len(actions)]) 74 | temp = len(actions) 75 | # 3 + 2 76 | for main in Card.cards: 77 | if main != '*' and main != '$': 78 | for extra in Card.cards: 79 | if extra != main and extra != '*' and extra != '$': 80 | actions.append([main] * 3 + [extra] * 2) 81 | # print(len(actions)) 82 | Category2Range.append([temp, len(actions)]) 83 | temp = len(actions) 84 | # single sequence 85 | for start_v in range(Card.to_value('3'), Card.to_value('2')): 86 | for end_v in range(start_v + 5, Card.to_value('*')): 87 | seq = range(start_v, end_v) 88 | actions.append(sorted(Card.to_cards(seq), key=lambda c: Card.cards.index(c))) 89 | # print(len(actions)) 90 | Category2Range.append([temp, len(actions)]) 91 | temp = len(actions) 92 | # double sequence 93 | for start_v in range(Card.to_value('3'), Card.to_value('2')): 94 | for end_v in range(start_v + 3, int(min(start_v + 20 / 2 + 1, Card.to_value('*')))): 95 | seq = range(start_v, end_v) 96 | actions.append(sorted(Card.to_cards(seq) * 2, key=lambda c: Card.cards.index(c))) 97 | # print(len(actions)) 98 | Category2Range.append([temp, len(actions)]) 99 | temp = len(actions) 100 | # triple sequence 101 | for start_v in range(Card.to_value('3'), Card.to_value('2')): 102 | for end_v in range(start_v + 2, int(min(start_v + 20 // 3 + 1, Card.to_value('*')))): 103 | seq = range(start_v, end_v) 104 | actions.append(sorted(Card.to_cards(seq) * 3, key=lambda c: Card.cards.index(c))) 105 | # print(len(actions)) 106 | Category2Range.append([temp, len(actions)]) 107 | temp = len(actions) 108 | # 3 + 1 sequence 109 | for start_v in range(Card.to_value('3'), Card.to_value('2')): 110 | for end_v in range(start_v + 2, int(min(start_v + 20 / 4 + 1, Card.to_value('*')))): 111 | seq = range(start_v, end_v) 112 | main = Card.to_cards(seq) 113 | remains = [card for card in Card.cards if card not in main] 114 | for extra in list(itertools.combinations(remains, end_v - start_v)): 115 | if not ('*' in list(extra) and '$' in list(extra) and len(extra) == 2): 116 | actions.append(sorted(main * 3, key=lambda c: Card.cards.index(c)) + list(extra)) 117 | # print(len(actions)) 118 | Category2Range.append([temp, len(actions)]) 119 | temp = len(actions) 120 | # 3 + 2 sequence 121 | for start_v in range(Card.to_value('3'), Card.to_value('2')): 122 | for end_v in range(start_v + 2, int(min(start_v + 20 / 5 + 1, Card.to_value('*')))): 123 | seq = range(start_v, end_v) 124 | main = Card.to_cards(seq) 125 | remains = [card for card in Card.cards if card not in main and card not in ['*', '$']] 126 | for extra in list(itertools.combinations(remains, end_v - start_v)): 127 | actions.append(sorted(main * 3, key=lambda c: Card.cards.index(c)) + list(extra) * 2) 128 | # print(len(actions)) 129 | Category2Range.append([temp, len(actions)]) 130 | temp = len(actions) 131 | # bigbang 132 | actions.append(['*', '$']) 133 | # print(len(actions)) 134 | Category2Range.append([temp, len(actions)]) 135 | temp = len(actions) 136 | # 4 + 1 + 1 137 | for main in Card.cards: 138 | if main != '*' and main != '$': 139 | remains = [card for card in Card.cards if card != main] 140 | for extra in list(itertools.combinations(remains, 2)): 141 | if not ('*' in list(extra) and '$' in list(extra) and len(extra) == 2): 142 | actions.append([main] * 4 + list(extra)) 143 | # print(len(actions)) 144 | Category2Range.append([temp, len(actions)]) 145 | temp = len(actions) 146 | # 4 + 2 + 2 147 | for main in Card.cards: 148 | if main != '*' and main != '$': 149 | remains = [card for card in Card.cards if card != main and card != '*' and card != '$'] 150 | for extra in list(itertools.combinations(remains, 2)): 151 | actions.append([main] * 4 + list(extra) * 2) 152 | # print(len(actions)) 153 | Category2Range.append([temp, len(actions)]) 154 | temp = len(actions) 155 | # temp = len(actions) 156 | # for a in actions: 157 | # a.sort(key=lambda c: Card.cards.index(c)) 158 | return actions 159 | 160 | 161 | class Card: 162 | cards = ['3', '4', '5', '6', '7', '8', '9', '10', 'J', 'Q', 'K', 'A', '2', '*', '$'] 163 | np_cards = np.array(cards) 164 | # full_cards = [x for pair in zip(cards, cards, cards, cards) for x in pair if x not in ['*', '$']] 165 | # full_cards += ['*', '$'] 166 | cards_to_onehot_idx = dict((x, i * 4) for (i, x) in enumerate(cards)) 167 | cards_to_onehot_idx['*'] = 52 168 | cards_to_onehot_idx['$'] = 53 169 | cards_to_value = dict(zip(cards, range(len(cards)))) 170 | value_to_cards = dict((v, c) for (c, v) in cards_to_value.items()) 171 | 172 | def __init__(self): 173 | pass 174 | 175 | @staticmethod 176 | def char2onehot(cards): 177 | counts = Counter(cards) 178 | onehot = np.zeros(54) 179 | for x in cards: 180 | if x in ['*', '$']: 181 | onehot[Card.cards_to_onehot_idx[x]] = 1 182 | else: 183 | subvec = np.zeros(4) 184 | subvec[:counts[x]] = 1 185 | onehot[Card.cards_to_onehot_idx[x]:Card.cards_to_onehot_idx[x]+4] = subvec 186 | return onehot 187 | 188 | @staticmethod 189 | def char2onehot60(cards): 190 | counts = Counter(cards) 191 | onehot = np.zeros(60, dtype=np.int32) 192 | for x in cards: 193 | subvec = np.zeros(4) 194 | subvec[:counts[x]] = 1 195 | onehot[Card.cards.index(x) * 4:Card.cards.index(x) * 4 + 4] = subvec 196 | return onehot 197 | 198 | @staticmethod 199 | def val2onehot(cards): 200 | chars = [Card.cards[i - 3] for i in cards] 201 | return Card.char2onehot(chars) 202 | 203 | @staticmethod 204 | def val2onehot60(cards): 205 | counts = Counter(cards) 206 | onehot = np.zeros(60) 207 | for x in cards: 208 | idx = (x - 3) * 4 209 | subvec = np.zeros(4) 210 | subvec[:counts[x]] = 1 211 | onehot[idx:idx+4] = subvec 212 | return onehot 213 | 214 | # convert char to 0-56 color cards 215 | @staticmethod 216 | def char2color(cards): 217 | result = np.zeros([len(cards)]) 218 | mask = np.zeros([57]) 219 | for i in range(len(cards)): 220 | ind = Card.cards.index(cards[i]) * 4 221 | while mask[ind] == 1: 222 | ind += 1 223 | mask[ind] = 1 224 | result[i] = ind 225 | 226 | return result 227 | 228 | @staticmethod 229 | def onehot2color(cards): 230 | result = [] 231 | for i in range(len(cards)): 232 | if cards[i] == 0: 233 | continue 234 | if i == 53: 235 | result.append(56) 236 | else: 237 | result.append(i) 238 | return np.array(result) 239 | 240 | @staticmethod 241 | def onehot2char(cards): 242 | result = [] 243 | for i in range(len(cards)): 244 | if cards[i] == 0: 245 | continue 246 | if i == 53: 247 | result.append(Card.cards[14]) 248 | else: 249 | result.append(Card.cards[i // 4]) 250 | return result 251 | 252 | @staticmethod 253 | def onehot2val(cards): 254 | result = [] 255 | for i in range(len(cards)): 256 | if cards[i] == 0: 257 | continue 258 | if i == 53: 259 | result.append(17) 260 | else: 261 | result.append(i // 4 + 3) 262 | return result 263 | 264 | @staticmethod 265 | def char2value_3_17(cards): 266 | result = [] 267 | if type(cards) is list or type(cards) is range: 268 | for c in cards: 269 | result.append(Card.cards_to_value[c] + 3) 270 | return np.array(result) 271 | else: 272 | return Card.cards_to_value[cards] + 3 273 | 274 | @staticmethod 275 | def to_value(card): 276 | if type(card) is list or type(card) is range: 277 | val = 0 278 | for c in card: 279 | val += Card.cards_to_value[c] 280 | return val 281 | else: 282 | return Card.cards_to_value[card] 283 | 284 | @staticmethod 285 | def to_cards(values): 286 | if type(values) is list or type(values) is range: 287 | cards = [] 288 | for v in values: 289 | cards.append(Card.value_to_cards[v]) 290 | return cards 291 | else: 292 | return Card.value_to_cards[values] 293 | 294 | @staticmethod 295 | def to_cards_from_3_17(values): 296 | return Card.np_cards[values-3].tolist() 297 | 298 | 299 | class CardGroup: 300 | def __init__(self, cards, t, val, len=1): 301 | self.type = t 302 | self.cards = cards 303 | self.value = val 304 | self.len = len 305 | 306 | def bigger_than(self, g): 307 | if self.type == Category.EMPTY: 308 | return g.type != Category.EMPTY 309 | if g.type == Category.EMPTY: 310 | return True 311 | if g.type == Category.BIGBANG: 312 | return False 313 | if self.type == Category.BIGBANG: 314 | return True 315 | if g.type == Category.QUADRIC: 316 | if self.type == Category.QUADRIC and self.value > g.value: 317 | return True 318 | else: 319 | return False 320 | if self.type == Category.QUADRIC or \ 321 | (self.type == g.type and self.len == g.len and self.value > g.value): 322 | return True 323 | else: 324 | return False 325 | 326 | @staticmethod 327 | def isvalid(cards): 328 | return CardGroup.folks(cards) == 1 329 | 330 | @staticmethod 331 | def to_cardgroup(cards): 332 | candidates = CardGroup.analyze(cards) 333 | for c in candidates: 334 | if len(c.cards) == len(cards): 335 | return c 336 | print("cards error!") 337 | print(cards) 338 | raise Exception("Invalid Cards!") 339 | 340 | @staticmethod 341 | def folks(cards): 342 | cand = CardGroup.analyze(cards) 343 | cnt = 10000 344 | # if not cards: 345 | # return 0 346 | # for c in cand: 347 | # remain = list(cards) 348 | # for card in c.cards: 349 | # remain.remove(card) 350 | # if CardGroup.folks(remain) + 1 < cnt: 351 | # cnt = CardGroup.folks(remain) + 1 352 | # return cnt 353 | spec = False 354 | for c in cand: 355 | if c.type == Category.TRIPLE_LINE or c.type == Category.THREE_ONE or \ 356 | c.type == Category.THREE_TWO or c.type == Category.FOUR_TAKE_ONE or \ 357 | c.type == Category.FOUR_TAKE_TWO or c.type == Category.THREE_ONE_LINE or \ 358 | c.type == Category.THREE_TWO_LINE or c.type == Category.SINGLE_LINE or \ 359 | c.type == Category.DOUBLE_LINE: 360 | spec = True 361 | remain = list(cards) 362 | for card in c.cards: 363 | remain.remove(card) 364 | if CardGroup.folks(remain) + 1 < cnt: 365 | cnt = CardGroup.folks(remain) + 1 366 | if not spec: 367 | cnt = len(cand) 368 | return cnt 369 | 370 | @staticmethod 371 | def analyze(cards): 372 | cards = list(cards) 373 | if len(cards) == 0: 374 | return [CardGroup([], Category.EMPTY, 0)] 375 | candidates = [] 376 | 377 | # TODO: this does not rule out Nuke kicker 378 | counts = Counter(cards) 379 | if '*' in cards and '$' in cards: 380 | candidates.append((CardGroup(['*', '$'], Category.BIGBANG, 100))) 381 | # cards.remove('*') 382 | # cards.remove('$') 383 | 384 | quadrics = [] 385 | # quadric 386 | for c in counts: 387 | if counts[c] == 4: 388 | quadrics.append(c) 389 | candidates.append(CardGroup([c] * 4, Category.QUADRIC, Card.to_value(c))) 390 | cards = list(filter(lambda a: a != c, cards)) 391 | 392 | counts = Counter(cards) 393 | singles = [c for c in counts if counts[c] == 1] 394 | doubles = [c for c in counts if counts[c] == 2] 395 | triples = [c for c in counts if counts[c] == 3] 396 | 397 | singles.sort(key=lambda k: Card.cards_to_value[k]) 398 | doubles.sort(key=lambda k: Card.cards_to_value[k]) 399 | triples.sort(key=lambda k: Card.cards_to_value[k]) 400 | 401 | # continuous sequence 402 | if len(singles) > 0: 403 | cnt = 1 404 | cand = [singles[0]] 405 | for i in range(1, len(singles)): 406 | if Card.to_value(singles[i]) >= Card.to_value('2'): 407 | break 408 | if Card.to_value(singles[i]) == Card.to_value(cand[-1]) + 1: 409 | cand.append(singles[i]) 410 | cnt += 1 411 | else: 412 | if cnt >= 5: 413 | candidates.append(CardGroup(cand, Category.SINGLE_LINE, Card.to_value(cand[0]), cnt)) 414 | # for c in cand: 415 | # cards.remove(c) 416 | cand = [singles[i]] 417 | cnt = 1 418 | if cnt >= 5: 419 | candidates.append(CardGroup(cand, Category.SINGLE_LINE, Card.to_value(cand[0]), cnt)) 420 | # for c in cand: 421 | # cards.remove(c) 422 | 423 | if len(doubles) > 0: 424 | cnt = 1 425 | cand = [doubles[0]] * 2 426 | for i in range(1, len(doubles)): 427 | if Card.to_value(doubles[i]) >= Card.to_value('2'): 428 | break 429 | if Card.to_value(doubles[i]) == Card.to_value(cand[-1]) + 1: 430 | cand += [doubles[i]] * 2 431 | cnt += 1 432 | else: 433 | if cnt >= 3: 434 | candidates.append(CardGroup(cand, Category.DOUBLE_LINE, Card.to_value(cand[0]), cnt)) 435 | # for c in cand: 436 | # if c in cards: 437 | # cards.remove(c) 438 | cand = [doubles[i]] * 2 439 | cnt = 1 440 | if cnt >= 3: 441 | candidates.append(CardGroup(cand, Category.DOUBLE_LINE, Card.to_value(cand[0]), cnt)) 442 | # for c in cand: 443 | # if c in cards: 444 | # cards.remove(c) 445 | 446 | if len(triples) > 0: 447 | cnt = 1 448 | cand = [triples[0]] * 3 449 | for i in range(1, len(triples)): 450 | if Card.to_value(triples[i]) >= Card.to_value('2'): 451 | break 452 | if Card.to_value(triples[i]) == Card.to_value(cand[-1]) + 1: 453 | cand += [triples[i]] * 3 454 | cnt += 1 455 | else: 456 | if cnt >= 2: 457 | candidates.append(CardGroup(cand, Category.TRIPLE_LINE, Card.to_value(cand[0]), cnt)) 458 | # for c in cand: 459 | # if c in cards: 460 | # cards.remove(c) 461 | cand = [triples[i]] * 3 462 | cnt = 1 463 | if cnt >= 2: 464 | candidates.append(CardGroup(cand, Category.TRIPLE_LINE, Card.to_value(cand[0]), cnt)) 465 | # for c in cand: 466 | # if c in cards: 467 | # cards.remove(c) 468 | 469 | for t in triples: 470 | candidates.append(CardGroup([t] * 3, Category.TRIPLE, Card.to_value(t))) 471 | 472 | counts = Counter(cards) 473 | singles = [c for c in counts if counts[c] == 1] 474 | doubles = [c for c in counts if counts[c] == 2] 475 | 476 | # single 477 | for s in singles: 478 | candidates.append(CardGroup([s], Category.SINGLE, Card.to_value(s))) 479 | 480 | # double 481 | for d in doubles: 482 | candidates.append(CardGroup([d] * 2, Category.DOUBLE, Card.to_value(d))) 483 | 484 | # 3 + 1, 3 + 2 485 | for c in triples: 486 | triple = [c] * 3 487 | for s in singles: 488 | if s not in triple: 489 | candidates.append(CardGroup(triple + [s], Category.THREE_ONE, 490 | Card.to_value(c))) 491 | for d in doubles: 492 | if d not in triple: 493 | candidates.append(CardGroup(triple + [d] * 2, Category.THREE_TWO, 494 | Card.to_value(c))) 495 | 496 | # 4 + 2 497 | for c in quadrics: 498 | for extra in list(itertools.combinations(singles, 2)): 499 | candidates.append(CardGroup([c] * 4 + list(extra), Category.FOUR_TAKE_ONE, 500 | Card.to_value(c))) 501 | for extra in list(itertools.combinations(doubles, 2)): 502 | candidates.append(CardGroup([c] * 4 + list(extra) * 2, Category.FOUR_TAKE_TWO, 503 | Card.to_value(c))) 504 | # 3 * n + n, 3 * n + 2 * n 505 | triple_seq = [c.cards for c in candidates if c.type == Category.TRIPLE_LINE] 506 | for cand in triple_seq: 507 | cnt = int(len(cand) / 3) 508 | for extra in list(itertools.combinations(singles, cnt)): 509 | candidates.append( 510 | CardGroup(cand + list(extra), Category.THREE_ONE_LINE, 511 | Card.to_value(cand[0]), cnt)) 512 | for extra in list(itertools.combinations(doubles, cnt)): 513 | candidates.append( 514 | CardGroup(cand + list(extra) * 2, Category.THREE_TWO_LINE, 515 | Card.to_value(cand[0]), cnt)) 516 | 517 | importance = [Category.EMPTY, Category.SINGLE, Category.DOUBLE, Category.DOUBLE_LINE, Category.SINGLE_LINE, Category.THREE_ONE, 518 | Category.THREE_TWO, Category.THREE_ONE_LINE, Category.THREE_TWO_LINE, 519 | Category.TRIPLE_LINE, Category.TRIPLE, Category.FOUR_TAKE_ONE, Category.FOUR_TAKE_TWO, 520 | Category.QUADRIC, Category.BIGBANG] 521 | candidates.sort(key=functools.cmp_to_key(lambda x, y: importance.index(x.type) - importance.index(y.type) 522 | if importance.index(x.type) != importance.index(y.type) else x.value - y.value)) 523 | # for c in candidates: 524 | # print c.cards 525 | return candidates 526 | 527 | action_space = get_action_space() 528 | action_space_onehot60 = np.array([Card.char2onehot60(a) for a in action_space]) 529 | action_space_category = [action_space[r[0]:r[1]] for r in Category2Range] 530 | 531 | augment_action_space = action_space + action_space_category[Category.SINGLE][:13] * 3 + action_space_category[Category.DOUBLE] 532 | 533 | extra_actions = [] 534 | for j in range(3): 535 | for i in range(13): 536 | tmp = np.zeros([60]) 537 | tmp[i * 4 + j + 1] = 1 538 | extra_actions.append(tmp) 539 | 540 | for i in range(13): 541 | tmp = np.zeros([60]) 542 | tmp[i * 4 + 2:i * 4 + 4] = 1 543 | extra_actions.append(tmp) 544 | 545 | augment_action_space_onehot60 = np.concatenate([action_space_onehot60, np.stack(extra_actions)], 0) 546 | 547 | 548 | def clamp_action_idx(idx): 549 | len_action = len(action_space) 550 | if idx < len_action: 551 | return idx 552 | if idx >= len_action + 13 * 3: 553 | idx = idx - len_action - 13 * 3 + 16 554 | else: 555 | idx = (idx - len_action) % 13 + 1 556 | return idx 557 | 558 | 559 | if __name__ == '__main__': 560 | pass 561 | # print(Card.val2onehot60([3, 3, 16, 17])) 562 | # print(Category2Range) 563 | print(len(action_space_category)) 564 | print(CardGroup.to_cardgroup(['6', '6', 'Q', 'Q', 'Q']).value) 565 | 566 | # print(len(action_space)) 567 | # for a in action_space: 568 | # assert len(a) <= 20 569 | # if len(a) > 0: 570 | # CardGroup.to_cardgroup(a) 571 | # print(a) 572 | # print(action_space_category[Category.SINGLE_LINE.value]) 573 | # print(action_space_category[Category.DOUBLE_LINE.value]) 574 | # print(action_space_category[Category.THREE_ONE.value]) 575 | # CardGroup.to_cardgroup(['6', '6', 'Q', 'Q', 'Q']) 576 | # actions = get_action_space() 577 | # for i in range(1, len(actions)): 578 | # CardGroup.to_cardgroup(actions[i]) 579 | # print(CardGroup.folks(['3', '4', '3', '4', '3', '4', '*', '$'])) 580 | # CardGroup.to_cardgroup(['3', '4', '3', '4', '3', '4', '*', '$']) 581 | # print actions[561] 582 | # print CardGroup.folks(actions[561]) 583 | # CardGroup.to_cardgroup(actions[i]) 584 | # print Card.to_onehot(['3', '4', '4', '$']) 585 | # print len(actions) 586 | # print Card.to_cards(1) 587 | # CardGroup.analyze(['3', '3', '3', '4', '4', '4', '10', 'J', 'Q', 'A', 'A', '2', '2', '*', '$']) -------------------------------------------------------------------------------- /server/rule_utils/decomposer.py: -------------------------------------------------------------------------------- 1 | # https://github.com/qq456cvb/doudizhu-C 2 | from server.rule_utils.card import Card, action_space, CardGroup, augment_action_space_onehot60, augment_action_space, clamp_action_idx 3 | from server.rule_utils.utils import get_mask_onehot60 4 | import numpy as np 5 | from envi import env 6 | 7 | 8 | class Decomposer: 9 | def __init__(self, num_actions=(100, 21)): 10 | self.num_actions = num_actions 11 | 12 | def get_combinations(self, curr_cards_char, last_cards_char): 13 | if len(curr_cards_char) > 10: 14 | card_mask = Card.char2onehot60(curr_cards_char).astype(np.uint8) 15 | mask = augment_action_space_onehot60 16 | a = np.expand_dims(1 - card_mask, 0) * mask 17 | invalid_row_idx = set(np.where(a > 0)[0]) 18 | if len(last_cards_char) == 0: 19 | invalid_row_idx.add(0) 20 | 21 | valid_row_idx = [i for i in range(len(augment_action_space)) if i not in invalid_row_idx] 22 | 23 | mask = mask[valid_row_idx, :] 24 | idx_mapping = dict(zip(range(mask.shape[0]), valid_row_idx)) 25 | 26 | # augment mask 27 | # TODO: known issue: 555444666 will not decompose into 5554 and 66644 28 | combs = env.get_combinations_nosplit(mask, card_mask) 29 | combs = [([] if len(last_cards_char) == 0 else [0]) + [clamp_action_idx(idx_mapping[idx]) for idx in comb] for 30 | comb in combs] 31 | 32 | if len(last_cards_char) > 0: 33 | idx_must_be_contained = set( 34 | [idx for idx in valid_row_idx if CardGroup.to_cardgroup(augment_action_space[idx]). \ 35 | bigger_than(CardGroup.to_cardgroup(last_cards_char))]) 36 | combs = [comb for comb in combs if not idx_must_be_contained.isdisjoint(comb)] 37 | fine_mask = np.zeros([len(combs), self.num_actions[1]], dtype=np.bool) 38 | for i in range(len(combs)): 39 | for j in range(len(combs[i])): 40 | if combs[i][j] in idx_must_be_contained: 41 | fine_mask[i][j] = True 42 | else: 43 | fine_mask = None 44 | else: 45 | mask = get_mask_onehot60(curr_cards_char, action_space, None).reshape(len(action_space), 15, 4).sum(-1).astype( 46 | np.uint8) 47 | valid = mask.sum(-1) > 0 48 | cards_target = Card.char2onehot60(curr_cards_char).reshape(-1, 4).sum(-1).astype(np.uint8) 49 | # do not feed empty to C++, which will cause infinite loop 50 | combs = env.get_combinations_recursive(mask[valid, :], cards_target) 51 | idx_mapping = dict(zip(range(valid.shape[0]), np.where(valid)[0])) 52 | 53 | combs = [([] if len(last_cards_char) == 0 else [0]) + [idx_mapping[idx] for idx in comb] for comb in combs] 54 | 55 | if len(last_cards_char) > 0: 56 | valid[0] = True 57 | idx_must_be_contained = set( 58 | [idx for idx in range(len(action_space)) if valid[idx] and CardGroup.to_cardgroup(action_space[idx]). \ 59 | bigger_than(CardGroup.to_cardgroup(last_cards_char))]) 60 | combs = [comb for comb in combs if not idx_must_be_contained.isdisjoint(comb)] 61 | fine_mask = np.zeros([len(combs), self.num_actions[1]], dtype=np.bool) 62 | for i in range(len(combs)): 63 | for j in range(len(combs[i])): 64 | if combs[i][j] in idx_must_be_contained: 65 | fine_mask[i][j] = True 66 | else: 67 | fine_mask = None 68 | return combs, fine_mask 69 | -------------------------------------------------------------------------------- /server/rule_utils/evaluator.py: -------------------------------------------------------------------------------- 1 | # https://www.jianshu.com/p/9fb001daedcf 2 | from server.rule_utils.card import action_space_category 3 | 4 | char2val = { 5 | "3": 3, "4": 4, "5": 5, "6": 6, 6 | "7": 7, "8": 8, "9": 9, "10": 10, 7 | "J": 11, "Q": 12, "K": 13, "A": 14, 8 | "2": 15, "*": 16, "$": 17 9 | } 10 | 11 | dapai = [] 12 | i = 9 13 | for i in range(11, 15): 14 | dapai.append(sorted(action_space_category[1][i])) 15 | for i in range(8, 13): 16 | dapai.append(sorted(action_space_category[2][i])) 17 | dapai.append(sorted(action_space_category[3][i])) 18 | for i in range(112, 182): 19 | dapai.append(sorted(action_space_category[5][i])) 20 | for i in range(96, 156): 21 | dapai.append(sorted(action_space_category[6][i])) 22 | for a in action_space_category[7]: 23 | if len(a) >= 7: 24 | dapai.append(sorted(a)) 25 | for i in [4, 8, 9, 10, 11, 13, 14]: 26 | for a in action_space_category[i]: 27 | dapai.append(sorted(a)) 28 | 29 | cards_value = [] 30 | for c in range(len(action_space_category)): 31 | for a in action_space_category[c]: 32 | v = None 33 | if c == 0: 34 | v = 0 35 | elif c <= 3: # 1单牌, 2对子, 3三条 36 | v = char2val[a[0]] - 10 # maxCard - 10 37 | if c == 2 and v > 0: 38 | if a == ['2', '2']: 39 | v *= 1.2 40 | elif a == ['A', 'A']: 41 | v *= 1.3 42 | else: 43 | v *= 1.4 # positive + 50% 44 | if c == 3 and v > 0: 45 | if a == ['2', '2', '2']: 46 | v *= 1 47 | elif a == ['A', 'A', 'A']: 48 | v *= 1.5 49 | else: 50 | v *= 1.8 # positive + 100% 51 | elif c == 4: # 4炸弹 52 | if a == ['2', '2', '2', '2']: 53 | v = 7 54 | else: 55 | v = 9 # 固定9分 56 | elif c <= 6: # 5三带一, 6三带二 57 | v = char2val[a[0]] - 10 # maxCard - 10 58 | if v > 0: 59 | if a[:3] == ['2', '2', '2']: 60 | v *= 1 61 | elif a == ['A', 'A', 'A']: 62 | v *= 1.3 63 | else: 64 | v *= 1.5 # 带牌比三条加得少 65 | elif c <= 9: # 7顺子, 8连对, 9飞机 66 | v = max(0, (char2val[a[-1]] - 10) / 2) # max(0, (maxCard - 10) / 2) 67 | elif c == 10: # 10飞机带小 68 | main_len = len(a) // 4 * 3 69 | v = max(0, (char2val[a[-1]] - 10) / 2) # max(0, (maxCard - 10) / 2) 70 | for i in range(main_len, len(a)): 71 | if char2val[a[i]] > 10: 72 | v += char2val[a[i]] - 10 # 带牌为正加上 73 | elif c == 11: # 11飞机带大 74 | main_len = len(a) // 5 * 3 75 | v = max(0, (char2val[a[-1]] - 10) / 2) # max(0, (maxCard - 10) / 2) 76 | for i in range(main_len, main_len + main_len // 3): 77 | if char2val[a[i]] > 10: 78 | v += 1.5 * (char2val[a[i]] - 10) # 带牌为正加上 79 | elif c == 12: # 12火箭 80 | v = 12 81 | elif c <= 14: # 13四带二只, 14四带二对 82 | v = char2val[a[0]] - 10 # maxCard - 10 83 | if v > 0: 84 | if a[:4] == ['2', '2', '2', '2']: 85 | v *= 1 86 | elif a[:4] == ['A', 'A', 'A', 'A']: 87 | v *= 1.2 88 | else: 89 | v *= 1.5 # 带牌比三条加得少 90 | assert v is not None 91 | 92 | cards_value.append(v) 93 | -------------------------------------------------------------------------------- /server/rule_utils/rule_based_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from server.rule_utils.card import action_space 4 | from server.rule_utils.decomposer import Decomposer 5 | from server.rule_utils.evaluator import cards_value, dapai 6 | 7 | card_list = [ 8 | "3", "4", "5", "6", 9 | "7", "8", "9", "10", 10 | "J", "Q", "K", "A", 11 | "2", "*", "$" 12 | ] 13 | 14 | 15 | # 返回当前局面上剩余最大的单牌(没考虑炸弹) 16 | def maxcard(other_hand, l, ll): 17 | max = 0 18 | # for i, j in enumerate(other_hand): 19 | # if j != 0: 20 | # max = i 21 | for i in other_hand[l]: 22 | if i > max: 23 | max = i 24 | for j in other_hand[ll]: 25 | if j > max: 26 | max = j 27 | return max - 2 28 | 29 | 30 | def choose(payload3): 31 | # 获得手牌 32 | hand_card = payload3['cur_cards'] 33 | 34 | # 拆牌器和引擎用了不同的编码 1 -> A, B -> *, R -> $ 35 | trans_hand_card = [card_list[i - 3] for i in hand_card] 36 | # 获得上家出牌 37 | role_id = payload3['role_id'] 38 | lrole_id = (role_id - 1 + 3) % 3 # 上家ID 39 | llrole_id = (role_id - 2 + 3) % 3 # 上上家ID 40 | last_pid = lrole_id # 上一个有效出牌 41 | tlast_move = payload3['last_taken'][last_pid] 42 | if not tlast_move: 43 | last_pid = llrole_id 44 | tlast_move = payload3['last_taken'][last_pid] 45 | last_move = [card_list[i - 3] for i in tlast_move] 46 | # last_move = [card_list[i] for i in range(15) for _ in range(state.last_move[i])] 47 | # 拆牌 48 | D = Decomposer() 49 | combs, fine_mask = D.get_combinations(trans_hand_card, last_move) 50 | # 根据对手剩余最少牌数决定每多一手牌的惩罚 51 | # left_crads = [sum(p.get_hand_card()) for p in self.game.players] 52 | # min_oppo_crads = min(left_crads[1], left_crads[2]) if self.player_id == 0 else left_crads[0] 53 | min_oppo_crads = min(payload3['left'][lrole_id], payload3['left'][llrole_id]) 54 | round_penalty = 17 - 12 * min_oppo_crads / 20 # 惩罚值调整为与敌人最少手牌数负线性相关 55 | 56 | if not last_move: 57 | if role_id == 0: # 地主 58 | round_penalty += 7 59 | elif role_id == 1: # 地主下家 60 | round_penalty += 5 61 | else: # 地主上家 62 | round_penalty += 3 63 | 64 | if role_id == 2 and not last_move: # 队友没要地主牌 65 | round_penalty += 5 66 | if role_id == 1 and not last_move: # 地主没要队友牌 67 | round_penalty -= 8 68 | 69 | # 寻找最优出牌 70 | best_move = None 71 | max_value = -np.inf 72 | for i in range(len(combs)): 73 | # 手牌总分 74 | total_value = sum([cards_value[x] for x in combs[i]]) 75 | # small_num = 0 76 | # for j in range(0, len(combs[i])): 77 | # if j > 0 and action_space[j][0] not in ["2", "R", "B"]: 78 | # small_num += 1 79 | # total_value -= small_num * round_penalty 80 | small_num = hand_card[-1] + hand_card[-2] + hand_card[-3] 81 | small_num = (len(combs[i]) - small_num) # 如果一手牌为小牌, 需要加上惩罚值, 所以要统计小牌数量 82 | total_value -= small_num * round_penalty 83 | 84 | # 手里有火箭和另一手牌 85 | if len(combs[i]) == 3 and combs[i][0] == 0 or len(combs[i]) == 2: 86 | if cards_value[combs[i][-1]] == 12 or cards_value[combs[i][-2]] == 12: 87 | print('*****rule 火箭直接走') 88 | return [0] * 13 + [1, 1], None 89 | 90 | # 下家农民手里只有一张牌,送队友走 91 | # if role_id == 1 and sum(self.game.players[2].get_hand_card()) == 1 and not last_move: 92 | if role_id == 1 and payload3['left'][2] == 1 and not last_move: 93 | for i, j in enumerate(hand_card): 94 | if j != 0: 95 | tem = [0] * 15 96 | tem[i] = 1 97 | print('******rule 下家农民手里只有一张牌,送队友走') 98 | return tem, None 99 | 100 | # 队友出大牌能走就压 101 | if role_id == 2 and len(combs[i]) == 3 and combs[i][0] == 0: 102 | if action_space[combs[i][1]] in dapai and (fine_mask is None or fine_mask[i, 1] == True): 103 | print('******rule 队友出大牌能走就压') 104 | best_move = combs[i][1] 105 | break 106 | elif action_space[combs[i][2]] in dapai and (fine_mask is None or fine_mask[i, 2] == True): 107 | print('******rule 队友出大牌能走就压') 108 | best_move = combs[i][2] 109 | break 110 | 111 | # 队友出大牌走不了就不压 112 | if role_id == 2 and last_pid == 1 and sorted(last_move) in dapai: 113 | print('******rule 队友出大牌走不了就不压') 114 | best_move = 0 115 | break 116 | 117 | for j in range(0, len(combs[i])): 118 | # Pass 得分 119 | if combs[i][j] == 0 and min_oppo_crads > 8: 120 | if total_value > max_value: 121 | max_value = total_value 122 | best_move = 0 123 | # print('pass得分',max_value,end=' // ') 124 | # 出牌得分 125 | elif combs[i][j] > 0 and (fine_mask is None or fine_mask[i, j] == True): # 枚举非pass且fine_mask为True的出牌 126 | # 特判只有一手 127 | if len(combs[i]) == 1 or len(combs[i]) == 2 and combs[i][0] == 0: 128 | max_value = np.inf 129 | best_move = combs[i][-1] 130 | break 131 | 132 | move_value = total_value - cards_value[combs[i][j]] + round_penalty 133 | 134 | # 手里有当前最大牌和另一手牌 135 | if len(combs[i]) == 3 and combs[i][0] == 0 or len(combs[i]) == 2: 136 | if combs[i][j] > maxcard(payload3['hand_card'], lrole_id, llrole_id) and combs[i][j] <= 15: 137 | move_value += 100 138 | 139 | # 地主只剩一张牌时别出单牌 140 | # if role_id != 0 and sum(self.game.players[0].get_hand_card()) == 1: 141 | if role_id != 0 and payload3['left'][0] == 1: 142 | if combs[i][j] <= maxcard(payload3['hand_card'], lrole_id, llrole_id): 143 | move_value -= 100 144 | 145 | # 农民只剩一张牌时别出单牌 146 | if role_id == 0 and (payload3['left'][1] == 1 or payload3['left'][2] == 1): 147 | if combs[i][j] <= maxcard(payload3['hand_card'], lrole_id, llrole_id): 148 | move_value -= 100 149 | 150 | if move_value > max_value: 151 | max_value = move_value 152 | best_move = combs[i][j] 153 | if best_move is None: 154 | best_move = 0 155 | 156 | # 最优出牌 157 | best_cards = action_space[best_move] 158 | move = [best_cards.count(x) for x in card_list] 159 | # print('出牌得分', max_value) 160 | # 输出选择的牌组 161 | # print("\nbest comb: ") 162 | # for m in best_comb: 163 | # print(action_space[m], cards_value[m]) 164 | # 输出 player i [手牌] // [出牌] 165 | # print("Player {}".format(role_id), ' ', Card.visual_card(hand_card), end=' // ') 166 | # print(Card.visual_card(move)) 167 | return move 168 | 169 | 170 | if __name__ == "__main__": 171 | payload = { 172 | 'role_id': 1, # 0代表地主上家,1代表地主,2代表地主下家 173 | 'last_taken': { # 更改处 174 | 0: [], 175 | 1: [9, 9, 9, 6], 176 | 2: [], 177 | }, 178 | 'cur_cards': [17, 16, 15, 14, 14, 12, 10], # 无需保持顺序 179 | 'history': { # 各家走过的牌的历史The environment 180 | 0: [], 181 | 1: [5, 5, 5, 4, 4, 3, 3, 3, 3, 9, 9, 9, 6], 182 | 2: [11, 11, 11, 8, 8], 183 | }, 184 | 'left': { # 各家剩余的牌 185 | 0: 17, 186 | 1: 7, 187 | 2: 12, 188 | }, 189 | 'hand_card': { 190 | 0: [15, 14, 13, 13, 12, 10, 10, 9, 8, 8, 7, 7, 7, 6, 6, 6, 4], 191 | 1: [17, 16, 15, 14, 14, 12, 10], 192 | 2: [15, 15, 14, 13, 13, 12, 12, 11, 10, 7, 5, 4], 193 | }, 194 | 'debug': False, # 是否返回debug 195 | } 196 | import time 197 | 198 | start = time.time() 199 | print(choose(payload)) 200 | end = time.time() 201 | print(end - start) 202 | -------------------------------------------------------------------------------- /server/rule_utils/utils.py: -------------------------------------------------------------------------------- 1 | import server.rule_utils.card as card 2 | from server.rule_utils.card import action_space 3 | import numpy as np 4 | from collections import Counter 5 | 6 | 7 | action_space_single = action_space[1:16] 8 | action_space_pair = action_space[16:29] 9 | action_space_triple = action_space[29:42] 10 | action_space_quadric = action_space[42:55] 11 | 12 | 13 | def counter_subset(list1, list2): 14 | c1, c2 = Counter(list1), Counter(list2) 15 | for (k, n) in c1.items(): 16 | if n > c2[k]: 17 | return False 18 | return True 19 | 20 | 21 | def get_mask_onehot60(cards, action_space, last_cards): 22 | # 1 valid; 0 invalid 23 | mask = np.zeros([len(action_space), 60]) 24 | if cards is None: 25 | return mask 26 | if len(cards) == 0: 27 | return mask 28 | for j in range(len(action_space)): 29 | if counter_subset(action_space[j], cards): 30 | mask[j] = card.Card.char2onehot60(action_space[j]) 31 | if last_cards is None: 32 | return mask 33 | if len(last_cards) > 0: 34 | for j in range(1, len(action_space)): 35 | if np.sum(mask[j]) > 0 and not card.CardGroup.to_cardgroup(action_space[j]).\ 36 | bigger_than(card.CardGroup.to_cardgroup(last_cards)): 37 | mask[j] = np.zeros([60]) 38 | return mask 39 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from game import Game 2 | from envi import Env, EnvComplicated, EnvCooperation, EnvCooperationSimplify 3 | from net import NetComplicated, NetMoreComplicated, NetCooperation, NetCooperationSimplify 4 | from dqn import DQNFirst 5 | 6 | 7 | def e1(): 8 | net_dict = { 9 | 'lord': NetMoreComplicated, 10 | } 11 | dqn_dict = { 12 | 'lord': DQNFirst, 13 | } 14 | model_dict = { 15 | 'lord': '0805_1409_lord_4000', 16 | } 17 | wins = Game.compete(EnvComplicated, net_dict, dqn_dict, model_dict, 18 | total=2000, print_every=100, debug=False) 19 | print(wins) 20 | 21 | 22 | def e2(): 23 | net_dict = { 24 | 'lord': NetComplicated, 25 | } 26 | dqn_dict = { 27 | 'lord': DQNFirst, 28 | } 29 | model_dict = { 30 | 'lord': '0804_2022_lord_scratch3000', 31 | } 32 | wins = Game.compete(Env, net_dict, dqn_dict, model_dict, 33 | total=1000, print_every=100, debug=False) 34 | print(wins) 35 | 36 | 37 | def e_0806_1906_lord(): 38 | net_dict = { 39 | 'lord': NetMoreComplicated, 40 | } 41 | dqn_dict = { 42 | 'lord': DQNFirst, 43 | } 44 | wins = {} 45 | for model in [3, 4]: 46 | model_dict = { 47 | 'lord': '0806_1906_lord_{}000'.format(model), 48 | } 49 | win = Game.compete(EnvComplicated, net_dict, dqn_dict, model_dict, 50 | total=1000, print_every=100, debug=False) 51 | wins[model] = win 52 | return wins 53 | 54 | 55 | def e_0807_1340(): 56 | net_dict = { 57 | 'lord': None, 58 | 'down': NetCooperation, 59 | 'up': NetCooperation, 60 | } 61 | dqn_dict = { 62 | 'lord': None, 63 | 'down': DQNFirst, 64 | 'up': DQNFirst, 65 | } 66 | model_dict = { 67 | 'lord': None, 68 | 'down': '0807_1344_down_3000', 69 | 'up': '0807_1344_up_3000', 70 | } 71 | win = Game.compete(EnvCooperation, net_dict, dqn_dict, model_dict, 72 | total=1000, print_every=100, debug=False) 73 | return win 74 | 75 | 76 | def e0808(): 77 | net_dict = { 78 | 'lord': None, 79 | 'down': NetCooperationSimplify, 80 | 'up': NetCooperationSimplify, 81 | } 82 | dqn_dict = { 83 | 'lord': None, 84 | 'down': DQNFirst, 85 | 'up': DQNFirst, 86 | } 87 | model_dict = { 88 | 'lord': None, 89 | 'down': '0808_0854_down_6000', 90 | 'up': '0808_0854_up_6000', 91 | } 92 | win = Game.compete(EnvCooperationSimplify, net_dict, dqn_dict, model_dict, 93 | total=1000, print_every=100, debug=False) 94 | return win 95 | 96 | 97 | def e_ensemble(): 98 | # 纯RL,58.7%,1000把 99 | from ensemble import Game 100 | net_dict = { 101 | 'lord': None, 102 | 'down': None, 103 | 'up': None, 104 | } 105 | dqn_dict = { 106 | 'lord': None, 107 | 'down': None, 108 | 'up': None, 109 | } 110 | model_dict = { 111 | 'lord': None, 112 | 'down': None, 113 | 'up': None, 114 | } 115 | win = Game.ensemble_compete(EnvCooperationSimplify, net_dict, dqn_dict, model_dict, 116 | total=1000, print_every=1, debug=False) 117 | return win 118 | 119 | 120 | # if __name__ == '__main__': 121 | res = e_ensemble() # 246:胜利148 122 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from game import Game 2 | 3 | if __name__ == '__main__': 4 | from envi import EnvCooperation 5 | from net import NetCooperation 6 | from dqn import DQNFirst 7 | 8 | net_dict = { 9 | 'lord': NetCooperation, 10 | 'down': None, 11 | 'up': None, 12 | } 13 | dqn_dict = { 14 | 'lord': DQNFirst, 15 | 'down': None, 16 | 'up': None, 17 | } 18 | reward_dict = { 19 | 'lord': 100, 20 | 'down': None, 21 | 'up': None, 22 | } 23 | train_dict = { 24 | 'lord': True, 25 | 'up': False, 26 | 'down': False, 27 | } 28 | game = Game(EnvCooperation, net_dict, dqn_dict, 29 | reward_dict=reward_dict, train_dict=train_dict, 30 | debug=True) 31 | game.train(20, 5, 10) 32 | --------------------------------------------------------------------------------