├── model_handler.py ├── README.md ├── filereader.py ├── env.py ├── DL_agent.py └── PolicyGradient.py /model_handler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @file: filereader.py 5 | @time: ???? 6 | 缩小模型供botzone使用 7 | ''' 8 | import torch 9 | import argparse 10 | 11 | parser = argparse.ArgumentParser(description='Building the smaller model') 12 | parser.add_argument('-o', '--old_path', type=str, default='../models/super_model_2', help='path to original model') 13 | parser.add_argument('-n', '--new_path', type=str, default='../models/super_model_small', help='path to smaller model') 14 | 15 | args = parser.parse_args() 16 | checkpoint = torch.load(args.old_path) 17 | state = {'model': checkpoint['model']} 18 | torch.save(state, args.new_path, _use_new_zipfile_serialization=False) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Botzone国标麻将深度学习&强化学习 2 | 3 | **北京大学强化学习课程2020秋季学期课程大作业** 4 | 5 | 6 | ## 文件结构简介 7 | - 深度学习部分文件为`filereader.py`和`DL_agent.py`,前者用于将初始训练数据转化为与botzone相同的输入输出,后者用于训练DL模型,并且可以调整参数变更为可以在botzone运行的bot。 8 | - 强化学习部分文件为`PolicyGradient.py`,`A3C.py`以及`PolicyGradient_naive.py`,前者为主要研究的部分,后两者为初期尝试的代码,没有经过很好的debug和维护。 9 | - 工具文件为`model_handler.py`,用于将模型数据缩小以在botzone运行。 10 | - 数据文件中包括人类对局原始数据,深度学习的训练数据(即filereader.py得到的结果),训练得到的模型三部分。深度学习得到的初始模型为models文件夹中的super_model_2,经过强化学习训练得到的模型为rl_pg_new,用于botzone的模型为super_model_small。 11 | + 文件连接:[百度网盘](https://pan.baidu.com/s/1wpPBHq3MRngMQx9EAS6-aw ) 12 | + 提取码:agmm 13 | 14 | 15 | ## 代码运行方式 16 | **所有代码均支持命令行运行,使用`python xxx.py`即可在默认参数下运行,具体参数解释如下** 17 | 18 | #### filereader.py 19 | - `-h, --help` 展示运行帮助信息 20 | - `-lp, --load_path` 原始训练数据所在文件夹 21 | - `-sp, --save_path` 处理后的训练数据输出文件夹,需要事先创建 22 | - `-tn, --thread_number, default=32` 线程总数,线程i处理的文件编号为 i + k * thread_number, k = 0, 1, 2... 23 | - `-tr, --thread_round, default=4` 每个线程运行多少轮, 实际线程数量为tn/tr, 推荐该值和cpu实际核数相同 24 | 25 | #### DL_agent.py 26 | 该文件可以上传botzone作为bot使用 27 | - 试运行推荐`python DL_agent.py -t -l -lp path_to_pretrained_model` 28 | - `-h, --help` 展示运行帮助信息 29 | - `-t, --train, default=False` 是否训练模型 30 | - `-s, --save, default=False` 是否保存模型 31 | - `-l, --load, default=False` 是否加载预训练的模型 32 | - `-lp, --load_path` 加载模型的路径 33 | - `-sp, --save_path` 保存模型的路径 34 | - `-b, --batch_size, default=1000` 训练用的batch大小 35 | - `--training_data` 训练数据所在文件夹,和filereader.py中的输出文件夹相同 36 | - `-pi, --print_interval, default=2000` 每隔这么多局输出总体预测正确率 37 | - `-si, --save_interval, default=5000` 每隔这么多局保存模型 38 | - `--botzone, default=False` 该选项不能在命令行设置。如果将该文件传到botzone作为bot,需要在程序中设置`parser.add_argument('--botzone', action='store_true', default=True, help='whether to run the model on botzone')` 39 | 40 | #### PolicyGradient.py 41 | - 试运行推荐`python PolicyGradient.py -p 1 -o path_to_pretrained_model -n path_to_pretrained_model -S path_to_rl_model -s -lp none -bs 500 -lr 8e-6` 42 | - 由于其余两种强化学习的尝试均效果不佳且速度很慢,故没有在后续进行维护。其运行参数基本上被该方法包含,不再在此列出,详情可以通过`python A3C.py -h`以及`python PolicyGradient_naive.py`进行查看。 43 | - `-h, --help` 展示运行帮助信息 44 | - `-lr, --learning_rate, default=1e-5` 学习率 45 | - `-s, --save, default=False` 是否保存模型 46 | - `-o, --old_path` 预训练模型路径,使用这个模型的bot用于作为对手,不会被训练 47 | - `-n, --new_path` 预训练模型路径,使用这个模型的bot用于训练 48 | - `-S, --save_path` 保存模型路径,设置-s后才有效果 49 | - `-p, --num_process_per_gpu, default=1` 每个GPU上运行多少进程,如果使用CPU则为总进程数 50 | - `-rn, --round_number, default=10` 每个episode运行多少个不同的对局 51 | - `-rt, --repeated_times, default=10` 每个对局重复多少次,即每个episode并发运行rn * rt场游戏 52 | - `-ti, --train_interval, default=2` 每隔多少episodes进行训练,即更新全局模型。如果设置过大的ti,会导致manager线程保存数据过多而引发错误,具体原因未知 53 | - `-ji, --join_interval, default=2` 训练多少次后将全局模型更新到所有本地模型,即经过ti * ji个episodes 54 | - `-pi, --print_interval, default=10` 每隔多少episodes进行输出 55 | - `-si, --save_interval, default=20` 每隔多少episodes保存模型 56 | - `--eps_clip, default=0.3` 重要性采样中的比例截断 57 | - `--entropy_weight, default=1e-3` 初始的entropy loss权重 58 | - `--entropy_target, default=1e-4` 平均entropy的目标值,用于动态更新entropy loss权重 59 | - `--entropy_step, default=0.01` 动态更新的步长 60 | - `-lp, --log_path` 保存输出的文件路径,设为none则不保存输出。不影响标准输出中的输出 61 | - `-e, --epochs, default=1` 训练时重复利用训练数据的轮次。各种资料表明不重复利用效果最好,为了效率考虑,目前的实现中设置不为1的数仅仅相当于提高学习率。 62 | - `-bs, --batch_size, default=1000` 训练时训练数据的最大batch size,用于防止显存溢出 63 | 64 | #### model_handler.py 65 | - 用于将训练好的模型转换为botzone上可以运行的、较小的模型,去除模型中的轮次、优化器信息 66 | - 运行方式:`python model_handler.py -o path_to_original_model -n path_to_smaller_model` 67 | -------------------------------------------------------------------------------- /filereader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @file: filereader.py 5 | @time: 2020/12/19 18:34 6 | 转换原始训练数据 7 | ''' 8 | # _*_coding:utf-8_*_ 9 | import time, threading 10 | import os 11 | import json 12 | import argparse 13 | 14 | parser = argparse.ArgumentParser(description='Model') 15 | parser.add_argument('-lp', '--load_path', type=str, default=r'C:\Users\zrf19\Desktop\大四上\强化学习\麻将\mjdata\output2017', 16 | help='from where to load raw data') 17 | parser.add_argument('-sp', '--save_path', type=str, default='training_data', help='where to save parsed training data') 18 | parser.add_argument('-tn', '--thread_number', type=int, default=32, 19 | help='thread number in total, for thread i, it parses files i + k * thread_number, k = 0, 1, 2...') 20 | parser.add_argument('-tr', '--thread_round', type=int, default=4, 21 | help='how many rounds the threads are run, tn // tr = threads running simultaneously') 22 | args = parser.parse_args() 23 | 24 | def round(outcome, fanxing, score, fname, zhuangjia, requests, responses): 25 | return { 26 | 'outcome' : outcome, 27 | 'fanxing' : fanxing, 28 | 'score' : score, 29 | 'fname' : fname, 30 | 'zhuangjia': zhuangjia, 31 | 'requests' : requests, 32 | 'responses' : responses 33 | } 34 | 35 | class Reader(threading.Thread): 36 | def __init__(self, files, id): 37 | super(Reader, self).__init__() 38 | self.files = files 39 | self.id = id 40 | 41 | def run(self): 42 | self.rounds = [] 43 | self.removed_hua = 0 44 | self.removed_cuohu = 0 45 | self.removed_fan = 0 46 | for filenum, file in enumerate(self.files): 47 | if filenum % 1000 == 0: 48 | print('{}/{}'.format(filenum, len(self.files))) 49 | with open(file, encoding='utf-8') as f: 50 | line = f.readline() 51 | count = 1 52 | requests = [[], [], [], []] 53 | responses = [[], [], [], []] 54 | zhuangjia = 0 55 | outcome = '' 56 | fanxing = [] 57 | score = 0 58 | fname = file 59 | flag = True 60 | already_hu = False 61 | draw_hua = False 62 | while line: 63 | # print(line) 64 | line = line.strip('\n').split('\t') 65 | if count == 2: 66 | # print(line) 67 | outcome = line[-1] 68 | fanxing = list(map(lambda x: x.strip("'"), line[2].strip('[]').split(','))) 69 | if outcome != '荒庄': 70 | fanshu = 0 71 | for fan in fanxing: 72 | try: 73 | description, this_score = fan.split('-') 74 | except: 75 | description = fan 76 | if description == '全带五': 77 | this_score = 16 78 | elif description == '三同刻': 79 | this_score = 16 80 | else: 81 | print(fan) 82 | if description == '花牌': 83 | continue 84 | fanshu += int(this_score) 85 | if fanshu < 8: 86 | flag = False 87 | self.removed_fan += 1 88 | break 89 | score = int(line[1]) 90 | # quan = line[0] 91 | # print(fanxing) 92 | if count > 2 and count <= 6: 93 | playerID = int(line[0]) 94 | request = "1 0 0 0 0 " 95 | cards = line[1] 96 | cards = list(map(lambda x: x.strip("'"), cards.strip('[]').split(','))) 97 | if len(cards) == 14: 98 | draw_card = cards[-1] 99 | hua_count = 0 100 | for card in cards: 101 | if 'H' in card: 102 | hua_count += 1 103 | draw_card = card 104 | draw_hua = True 105 | if hua_count > 1: 106 | self.removed_hua += 1 107 | flag = False 108 | break 109 | cards.remove(draw_card) 110 | zhuangjia = playerID 111 | else: 112 | if 'H' in ' '.join(cards): 113 | self.removed_hua += 1 114 | flag = False 115 | break 116 | draw_card = None 117 | request += (' '.join(cards)) 118 | requests[playerID].append(request) 119 | if draw_card is not None: 120 | requests[playerID].append('2 ' + draw_card) 121 | responses[playerID].append("PASS") 122 | else: 123 | requests[playerID].append('3 {} DRAW'.format(zhuangjia)) 124 | responses[playerID].append("PASS") 125 | if count > 6: 126 | playerID = int(line[0]) 127 | action = line[1] 128 | cards = line[2] 129 | if draw_hua and action != '补花': 130 | self.removed_hua += 1 131 | flag = False 132 | break 133 | if action == '吃': 134 | middel_card = list(map(lambda x: x.strip("'"), cards.strip('[]').split(',')))[1] 135 | next_line = f.readline() 136 | play_card = next_line.strip('\n').split('\t')[2] 137 | play_card = play_card.strip("[]'") 138 | _cards = [middel_card, play_card] 139 | _action = action 140 | elif action == '碰': 141 | next_line = f.readline() 142 | play_card = next_line.strip('\n').split('\t')[2] 143 | play_card = play_card.strip("[]'") 144 | _cards = [play_card] 145 | _action = action 146 | elif action == '补花': 147 | # flag = False 148 | draw_hua = False 149 | for i in range(4): 150 | requests[i].pop() 151 | responses[i].pop() 152 | # if not flag: 153 | # self.removed_hua += 1 154 | # break 155 | line = f.readline() 156 | if not line: 157 | for i in range(4): 158 | requests[i].pop() 159 | last_hua = False 160 | for i in range(4): 161 | if 'H' in requests[i][-1]: 162 | last_hua = True 163 | if last_hua: 164 | for i in range(4): 165 | requests[i].pop() 166 | responses[i].pop() 167 | count += 1 168 | continue 169 | else: 170 | card = list(map(lambda x: x.strip("'"), cards.strip('[]').split(',')))[0] 171 | _cards = [card] 172 | _action = action 173 | if action == '和牌': 174 | already_hu = True 175 | if action == '摸牌' or action == '补花后摸牌' or action == '杠后摸牌': 176 | if 'H' in card: 177 | draw_hua = True 178 | for i in range(4): 179 | request = get_request(_action, playerID, _cards, i) 180 | # print(request) 181 | response = get_response(_action, playerID, _cards, i) 182 | # print(response) 183 | 184 | requests[i].append(request) 185 | if response is not None: 186 | responses[i].append(response) 187 | 188 | line = f.readline() 189 | # 胡牌之后就没有了 190 | if line and already_hu: 191 | self.removed_cuohu += 1 192 | flag = False 193 | # print(fname) 194 | break 195 | if not line: 196 | for i in range(4): 197 | requests[i].pop() 198 | last_hua = False 199 | for i in range(4): 200 | if 'H' in requests[i][-1]: 201 | last_hua = True 202 | if last_hua: 203 | for i in range(4): 204 | requests[i].pop() 205 | responses[i].pop() 206 | count += 1 207 | if flag: 208 | self.rounds.append(round(outcome, fanxing, score, fname, zhuangjia, requests, responses)) 209 | 210 | def get_res(self): 211 | with open('{}/Tread {}-mini.json'.format(args.save_path, self.id), 'w') as file_obj: 212 | json.dump(self.rounds, file_obj) 213 | return self.removed_hua, self.removed_cuohu, self.removed_fan 214 | 215 | def get_request(action, playerid, cards, myplayerid): 216 | playerid = str(playerid) 217 | myplayerid = str(myplayerid) 218 | request = None 219 | if action == '打牌': 220 | request = ['3', playerid, 'PLAY', cards[0]] 221 | if action == '摸牌' or action == '补花后摸牌' or action == '杠后摸牌': 222 | if playerid == myplayerid: 223 | request = ['2', cards[0]] 224 | else: 225 | request = ['3', playerid, 'DRAW'] 226 | if action == '吃': 227 | request = ['3', playerid, 'CHI'] + cards 228 | if action == '碰': 229 | request = ['3', playerid, 'PENG', cards[0]] 230 | if action == '明杠' or action == '暗杠': 231 | request = ['3', playerid, 'GANG'] 232 | if action == '补杠': 233 | request = ['3', playerid, 'BUGANG', cards[0]] 234 | if request is None: 235 | return None 236 | return ' '.join(request) 237 | 238 | def get_response(action, playerid, cards, myplayerid): 239 | if playerid != myplayerid: 240 | response = ['PASS'] 241 | else: 242 | if action == '打牌': 243 | response = ['PLAY', cards[0]] 244 | if action == '摸牌' or action == '补花后摸牌' or action == '杠后摸牌': 245 | response = ['PASS'] 246 | if action == '吃': 247 | response = ['CHI'] + cards 248 | if action == '碰': 249 | response = ['PENG', cards[0]] 250 | if action == '明杠': 251 | response = ['GANG'] 252 | if action == '暗杠': 253 | response = ['GANG', cards[0]] 254 | if action == '补杠': 255 | response = ['BUGANG', cards[0]] 256 | if action == '和牌': 257 | response = ['HU'] 258 | # print(action) 259 | return ' '.join(response) 260 | 261 | 262 | if __name__ == '__main__': 263 | # reader = Reader(['C:\\Users\\zrf19\\Desktop\\强化学习\\麻将\\mjdata\\output2017/PLAY/2017-07-29-305.txt'], 10086) 264 | # reader.start() 265 | # reader.join() 266 | # reader.get_res() 267 | #线程数量 268 | thread_num = args.thread_number 269 | thread_rounds = args.thread_round 270 | thread_per_round = thread_num // thread_rounds 271 | #起始时间 272 | t = [] 273 | folder = args.load_path 274 | dirs = os.listdir(folder) 275 | files = [] 276 | for dir in dirs: 277 | subfolder = folder + '/' + dir 278 | for file in os.listdir(subfolder): 279 | files.append(subfolder + '/' + file) 280 | # files = files[:10000] 281 | filenum = len(files) 282 | #生成线程 283 | for i in range(thread_num): 284 | t.append(Reader(files[i::thread_num], i)) 285 | rm_hua = 0 286 | rm_cuohu = 0 287 | rm_fan = 0 288 | for this_round in range(thread_rounds): 289 | #开启线程 290 | for i in range(thread_per_round): 291 | t[i+thread_per_round*this_round].start() 292 | for i in range(thread_per_round): 293 | t[i+thread_per_round*this_round].join() 294 | r_h, r_c, r_f = t[i+thread_per_round*this_round].get_res() 295 | rm_hua += r_h 296 | rm_cuohu += r_c 297 | rm_fan += r_f 298 | print("一共{}局记录,其中补花错误{}局,错和{}局,番数不足{}局,训练数据共{}局".format(filenum, rm_hua, rm_cuohu, rm_fan, 299 | filenum - rm_hua - rm_cuohu - rm_fan)) -------------------------------------------------------------------------------- /env.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: zrf 5 | @license: (C) Copyright 2013-2017, Node Supply Chain Manager Corporation Limited. 6 | @contact: deamoncao100@gmail.com 7 | @software:XXXX 8 | @file: DL_agent.py 9 | @time: 2020/12/10 13:27 10 | @desc: 11 | ''' 12 | import sys 13 | import os 14 | curPath = os.path.abspath(os.path.dirname(__file__)) 15 | rootPath = os.path.split(curPath)[0] 16 | sys.path.append("..") 17 | import json 18 | from MahjongGB import MahjongFanCalculator 19 | import torch 20 | import torch.optim as optim 21 | from enum import Enum 22 | import numpy as np 23 | import torch.nn as nn 24 | import torch.nn.functional as F 25 | 26 | import math 27 | from copy import deepcopy 28 | from torch.distributions import Categorical 29 | import torch.multiprocessing as mp 30 | import argparse 31 | import time 32 | from DL_agent import agent as dl_agent 33 | # from botzone.new_main import agent as dl_search_agent0 34 | # from botzone.new_main_1 import agent as dl_search_agent1 35 | # from new_main_2 import agent as dl_search_agent2 36 | # from new_main_3 import agent as dl_search_agent3 37 | 38 | parser = argparse.ArgumentParser(description='the environment to test bots') 39 | parser.add_argument('-o', '--old_path', type=str, default='../models/super_model_2', help='path to stable model') 40 | parser.add_argument('-n', '--new_path', type=str, default='../models/rl_pg_new', help='path to trained model') 41 | parser.add_argument('-p', '--num_process_per_gpu', type=int, default=1, help='number of processes to run per gpu') 42 | parser.add_argument('-pi', '--print_interval', type=int, default=500, help='how often to print') 43 | 44 | args = parser.parse_args() 45 | 46 | class requests(Enum): 47 | initialHand = 1 48 | drawCard = 2 49 | DRAW = 4 50 | PLAY = 5 51 | PENG = 6 52 | CHI = 7 53 | GANG = 8 54 | BUGANG = 9 55 | MINGGANG = 10 56 | ANGANG = 11 57 | 58 | class responses(Enum): 59 | PASS = 0 60 | PLAY = 1 61 | HU = 2 62 | # 需要区分明杠和暗杠 63 | MINGGANG = 3 64 | ANGANG = 4 65 | BUGANG = 5 66 | PENG = 6 67 | CHI = 7 68 | need_cards = [0, 1, 0, 0, 1, 1, 0, 1] 69 | loss_weight = [1, 1, 5, 2, 2, 2, 2, 2] 70 | 71 | class cards(Enum): 72 | # 饼万条 73 | B = 0 74 | W = 9 75 | T = 18 76 | # 风 77 | F = 27 78 | # 箭牌 79 | J = 31 80 | 81 | 82 | class ActorCritic(nn.Module): 83 | def __init__(self, card_feat_depth, num_extra_feats, num_cards, num_actions): 84 | super().__init__() 85 | hidden_channels = [8, 16, 32] 86 | hidden_layers_size = [512, 1024] 87 | linear_length = hidden_channels[1] * num_cards * card_feat_depth 88 | self.linear_length = linear_length + num_extra_feats 89 | # self.number_card_net = nn.Sequential( 90 | # nn.Conv2d(3, hidden_channels[0], 3, stride=1, padding=1), 91 | # nn.ReLU(), 92 | # ) 93 | self.card_net = nn.Sequential( 94 | nn.Conv2d(1, hidden_channels[0], 3, stride=1, padding=1), 95 | nn.ReLU(), 96 | nn.Conv2d(hidden_channels[0], hidden_channels[1], 5, stride=1, padding=2), 97 | ) 98 | self.card_play_decision_net = nn.Sequential( 99 | nn.Linear(num_extra_feats + linear_length, hidden_layers_size[0]), 100 | nn.Sigmoid(), 101 | nn.Linear(hidden_layers_size[0], hidden_layers_size[1]), 102 | nn.ReLU(), 103 | nn.Linear(hidden_layers_size[1], num_cards), 104 | nn.Softmax(dim=1) 105 | ) 106 | self.chi_peng_decision_net = nn.Sequential( 107 | nn.Linear(num_extra_feats + linear_length, hidden_layers_size[0]), 108 | nn.ReLU(), 109 | nn.Linear(hidden_layers_size[0], num_cards), 110 | nn.Softmax(dim=1) 111 | ) 112 | self.action_decision_net = nn.Sequential( 113 | nn.Linear(num_extra_feats + linear_length, hidden_layers_size[1]), 114 | nn.ReLU(), 115 | nn.Linear(hidden_layers_size[1], num_actions), 116 | nn.Softmax(dim=1) 117 | ) 118 | self.critic = nn.Sequential( 119 | nn.Linear(num_extra_feats + linear_length, hidden_layers_size[0]), 120 | nn.ReLU(), 121 | nn.Linear(hidden_layers_size[0], hidden_layers_size[1]), 122 | nn.Sigmoid(), 123 | nn.Linear(hidden_layers_size[1], 2048), 124 | nn.ReLU(), 125 | nn.Linear(2048, 1) 126 | ) 127 | 128 | # play, chi_gang, 129 | def forward(self, card_feats, extra_feats, device, decide_which, mask): 130 | assert decide_which in ['play', 'chi_gang', 'action'] 131 | card_feats = torch.from_numpy(card_feats).to(device).unsqueeze(1).to(torch.float32) 132 | card_layer = self.card_net(card_feats) 133 | batch_size = card_layer.shape[0] 134 | extra_feats_tensor = torch.from_numpy(extra_feats).to(torch.float32).to(device) 135 | linear_layer = torch.cat((card_layer.view(batch_size, -1), extra_feats_tensor), dim=1) 136 | mask_tensor = torch.from_numpy(mask).to(torch.float32).to(device) 137 | if decide_which == 'play': 138 | card_probs = self.card_play_decision_net(linear_layer) 139 | valid_card_play = self.mask_unavailable_actions(card_probs, mask_tensor) 140 | return valid_card_play 141 | elif decide_which == 'action': 142 | # print(linear_layer.shape) 143 | action_probs = self.action_decision_net(linear_layer) 144 | valid_actions = self.mask_unavailable_actions(action_probs, mask_tensor) 145 | # print(valid_actions, valid_card_play) 146 | return valid_actions 147 | else: 148 | card_probs = self.chi_peng_decision_net(linear_layer) 149 | valid_card_play = self.mask_unavailable_actions(card_probs, mask_tensor) 150 | return valid_card_play 151 | 152 | def mask_unavailable_actions(self, result, valid_actions_tensor): 153 | valid_actions = result * valid_actions_tensor 154 | if valid_actions.sum() > 0: 155 | masked_actions = valid_actions / valid_actions.sum() 156 | else: 157 | masked_actions = valid_actions_tensor / valid_actions_tensor.sum() 158 | return masked_actions 159 | 160 | 161 | class MahjongEnv: 162 | def __init__(self): 163 | self.test = True 164 | self.total_cards = 34 165 | self.total_actions = len(responses) - 2 166 | self.print_interval = args.print_interval 167 | self.round_count = 0 168 | # state = {'model': self.model.state_dict(), 'optimizer': self.optimizer.state_dict()} 169 | # torch.save(state, self.model_path, _use_new_zipfile_serialization=False) 170 | self.bots = [] 171 | self.winner = np.zeros(4) 172 | self.dianpaoer = np.zeros(4) 173 | self.winning_rate = np.zeros(4) 174 | self.win_steps = [] 175 | self.losses = [] 176 | self.scores = np.zeros(4) 177 | # 0, 2为老model 178 | if args.old_path == args.new_path: 179 | model_path = args.old_path 180 | self.bots = [dl_agent(model_path), dl_search_agent0(model_path), dl_agent(model_path), dl_search_agent1(model_path)] 181 | # for i in range(4): 182 | # if i != 0: 183 | # self.bots.append(dl_agent(model_path)) 184 | # else: 185 | # self.bots.append(dl_search_agent0(model_path)) 186 | else: 187 | for i in range(4): 188 | if i % 2 == 0: 189 | self.bots.append(dl_agent(args.old_path)) 190 | else: 191 | self.bots.append(dl_agent(args.new_path)) 192 | self.reset(True) 193 | 194 | 195 | def reset_for_test(self, initial=False, global_counter=0): 196 | self.round_count += 1 197 | if self.round_count % 4 == 0: 198 | self.reset(initial, global_counter) 199 | else: 200 | self.tile_wall = deepcopy(self.doc_tile) 201 | self.men = (self.men + 1) % 4 202 | self.bots_order = [self.bots[(i + self.men) % 4] for i in range(4)] 203 | self.turnID = 0 204 | self.drawer = 0 205 | for bot, reward in zip(self.bots, self.scores): 206 | bot.reset() 207 | 208 | 209 | def reset(self, initial=False, global_counter=0, global_winning_rate=None): 210 | all_tiles = np.arange(self.total_cards) 211 | all_tiles = all_tiles.repeat(4) 212 | np.random.shuffle(all_tiles) 213 | # 用pop,从后面摸牌 214 | self.tile_wall = np.reshape(all_tiles, (4, -1)).tolist() 215 | self.doc_tile = deepcopy(self.tile_wall) 216 | self.quan = np.random.choice(4) 217 | self.men = np.random.choice(4) 218 | 219 | # 这一局bots的order,牌墙永远下标和bot一致 220 | self.bots_order = [self.bots[(i + self.men) % 4] for i in range(4)] 221 | self.turnID = 0 222 | self.drawer = 0 223 | if not initial: 224 | for bot, reward in zip(self.bots_order, self.scores): 225 | bot.reset() 226 | if self.round_count % (4 * self.print_interval) == 0: 227 | win_sum = self.winner.sum() 228 | total_rounts = 4 * self.print_interval 229 | print( 230 | '目前进行了{}轮,在前{}轮中,new bot winning rate: {:.2%}, new bot total score: {},' 231 | 'new bot dianpao: {}, old bot score: {}, dianpao: {}, old bot winning rate: {:.2%}\n' 232 | ' 和牌{}局,荒庄{}局,和牌率{:.2%},平均和牌回合数{}'.format( 233 | self.round_count, total_rounts, self.winner[1::2].sum() / win_sum, self.scores[1::2].sum(), 234 | self.dianpaoer[1::2].sum(), self.scores[::2].sum(), self.dianpaoer[::2].sum(), 235 | self.winner[::2].sum() / win_sum, self.winner.sum(), 236 | total_rounts - self.winner.sum(), 237 | self.winner.sum() / total_rounts, 238 | sum(self.win_steps) / self.winner.sum() 239 | )) 240 | self.winner = np.zeros(4) 241 | self.dianpaoer = np.zeros(4) 242 | self.scores = np.zeros(4) 243 | self.win_steps = [] 244 | 245 | 246 | def run_round(self): 247 | fan_count = 0 248 | player_id = 0 249 | dianpaoer = None 250 | outcome = '' 251 | player_responses = [] 252 | self.drawn = False 253 | while True: 254 | if self.turnID == 0: 255 | for id, player in enumerate(self.bots_order): 256 | player_responses.append(player.step('0 %d %d' % (id, self.quan))) 257 | elif self.turnID == 1: 258 | player_responses = [] 259 | for id, player in enumerate(self.bots_order): 260 | request = ['1'] 261 | for i in range(4): 262 | request.append('0') 263 | for i in range(13): 264 | request.append(self.getCardName(self.tile_wall[id].pop())) 265 | request = ' '.join(request) 266 | player_responses.append(player.step(request)) 267 | else: 268 | requests = self.parse_response(player_responses) 269 | if requests[0] in ['hu', 'huangzhuang']: 270 | outcome = requests[0] 271 | if outcome == 'hu': 272 | player_id = int(requests[1]) 273 | fan_count = int(requests[2]) 274 | dianpaoer_id = requests[3] 275 | self.winner[self.bots.index(self.bots_order[player_id])] += 1 276 | self.scores[self.bots.index(self.bots_order[player_id])] += fan_count 277 | if dianpaoer_id != 'None': 278 | dianpaoer_id = int(dianpaoer_id) 279 | self.scores[self.bots.index(self.bots_order[dianpaoer_id])] -= 0.5 * fan_count 280 | self.dianpaoer[self.bots.index(self.bots_order[dianpaoer_id])] += 1 281 | self.win_steps.append(self.turnID) 282 | break 283 | else: 284 | player_responses = [] 285 | for i in range(4): 286 | player_responses.append(self.bots_order[i].step(requests[i])) 287 | self.turnID += 1 288 | # print('{} {}'.format(outcome, fan_count)) 289 | # difen = 8 290 | # if outcome == 'hu': 291 | # for i in range(4): 292 | # if i == player_id: 293 | # self.scores[i] = 20 294 | # elif i == dianpaoer: 295 | # self.scores[i] = -5 296 | # else: 297 | # self.scores[i] = -1 298 | 299 | # for i in range(4): 300 | # if i == player_id: 301 | # self.scores[i] = 10 302 | # if dianpaoer is None: 303 | # # 自摸 304 | # self.scores[i] = 3 * (difen + fan_count) 305 | # else: 306 | # self.scores[i] = 3 * difen + fan_count 307 | # else: 308 | # if dianpaoer is None: 309 | # self.scores[i] = -0.5 * (difen + fan_count) 310 | # else: 311 | # if i == dianpaoer: 312 | # self.scores[i] = -2 * (difen + fan_count) 313 | # else: 314 | # self.scores[i] = -0.5 * difen 315 | # print(self.scores) 316 | 317 | def parse_response(self, player_responses): 318 | requests = [] 319 | for id, response in enumerate(player_responses): 320 | response = response.split(' ') 321 | response_name = response[0] 322 | if response_name == 'HU': 323 | return ['hu', id, response[1], response[2]] 324 | if response_name == 'PENG': 325 | requests = [] 326 | for i in range(4): 327 | requests.append('3 %d PENG %s' % (id, response[1])) 328 | self.drawer = (id + 1) % 4 329 | break 330 | if response_name == "GANG": 331 | requests = [] 332 | for i in range(4): 333 | requests.append('3 %d GANG' % (id)) 334 | self.drawer = id 335 | break 336 | if response_name == 'CHI': 337 | for i in range(4): 338 | requests.append('3 %d CHI %s %s' % (id, response[1], response[2])) 339 | self.drawer = (id + 1) % 4 340 | if response_name == 'PLAY': 341 | for i in range(4): 342 | requests.append('3 %d PLAY %s' % (id, response[1])) 343 | self.drawer = (id + 1) % 4 344 | if response_name == 'BUGANG': 345 | for i in range(4): 346 | requests.append('3 %d BUGANG %s' % (id, response[1])) 347 | self.drawer = id 348 | # 所有人pass,摸牌 349 | if len(requests) == 0: 350 | if len(self.tile_wall[self.drawer]) == 0: 351 | return ['huangzhuang', 0] 352 | draw_card = self.tile_wall[self.drawer].pop() 353 | for i in range(4): 354 | if i == self.drawer: 355 | requests.append('2 %s' % self.getCardName(draw_card)) 356 | else: 357 | requests.append('3 %d DRAW' % self.drawer) 358 | return requests 359 | 360 | def getCardInd(self, cardName): 361 | return cards[cardName[0]].value + int(cardName[1]) - 1 362 | 363 | def getCardName(self, cardInd): 364 | num = 1 365 | while True: 366 | if cardInd in cards._value2member_map_: 367 | break 368 | num += 1 369 | cardInd -= 1 370 | return cards(cardInd).name + str(num) 371 | 372 | import time 373 | def train_thread(global_episode_counter): 374 | 375 | env = MahjongEnv() 376 | while True: 377 | env.run_round() 378 | env.reset_for_test(False, global_episode_counter.value) 379 | global_episode_counter.value += 1 380 | 381 | 382 | def main(): 383 | num_processes_per_gpu = args.num_process_per_gpu 384 | new_model_path = args.new_path 385 | mp.set_start_method('spawn') # required to avoid Conv2d froze issue 386 | # critic 387 | gpu_count = torch.cuda.device_count() 388 | num_processes = gpu_count * num_processes_per_gpu 389 | 390 | # multiprocesses, Hogwild! style update 391 | processes = [] 392 | init_episode_counter_val = 0 393 | global_episode_counter = mp.Value('i', init_episode_counter_val) 394 | # each worker_thread creates its own environment and trains agents 395 | for rank in range(num_processes): 396 | # only write summaries in one of the workers, since they are identical 397 | # worker_summary_queue = summary_queue if rank == 0 else None 398 | worker_thread = mp.Process( 399 | target=train_thread, args=(global_episode_counter, )) 400 | worker_thread.daemon = True 401 | worker_thread.start() 402 | processes.append(worker_thread) 403 | time.sleep(2) 404 | 405 | # wait for all processes to finish 406 | try: 407 | killed_process_count = 0 408 | for process in processes: 409 | process.join() 410 | killed_process_count += 1 if process.exitcode == 1 else 0 411 | if killed_process_count >= num_processes: 412 | # exit if only monitor and writer alive 413 | raise SystemExit 414 | except (KeyboardInterrupt, SystemExit): 415 | for process in processes: 416 | # without killing child process, process.terminate() will cause orphans 417 | # ref: https://thebearsenal.blogspot.com/2018/01/creation-of-orphan-process-in-linux.html 418 | # kill_child_processes(process.pid) 419 | process.terminate() 420 | process.join() 421 | 422 | if __name__ == '__main__': 423 | main() -------------------------------------------------------------------------------- /DL_agent.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: zrf 5 | @time: 2020/12/10 13:27 6 | 深度学习 7 | ''' 8 | import json 9 | from MahjongGB import MahjongFanCalculator 10 | import torch 11 | import torch.optim as optim 12 | from enum import Enum 13 | import numpy as np 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | import sys 17 | import os 18 | import random 19 | from copy import deepcopy 20 | import argparse 21 | 22 | parser = argparse.ArgumentParser(description='Model') 23 | parser.add_argument('-t', '--train', action='store_true', default=False, help='whether to train model') 24 | parser.add_argument('-s', '--save', action='store_true', default=False, help='whether to store model') 25 | parser.add_argument('-l', '--load', action='store_true', default=False, help='whether to load model') 26 | parser.add_argument('-lp', '--load_path', type=str, default='../models/super_model_2', help='from where to load model') 27 | parser.add_argument('-sp', '--save_path', type=str, default='../models/super_model_2', help='save model path') 28 | parser.add_argument('-b', '--batch_size', type=int, default=1000, help='training batch size') 29 | parser.add_argument('--training_data', type=str, default='../training_data', help='path to training data folder') 30 | parser.add_argument('--botzone', action='store_true', default=False, help='whether to run the model on botzone') 31 | parser.add_argument('-pi', '--print_interval', type=int, default=2000, help='how often to print') 32 | parser.add_argument('-si', '--save_interval', type=int, default=5000, help='how often to save') 33 | args = parser.parse_args() 34 | 35 | 36 | class requests(Enum): 37 | initialHand = 1 38 | drawCard = 2 39 | DRAW = 4 40 | PLAY = 5 41 | PENG = 6 42 | CHI = 7 43 | GANG = 8 44 | BUGANG = 9 45 | MINGGANG = 10 46 | ANGANG = 11 47 | 48 | class responses(Enum): 49 | PASS = 0 50 | PLAY = 1 51 | HU = 2 52 | # 需要区分明杠和暗杠 53 | MINGGANG = 3 54 | ANGANG = 4 55 | BUGANG = 5 56 | PENG = 6 57 | CHI = 7 58 | need_cards = [0, 1, 0, 0, 1, 1, 0, 1] 59 | loss_weight = [1, 1, 5, 2, 2, 2, 2, 2] 60 | 61 | class cards(Enum): 62 | # 饼万条 63 | B = 0 64 | W = 9 65 | T = 18 66 | # 风 67 | F = 27 68 | # 箭牌 69 | J = 31 70 | 71 | # store all data 72 | class dataManager: 73 | def __init__(self): 74 | self.reset('all') 75 | 76 | def reset(self, which_part): 77 | assert which_part in ['all', 'play', 'action', 'chi_gang'] 78 | doc_dict = { 79 | "card_feats": [], 80 | "extra_feats": [], 81 | "mask": [], 82 | "target": [] 83 | } 84 | if which_part == 'all': 85 | self.doc = { 86 | "play": deepcopy(doc_dict), 87 | "action": deepcopy(doc_dict), 88 | "chi_gang": deepcopy(doc_dict) 89 | } 90 | self.training = { 91 | "play": deepcopy(doc_dict), 92 | "action": deepcopy(doc_dict), 93 | "chi_gang": deepcopy(doc_dict) 94 | } 95 | else: 96 | for key, item in self.training[which_part].items(): 97 | self.doc[which_part][key].extend(item) 98 | self.training[which_part] = deepcopy(doc_dict) 99 | 100 | # 可以将所有训练数据保存成numpy,但是占据空间过大,不推荐 101 | def save_data(self, round): 102 | # np.save('training_data/round {}.npy'.format(round), self.doc) 103 | self.reset('all') 104 | 105 | class myModel(nn.Module): 106 | def __init__(self, card_feat_depth, num_extra_feats, num_cards, num_actions): 107 | super(myModel, self).__init__() 108 | hidden_channels = [8, 16, 32] 109 | hidden_layers_size = [512, 1024] 110 | linear_length = hidden_channels[1] * num_cards * card_feat_depth 111 | self.card_net = nn.Sequential( 112 | nn.Conv2d(1, hidden_channels[0], 3, stride=1, padding=1), 113 | nn.ReLU(), 114 | nn.Conv2d(hidden_channels[0], hidden_channels[1], 5, stride=1, padding=2), 115 | ) 116 | self.card_play_decision_net = nn.Sequential( 117 | nn.Linear(num_extra_feats + linear_length, hidden_layers_size[0]), 118 | nn.Sigmoid(), 119 | nn.Linear(hidden_layers_size[0], hidden_layers_size[1]), 120 | nn.ReLU(), 121 | nn.Linear(hidden_layers_size[1], num_cards), 122 | ) 123 | self.chi_peng_decision_net = nn.Sequential( 124 | nn.Linear(num_extra_feats + linear_length, hidden_layers_size[0]), 125 | nn.ReLU(), 126 | nn.Linear(hidden_layers_size[0], num_cards), 127 | ) 128 | self.action_decision_net = nn.Sequential( 129 | nn.Linear(num_extra_feats + linear_length, hidden_layers_size[1]), 130 | nn.ReLU(), 131 | nn.Linear(hidden_layers_size[1], num_actions), 132 | ) 133 | 134 | # play, chi_gang, 135 | def forward(self, card_feats, extra_feats, device, decide_which, mask): 136 | assert decide_which in ['play', 'chi_gang', 'action'] 137 | card_feats = torch.from_numpy(card_feats).to(device).unsqueeze(1).to(torch.float32) 138 | card_layer = self.card_net(card_feats) 139 | batch_size = card_layer.shape[0] 140 | extra_feats_tensor = torch.from_numpy(extra_feats).to(torch.float32).to(device) 141 | linear_layer = torch.cat((card_layer.view(batch_size, -1), extra_feats_tensor), dim=1) 142 | mask_tensor = torch.from_numpy(mask).to(torch.float32).to(device) 143 | if decide_which == 'play': 144 | card_probs = self.card_play_decision_net(linear_layer) 145 | valid_card_play = card_probs * mask_tensor 146 | return valid_card_play 147 | elif decide_which == 'action': 148 | action_probs = self.action_decision_net(linear_layer) 149 | valid_actions = action_probs * mask_tensor 150 | return valid_actions 151 | else: 152 | card_probs = self.chi_peng_decision_net(linear_layer) 153 | valid_card_play = card_probs * mask_tensor 154 | return valid_card_play 155 | 156 | def train_backward(self, loss, optim): 157 | optim.zero_grad() 158 | loss.backward() 159 | optim.step() 160 | 161 | # 添加 牌墙无牌不能杠 162 | class MahjongHandler(): 163 | def __init__(self, train, model_path, load_model=False, save_model=True, botzone=False, batch_size=1000): 164 | use_cuda = torch.cuda.is_available() 165 | self.botzone = botzone 166 | self.device = torch.device("cuda" if use_cuda else "cpu") 167 | if not botzone: 168 | print('using ' + str(self.device)) 169 | self.train = train 170 | self.training_data = dataManager() 171 | self.model_path = model_path 172 | self.load_model = load_model 173 | self.save_model = save_model 174 | self.total_cards = 34 175 | self.learning_rate = 1e-4 176 | self.action_loss_weight = responses.loss_weight.value 177 | self.action_weight = torch.from_numpy((np.array(responses.loss_weight.value))).to(device=self.device, dtype=torch.float32) 178 | self.card_loss_weight = 2 179 | self.total_actions = len(responses) - 2 180 | self.model = myModel( 181 | card_feat_depth=14, 182 | num_extra_feats=self.total_actions + 16, 183 | num_cards=self.total_cards, 184 | num_actions=self.total_actions 185 | ).to(self.device) 186 | self.optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate) 187 | self.best_precision = 0 188 | self.batch_size = batch_size 189 | self.print_interval = args.print_interval 190 | self.save_interval = args.save_interval 191 | self.round_count = 0 192 | self.match = np.zeros(self.total_actions) 193 | self.count = np.zeros(self.total_actions) 194 | if self.load_model: 195 | checkpoint = torch.load(self.model_path, map_location=self.device) 196 | self.model.load_state_dict(checkpoint['model']) 197 | if not botzone: 198 | self.optimizer.load_state_dict(checkpoint['optimizer']) 199 | try: 200 | self.round_count = checkpoint['progress'] 201 | except KeyError: 202 | self.round_count = 0 203 | if not train: 204 | self.model.eval() 205 | self.reset(True) 206 | 207 | def reset(self, initial=False): 208 | self.hand_free = np.zeros(self.total_cards, dtype=int) 209 | self.history = np.zeros(self.total_cards, dtype=int) 210 | self.player_history = np.zeros((4, self.total_cards), dtype=int) 211 | self.player_on_table = np.zeros((4, self.total_cards), dtype=int) 212 | self.hand_fixed = self.player_on_table[0] 213 | self.player_last_play = np.zeros(4, dtype=int) 214 | self.player_angang = np.zeros(4, dtype=int) 215 | self.fan_count = 0 216 | self.hand_fixed_data = [] 217 | self.turnID = 0 218 | self.tile_count = [21, 21, 21, 21] 219 | self.myPlayerID = 0 220 | self.quan = 0 221 | self.prev_request = '' 222 | self.an_gang_card = '' 223 | # test training acc 224 | if self.train and not self.botzone and not initial and self.round_count % self.print_interval == 0: 225 | training_data = self.training_data.doc 226 | with torch.no_grad(): 227 | print('-'*50) 228 | for kind in ['play', 'action', 'chi_gang']: 229 | data = training_data[kind] 230 | probs = self.model(np.array(data['card_feats']), 231 | np.array(data['extra_feats']), self.device, 232 | kind, np.array(data['mask'])) 233 | target_tensor = torch.from_numpy(np.array(data['target'])).to(device=self.device, dtype=torch.int64) 234 | losses = F.cross_entropy(probs, target_tensor).to(torch.float32) 235 | pred = torch.argmax(probs, dim=1) 236 | acc = (pred == target_tensor).sum() / probs.shape[0] 237 | print('{}: acc {} loss {}'.format(kind, float(acc.cpu()), float(losses.mean().cpu()))) 238 | if kind == 'action': 239 | counts = np.zeros(self.total_actions, dtype=float) 240 | matches = np.zeros(self.total_actions, dtype=float) 241 | for p, t in zip(list(pred.cpu()), list(target_tensor.cpu())): 242 | p = int(p) 243 | t = int(t) 244 | if p == t: 245 | matches[p] += 1 246 | counts[t] += 1 247 | accs = matches / counts 248 | for i in range(self.total_actions): 249 | print('{}: {},{} {:.2%}'.format(responses(i).name, matches[i], counts[i], accs[i])) 250 | self.training_data.save_data(self.round_count) 251 | if self.save_model and not initial and self.round_count % self.save_interval == 0: 252 | state = {'model': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'progress': self.round_count} 253 | torch.save(state, args.save_path, _use_new_zipfile_serialization=False) 254 | self.round_count += 1 255 | self.loss = [] 256 | 257 | def step_for_train(self, request=None, response_target=None, fname=None): 258 | if fname: 259 | self.fname = fname 260 | if request is None: 261 | if self.turnID == 0: 262 | inputJSON = json.loads(input()) 263 | request = inputJSON['requests'][0].split(' ') 264 | else: 265 | request = input().split(' ') 266 | else: 267 | request = request.split(' ') 268 | 269 | request = self.build_hand_history(request) 270 | if self.turnID <= 1: 271 | if self.botzone: 272 | print(json.dumps({"response": "PASS"})) 273 | else: 274 | available_action_mask, available_card_mask = self.build_available_action_mask(request) 275 | card_feats = self.build_input(self.hand_free, self.history, self.player_history, 276 | self.player_on_table, self.player_last_play) 277 | extra_feats = np.concatenate((self.player_angang[1:], available_action_mask, 278 | [self.hand_free.sum()], *np.eye(4)[[self.quan, self.myPlayerID]], self.tile_count)) 279 | 280 | def judge_response(available_action_mask): 281 | if available_action_mask.sum() == available_action_mask[responses.PASS.value]: 282 | return False 283 | return True 284 | 285 | if self.train and response_target is not None and judge_response(available_action_mask): 286 | training_data = self.training_data.training 287 | for kind in ['play', 'action', 'chi_gang']: 288 | if len(training_data[kind]['card_feats']) >= self.batch_size: 289 | data = training_data[kind] 290 | probs = self.model(np.array(data['card_feats']), 291 | np.array(data['extra_feats']), self.device, 292 | kind, np.array(data['mask'])) 293 | target_tensor = torch.from_numpy(np.array(data['target'])).to(device=self.device, dtype=torch.int64) 294 | if kind == 'action': 295 | losses = F.cross_entropy(probs, target_tensor, weight=self.action_weight).to(torch.float32) 296 | else: 297 | losses = F.cross_entropy(probs, target_tensor).to(torch.float32) 298 | self.model.train_backward(losses, self.optimizer) 299 | self.training_data.reset(kind) 300 | 301 | response_target = response_target.split(' ') 302 | response_name = response_target[0] 303 | if response_name == 'GANG': 304 | if len(response_target) > 1: 305 | response_name = 'ANGANG' 306 | self.an_gang_card = response_target[-1] 307 | else: 308 | response_name = 'MINGGANG' 309 | if available_action_mask.sum() > 1: 310 | action_target = responses[response_name].value 311 | data = training_data["action"] 312 | data['card_feats'].append(card_feats) 313 | data['mask'].append(available_action_mask) 314 | data['target'].append(action_target) 315 | data['extra_feats'].append(extra_feats) 316 | 317 | available_card_mask = available_card_mask[responses[response_name].value] 318 | if responses[response_name] in [responses.CHI, responses.ANGANG, responses.BUGANG]: 319 | data = training_data["chi_gang"] 320 | data['mask'].append(available_card_mask) 321 | data['card_feats'].append(card_feats) 322 | data['extra_feats'].append(extra_feats) 323 | data['target'].append(self.getCardInd(response_target[1])) 324 | 325 | if responses[response_name] in [responses.PLAY, responses.CHI, responses.PENG]: 326 | if responses[response_name] == responses.PLAY: 327 | play_target = self.getCardInd(response_target[1]) 328 | card_mask = available_card_mask 329 | else: 330 | if responses[response_name] == responses.CHI: 331 | chi_peng_ind = self.getCardInd(response_target[1]) 332 | else: 333 | chi_peng_ind = self.getCardInd(request[-1]) 334 | play_target = self.getCardInd(response_target[-1]) 335 | card_feats, extra_feats, card_mask = self.simulate_chi_peng(request, responses[response_name], chi_peng_ind, True) 336 | data = training_data['play'] 337 | data['card_feats'].append(card_feats) 338 | data['mask'].append(card_mask) 339 | data['extra_feats'].append(extra_feats) 340 | data['target'].append(play_target) 341 | 342 | 343 | self.prev_request = request 344 | self.turnID += 1 345 | 346 | def step(self, request=None, response_target=None, fname=None): 347 | if fname: 348 | self.fname = fname 349 | if request is None: 350 | if self.turnID == 0: 351 | inputJSON = json.loads(input()) 352 | request = inputJSON['requests'][0].split(' ') 353 | else: 354 | request = input().split(' ') 355 | else: 356 | request = request.split(' ') 357 | 358 | request = self.build_hand_history(request) 359 | if self.turnID <= 1: 360 | response = 'PASS' 361 | else: 362 | def make_decision(probs): 363 | vals = probs.data.cpu().numpy()[0] 364 | max = -1000000 365 | decision = 0 366 | for i, val in enumerate(vals): 367 | if val != 0 and val > max: 368 | max = val 369 | decision = i 370 | return decision 371 | 372 | available_action_mask, available_card_mask = self.build_available_action_mask(request) 373 | card_feats = self.build_input(self.hand_free, self.history, self.player_history, 374 | self.player_on_table, self.player_last_play) 375 | extra_feats = np.concatenate((self.player_angang[1:], available_action_mask, 376 | [self.hand_free.sum()], *np.eye(4)[[self.quan, self.myPlayerID]], 377 | self.tile_count)) 378 | action_probs = self.model(np.array([card_feats]), 379 | np.array([extra_feats]), self.device, 380 | 'action', np.array([available_action_mask])) 381 | action = make_decision(action_probs) 382 | cards = [] 383 | if responses(action) in [responses.CHI, responses.ANGANG, responses.BUGANG]: 384 | card_probs = self.model(np.array([card_feats]), 385 | np.array([extra_feats]), self.device, 386 | 'chi_gang', np.array([available_card_mask[action]])) 387 | card_ind = make_decision(card_probs) 388 | cards.append(card_ind) 389 | if responses(action) in [responses.PLAY, responses.CHI, responses.PENG]: 390 | if responses(action) == responses.PLAY: 391 | card_mask = available_card_mask[action] 392 | else: 393 | if responses(action) == responses.CHI: 394 | chi_peng_ind = cards[0] 395 | else: 396 | chi_peng_ind = self.getCardInd(request[-1]) 397 | card_feats, extra_feats, card_mask = self.simulate_chi_peng(request, responses(action), 398 | chi_peng_ind, True) 399 | card_probs = self.model(np.array([card_feats]), 400 | np.array([extra_feats]), self.device, 401 | 'play', np.array([card_mask])) 402 | card_ind = make_decision(card_probs) 403 | cards.append(card_ind) 404 | response = self.build_output(responses(action), cards) 405 | if responses(action) == responses.ANGANG: 406 | self.an_gang_card = self.getCardName(cards[0]) 407 | 408 | self.prev_request = request 409 | self.turnID += 1 410 | if self.botzone: 411 | print(json.dumps({"response": response})) 412 | else: 413 | return response 414 | 415 | def build_input(self, my_free, history, play_history, on_table, last_play): 416 | temp = np.array([my_free, 4 - history]) 417 | one_hot_last_play = np.eye(self.total_cards)[last_play] 418 | card_feats = np.concatenate((temp, on_table, play_history, one_hot_last_play)) 419 | return card_feats 420 | 421 | def build_result_summary(self, response, response_target): 422 | if response_target.split(' ')[0] == 'CHI': 423 | print(response, response_target) 424 | resp_name = response.split(' ')[0] 425 | resp_target_name = response_target.split(' ')[0] 426 | if resp_target_name == 'GANG': 427 | if len(response_target.split(' ')) > 1: 428 | resp_target_name = 'ANGANG' 429 | else: 430 | resp_target_name = 'MINGGANG' 431 | if resp_name == 'GANG': 432 | if len(response.split(' ')) > 1: 433 | resp_name = 'ANGANG' 434 | else: 435 | resp_name = 'MINGGANG' 436 | self.count[responses[resp_target_name].value] += 1 437 | if response == response_target: 438 | self.match[responses[resp_name].value] += 1 439 | 440 | def simulate_chi_peng(self, request, response, chi_peng_ind, only_feature=False): 441 | last_card_played = self.getCardInd(request[-1]) 442 | available_card_play_mask = np.zeros(self.total_cards, dtype=int) 443 | my_free, on_table = self.hand_free.copy(), self.player_on_table.copy() 444 | if response == responses.CHI: 445 | my_free[chi_peng_ind - 1:chi_peng_ind + 2] -= 1 446 | my_free[last_card_played] += 1 447 | on_table[0][chi_peng_ind - 1:chi_peng_ind + 2] += 1 448 | is_chi = True 449 | else: 450 | chi_peng_ind = last_card_played 451 | my_free[last_card_played] -= 2 452 | on_table[0][last_card_played] += 3 453 | is_chi = False 454 | self.build_available_card_mask(available_card_play_mask, responses.PLAY, last_card_played, 455 | chi_peng_ind=chi_peng_ind, is_chi=is_chi) 456 | card_feats = self.build_input(my_free, self.history, self.player_history, on_table, self.player_last_play) 457 | if only_feature: 458 | action_mask = np.zeros(self.total_actions, dtype=int) 459 | action_mask[responses.PLAY.value] = 1 460 | extra_feats = np.concatenate((self.player_angang[1:], action_mask, [my_free.sum()], *np.eye(4)[[self.quan, self.myPlayerID]], self.tile_count)) 461 | return card_feats, extra_feats, available_card_play_mask 462 | card_play_probs = self.model(card_feats, self.device, decide_cards=True, card_mask=available_card_play_mask) 463 | return card_play_probs 464 | 465 | def build_available_action_mask(self, request): 466 | available_action_mask = np.zeros(self.total_actions, dtype=int) 467 | available_card_mask = np.zeros((self.total_actions, self.total_cards), dtype=int) 468 | requestID = int(request[0]) 469 | playerID = int(request[1]) 470 | myPlayerID = self.myPlayerID 471 | try: 472 | last_card = request[-1] 473 | last_card_ind = self.getCardInd(last_card) 474 | except: 475 | last_card = '' 476 | last_card_ind = 0 477 | # 摸牌回合 478 | if requests(requestID) == requests.drawCard: 479 | for response in [responses.PLAY, responses.ANGANG, responses.BUGANG]: 480 | if self.tile_count[self.myPlayerID] == 0 and response in [responses.ANGANG, responses.BUGANG]: 481 | continue 482 | self.build_available_card_mask(available_card_mask[response.value], response, last_card_ind) 483 | if available_card_mask[response.value].sum() > 0: 484 | available_action_mask[response.value] = 1 485 | # 杠上开花 486 | if requests(int(self.prev_request[0])) in [requests.ANGANG, requests.BUGANG]: 487 | isHu = self.judgeHu(last_card, playerID, True) 488 | # 这里胡的最后一张牌其实不一定是last_card,因为可能是吃了上家胡,需要知道上家到底打的是哪张 489 | else: 490 | isHu = self.judgeHu(last_card, playerID, False) 491 | if isHu >= 8: 492 | available_action_mask[responses.HU.value] = 1 493 | self.fan_count = isHu 494 | else: 495 | available_action_mask[responses.PASS.value] = 1 496 | # 别人出牌 497 | if requests(requestID) in [requests.PENG, requests.CHI, requests.PLAY]: 498 | if playerID != myPlayerID: 499 | for response in [responses.PENG, responses.MINGGANG, responses.CHI]: 500 | # 不是上家 501 | if response == responses.CHI and (self.myPlayerID - playerID) % 4 != 1: 502 | continue 503 | # 最后一张,不能吃碰杠 504 | if self.tile_count[(playerID + 1) % 4] == 0: 505 | continue 506 | self.build_available_card_mask(available_card_mask[response.value], response, last_card_ind) 507 | if available_card_mask[response.value].sum() > 0: 508 | available_action_mask[response.value] = 1 509 | # 是你必须现在决定要不要抢胡 510 | isHu = self.judgeHu(last_card, playerID, False, dianPao=True) 511 | if isHu >= 8: 512 | available_action_mask[responses.HU.value] = 1 513 | self.fan_count = isHu 514 | # 抢杠胡 515 | if requests(requestID) == requests.BUGANG and playerID != myPlayerID: 516 | isHu = self.judgeHu(last_card, playerID, True, dianPao=True) 517 | if isHu >= 8: 518 | available_action_mask[responses.HU.value] = 1 519 | self.fan_count = isHu 520 | return available_action_mask, available_card_mask 521 | 522 | def build_available_card_mask(self, available_card_mask, response, last_card_ind, chi_peng_ind=None, is_chi=False): 523 | if response == responses.PLAY: 524 | # 正常出牌 525 | if chi_peng_ind is None: 526 | for i, card_num in enumerate(self.hand_free): 527 | if card_num > 0: 528 | available_card_mask[i] = 1 529 | else: 530 | # 吃了再出 531 | if is_chi: 532 | for i, card_num in enumerate(self.hand_free): 533 | if i in [chi_peng_ind - 1, chi_peng_ind, chi_peng_ind + 1] and i != last_card_ind: 534 | if card_num > 1: 535 | available_card_mask[i] = 1 536 | elif card_num > 0: 537 | available_card_mask[i] = 1 538 | else: 539 | for i, card_num in enumerate(self.hand_free): 540 | if i == chi_peng_ind: 541 | if card_num > 2: 542 | available_card_mask[i] = 1 543 | elif card_num > 0: 544 | available_card_mask[i] = 1 545 | elif response == responses.PENG: 546 | if self.hand_free[last_card_ind] >= 2: 547 | available_card_mask[last_card_ind] = 1 548 | elif response == responses.CHI: 549 | # 数字牌才可以吃 550 | if last_card_ind < cards.F.value: 551 | card_name = self.getCardName(last_card_ind) 552 | card_number = int(card_name[1]) 553 | for i in [-1, 0, 1]: 554 | middle_card = card_number + i 555 | if middle_card >= 2 and middle_card <= 8: 556 | can_chi = True 557 | for card in range(last_card_ind + i - 1, last_card_ind + i + 2): 558 | if card != last_card_ind and self.hand_free[card] == 0: 559 | can_chi = False 560 | if can_chi: 561 | available_card_mask[last_card_ind + i] = 1 562 | elif response == responses.ANGANG: 563 | for card in range(len(self.hand_free)): 564 | if self.hand_free[card] == 4: 565 | available_card_mask[card] = 1 566 | elif response == responses.MINGGANG: 567 | if self.hand_free[last_card_ind] == 3: 568 | available_card_mask[last_card_ind] = 1 569 | elif response == responses.BUGANG: 570 | for card in range(len(self.hand_free)): 571 | if self.hand_fixed[card] == 3 and self.hand_free[card] == 1: 572 | for card_combo in self.hand_fixed_data: 573 | if card_combo[1] == self.getCardName(card) and card_combo[0] == 'PENG': 574 | available_card_mask[card] = 1 575 | else: 576 | available_card_mask[last_card_ind] = 1 577 | return available_card_mask 578 | 579 | def judgeHu(self, last_card, playerID, isGANG, dianPao=False): 580 | hand = [] 581 | for ind, cardcnt in enumerate(self.hand_free): 582 | for _ in range(cardcnt): 583 | hand.append(self.getCardName(ind)) 584 | if self.history[self.getCardInd(last_card)] == 4: 585 | isJUEZHANG = True 586 | else: 587 | isJUEZHANG = False 588 | if self.tile_count[(playerID + 1) % 4] == 0: 589 | isLAST = True 590 | else: 591 | isLAST = False 592 | if not dianPao: 593 | hand.remove(last_card) 594 | try: 595 | ans = MahjongFanCalculator(tuple(self.hand_fixed_data), tuple(hand), last_card, 0, playerID==self.myPlayerID, 596 | isJUEZHANG, isGANG, isLAST, self.myPlayerID, self.quan) 597 | except Exception as err: 598 | if str(err) == 'ERROR_NOT_WIN': 599 | return 0 600 | else: 601 | if not self.botzone: 602 | print(hand, last_card, self.hand_fixed_data) 603 | print(self.fname) 604 | print(err) 605 | return 0 606 | else: 607 | fan_count = 0 608 | for fan in ans: 609 | fan_count += fan[0] 610 | return fan_count 611 | 612 | def build_hand_history(self, request): 613 | # 第0轮,确定位置 614 | if self.turnID == 0: 615 | _, myPlayerID, quan = request 616 | self.myPlayerID = int(myPlayerID) 617 | self.other_players_id = [(self.myPlayerID - i) % 4 for i in range(4)] 618 | self.player_positions = {} 619 | for position, id in enumerate(self.other_players_id): 620 | self.player_positions[id] = position 621 | self.quan = int(quan) 622 | return request 623 | # 第一轮,发牌 624 | if self.turnID == 1: 625 | for i in range(5, 18): 626 | cardInd = self.getCardInd(request[i]) 627 | self.hand_free[cardInd] += 1 628 | self.history[cardInd] += 1 629 | return request 630 | if int(request[0]) == 3: 631 | request[0] = str(requests[request[2]].value) 632 | elif int(request[0]) == 2: 633 | request.insert(1, self.myPlayerID) 634 | request = self.maintain_status(request, self.hand_free, self.history, self.player_history, 635 | self.player_on_table, self.player_last_play, self.player_angang) 636 | return request 637 | 638 | def maintain_status(self, request, my_free, history, play_history, on_table, last_play, angang): 639 | requestID = int(request[0]) 640 | playerID = int(request[1]) 641 | player_position = self.player_positions[playerID] 642 | if requests(requestID) in [requests.drawCard, requests.DRAW]: 643 | self.tile_count[playerID] -= 1 644 | if requests(requestID) == requests.drawCard: 645 | my_free[self.getCardInd(request[-1])] += 1 646 | history[self.getCardInd(request[-1])] += 1 647 | elif requests(requestID) == requests.PLAY: 648 | play_card = self.getCardInd(request[-1]) 649 | play_history[player_position][play_card] += 1 650 | last_play[player_position] = play_card 651 | # 自己 652 | if player_position == 0: 653 | my_free[play_card] -= 1 654 | else: 655 | history[play_card] += 1 656 | elif requests(requestID) == requests.PENG: 657 | # 上一步一定有play 658 | last_card_ind = self.getCardInd(self.prev_request[-1]) 659 | play_card_ind = self.getCardInd(request[-1]) 660 | on_table[player_position][last_card_ind] = 3 661 | play_history[player_position][play_card_ind] += 1 662 | last_play[player_position] = play_card_ind 663 | if player_position != 0: 664 | history[last_card_ind] += 2 665 | history[play_card_ind] += 1 666 | else: 667 | # 记录peng来源于哪个玩家 668 | last_player = int(self.prev_request[1]) 669 | last_player_pos = self.player_positions[last_player] 670 | self.hand_fixed_data.append(('PENG', self.prev_request[-1], last_player_pos)) 671 | my_free[last_card_ind] -= 2 672 | my_free[play_card_ind] -= 1 673 | elif requests(requestID) == requests.CHI: 674 | # 上一步一定有play 675 | last_card_ind = self.getCardInd(self.prev_request[-1]) 676 | middle_card, play_card = request[3:5] 677 | middle_card_ind = self.getCardInd(middle_card) 678 | play_card_ind = self.getCardInd(play_card) 679 | on_table[player_position][middle_card_ind-1:middle_card_ind+2] += 1 680 | if player_position != 0: 681 | history[middle_card_ind-1:middle_card_ind+2] += 1 682 | history[last_card_ind] -= 1 683 | history[play_card_ind] += 1 684 | else: 685 | # CHI,中间牌名,123代表上家的牌是第几张 686 | self.hand_fixed_data.append(('CHI', middle_card, last_card_ind - middle_card_ind + 2)) 687 | my_free[middle_card_ind-1:middle_card_ind+2] -= 1 688 | my_free[last_card_ind] += 1 689 | my_free[play_card_ind] -= 1 690 | elif requests(requestID) == requests.GANG: 691 | # 暗杠 692 | if requests(int(self.prev_request[0])) in [requests.drawCard, requests.DRAW]: 693 | request[2] = requests.ANGANG.name 694 | if player_position == 0: 695 | gangCard = self.an_gang_card 696 | # print(gangCard) 697 | if gangCard == '' and not self.botzone: 698 | print(self.prev_request) 699 | print(request) 700 | print(self.fname) 701 | gangCardInd = self.getCardInd(gangCard) 702 | # 记录gang来源于哪个玩家(可能来自自己,暗杠) 703 | self.hand_fixed_data.append(('GANG', gangCard, 0)) 704 | on_table[0][gangCardInd] = 4 705 | my_free[gangCardInd] = 0 706 | else: 707 | angang[player_position] += 1 708 | else: 709 | # 明杠 710 | gangCardInd = self.getCardInd(self.prev_request[-1]) 711 | request[2] = requests.MINGGANG.name 712 | history[gangCardInd] = 4 713 | on_table[player_position][gangCardInd] = 4 714 | if player_position == 0: 715 | # 记录gang来源于哪个玩家 716 | last_player = int(self.prev_request[1]) 717 | self.hand_fixed_data.append( 718 | ('GANG', self.prev_request[-1], self.player_positions[last_player])) 719 | my_free[gangCardInd] = 0 720 | elif requests(requestID) == requests.BUGANG: 721 | bugang_card_ind = self.getCardInd(request[-1]) 722 | history[bugang_card_ind] = 4 723 | on_table[player_position][bugang_card_ind] = 4 724 | if player_position == 0: 725 | for id, comb in enumerate(self.hand_fixed_data): 726 | if comb[1] == request[-1]: 727 | self.hand_fixed_data[id] = ('GANG', comb[1], comb[2]) 728 | break 729 | my_free[bugang_card_ind] = 0 730 | return request 731 | 732 | def build_output(self, response, cards_ind): 733 | if (responses.need_cards.value[response.value] == 1 and response != responses.CHI) or response == responses.PENG: 734 | response_name = response.name 735 | if response == responses.ANGANG: 736 | response_name = 'GANG' 737 | return '{} {}'.format(response_name, self.getCardName(cards_ind[0])) 738 | if response == responses.CHI: 739 | return 'CHI {} {}'.format(self.getCardName(cards_ind[0]), self.getCardName(cards_ind[1])) 740 | response_name = response.name 741 | if response == responses.MINGGANG: 742 | response_name = 'GANG' 743 | return response_name 744 | 745 | 746 | def getCardInd(self, cardName): 747 | if cardName[0] == 'H': 748 | print('hua ' + self.fname) 749 | return cards[cardName[0]].value + int(cardName[1]) - 1 750 | 751 | def getCardName(self, cardInd): 752 | num = 1 753 | while True: 754 | if cardInd in cards._value2member_map_: 755 | break 756 | num += 1 757 | cardInd -= 1 758 | return cards(cardInd).name + str(num) 759 | 760 | def train_main(): 761 | train = args.train 762 | load = args.load 763 | save = args.save 764 | model_path = args.load_path 765 | batch_size = args.batch_size 766 | 767 | 768 | my_bot = MahjongHandler(train=train, model_path=model_path, load_model=load, save_model=save, batch_size=batch_size) 769 | count = 0 770 | restore_count = my_bot.round_count 771 | trainning_data_files = os.listdir(args.training_data) 772 | while True: 773 | for fname in trainning_data_files: 774 | if fname[-1] == 'y': 775 | continue 776 | with open('{}/{}'.format(args.training_data, fname), 'r') as f: 777 | rounds_data = json.load(f) 778 | random.shuffle(rounds_data) 779 | for round_data in rounds_data: 780 | for j in range(4): 781 | count += 1 782 | if count < restore_count: 783 | continue 784 | if count % 500 == 0: 785 | print(count) 786 | train_requests = round_data["requests"][j] 787 | first_request = '0 {} {}'.format(j, 0) 788 | train_requests.insert(0, first_request) 789 | train_responses = ['PASS'] + round_data["responses"][j] 790 | for _request, _response in zip(train_requests, train_responses): 791 | my_bot.step_for_train(_request, _response, round_data['fname']) 792 | my_bot.reset() 793 | count = 0 794 | my_bot.round_count = 1 795 | restore_count = 0 796 | 797 | def run_main(): 798 | my_bot = MahjongHandler(train=False, model_path='data/naive_model', load_model=True, save_model=False, botzone=True) 799 | while True: 800 | my_bot.step() 801 | print('>>>BOTZONE_REQUEST_KEEP_RUNNING<<<') 802 | sys.stdout.flush() 803 | 804 | # for competing in env 805 | class agent: 806 | def __init__(self, model_path): 807 | self.my_bot = MahjongHandler(train=False, model_path=model_path, load_model=True, save_model=False, botzone=False) 808 | self.reset() 809 | 810 | def reset(self): 811 | self.turnID = 0 812 | self.my_bot.reset(False) 813 | 814 | def step(self, request): 815 | return self.my_bot.step(request=request) 816 | 817 | if __name__ == '__main__': 818 | if args.botzone: 819 | run_main() 820 | else: 821 | train_main() 822 | -------------------------------------------------------------------------------- /PolicyGradient.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @file: filereader.py 5 | @time: ???? 6 | 强化学习模型 7 | ''' 8 | 9 | import json 10 | from MahjongGB import MahjongFanCalculator 11 | import torch 12 | import torch.optim as optim 13 | from enum import Enum 14 | import numpy as np 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | import sys 18 | import os 19 | import math 20 | from copy import deepcopy 21 | from torch.distributions import Categorical 22 | import torch.multiprocessing as mp 23 | from multiprocessing.managers import BaseManager 24 | import argparse 25 | import time 26 | 27 | curPath = os.path.abspath(os.path.dirname(__file__)) 28 | rootPath = os.path.split(curPath)[0] 29 | sys.path.append(rootPath) 30 | 31 | parser = argparse.ArgumentParser(description='Policy-Gradient Model') 32 | parser.add_argument('-lr', '--learning_rate', type=float, default=1e-5, help='learning rate') 33 | parser.add_argument('-s', '--save', action='store_true', default=False, help='whether to store model') 34 | parser.add_argument('-o', '--old_path', type=str, default='../models/super_model_2', help='path to old model as components') 35 | parser.add_argument('-n', '--new_path', type=str, default='../models/super_model_2', help='path to load training model') 36 | parser.add_argument('-S', '--save_path', type=str, default='../models/rl_pg_2', help='path to save model') 37 | parser.add_argument('-p', '--num_process_per_gpu', type=int, default=1, help='number of processes to run per gpu') 38 | # pi, si的含义为经过多少个episode,若rn=10, rt=10, 则一个episode为10*10=100games 39 | parser.add_argument('-pi', '--print_interval', type=int, default=10, help='how often to print') 40 | parser.add_argument('-si', '--save_interval', type=int, default=20, help='how often to save') 41 | parser.add_argument('-ti', '--train_interval', type=int, default=2, help='how often to train backward') 42 | parser.add_argument('-ji', '--join_interval', type=int, default=2, help='how often to update shared model') 43 | parser.add_argument('--eps_clip', type=float, default=0.3, help='epsilon for p/q clipped') 44 | parser.add_argument('--entropy_weight', type=float, default=1e-3, help='initial entropy loss weight') 45 | parser.add_argument('--entropy_target', type=float, default=1e-4, help='targeted entropy value') 46 | parser.add_argument('--entropy_step', type=float, default=0.01, help='entropy change per step') 47 | parser.add_argument('-rn', '--round_number', type=int, default=10, help='round number to run in parallel') 48 | parser.add_argument('-rt', '--repeated_times', type=int, default=10, help='the repeated times for one round') 49 | parser.add_argument('-e', '--epochs', type=int, default=1, help='training epochs for stored data') 50 | parser.add_argument('-lp', '--log_path', type=str, default='../logs/log.txt', help='log path, set to "none" for no logging') 51 | parser.add_argument('-bs', '--batch_size', type=int, default=1000, help='max batch size to update model') 52 | args = parser.parse_args() 53 | 54 | class requests(Enum): 55 | initialHand = 1 56 | drawCard = 2 57 | DRAW = 4 58 | PLAY = 5 59 | PENG = 6 60 | CHI = 7 61 | GANG = 8 62 | BUGANG = 9 63 | MINGGANG = 10 64 | ANGANG = 11 65 | 66 | class responses(Enum): 67 | PASS = 0 68 | PLAY = 1 69 | HU = 2 70 | # 需要区分明杠和暗杠 71 | MINGGANG = 3 72 | ANGANG = 4 73 | BUGANG = 5 74 | PENG = 6 75 | CHI = 7 76 | need_cards = [0, 1, 0, 0, 1, 1, 0, 1] 77 | loss_weight = [1, 1, 5, 2, 2, 2, 2, 2] 78 | 79 | class cards(Enum): 80 | # 饼万条 81 | B = 0 82 | W = 9 83 | T = 18 84 | # 风 85 | F = 27 86 | # 箭牌 87 | J = 31 88 | 89 | # Memory for a certain kind of action 90 | class memory_for_kind: 91 | def __init__(self): 92 | self.states = {'card_feats': [], 93 | 'extra_feats': [], 94 | 'masks': []} 95 | self.actions = [] 96 | self.rewards = [] 97 | self.logprobs = [] 98 | 99 | def add_memory(self, card_feats=None, extra_feats=None, mask=None, action=None, logprob=None, reward=None): 100 | if card_feats is not None: 101 | self.states['card_feats'].append(card_feats) 102 | self.states['extra_feats'].append(extra_feats) 103 | self.states['masks'].append(mask) 104 | if action is not None: 105 | self.actions.append(action) 106 | if logprob is not None: 107 | self.logprobs.append(logprob) 108 | if reward is not None: 109 | turn_length = len(self.actions)-len(self.rewards) 110 | if turn_length > 0: 111 | self.rewards.extend([reward/turn_length]*turn_length) 112 | 113 | def merge_memory(self, memory): 114 | for key, val in self.states.items(): 115 | val.extend(memory.states[key]) 116 | self.rewards.extend(memory.rewards) 117 | self.actions.extend(memory.actions) 118 | self.logprobs.extend(memory.logprobs) 119 | 120 | def clear_memory(self): 121 | for value in self.states.values(): 122 | del value[:] 123 | del self.actions[:] 124 | del self.rewards[:] 125 | del self.logprobs[:] 126 | 127 | 128 | class Memory: # collected from old policy 129 | def __init__(self): 130 | self.memory = { 131 | 'action': memory_for_kind(), 132 | 'chi_gang': memory_for_kind(), 133 | 'play': memory_for_kind() 134 | } 135 | 136 | def get_memory(self): 137 | return self.memory 138 | 139 | def add_memory(self, kind=None, card_feats=None, extra_feats=None, mask=None, action=None, logprob=None, reward=None): 140 | if kind is None: 141 | for memory in self.memory.values(): 142 | memory.add_memory(card_feats, extra_feats, mask, action, logprob, reward) 143 | else: 144 | self.memory[kind].add_memory(card_feats, extra_feats, mask, action, logprob, reward) 145 | 146 | def merge_memory(self, round_memory): 147 | for kind, memory in round_memory.memory.items(): 148 | self.memory[kind].merge_memory(memory) 149 | 150 | def clear_memory(self): 151 | for memory in self.memory.values(): 152 | memory.clear_memory() 153 | 154 | # Model 155 | class Policy(nn.Module): 156 | def __init__(self, card_feat_depth=14, num_extra_feats=24, num_cards=34, num_actions=8): 157 | super().__init__() 158 | self.entropy_weight = args.entropy_weight 159 | hidden_channels = [8, 16, 32] 160 | hidden_layers_size = [512, 1024] 161 | linear_length = hidden_channels[1] * num_cards * card_feat_depth 162 | self.linear_length = linear_length + num_extra_feats 163 | # self.number_card_net = nn.Sequential( 164 | # nn.Conv2d(3, hidden_channels[0], 3, stride=1, padding=1), 165 | # nn.ReLU(), 166 | # ) 167 | self.card_net = nn.Sequential( 168 | nn.Conv2d(1, hidden_channels[0], 3, stride=1, padding=1), 169 | nn.ReLU(), 170 | nn.Conv2d(hidden_channels[0], hidden_channels[1], 5, stride=1, padding=2), 171 | ) 172 | self.card_play_decision_net = nn.Sequential( 173 | nn.Linear(num_extra_feats + linear_length, hidden_layers_size[0]), 174 | nn.Sigmoid(), 175 | nn.Linear(hidden_layers_size[0], hidden_layers_size[1]), 176 | nn.ReLU(), 177 | nn.Linear(hidden_layers_size[1], num_cards), 178 | nn.Softmax(dim=1) 179 | ) 180 | self.chi_peng_decision_net = nn.Sequential( 181 | nn.Linear(num_extra_feats + linear_length, hidden_layers_size[0]), 182 | nn.ReLU(), 183 | nn.Linear(hidden_layers_size[0], num_cards), 184 | nn.Softmax(dim=1) 185 | ) 186 | self.action_decision_net = nn.Sequential( 187 | nn.Linear(num_extra_feats + linear_length, hidden_layers_size[1]), 188 | nn.ReLU(), 189 | nn.Linear(hidden_layers_size[1], num_actions), 190 | nn.Softmax(dim=1) 191 | ) 192 | 193 | def forward(self, card_feats, extra_feats, device, decide_which, mask): 194 | assert decide_which in ['play', 'chi_gang', 'action'] 195 | card_feats = torch.from_numpy(card_feats).to(device).unsqueeze(1).to(torch.float32) 196 | card_layer = self.card_net(card_feats) 197 | batch_size = card_layer.shape[0] 198 | extra_feats_tensor = torch.from_numpy(extra_feats).to(torch.float32).to(device) 199 | linear_layer = torch.cat((card_layer.view(batch_size, -1), extra_feats_tensor), dim=1) 200 | mask_tensor = torch.from_numpy(mask).to(torch.float32).to(device) 201 | if decide_which == 'play': 202 | card_probs = self.card_play_decision_net(linear_layer) 203 | valid_card_play = self.mask_unavailable_actions(card_probs, mask_tensor) 204 | return valid_card_play 205 | elif decide_which == 'action': 206 | # print(linear_layer.shape) 207 | action_probs = self.action_decision_net(linear_layer) 208 | valid_actions = self.mask_unavailable_actions(action_probs, mask_tensor) 209 | return valid_actions 210 | else: 211 | card_probs = self.chi_peng_decision_net(linear_layer) 212 | valid_card_play = self.mask_unavailable_actions(card_probs, mask_tensor) 213 | return valid_card_play 214 | 215 | def mask_unavailable_actions(self, result, valid_actions_tensor): 216 | replace_nan = torch.isnan(result) 217 | result_no_nan = result.masked_fill(mask=replace_nan, value=1e-9) 218 | masked_actions = result_no_nan * valid_actions_tensor 219 | return masked_actions 220 | 221 | 222 | class MahjongEnv: 223 | def __init__(self, old_model_path=None, shared_model=None, gpu_id=0, shared_memory=None, lock=None): 224 | use_cuda = torch.cuda.is_available() 225 | self.repeated_times = args.repeated_times 226 | self.round_number = args.round_number * self.repeated_times 227 | self.shared_memory = shared_memory 228 | self.local_memory = Memory() 229 | self.memory_this_rounds = [Memory() for _ in range(4 * self.round_number)] 230 | self.device = torch.device("cuda:%d" % gpu_id if use_cuda else "cpu") 231 | print('using ' + str(self.device)) 232 | self.total_cards = 34 233 | self.total_actions = len(responses) - 2 234 | self.model = Policy().to(self.device) 235 | self.old_model = Policy().to(self.device) 236 | self.shared_model = shared_model 237 | self.old_model_code = 0 238 | self.new_model_code = 1 239 | self.models = [self.old_model, self.model] 240 | # 加载预训练好的actor网络 241 | old_check_point = torch.load(old_model_path, map_location=self.device) 242 | old_dict = self.model.state_dict() 243 | for k in old_check_point['model'].keys(): 244 | old_dict[k] = old_check_point['model'][k] 245 | self.old_model.load_state_dict(old_dict) 246 | self.old_model.eval() 247 | self.model.eval() 248 | self.lock = lock 249 | self.gpu_id = gpu_id 250 | self.round_count = 0 251 | self.model.load_state_dict(shared_model.state_dict()) 252 | # state = {'model': self.model.state_dict(), 'optimizer': self.optimizer.state_dict()} 253 | # torch.save(state, self.model_path, _use_new_zipfile_serialization=False) 254 | self.bots = [] 255 | self.total_cards = 34 256 | self.total_actions = len(responses) - 2 257 | self.train_count = 0 258 | self.winners = np.zeros(4, dtype=int) 259 | self.win_steps = [] 260 | self.losses = [] 261 | self.scores = np.zeros((self.round_number, 4), dtype=float) 262 | # 0, 2为老model 263 | for _ in range(self.round_number): 264 | bots = [] 265 | for i in range(4): 266 | if i % 2 == self.old_model_code: 267 | bots.append(MahjongHandler(self.old_model, self.device)) 268 | else: 269 | bots.append(MahjongHandler(self.model, self.device)) 270 | self.bots.append(bots) 271 | self.reset(True) 272 | 273 | def reset(self, initial=False): 274 | self.tile_walls = [] 275 | self.quans = [] 276 | self.mens = [] 277 | self.bots_orders = [] 278 | # 初始摸牌bot为0 279 | self.drawers = [] 280 | all_tiles = np.arange(self.total_cards) 281 | # 构建牌墙 282 | all_tiles = all_tiles.repeat(4) 283 | for round_id in range(self.round_number): 284 | if round_id % self.repeated_times == 0: 285 | np.random.shuffle(all_tiles) 286 | quan = np.random.choice(4) 287 | men = np.random.choice(4) 288 | # 用pop,从后面摸牌 289 | self.tile_walls.append(np.reshape(all_tiles, (4, -1)).tolist()) 290 | self.quans.append(quan) 291 | self.mens.append(men) 292 | # 这一局bots的order,牌墙永远下标和bot一致 293 | self.bots_orders.append([self.bots[round_id][(i + self.mens[-1]) % 4] for i in range(4)]) 294 | self.drawers.append(0) 295 | 296 | if not initial: 297 | with self.lock: 298 | # 每局结束将memory汇总 299 | self.shared_memory.merge_memory(self.local_memory) 300 | self.local_memory.clear_memory() 301 | self.round_count += self.round_number 302 | for bots in self.bots_orders: 303 | for bot in bots: 304 | bot.reset() 305 | self.scores = np.zeros((self.round_number, 4), dtype=float) 306 | 307 | def run_round(self): 308 | turnID = 0 309 | player_responses = [['PASS'] * 4 for _ in range(self.round_number)] 310 | finished = np.zeros(self.round_number, dtype=int) 311 | state_dimension = 5 312 | while finished.sum() < self.round_number: 313 | data = { 314 | # card_feats, extra_feats, action_mask, card_mask, bot 315 | self.old_model_code: [[], [], [], [], []], 316 | self.new_model_code: [[], [], [], [], []], 317 | # which bot the data came from 318 | "order": [[], []] 319 | } 320 | lengths = { 321 | self.old_model_code: 0, 322 | self.new_model_code: 0 323 | } 324 | for round_id in range(self.round_number): 325 | if finished[round_id]: 326 | continue 327 | if turnID == 0: 328 | for id, player in enumerate(self.bots_orders[round_id]): 329 | player.step('0 %d %d' % (id, self.quans[round_id])) 330 | elif turnID == 1: 331 | for id, player in enumerate(self.bots_orders[round_id]): 332 | request = ['1'] 333 | for i in range(4): 334 | request.append('0') 335 | for i in range(13): 336 | request.append(self.getCardName(self.tile_walls[round_id][id].pop())) 337 | request = ' '.join(request) 338 | player.step(request) 339 | else: 340 | requests = self.parse_response(player_responses[round_id], round_id) 341 | if requests[0] in ['hu', 'huangzhuang']: 342 | outcome = requests[0] 343 | if outcome == 'hu': 344 | winner_id = int(requests[1]) 345 | self.winners[self.getBotInd(round_id, winner_id)] += 1 346 | self.win_steps.append(turnID) 347 | fan_count = int(requests[2]) 348 | dianpaoer = requests[3] 349 | if dianpaoer == 'None': 350 | dianpaoer = None 351 | if dianpaoer is not None: 352 | dianpaoer = int(dianpaoer) 353 | self.calculate_scores(round_id, winner_id, dianpaoer, fan_count, mode='naive') 354 | finished[round_id] = 1 355 | else: 356 | for i in range(4): 357 | this_bot = self.bots_orders[round_id][i] 358 | this_model = this_bot.model 359 | model_code = self.models.index(this_model) 360 | state = self.bots_orders[round_id][i].step(requests[i]) 361 | for j in range(state_dimension): 362 | data[model_code][j].append(state[j]) 363 | data["order"][model_code].append((round_id, i)) 364 | lengths[model_code] += 1 365 | 366 | # all bots stepped, if turnID > 1, should produce responses through model 367 | if turnID > 1: 368 | for key in [self.old_model_code, self.new_model_code]: 369 | actions, cards = self.step(data[key], data["order"][key], self.models[key]) 370 | action_num = len(actions) 371 | for i in range(action_num): 372 | round_id, bot_id = data["order"][key][i] 373 | response = self.bots_orders[round_id][bot_id].build_response(actions[i], cards[i]) 374 | player_responses[round_id][bot_id] = response 375 | turnID += 1 376 | 377 | # end while, all games finished, merge all bots' memory to local memory 378 | self.regularize_scores() 379 | for i in range(self.round_number): 380 | for j in range(4): 381 | if self.scores[i][j] != 0 and self.getBotInd(i, j) % 2 == self.new_model_code: 382 | self.memory_this_rounds[i*4+j].add_memory(reward=self.scores[i][j]) 383 | self.local_memory.merge_memory(self.memory_this_rounds[i*4+j]) 384 | self.memory_this_rounds[i*4+j].clear_memory() 385 | 386 | def step(self, state, order, model): 387 | number_of_data = len(state[0]) 388 | actions = np.zeros(number_of_data, dtype=int) 389 | cards = [[] for _ in range(number_of_data)] 390 | training_data = { 391 | 'action': { 392 | 'card_feats': [], 393 | 'extra_feats': [], 394 | 'masks': [], 395 | 'mapping': [] 396 | }, 397 | 'play': { 398 | 'card_feats': [], 399 | 'extra_feats': [], 400 | 'masks': [], 401 | 'mapping': [] 402 | }, 403 | 'chi_gang': { 404 | 'card_feats': [], 405 | 'extra_feats': [], 406 | 'masks': [], 407 | 'mapping': [] 408 | } 409 | } 410 | 411 | # collect all bots' states 412 | def collect_data(kind, card_feats, extra_feats, mask, idx): 413 | if mask.sum() > 1: 414 | training_data[kind]['card_feats'].append(card_feats) 415 | training_data[kind]['masks'].append(mask) 416 | training_data[kind]['extra_feats'].append(extra_feats) 417 | training_data[kind]['mapping'].append(idx) 418 | else: 419 | action = np.argmax(mask) 420 | if kind == 'action': 421 | actions[idx] = action 422 | else: 423 | cards[idx].append(action) 424 | 425 | # run forward with data together 426 | def run_forward(kind): 427 | if len(training_data[kind]['card_feats']) == 0: 428 | return 429 | probs = model( 430 | np.array(training_data[kind]['card_feats']), 431 | np.array(training_data[kind]['extra_feats']), 432 | self.device, 433 | kind, 434 | np.array(training_data[kind]['masks'])) 435 | probs_dist = Categorical(probs) 436 | if model != self.old_model: 437 | action_tensor = probs_dist.sample() 438 | else: 439 | action_tensor = torch.argmax(probs, dim=1) 440 | log_probs = probs_dist.log_prob(action_tensor).detach().cpu().numpy() 441 | action_numpy = action_tensor.cpu().numpy() 442 | action_idx = training_data[kind]['mapping'] 443 | for idx, action in zip(action_idx, action_numpy): 444 | if kind == 'action': 445 | actions[idx] = int(action) 446 | else: 447 | cards[idx].append(int(action)) 448 | if model != self.old_model: 449 | for idx, action, log_prob, card_feat, extra_feat, mask in zip(action_idx, action_numpy, log_probs, 450 | training_data[kind]['card_feats'], training_data[kind]['extra_feats'], training_data[kind]['masks']): 451 | if log_prob >= -0.1: 452 | continue 453 | original_order = order[idx][0] * 4 + order[idx][1] 454 | self.memory_this_rounds[original_order].add_memory(kind, card_feat, extra_feat, mask, action, log_prob) 455 | kind = 'action' 456 | for i, card_feats, extra_feats, available_action_mask, _, _ in zip(np.arange(number_of_data), *state): 457 | collect_data(kind, card_feats, extra_feats, available_action_mask, i) 458 | 459 | run_forward(kind) 460 | 461 | kind = 'chi_gang' 462 | for i, card_feats, extra_feats, _, available_card_mask, _ in zip(np.arange(number_of_data), *state): 463 | action = actions[i] 464 | if responses(action) in [responses.CHI, responses.ANGANG, responses.BUGANG]: 465 | card_mask = available_card_mask[action] 466 | collect_data(kind, card_feats, extra_feats, card_mask, i) 467 | 468 | run_forward(kind) 469 | 470 | kind = 'play' 471 | 472 | for i, card_feats, extra_feats, _, available_card_mask, bot in zip(np.arange(number_of_data), *state): 473 | action = actions[i] 474 | if responses(action) in [responses.PLAY, responses.CHI, responses.PENG]: 475 | if responses(action) == responses.PLAY: 476 | card_mask = available_card_mask[action] 477 | else: 478 | request = bot.prev_request 479 | if responses(action) == responses.CHI: 480 | chi_peng_ind = cards[i][0] 481 | else: 482 | chi_peng_ind = self.getCardInd(request[-1]) 483 | card_feats, extra_feats, card_mask = bot.simulate_chi_peng(request, responses(action), chi_peng_ind) 484 | collect_data(kind, card_feats, extra_feats, card_mask, i) 485 | 486 | run_forward(kind) 487 | 488 | return actions, cards 489 | 490 | 491 | # 不和牌,分数都是0,不会调用这个函数 492 | def calculate_scores(self, round_id, winner_id=0, dianpaoer=None, fan_count=0, difen=8, mode='naive'): 493 | assert mode in ['naive', 'botzone'] 494 | if mode == 'botzone': 495 | for i in range(4): 496 | if i == winner_id: 497 | self.scores[round_id][i] = 10 498 | if dianpaoer is None: 499 | # 自摸 500 | self.scores[round_id][i] = 3 * (difen + fan_count) 501 | else: 502 | self.scores[round_id][i] = 3 * difen + fan_count 503 | else: 504 | if dianpaoer is None: 505 | self.scores[round_id][i] = -0.5 * (difen + fan_count) 506 | else: 507 | if i == dianpaoer: 508 | self.scores[round_id][i] = -2 * (difen + fan_count) 509 | else: 510 | self.scores[round_id][i] = -0.5 * difen 511 | else: 512 | for i in range(4): 513 | if i == winner_id: 514 | self.scores[round_id][i] = fan_count 515 | elif i == dianpaoer: 516 | self.scores[round_id][i] = -0.5 * fan_count 517 | else: 518 | self.scores[round_id][i] = 0 519 | 520 | # 将得分减去平均得分 521 | def regularize_scores(self): 522 | # print(self.scores) 523 | for i in range(self.round_number // self.repeated_times): 524 | start_round = i * self.repeated_times 525 | end_round = (i + 1) * self.repeated_times 526 | self.scores[start_round:end_round] -= np.mean(self.scores[start_round:end_round], axis=0) 527 | # print(self.scores) 528 | 529 | def parse_response(self, player_responses, round_id): 530 | requests = [] 531 | for id, response in enumerate(player_responses): 532 | response = response.split(' ') 533 | response_name = response[0] 534 | if response_name == 'HU': 535 | return ['hu', id, response[1], response[2]] 536 | if response_name == 'PENG': 537 | requests = [] 538 | for i in range(4): 539 | requests.append('3 %d PENG %s' % (id, response[1])) 540 | self.drawers[round_id] = (id + 1) % 4 541 | break 542 | if response_name == "GANG": 543 | requests = [] 544 | for i in range(4): 545 | requests.append('3 %d GANG' % (id)) 546 | self.drawers[round_id] = id 547 | break 548 | if response_name == 'CHI': 549 | for i in range(4): 550 | requests.append('3 %d CHI %s %s' % (id, response[1], response[2])) 551 | self.drawers[round_id] = (id + 1) % 4 552 | if response_name == 'PLAY': 553 | for i in range(4): 554 | requests.append('3 %d PLAY %s' % (id, response[1])) 555 | self.drawers[round_id] = (id + 1) % 4 556 | if response_name == 'BUGANG': 557 | for i in range(4): 558 | requests.append('3 %d BUGANG %s' % (id, response[1])) 559 | self.drawers[round_id] = id 560 | # 所有人pass,摸牌 561 | if len(requests) == 0: 562 | if len(self.tile_walls[round_id][self.drawers[round_id]]) == 0: 563 | return ['huangzhuang', 0] 564 | draw_card = self.tile_walls[round_id][self.drawers[round_id]].pop() 565 | for i in range(4): 566 | if i == self.drawers[round_id]: 567 | requests.append('2 %s' % self.getCardName(draw_card)) 568 | else: 569 | requests.append('3 %d DRAW' % self.drawers[round_id]) 570 | return requests 571 | 572 | def print_log(self, current_round, print_interval, winners, time_cost): 573 | win_sum = sum(winners) 574 | total_rounds = current_round * self.round_number 575 | rounds_this_stage = print_interval * self.round_number 576 | print( 577 | 'total rounds: {}, during the last {} rounds, new bot winning rate: {:.2%}, old bot winning rate: {:.2%}\n' 578 | 'Hu {} rounds,Huang-zhuang {} rounds,hu ratio {:.2%},average rounds to hu: {}, took {:.2f} minutes per 10000 rounds'.format( 579 | total_rounds, rounds_this_stage, 580 | sum(winners[1::2]) / win_sum, 581 | sum(winners[::2]) / win_sum, win_sum, 582 | rounds_this_stage - win_sum, 583 | win_sum / rounds_this_stage, 584 | sum(self.win_steps) / len(self.win_steps), 585 | time_cost 586 | )) 587 | if args.log_path != 'none': 588 | with open(args.log_path, 'a+') as f: 589 | print( 590 | '{} {} {:.2%} {:.2%} {} {} {:.2%} {} {:.2f}'.format( 591 | total_rounds, rounds_this_stage, 592 | sum(winners[1::2]) / win_sum, 593 | sum(winners[::2]) / win_sum, win_sum, 594 | rounds_this_stage - win_sum, 595 | win_sum / rounds_this_stage, 596 | sum(self.win_steps) / len(self.win_steps), 597 | time_cost 598 | ), file=f) 599 | self.win_steps = [] 600 | 601 | def getBotInd(self, round_id, bot_id): 602 | return self.bots[round_id].index(self.bots_orders[round_id][bot_id]) 603 | 604 | def getCardInd(self, cardName): 605 | return cards[cardName[0]].value + int(cardName[1]) - 1 606 | 607 | def getCardName(self, cardInd): 608 | num = 1 609 | while True: 610 | if cardInd in cards._value2member_map_: 611 | break 612 | num += 1 613 | cardInd -= 1 614 | return cards(cardInd).name + str(num) 615 | 616 | 617 | # 维护对局环境 618 | class MahjongHandler: 619 | def __init__(self, model, device): 620 | self.total_cards = 34 621 | self.total_actions = len(responses) - 2 622 | self.model = model 623 | self.optimizer = optim 624 | self.device = device 625 | self.reset() 626 | 627 | def reset(self): 628 | self.hand_free = np.zeros(self.total_cards, dtype=int) 629 | self.history = np.zeros(self.total_cards, dtype=int) 630 | self.player_history = np.zeros((4, self.total_cards), dtype=int) 631 | self.player_on_table = np.zeros((4, self.total_cards), dtype=int) 632 | self.hand_fixed = self.player_on_table[0] 633 | self.player_last_play = np.zeros(4, dtype=int) 634 | self.player_angang = np.zeros(4, dtype=int) 635 | self.fan_count = 0 636 | self.hand_fixed_data = [] 637 | self.turnID = 0 638 | self.tile_count = [21, 21, 21, 21] 639 | self.myPlayerID = 0 640 | self.quan = 0 641 | self.prev_request = '' 642 | self.an_gang_card = '' 643 | self.dianpaoer = None 644 | self.doc_data = { 645 | "old_probs": [], 646 | "new_probs": [], 647 | "entropy": [] 648 | } 649 | 650 | def step(self, request): 651 | if request is None: 652 | if self.turnID == 0: 653 | inputJSON = json.loads(input()) 654 | request = inputJSON['requests'][0].split(' ') 655 | else: 656 | request = input().split(' ') 657 | else: 658 | request = request.split(' ') 659 | 660 | request = self.build_hand_history(request) 661 | if self.turnID <= 1: 662 | self.prev_request = request 663 | self.turnID += 1 664 | return 665 | else: 666 | available_action_mask, available_card_mask = self.build_available_action_mask(request) 667 | card_feats = self.build_input(self.hand_free, self.history, self.player_history, 668 | self.player_on_table, self.player_last_play) 669 | extra_feats = np.concatenate((self.player_angang[1:], available_action_mask, 670 | [self.hand_free.sum()], *np.eye(4)[[self.quan, self.myPlayerID]], 671 | self.tile_count)) 672 | self.prev_request = request 673 | self.turnID += 1 674 | return card_feats, extra_feats, available_action_mask, available_card_mask, self 675 | 676 | 677 | 678 | def build_input(self, my_free, history, play_history, on_table, last_play): 679 | temp = np.array([my_free, 4 - history]) 680 | one_hot_last_play = np.eye(self.total_cards)[last_play] 681 | card_feats = np.concatenate((temp, on_table, play_history, one_hot_last_play)) 682 | return card_feats 683 | 684 | def build_response(self, action, cards): 685 | response = self.build_output(responses(action), cards) 686 | if responses(action) == responses.ANGANG: 687 | self.an_gang_card = self.getCardName(cards[0]) 688 | self.turnID += 1 689 | self.response = response 690 | return response 691 | 692 | def simulate_chi_peng(self, request, response, chi_peng_ind): 693 | last_card_played = self.getCardInd(request[-1]) 694 | available_card_play_mask = np.zeros(self.total_cards, dtype=int) 695 | my_free, on_table = self.hand_free.copy(), self.player_on_table.copy() 696 | if response == responses.CHI: 697 | my_free[chi_peng_ind - 1:chi_peng_ind + 2] -= 1 698 | my_free[last_card_played] += 1 699 | on_table[0][chi_peng_ind - 1:chi_peng_ind + 2] += 1 700 | is_chi = True 701 | else: 702 | chi_peng_ind = last_card_played 703 | my_free[last_card_played] -= 2 704 | on_table[0][last_card_played] += 3 705 | is_chi = False 706 | self.build_available_card_mask(available_card_play_mask, responses.PLAY, last_card_played, 707 | chi_peng_ind=chi_peng_ind, is_chi=is_chi) 708 | card_feats = self.build_input(my_free, self.history, self.player_history, on_table, self.player_last_play) 709 | 710 | action_mask = np.zeros(self.total_actions, dtype=int) 711 | action_mask[responses.PLAY.value] = 1 712 | extra_feats = np.concatenate((self.player_angang[1:], action_mask, [my_free.sum()], *np.eye(4)[[self.quan, self.myPlayerID]], self.tile_count)) 713 | return card_feats, extra_feats, available_card_play_mask 714 | 715 | def build_available_action_mask(self, request): 716 | available_action_mask = np.zeros(self.total_actions, dtype=int) 717 | available_card_mask = np.zeros((self.total_actions, self.total_cards), dtype=int) 718 | requestID = int(request[0]) 719 | playerID = int(request[1]) 720 | myPlayerID = self.myPlayerID 721 | try: 722 | last_card = request[-1] 723 | last_card_ind = self.getCardInd(last_card) 724 | except: 725 | last_card = '' 726 | last_card_ind = 0 727 | # 摸牌回合 728 | if requests(requestID) == requests.drawCard: 729 | for response in [responses.PLAY, responses.ANGANG, responses.BUGANG]: 730 | if self.tile_count[self.myPlayerID] == 0 and response in [responses.ANGANG, responses.BUGANG]: 731 | continue 732 | self.build_available_card_mask(available_card_mask[response.value], response, last_card_ind) 733 | if available_card_mask[response.value].sum() > 0: 734 | available_action_mask[response.value] = 1 735 | # 杠上开花 736 | if requests(int(self.prev_request[0])) in [requests.ANGANG, requests.BUGANG]: 737 | isHu = self.judgeHu(last_card, playerID, True) 738 | # 这里胡的最后一张牌其实不一定是last_card,因为可能是吃了上家胡,需要知道上家到底打的是哪张 739 | else: 740 | isHu = self.judgeHu(last_card, playerID, False) 741 | if isHu >= 8: 742 | available_action_mask[responses.HU.value] = 1 743 | self.fan_count = isHu 744 | else: 745 | available_action_mask[responses.PASS.value] = 1 746 | # 别人出牌 747 | if requests(requestID) in [requests.PENG, requests.CHI, requests.PLAY]: 748 | if playerID != myPlayerID: 749 | for response in [responses.PENG, responses.MINGGANG, responses.CHI]: 750 | # 不是上家 751 | if response == responses.CHI and (self.myPlayerID - playerID) % 4 != 1: 752 | continue 753 | # 最后一张,不能吃碰杠 754 | if self.tile_count[(playerID + 1) % 4] == 0: 755 | continue 756 | self.build_available_card_mask(available_card_mask[response.value], response, last_card_ind) 757 | if available_card_mask[response.value].sum() > 0: 758 | available_action_mask[response.value] = 1 759 | # 是你必须现在决定要不要抢胡 760 | isHu = self.judgeHu(last_card, playerID, False, dianPao=True) 761 | if isHu >= 8: 762 | available_action_mask[responses.HU.value] = 1 763 | self.fan_count = isHu 764 | # 抢杠胡 765 | if requests(requestID) == requests.BUGANG and playerID != myPlayerID: 766 | isHu = self.judgeHu(last_card, playerID, True, dianPao=True) 767 | if isHu >= 8: 768 | available_action_mask[responses.HU.value] = 1 769 | self.fan_count = isHu 770 | return available_action_mask, available_card_mask 771 | 772 | def build_available_card_mask(self, available_card_mask, response, last_card_ind, chi_peng_ind=None, is_chi=False): 773 | if response == responses.PLAY: 774 | # 正常出牌 775 | if chi_peng_ind is None: 776 | for i, card_num in enumerate(self.hand_free): 777 | if card_num > 0: 778 | available_card_mask[i] = 1 779 | else: 780 | # 吃了再出 781 | if is_chi: 782 | for i, card_num in enumerate(self.hand_free): 783 | if i in [chi_peng_ind - 1, chi_peng_ind, chi_peng_ind + 1] and i != last_card_ind: 784 | if card_num > 1: 785 | available_card_mask[i] = 1 786 | elif card_num > 0: 787 | available_card_mask[i] = 1 788 | else: 789 | for i, card_num in enumerate(self.hand_free): 790 | if i == chi_peng_ind: 791 | if card_num > 2: 792 | available_card_mask[i] = 1 793 | elif card_num > 0: 794 | available_card_mask[i] = 1 795 | elif response == responses.PENG: 796 | if self.hand_free[last_card_ind] >= 2: 797 | available_card_mask[last_card_ind] = 1 798 | elif response == responses.CHI: 799 | # 数字牌才可以吃 800 | if last_card_ind < cards.F.value: 801 | card_name = self.getCardName(last_card_ind) 802 | card_number = int(card_name[1]) 803 | for i in [-1, 0, 1]: 804 | middle_card = card_number + i 805 | if middle_card >= 2 and middle_card <= 8: 806 | can_chi = True 807 | for card in range(last_card_ind + i - 1, last_card_ind + i + 2): 808 | if card != last_card_ind and self.hand_free[card] == 0: 809 | can_chi = False 810 | if can_chi: 811 | available_card_mask[last_card_ind + i] = 1 812 | elif response == responses.ANGANG: 813 | for card in range(len(self.hand_free)): 814 | if self.hand_free[card] == 4: 815 | available_card_mask[card] = 1 816 | elif response == responses.MINGGANG: 817 | if self.hand_free[last_card_ind] == 3: 818 | available_card_mask[last_card_ind] = 1 819 | elif response == responses.BUGANG: 820 | for card in range(len(self.hand_free)): 821 | if self.hand_fixed[card] == 3 and self.hand_free[card] == 1: 822 | for card_combo in self.hand_fixed_data: 823 | if card_combo[1] == self.getCardName(card) and card_combo[0] == 'PENG': 824 | available_card_mask[card] = 1 825 | else: 826 | available_card_mask[last_card_ind] = 1 827 | return available_card_mask 828 | 829 | def judgeHu(self, last_card, playerID, isGANG, dianPao=False): 830 | hand = [] 831 | for ind, cardcnt in enumerate(self.hand_free): 832 | for _ in range(cardcnt): 833 | hand.append(self.getCardName(ind)) 834 | if self.history[self.getCardInd(last_card)] == 4: 835 | isJUEZHANG = True 836 | else: 837 | isJUEZHANG = False 838 | if self.tile_count[(playerID + 1) % 4] == 0: 839 | isLAST = True 840 | else: 841 | isLAST = False 842 | if not dianPao: 843 | hand.remove(last_card) 844 | try: 845 | ans = MahjongFanCalculator(tuple(self.hand_fixed_data), tuple(hand), last_card, 0, playerID==self.myPlayerID, 846 | isJUEZHANG, isGANG, isLAST, self.myPlayerID, self.quan) 847 | except Exception as err: 848 | # print(hand, last_card, self.hand_fixed_data) 849 | # print(err) 850 | if str(err) == 'ERROR_NOT_WIN': 851 | return 0 852 | else: 853 | with open('error.txt', 'a+') as f: 854 | print(self.prev_request, file=f) 855 | print(self.response, file=f) 856 | print(hand, last_card, self.hand_fixed_data, file=f) 857 | print(err, file=f) 858 | print(self.prev_request) 859 | print(self.response) 860 | print(hand, last_card, self.hand_fixed_data) 861 | return 0 862 | else: 863 | fan_count = 0 864 | # with open('hu.txt', 'a+') as f: 865 | # print(ans, file=f) 866 | for fan in ans: 867 | fan_count += fan[0] 868 | if dianPao: 869 | self.dianpaoer = playerID 870 | return fan_count 871 | 872 | def build_hand_history(self, request): 873 | # 第0轮,确定位置 874 | if self.turnID == 0: 875 | _, myPlayerID, quan = request 876 | self.myPlayerID = int(myPlayerID) 877 | self.other_players_id = [(self.myPlayerID - i) % 4 for i in range(4)] 878 | self.player_positions = {} 879 | for position, id in enumerate(self.other_players_id): 880 | self.player_positions[id] = position 881 | self.quan = int(quan) 882 | return request 883 | # 第一轮,发牌 884 | if self.turnID == 1: 885 | for i in range(5, 18): 886 | cardInd = self.getCardInd(request[i]) 887 | self.hand_free[cardInd] += 1 888 | self.history[cardInd] += 1 889 | return request 890 | if int(request[0]) == 3: 891 | request[0] = str(requests[request[2]].value) 892 | elif int(request[0]) == 2: 893 | request.insert(1, str(self.myPlayerID)) 894 | request = self.maintain_status(request, self.hand_free, self.history, self.player_history, 895 | self.player_on_table, self.player_last_play, self.player_angang) 896 | return request 897 | 898 | def maintain_status(self, request, my_free, history, play_history, on_table, last_play, angang): 899 | requestID = int(request[0]) 900 | playerID = int(request[1]) 901 | player_position = self.player_positions[playerID] 902 | if requests(requestID) in [requests.drawCard, requests.DRAW]: 903 | self.tile_count[playerID] -= 1 904 | if requests(requestID) == requests.drawCard: 905 | my_free[self.getCardInd(request[-1])] += 1 906 | history[self.getCardInd(request[-1])] += 1 907 | elif requests(requestID) == requests.PLAY: 908 | play_card = self.getCardInd(request[-1]) 909 | play_history[player_position][play_card] += 1 910 | last_play[player_position] = play_card 911 | # 自己 912 | if player_position == 0: 913 | my_free[play_card] -= 1 914 | else: 915 | history[play_card] += 1 916 | elif requests(requestID) == requests.PENG: 917 | # 上一步一定有play 918 | last_card_ind = self.getCardInd(self.prev_request[-1]) 919 | play_card_ind = self.getCardInd(request[-1]) 920 | on_table[player_position][last_card_ind] = 3 921 | play_history[player_position][play_card_ind] += 1 922 | last_play[player_position] = play_card_ind 923 | if player_position != 0: 924 | history[last_card_ind] += 2 925 | history[play_card_ind] += 1 926 | else: 927 | # 记录peng来源于哪个玩家 928 | last_player = int(self.prev_request[1]) 929 | last_player_pos = self.player_positions[last_player] 930 | self.hand_fixed_data.append(('PENG', self.prev_request[-1], last_player_pos)) 931 | my_free[last_card_ind] -= 2 932 | my_free[play_card_ind] -= 1 933 | elif requests(requestID) == requests.CHI: 934 | # 上一步一定有play 935 | last_card_ind = self.getCardInd(self.prev_request[-1]) 936 | middle_card, play_card = request[3:5] 937 | middle_card_ind = self.getCardInd(middle_card) 938 | play_card_ind = self.getCardInd(play_card) 939 | on_table[player_position][middle_card_ind-1:middle_card_ind+2] += 1 940 | if player_position != 0: 941 | history[middle_card_ind-1:middle_card_ind+2] += 1 942 | history[last_card_ind] -= 1 943 | history[play_card_ind] += 1 944 | else: 945 | # CHI,中间牌名,123代表上家的牌是第几张 946 | self.hand_fixed_data.append(('CHI', middle_card, last_card_ind - middle_card_ind + 2)) 947 | my_free[middle_card_ind-1:middle_card_ind+2] -= 1 948 | my_free[last_card_ind] += 1 949 | my_free[play_card_ind] -= 1 950 | elif requests(requestID) == requests.GANG: 951 | # 暗杠 952 | if requests(int(self.prev_request[0])) in [requests.drawCard, requests.DRAW]: 953 | request[2] = requests.ANGANG.name 954 | if player_position == 0: 955 | gangCard = self.an_gang_card 956 | # print(gangCard) 957 | if gangCard == '': 958 | print(self.prev_request) 959 | print(request) 960 | gangCardInd = self.getCardInd(gangCard) 961 | # 记录gang来源于哪个玩家(可能来自自己,暗杠) 962 | self.hand_fixed_data.append(('GANG', gangCard, 0)) 963 | on_table[0][gangCardInd] = 4 964 | my_free[gangCardInd] = 0 965 | else: 966 | angang[player_position] += 1 967 | else: 968 | # 明杠 969 | gangCardInd = self.getCardInd(self.prev_request[-1]) 970 | request[2] = requests.MINGGANG.name 971 | history[gangCardInd] = 4 972 | on_table[player_position][gangCardInd] = 4 973 | if player_position == 0: 974 | # 记录gang来源于哪个玩家 975 | last_player = int(self.prev_request[1]) 976 | self.hand_fixed_data.append( 977 | ('GANG', self.prev_request[-1], self.player_positions[last_player])) 978 | my_free[gangCardInd] = 0 979 | elif requests(requestID) == requests.BUGANG: 980 | bugang_card_ind = self.getCardInd(request[-1]) 981 | history[bugang_card_ind] = 4 982 | on_table[player_position][bugang_card_ind] = 4 983 | if player_position == 0: 984 | for id, comb in enumerate(self.hand_fixed_data): 985 | if comb[1] == request[-1]: 986 | self.hand_fixed_data[id] = ('GANG', comb[1], comb[2]) 987 | break 988 | my_free[bugang_card_ind] = 0 989 | return request 990 | 991 | def build_output(self, response, cards_ind): 992 | if (responses.need_cards.value[response.value] == 1 and response != responses.CHI) or response == responses.PENG: 993 | response_name = response.name 994 | if response == responses.ANGANG: 995 | response_name = 'GANG' 996 | return '{} {}'.format(response_name, self.getCardName(cards_ind[0])) 997 | if response == responses.CHI: 998 | return 'CHI {} {}'.format(self.getCardName(cards_ind[0]), self.getCardName(cards_ind[1])) 999 | response_name = response.name 1000 | if response == responses.MINGGANG: 1001 | response_name = 'GANG' 1002 | if response == responses.HU: 1003 | return '{} {} {}'.format(response_name, self.fan_count, self.dianpaoer) 1004 | return response_name 1005 | 1006 | 1007 | def getCardInd(self, cardName): 1008 | return cards[cardName[0]].value + int(cardName[1]) - 1 1009 | 1010 | def getCardName(self, cardInd): 1011 | num = 1 1012 | while True: 1013 | if cardInd in cards._value2member_map_: 1014 | break 1015 | num += 1 1016 | cardInd -= 1 1017 | return cards(cardInd).name + str(num) 1018 | 1019 | 1020 | def ensure_shared_grads(model, shared_model, device): 1021 | """ ensure proper initialization of global grad""" 1022 | # NOTE: due to no backward passes has ever been ran on the global model 1023 | # NOTE: ref: https://discuss.pytorch.org/t/problem-on-variable-grad-data/957 1024 | for shared_param, local_param in zip(shared_model.parameters(), 1025 | model.parameters()): 1026 | if 'cuda' in str(device): 1027 | # GPU 1028 | if local_param.grad is None: 1029 | shared_param._grad = None 1030 | else: 1031 | shared_param._grad = local_param.grad.clone().cpu() # pylint: disable=W0212 1032 | else: 1033 | # CPU 1034 | if shared_param.grad is not None: 1035 | return 1036 | else: 1037 | shared_param._grad = local_param.grad # pylint: disable=W0212 1038 | 1039 | def update(all_memories: dict, model, shared_model, optimizer, device): 1040 | print('-' * 30) 1041 | print('Updating model...') 1042 | model.load_state_dict(shared_model.state_dict()) 1043 | model.zero_grad() 1044 | optimizer.zero_grad() 1045 | for kind, memory in all_memories.items(): 1046 | # Monte Carlo estimation of rewards 1047 | rewards = memory.rewards 1048 | if len(rewards) == 0: 1049 | continue 1050 | 1051 | # Normalize rewards 1052 | rewards = torch.tensor(rewards).to(device, dtype=torch.float32) 1053 | # rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-5) 1054 | 1055 | card_feats = np.array(memory.states['card_feats']) 1056 | feature_size = card_feats.shape[0] 1057 | batch_number = feature_size // args.batch_size + int(bool(feature_size % args.batch_size)) 1058 | print('Training for {} model, training data length {}'.format(kind, feature_size)) 1059 | extra_feats = np.array(memory.states['extra_feats']) 1060 | masks = np.array(memory.states['masks']) 1061 | old_actions = torch.tensor(memory.actions).to(device) 1062 | old_logprobs = torch.tensor(memory.logprobs).to(device) 1063 | # print(old_actions.shape, old_logprobs.shape) 1064 | 1065 | # Train policy for K epochs: sampling and updating 1066 | for _ in range(args.epochs): 1067 | for i in range(batch_number): 1068 | start = i * args.batch_size 1069 | end = min(feature_size, start + args.batch_size) 1070 | # Evaluate old actions and values using current policy 1071 | new_probs = model( 1072 | card_feats[start:end], 1073 | extra_feats[start:end], 1074 | device, 1075 | kind, 1076 | masks[start:end] 1077 | ) 1078 | new_probs_dist = Categorical(new_probs) 1079 | new_logprobs = new_probs_dist.log_prob(old_actions[start:end]) 1080 | entropy = new_probs_dist.entropy() 1081 | 1082 | # Importance ratio: p/q 1083 | ratios = torch.exp(new_logprobs - old_logprobs[start:end].detach()) 1084 | 1085 | # Actor loss using Surrogate loss 1086 | surr1 = ratios * rewards[start:end] 1087 | surr2 = torch.clamp(ratios, 1 - args.eps_clip, 1 + args.eps_clip) * rewards[start:end] 1088 | shared_model.entropy_weight = shared_model.entropy_weight + args.entropy_step * ( 1089 | args.entropy_target - float(entropy.mean().data.cpu())) 1090 | loss = - torch.min(surr1, surr2) - shared_model.entropy_weight * entropy 1091 | # loss = - surr1 - shared_model.entropy_weight * entropy 1092 | 1093 | # Backward gradients 1094 | # optimizer.zero_grad() 1095 | loss.mean().backward() 1096 | torch.nn.utils.clip_grad_norm_(shared_model.parameters(), 10) 1097 | ensure_shared_grads(model, shared_model, device) 1098 | optimizer.step() 1099 | """ 1100 | PolicyGradient.py -p 4 -rn 50 -rt 40 -o models/super_model_2 -n models/rl_pg -S models/rl_pg_new -s -lp logs/new_log -bs 8000 -lr 2e-6 -ti 8 -ji 10 -pi 160 -si 320 -e 1 1101 | """ 1102 | class SharedAdam(optim.Adam): 1103 | # pylint: disable=C0103 1104 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 1105 | super(SharedAdam, self).__init__(params, lr, betas, eps, weight_decay) 1106 | 1107 | for group in self.param_groups: 1108 | for p in group['params']: 1109 | state = self.state[p] 1110 | state['step'] = torch.zeros(1) 1111 | state['exp_avg'] = p.data.new().resize_as_(p.data).zero_() 1112 | state['exp_avg_sq'] = p.data.new().resize_as_(p.data).zero_() 1113 | state['max_exp_avg_sq'] = p.data.new().resize_as_(p.data).zero_() 1114 | 1115 | def share_memory(self): 1116 | for group in self.param_groups: 1117 | for p in group['params']: 1118 | state = self.state[p] 1119 | state['step'].share_memory_() 1120 | state['exp_avg'].share_memory_() 1121 | state['exp_avg_sq'].share_memory_() 1122 | state['max_exp_avg_sq'].share_memory_() 1123 | 1124 | def step(self, closure=None): 1125 | """ 1126 | Performs a single optimization step. 1127 | Args: 1128 | closure (callable, optional): A closure that reevaluates the model 1129 | Returns: 1130 | loss 1131 | """ 1132 | loss = None 1133 | if closure is not None: 1134 | loss = closure() 1135 | 1136 | for group in self.param_groups: 1137 | for p in group['params']: 1138 | if p.grad is None: 1139 | continue 1140 | grad = p.grad.data 1141 | if grad.is_sparse: 1142 | raise RuntimeError( 1143 | 'Adam does not support sparse gradients, please consider SparseAdam instead' 1144 | ) 1145 | amsgrad = group['amsgrad'] 1146 | 1147 | state = self.state[p] 1148 | 1149 | # State initialization 1150 | if len(state) == 0: 1151 | state['step'] = 0 1152 | # Exponential moving average of gradient values 1153 | state['exp_avg'] = torch.zeros_like(p.data) 1154 | # Exponential moving average of squared gradient values 1155 | state['exp_avg_sq'] = torch.zeros_like(p.data) 1156 | if amsgrad: 1157 | # Maintains max of all exp. moving avg. of sq. grad. values 1158 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 1159 | 1160 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 1161 | if amsgrad: 1162 | max_exp_avg_sq = state['max_exp_avg_sq'] 1163 | beta1, beta2 = group['betas'] 1164 | 1165 | state['step'] += 1 1166 | 1167 | if group['weight_decay'] != 0: 1168 | grad = grad.add(group['weight_decay'], p.data) 1169 | 1170 | # Decay the first and second moment running average coefficient 1171 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 1172 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 1173 | if amsgrad: 1174 | # Maintains the maximum of all 2nd moment running avg. till now 1175 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 1176 | # Use the max. for normalizing running avg. of gradient 1177 | denom = max_exp_avg_sq.sqrt().add_(group['eps']) 1178 | else: 1179 | denom = exp_avg_sq.sqrt().add_(group['eps']) 1180 | 1181 | bias_correction1 = 1 - beta1**state['step'] 1182 | bias_correction2 = 1 - beta2**state['step'] 1183 | step_size = group['lr'] * math.sqrt( 1184 | bias_correction2) / bias_correction1 1185 | p.data.addcdiv_(exp_avg, denom, value=-step_size[0]) # first arg has to be scalar 1186 | 1187 | return loss 1188 | 1189 | # thread for training 1190 | def train_thread(gpu_id, shared_model, global_episode_counter, global_training_counter, optimizer, lock, shared_memory, winners): 1191 | # shared_model must be on cpu, so create a duplicate to update on gpu 1192 | model_for_update = Policy().to("cuda:%d" % gpu_id if torch.cuda.is_available() else "cpu") 1193 | for n, p in model_for_update.named_parameters(): 1194 | if 'card_net' in n: 1195 | p.requires_grad = False 1196 | start_time = time.perf_counter() 1197 | start_counter = global_episode_counter.value 1198 | save = args.save 1199 | old_model_path = args.old_path 1200 | save_path = args.save_path 1201 | print_interval = args.print_interval 1202 | save_interval = args.save_interval 1203 | train_interval = args.train_interval 1204 | 1205 | env = MahjongEnv(old_model_path=old_model_path, 1206 | shared_model=shared_model, 1207 | gpu_id=gpu_id, 1208 | shared_memory=shared_memory, 1209 | lock=lock) 1210 | while True: 1211 | env.run_round() 1212 | env.reset(False) 1213 | with lock: 1214 | global_episode_counter.value += 1 1215 | current_round = global_episode_counter.value 1216 | for i in range(4): 1217 | winners[i] += env.winners[i] 1218 | env.winners = np.zeros(4, dtype=int) 1219 | 1220 | # update policy 1221 | if current_round % train_interval == 0: 1222 | update(shared_memory.get_memory(), model_for_update, shared_model, optimizer, env.device) 1223 | shared_memory.clear_memory() 1224 | global_training_counter.value += 1 1225 | 1226 | # update local policy 1227 | if global_training_counter.value - env.train_count >= args.join_interval: 1228 | env.train_count = global_training_counter.value 1229 | env.model.load_state_dict(shared_model.state_dict()) 1230 | 1231 | if current_round % print_interval == 0: 1232 | total_rounds = (current_round - start_counter) * env.round_number 1233 | this_time = time.perf_counter() 1234 | time_cost = (this_time - start_time) / (60 * (total_rounds / 10000)) 1235 | env.print_log(current_round, print_interval, winners, time_cost) 1236 | for i in range(4): 1237 | winners[i] = 0 1238 | 1239 | if save and current_round % save_interval == 0: 1240 | env.losses = [] 1241 | env.model.load_state_dict(shared_model.state_dict()) 1242 | print('total rounds: %d, saving model...' % current_round) 1243 | state = {'model': shared_model.state_dict(), 'optimizer': optimizer.state_dict(), 'counter': current_round} 1244 | torch.save(state, save_path, _use_new_zipfile_serialization=False) 1245 | 1246 | # 这两个类也许是多余的,抄的网上 1247 | class MyManager(BaseManager): 1248 | pass 1249 | 1250 | def ManagerStarter(): 1251 | m = MyManager() 1252 | m.start() 1253 | return m 1254 | 1255 | def main(): 1256 | mp.set_start_method('spawn') # required to avoid Conv2d froze issue 1257 | num_processes_per_gpu = args.num_process_per_gpu 1258 | new_model_path = args.new_path 1259 | lr = args.learning_rate 1260 | shared_model = Policy( 1261 | card_feat_depth=14, 1262 | num_extra_feats=24, 1263 | num_cards=34, 1264 | num_actions=8, 1265 | ) 1266 | for n, p in shared_model.named_parameters(): 1267 | if 'card_net' in n: 1268 | p.requires_grad = False 1269 | checkpoint = torch.load(new_model_path) 1270 | new_dict = shared_model.state_dict() 1271 | for k in checkpoint['model'].keys(): 1272 | new_dict[k] = checkpoint['model'][k] 1273 | shared_model.load_state_dict(new_dict) 1274 | shared_model.share_memory() 1275 | gpu_count = torch.cuda.device_count() 1276 | num_processes = gpu_count * num_processes_per_gpu if gpu_count > 0 else num_processes_per_gpu 1277 | 1278 | lock = mp.Lock() 1279 | optimizer = SharedAdam(shared_model.parameters(), lr=lr) 1280 | optimizer.share_memory() 1281 | 1282 | MyManager.register('Memory', Memory) 1283 | manager = ManagerStarter() 1284 | shared_memory = manager.Memory() 1285 | 1286 | # multiprocesses, Hogwild! style update 1287 | processes = [] 1288 | try: 1289 | init_episode_counter_val = checkpoint['counter'] 1290 | optimizer.load_state_dict(checkpoint['optimizer']) 1291 | except KeyError: 1292 | init_episode_counter_val = 0 1293 | max_winning_rate = 0.0 1294 | global_episode_counter = mp.Value('i', init_episode_counter_val) 1295 | global_training_counter = mp.Value('i', 0) 1296 | winners_count = mp.Array('i', 4, lock=True) 1297 | # each worker_thread creates its own environment and trains agents 1298 | for rank in range(num_processes): 1299 | worker_thread = mp.Process( 1300 | target=train_thread, args=(rank // num_processes_per_gpu, shared_model, global_episode_counter, 1301 | global_training_counter, optimizer, lock, shared_memory, winners_count)) 1302 | worker_thread.daemon = True 1303 | worker_thread.start() 1304 | processes.append(worker_thread) 1305 | time.sleep(2) 1306 | 1307 | # wait for all processes to finish 1308 | try: 1309 | killed_process_count = 0 1310 | for process in processes: 1311 | process.join() 1312 | killed_process_count += 1 if process.exitcode == 1 else 0 1313 | if killed_process_count >= num_processes: 1314 | # exit if only monitor and writer alive 1315 | raise SystemExit 1316 | except (KeyboardInterrupt, SystemExit): 1317 | for process in processes: 1318 | # without killing child process, process.terminate() will cause orphans 1319 | # ref: https://thebearsenal.blogspot.com/2018/01/creation-of-orphan-process-in-linux.html 1320 | # kill_child_processes(process.pid) 1321 | process.terminate() 1322 | process.join() 1323 | 1324 | if __name__ == '__main__': 1325 | main() --------------------------------------------------------------------------------