├── README.md ├── actor_n ├── __pycache__ │ └── model.cpython-36.pyc ├── actor.py ├── game.py ├── guandan ├── kill.sh ├── model.py ├── start.sh └── utils │ ├── __pycache__ │ ├── data_trans.cpython-36.pyc │ ├── data_trans.cpython-38.pyc │ ├── logger.cpython-36.pyc │ ├── logger.cpython-38.pyc │ ├── utils.cpython-36.pyc │ └── utils.cpython-38.pyc │ ├── data_trans.py │ ├── logger.py │ └── utils.py ├── actor_torch ├── __pycache__ │ └── model.cpython-36.pyc ├── actor.py ├── danserver ├── game.py ├── kill.sh ├── model.py ├── q_network.ckpt ├── rekill.sh ├── restart.py ├── restart.sh ├── start.sh └── utils │ ├── __pycache__ │ ├── data_trans.cpython-36.pyc │ ├── logger.cpython-36.pyc │ └── utils.cpython-36.pyc │ ├── data_trans.py │ ├── logger.py │ └── utils.py ├── create_container.sh ├── learner_n ├── __pycache__ │ ├── common.cpython-36.pyc │ └── common.cpython-38.pyc ├── agents │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ └── __init__.cpython-38.pyc │ └── dqn │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── dqn_agent.cpython-36.pyc │ │ ├── dqn_agent.cpython-38.pyc │ │ ├── guandan_agent.cpython-36.pyc │ │ └── guandan_agent.cpython-38.pyc │ │ └── guandan_agent.py ├── build │ ├── conda │ │ └── env_linux.yaml │ └── docker │ │ └── Dockerfile ├── common.py ├── core │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── agent.cpython-36.pyc │ │ ├── agent.cpython-38.pyc │ │ ├── env.cpython-36.pyc │ │ ├── env.cpython-38.pyc │ │ ├── mem_pool.cpython-36.pyc │ │ ├── mem_pool.cpython-38.pyc │ │ ├── model.cpython-36.pyc │ │ ├── model.cpython-38.pyc │ │ ├── registry.cpython-36.pyc │ │ ├── registry.cpython-38.pyc │ │ ├── utils.cpython-36.pyc │ │ └── utils.cpython-38.pyc │ ├── agent.py │ ├── env.py │ ├── mem_pool.py │ ├── model.py │ ├── registry.py │ └── utils.py ├── kill_all.sh ├── kill_learner.sh ├── learner.py ├── models │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── ac_model.cpython-36.pyc │ │ ├── ac_model.cpython-38.pyc │ │ ├── custom_model.cpython-36.pyc │ │ ├── custom_model.cpython-38.pyc │ │ ├── distributions.cpython-36.pyc │ │ ├── distributions.cpython-38.pyc │ │ ├── q_model.cpython-36.pyc │ │ ├── q_model.cpython-38.pyc │ │ ├── tf_v1_model.cpython-36.pyc │ │ ├── tf_v1_model.cpython-38.pyc │ │ ├── utils.cpython-36.pyc │ │ └── utils.cpython-38.pyc │ ├── distributions.py │ ├── q_model.py │ ├── tf_v1_model.py │ └── utils.py ├── start_all.sh └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-38.pyc │ ├── cmdline.cpython-36.pyc │ ├── cmdline.cpython-38.pyc │ ├── logger.cpython-36.pyc │ └── logger.cpython-38.pyc │ ├── cmdline.py │ ├── logger.py │ └── mpi_util.py ├── learner_torch ├── __pycache__ │ ├── common.cpython-38.pyc │ ├── mem_pool.cpython-38.pyc │ ├── model.cpython-38.pyc │ └── ppo.cpython-38.pyc ├── common.py ├── kill_all.sh ├── kill_learner.sh ├── learner.py ├── mem_pool.py ├── model.py ├── ppo.py ├── start_all.sh └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── cmdline.cpython-38.pyc │ ├── logger.cpython-38.pyc │ ├── mpi_pytorch.cpython-38.pyc │ └── mpi_tools.cpython-38.pyc │ ├── cmdline.py │ ├── logger.py │ ├── model_utils.py │ ├── mpi_pytorch.py │ └── mpi_tools.py ├── rm_container.sh ├── wintest ├── ai1 │ ├── action.py │ ├── client0.py │ ├── client1.py │ ├── client2.py │ ├── client3.py │ ├── state.py │ └── utils.py ├── ai2 │ ├── CountValue.py │ ├── CreateActionList.py │ ├── PlayCard.py │ ├── action.py │ ├── client0.py │ ├── client1.py │ ├── client2.py │ ├── client3.py │ ├── config.py │ ├── state.py │ └── strategy.py ├── ai3 │ ├── action.py │ ├── client0.py │ ├── client1.py │ ├── client2.py │ ├── client3.py │ ├── message_Reyn_CUR.py │ └── state.py ├── ai4 │ ├── action.py │ ├── client.py │ ├── client1.py │ ├── client2.py │ ├── client3.py │ ├── client4.py │ ├── state.py │ └── utils.py ├── ai5 │ ├── action.py │ ├── active.py │ ├── back_tribute.py │ ├── client.py │ ├── client1.py │ ├── client2.py │ ├── client3.py │ ├── client4.py │ ├── passive.py │ ├── state.py │ └── utils.py ├── ai6 │ ├── Myfunc1014.py │ ├── __init__.py │ ├── action.py │ ├── action2.py │ ├── action3.py │ ├── action4.py │ ├── client1.py │ ├── client2.py │ ├── client3.py │ ├── client4.py │ ├── data1.txt │ ├── data2.txt │ ├── data3.txt │ ├── data4.txt │ ├── lasthand.py │ ├── state.py │ ├── state2.py │ ├── state3.py │ └── state4.py ├── ai7 │ ├── action.py │ ├── client1.py │ ├── client2.py │ ├── client3.py │ ├── client4.py │ ├── mysolve.py │ └── state.py ├── ai8 │ ├── action.py │ ├── client1.py │ ├── client2.py │ ├── client3.py │ ├── client4.py │ └── state.py ├── create_container.sh ├── danzero │ ├── __pycache__ │ │ ├── model.cpython-36.pyc │ │ └── util.cpython-36.pyc │ ├── actor.py │ ├── client0.py │ ├── client2.py │ ├── model.py │ ├── q_network.ckpt │ └── util.py ├── readme.md └── torch │ ├── __pycache__ │ ├── model.cpython-36.pyc │ └── util.cpython-36.pyc │ ├── actor.py │ ├── client1.py │ ├── client3.py │ ├── danserver │ ├── evaluate_dqn.py │ ├── evaluate_model.py │ ├── kill.sh │ ├── kill_auto.sh │ ├── model.py │ ├── q_network.ckpt │ ├── testmodel.sh │ ├── testvsdqn.sh │ └── util.py └── 离线平台使用说明.pdf /README.md: -------------------------------------------------------------------------------- 1 | As this work is based on "DanZero: Mastering GuanDan Game with Reinforcement Learning", the code is also built on 2 | the repository "https://github.com/AltmanD/guandan_mcc/tree/main". 3 | 4 | ## Install 5 | lib needs(If you just train the DMC model, torch is not required): 6 | 7 | linux20.04 8 | 9 | python=3.8(learner) python=3.6(actor) 10 | 11 | tensorflow=1.15.5+nv22.2(learner) or 1.15.4(actor) 12 | 13 | numpy=1.18.5 14 | 15 | websocket(ws4py)=0.5.1 16 | 17 | pyarrow=5.0.0 18 | 19 | pyzmq=22.3.0 20 | 21 | torch=1.9.1+cpu(actor) or 1.13.1+cu116(learner) 22 | 23 | To realize the communication between dockers, you can refer to https://cloud.tencent.com/developer/article/1013167. 24 | If you use the docker, follow the create_containeder.sh to set docker network. 25 | Then you can enter the learner and use "ssh-keygen -t rsa" to create the pub file and copy this to authorized_keys file in 26 | the actors. After that, edit the /etc/ssh/ssh_config file to set "StrictHostKeyChecking" to be no. 27 | In this way, the dockers can communicate directly. 28 | 29 | For convenience of installation, we offer the actor image at 30 | Link:https://pan.baidu.com/s/1ICAKWF3F-LxraphzqNjYhg?pwd=0704 31 | Extracted code:0704 32 | 33 | ## Run 34 | The direct command to run the code is as below: 35 | 36 | actor: 37 | python actor_n/actor.py 38 | 39 | learner: 40 | python learner_n/learner.py 41 | 42 | Here we offer a start shell file in the learner directory. 43 | 44 | ## Evaluation 45 | 46 | The evaluation code is in the ./wintest directory and we give introduction in the directory. 47 | -------------------------------------------------------------------------------- /actor_n/__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/actor_n/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /actor_n/guandan: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/actor_n/guandan -------------------------------------------------------------------------------- /actor_n/kill.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ps -ef | grep guandan | awk '{print $2}' | xargs kill -9 3 | ps -ef | grep server | awk '{print $2}' | xargs kill -9 4 | ps aux|grep python|grep -v grep|cut -c 9-15|xargs kill -9 5 | ps -ef | grep actor | awk '{print $2}' | xargs kill -9 6 | ps -ef | grep game | awk '{print $2}' | xargs kill -9 7 | ps -ef | grep python | awk '{print $2}' | xargs kill -9 8 | rm /home/luyd/game_out.log 9 | -------------------------------------------------------------------------------- /actor_n/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from tensorflow.keras.backend import get_session 4 | 5 | 6 | def combined_shape(length, shape=None): 7 | if shape is None: 8 | return (length,) 9 | return (length, shape) if np.isscalar(shape) else (length, *shape) 10 | 11 | 12 | def placeholder(dtype=tf.float32, shape=None): 13 | return tf.placeholder(dtype=dtype, shape=combined_shape(None, shape)) 14 | 15 | 16 | def mlp(x, hidden_sizes=(32,), activation=tf.tanh, output_activation=None): 17 | for h in hidden_sizes[:-1]: 18 | x = tf.layers.dense(x, units=h, activation=activation) 19 | return tf.layers.dense(x, units=hidden_sizes[-1], activation=output_activation) 20 | 21 | 22 | class GDModel(): 23 | def __init__(self, observation_space, action_space, config=None, model_id='0', session=None): 24 | with tf.variable_scope(model_id): 25 | self.x_ph = placeholder(shape=observation_space) 26 | # self.z = placeholder(shape=action_space) 27 | # self.zero = placeholder(shape=128) 28 | 29 | # 输出张量 30 | self.values = None 31 | 32 | # Initialize Tensorflow session 33 | if session is None: 34 | session = get_session() 35 | self.sess = session 36 | 37 | self.scope = model_id 38 | self.observation_space = observation_space 39 | self.action_space = action_space 40 | self.model_id = model_id 41 | self.config = config 42 | 43 | # Set configurations 44 | if config is not None: 45 | self.load_config(config) 46 | 47 | # Build up model 48 | self.build() 49 | 50 | # Build assignment ops 51 | self._weight_ph = None 52 | self._to_assign = None 53 | self._nodes = None 54 | self._build_assign() 55 | 56 | # Build saver 57 | self.saver = tf.train.Saver(tf.trainable_variables()) 58 | 59 | # 参数初始化 60 | self.sess.run(tf.global_variables_initializer()) 61 | 62 | def set_weights(self, weights) -> None: 63 | feed_dict = {self._weight_ph[var.name]: weight 64 | for (var, weight) in zip(tf.trainable_variables(self.scope), weights)} 65 | self.sess.run(self._nodes, feed_dict=feed_dict) 66 | 67 | def get_weights(self): 68 | return self.sess.run(tf.trainable_variables(self.scope)) 69 | 70 | def save(self, path) -> None: 71 | self.saver.save(self.sess, str(path)) 72 | 73 | def load(self, path) -> None: 74 | self.saver.restore(self.sess, str(path)) 75 | 76 | def _build_assign(self): 77 | self._weight_ph, self._to_assign = dict(), dict() 78 | variables = tf.trainable_variables(self.scope) 79 | for var in variables: 80 | self._weight_ph[var.name] = tf.placeholder(var.value().dtype, var.get_shape().as_list()) 81 | self._to_assign[var.name] = var.assign(self._weight_ph[var.name]) 82 | self._nodes = list(self._to_assign.values()) 83 | 84 | def forward(self, x_batch): 85 | return self.sess.run(self.values, feed_dict={self.x_ph: x_batch}) 86 | 87 | def build(self) -> None: 88 | with tf.variable_scope(self.scope): 89 | with tf.variable_scope('v'): 90 | self.values = mlp(self.x_ph, [512, 512, 512, 512, 512, 1], activation='tanh', 91 | output_activation=None) 92 | 93 | if __name__ == '__main__': 94 | model = GDModel((567,), (5, 216)) 95 | with open('/home/luyd/guandan/actor_torch/q_network.ckpt', 'rb') as f: 96 | import pickle 97 | new_weights = pickle.load(f) 98 | model.set_weights(new_weights) 99 | b = np.load("/home/luyd/guandan/actor_ppo/debug128.npy", allow_pickle=True).item() 100 | state = b['x_batch'][7] 101 | info = model.forward(state) 102 | info = info.reshape(-1,) 103 | info = info.argsort()[-10:][::-1].tolist() 104 | print(info) 105 | -------------------------------------------------------------------------------- /actor_n/start.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | nohup /home/luyd/guandan/actor_n/guandan 1000000000000 >/dev/null 2>&1 & 3 | sleep 0.5s 4 | nohup /root/miniconda3/envs/guandan/bin/python -u /home/luyd/guandan/actor_n/actor.py > /home/luyd/actor_out.log 2>&1 & 5 | sleep 0.5s 6 | nohup /root/miniconda3/envs/guandan/bin/python -u /home/luyd/guandan/actor_n/game.py > /home/luyd/game_out.log 2>&1 & -------------------------------------------------------------------------------- /actor_n/utils/__pycache__/data_trans.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/actor_n/utils/__pycache__/data_trans.cpython-36.pyc -------------------------------------------------------------------------------- /actor_n/utils/__pycache__/data_trans.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/actor_n/utils/__pycache__/data_trans.cpython-38.pyc -------------------------------------------------------------------------------- /actor_n/utils/__pycache__/logger.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/actor_n/utils/__pycache__/logger.cpython-36.pyc -------------------------------------------------------------------------------- /actor_n/utils/__pycache__/logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/actor_n/utils/__pycache__/logger.cpython-38.pyc -------------------------------------------------------------------------------- /actor_n/utils/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/actor_n/utils/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /actor_n/utils/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/actor_n/utils/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /actor_n/utils/data_trans.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import pickle 4 | import time 5 | from itertools import count 6 | from pathlib import Path 7 | from typing import Any, Tuple 8 | 9 | import zmq 10 | 11 | 12 | def find_new_weights(current_model_id: int, ckpt_path: Path) -> Tuple[Any, int]: 13 | try: 14 | ckpt_files = sorted(os.listdir(ckpt_path), key=lambda p: int(p.split('.')[0])) 15 | latest_file = ckpt_files[-1] 16 | except IndexError: 17 | # No checkpoint file 18 | return None, -1 19 | new_model_id = int(latest_file.split('.')[0]) 20 | 21 | if int(new_model_id) > current_model_id: 22 | loaded = False 23 | while not loaded: 24 | try: 25 | with open(ckpt_path / latest_file, 'rb') as f: 26 | new_weights = pickle.load(f) 27 | loaded = True 28 | except (EOFError, pickle.UnpicklingError): 29 | # The file of weights does not finish writing 30 | pass 31 | 32 | return new_weights, new_model_id 33 | else: 34 | return None, current_model_id 35 | 36 | 37 | def create_experiment_dir(args, prefix: str) -> None: 38 | if args.exp_path is None: 39 | args.exp_path = prefix + datetime.datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d-%H-%M-%S') 40 | args.exp_path = Path(args.exp_path) 41 | 42 | if args.exp_path.exists(): 43 | os.system(f'rm -rf {args.exp_path}') 44 | # raise FileExistsError(f'Experiment directory {str(args.exp_path)!r} already exists') 45 | 46 | args.exp_path.mkdir() 47 | 48 | 49 | def run_weights_subscriber(args, unknown_args): 50 | """Subscribe weights from Learner and save them locally""" 51 | context = zmq.Context() 52 | socket = context.socket(zmq.SUB) 53 | socket.connect(f'tcp://{args.ip}:{args.param_port}') 54 | socket.setsockopt_string(zmq.SUBSCRIBE, '') # Subscribe everything 55 | for model_id in count(1): # Starts from 1 56 | while True: 57 | try: 58 | weights = socket.recv(flags=zmq.NOBLOCK) 59 | # Weights received 60 | with open(args.ckpt_path / f'{model_id}.ckpt', 'wb') as f: 61 | f.write(weights) 62 | 63 | if model_id > args.num_saved_ckpt: 64 | os.remove(args.ckpt_path / f'{model_id - args.num_saved_ckpt}.ckpt') 65 | break 66 | except zmq.Again: 67 | pass 68 | 69 | # For not cpu-intensive 70 | time.sleep(1) 71 | -------------------------------------------------------------------------------- /actor_torch/__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/actor_torch/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /actor_torch/danserver: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/actor_torch/danserver -------------------------------------------------------------------------------- /actor_torch/kill.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ps -ef | grep danserver | grep -v grep | awk '{print $2}' | xargs kill -9 3 | ps -ef | grep actor.py | grep -v grep | awk '{print $2}' | xargs kill -9 4 | ps -ef | grep game.py | grep -v grep | awk '{print $2}' | xargs kill -9 5 | ps -ef | grep restart.py | grep -v grep | awk '{print $2}' | xargs kill -9 6 | -------------------------------------------------------------------------------- /actor_torch/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | #import scipy.signal 3 | import torch 4 | import torch.nn as nn 5 | from torch.distributions.categorical import Categorical 6 | 7 | 8 | def combined_shape(length, shape=None): 9 | if shape is None: 10 | return (length,) 11 | return (length, shape) if np.isscalar(shape) else (length, *shape) 12 | 13 | def orthogonal_init(layer, gain=1.0): 14 | nn.init.orthogonal_(layer.weight, gain=gain) 15 | nn.init.constant_(layer.bias, 0) 16 | 17 | def mlp(sizes, activation, output_activation=nn.Identity, use_init=False): 18 | layers = [] 19 | for j in range(len(sizes)-1): 20 | act = activation if j < len(sizes)-2 else output_activation 21 | if use_init: 22 | net = nn.Linear(sizes[j], sizes[j+1]) 23 | orthogonal_init(net) 24 | layers += [net, act()] 25 | else: 26 | layers += [nn.Linear(sizes[j], sizes[j+1]), act()] 27 | return nn.Sequential(*layers) 28 | 29 | def shared_mlp(obs_dim, sizes, activation, use_init=False): # 分两个叉,一个是过softmax的logits,另一个不过,就是单纯的q(s,a),这里是前面的共享层 30 | layers = [] 31 | shapes = [obs_dim] + list(sizes) 32 | for j in range(len(shapes) - 1): 33 | act = activation 34 | if use_init: 35 | net = nn.Linear(shapes[j], shapes[j+1]) 36 | orthogonal_init(net) 37 | layers += [net, act()] 38 | else: 39 | layers += [nn.Linear(shapes[j], shapes[j + 1]), act()] 40 | return nn.Sequential(*layers) 41 | 42 | 43 | def count_vars(module): 44 | return sum([np.prod(p.shape) for p in module.parameters()]) 45 | 46 | 47 | class Actor(nn.Module): 48 | def _distribution(self, obs): 49 | raise NotImplementedError 50 | 51 | def _log_prob_from_distribution(self, pi, act): 52 | raise NotImplementedError 53 | 54 | def forward(self, obs, act=None, legalaction=torch.tensor(list(range(10))).to(torch.float32)): 55 | # Produce action distributions for given observations, and 56 | # optionally compute the log likelihood of given actions under 57 | # those distributions. 58 | pi = self._distribution(obs, legalaction) 59 | logp_a = None 60 | if act is not None: 61 | logp_a = self._log_prob_from_distribution(pi, act) 62 | return pi, logp_a 63 | 64 | 65 | class MLPCategoricalActor(Actor): 66 | def __init__(self, obs_dim, act_dim, hidden_sizes, activation): 67 | super().__init__() 68 | self.logits_net = mlp([obs_dim] + list(hidden_sizes) + [act_dim], activation) 69 | 70 | def _distribution(self, obs, legal_action): 71 | logits = torch.squeeze(self.logits_net(obs)) - (1 - legal_action) * 1e6 72 | return Categorical(logits=logits) 73 | 74 | def _log_prob_from_distribution(self, pi, act): 75 | return pi.log_prob(act) 76 | 77 | 78 | class MLPCritic(nn.Module): 79 | def __init__(self, obs_dim, hidden_sizes, activation): 80 | super().__init__() 81 | self.v_net = mlp([obs_dim] + list(hidden_sizes) + [1], activation) 82 | 83 | def forward(self, obs): 84 | return torch.squeeze(self.v_net(obs), -1) # Critical to ensure v has right shape. 85 | 86 | 87 | class MLPQ(nn.Module): 88 | def __init__(self, obs_dim, hidden_sizes, activation): 89 | super().__init__() 90 | self.q_net = mlp([obs_dim] + list(hidden_sizes) + [1], activation) 91 | 92 | def forward(self, obs): 93 | return torch.squeeze(self.q_net(obs), -1) # Critical to ensure q has right shape. 94 | 95 | 96 | class MLPActorCritic(nn.Module): 97 | def __init__(self, observation_space, action_space, 98 | hidden_sizes=(512, 512, 512, 512, 256), activation=nn.Tanh): 99 | super().__init__() 100 | 101 | obs_dim = observation_space 102 | self.shared = shared_mlp(obs_dim[1], hidden_sizes, activation, use_init=True) 103 | self.pi = mlp([hidden_sizes[-1], 128, action_space], activation, use_init=True) # 输出logits 104 | self.v = mlp([hidden_sizes[-1], 128, 1], activation, use_init=True) # 输出q(s,a) 105 | 106 | 107 | def step(self, obs, legal_action): 108 | obs = torch.tensor(obs).to(torch.float32) 109 | legal_action = torch.tensor(legal_action).to(torch.float32) 110 | with torch.no_grad(): 111 | shared_feature = self.shared(obs) 112 | # print(shared_feature.shape, legal_action.shape) 113 | logits = torch.squeeze(self.pi(shared_feature)) - (1 - legal_action) * 1e8 114 | #print('share_feature', self.pi(shared_feature).shape, 'logits', logits.shape, 'legal_action', legal_action) 115 | pi = Categorical(logits=logits) 116 | a = pi.sample() 117 | logp_a = pi.log_prob(a) # 该动作的log(pi) 118 | 119 | value = torch.squeeze(self.v(shared_feature), -1) 120 | 121 | return a.numpy(), value.numpy(), logp_a.numpy() 122 | 123 | def act(self, obs): 124 | return self.step(obs)[0] 125 | 126 | def get_weights(self): 127 | return self.state_dict() 128 | 129 | 130 | class MLPQNetwork(nn.Module): 131 | def __init__(self, observation_space, 132 | hidden_sizes=(512, 512, 512, 512, 512), activation=nn.Tanh): 133 | super().__init__() 134 | 135 | obs_dim = observation_space 136 | 137 | # build Q function 138 | self.q = MLPQ(obs_dim, hidden_sizes, activation) 139 | 140 | def load_tf_weights(self, weights): 141 | name = ['q_net.0.weight', 'q_net.0.bias', 'q_net.2.weight', 'q_net.2.bias', 'q_net.4.weight', 'q_net.4.bias', 'q_net.6.weight', 'q_net.6.bias', 'q_net.8.weight', 'q_net.8.bias', 'q_net.10.weight', 'q_net.10.bias'] 142 | tensor_weights = [] 143 | for weight in weights: 144 | temp = torch.tensor(weight).T 145 | tensor_weights.append(temp) 146 | new_weights = dict(zip(name, tensor_weights)) 147 | self.q.load_state_dict(new_weights) 148 | print('load tf weights success') 149 | 150 | def get_max_n_index(self, data, n): 151 | q_list = self.q(torch.tensor(data).to(torch.float32)) 152 | q_list = q_list.detach().numpy() 153 | return q_list.argsort()[-n:][::-1].tolist() 154 | 155 | 156 | if __name__ == '__main__': 157 | model = MLPActorCritic((10, 567), 1) 158 | # state = np.random.random((513, )) 159 | # action1 = np.random.random((54, )) 160 | # action2 = np.random.random((54, )) 161 | # action3 = np.random.random((54, )) 162 | # b = np.load("/home/zhaoyp/guandan_tog/actor_torch/debug145.npy", allow_pickle=True).item() 163 | # print(b.keys()) 164 | # print(b['obs_cut'].shape) 165 | # print(b['obs'].shape) 166 | import objgraph 167 | 168 | # print('time1') 169 | # objgraph.show_most_common_types(limit=30) 170 | # objgraph.show_growth() 171 | 172 | state = np.zeros((10,567)) 173 | legal_index = np.ones(10) 174 | 175 | # print('time2') 176 | # objgraph.show_most_common_types(limit=30) 177 | # objgraph.show_growth() 178 | 179 | # a, v, p = model.step(state, legal_index) 180 | 181 | # print('time3') 182 | # objgraph.show_most_common_types(limit=30) 183 | # objgraph.show_growth() 184 | 185 | # print(a,v,p) 186 | # print(type(a),type(v),type(p)) 187 | -------------------------------------------------------------------------------- /actor_torch/q_network.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/actor_torch/q_network.ckpt -------------------------------------------------------------------------------- /actor_torch/rekill.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ps -ef | grep danserver | grep -v grep | awk '{print $2}' | xargs kill -9 3 | ps -ef | grep actor.py | grep -v grep | awk '{print $2}' | xargs kill -9 4 | ps -ef | grep game.py | grep -v grep | awk '{print $2}' | xargs kill -9 -------------------------------------------------------------------------------- /actor_torch/restart.py: -------------------------------------------------------------------------------- 1 | import os 2 | import psutil 3 | import time 4 | 5 | def find_procs_by_name(name): 6 | "Return a list of processes matching 'name'." 7 | ls = [] 8 | for p in psutil.process_iter(["name", "exe", "cmdline"]): 9 | if name == p.info['name'] or \ 10 | p.info['exe'] and os.path.basename(p.info['exe']) == name or \ 11 | p.info['cmdline'] and p.info['cmdline'][0] == name: 12 | ls.append(p) 13 | return ls 14 | 15 | res = find_procs_by_name('/root/miniconda3/envs/guandan/bin/python') 16 | print(res) 17 | print(len(res)) 18 | 19 | while True: 20 | time.sleep(120) 21 | res = find_procs_by_name('/root/miniconda3/envs/guandan/bin/python') 22 | # res = find_procs_by_name('python') 23 | if len(res) < 10: 24 | print('restart actor') 25 | os.system("bash /home/zhaoyp/guandan_tog/actor_torch/restart.sh") 26 | time.sleep(300) 27 | -------------------------------------------------------------------------------- /actor_torch/restart.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | nohup /home/zhaoyp/guandan_tog/actor_torch/danserver 100000 >/dev/null 2>&1 & 3 | sleep 1s 4 | nohup /root/miniconda3/envs/guandan/bin/python -u /home/zhaoyp/guandan_tog/actor_torch/actor.py > /home/zhaoyp/actor_out.log 2>&1 & 5 | sleep 1s 6 | nohup /root/miniconda3/envs/guandan/bin/python -u /home/zhaoyp/guandan_tog/actor_torch/game.py > /home/zhaoyp/game_out.log 2>&1 & 7 | -------------------------------------------------------------------------------- /actor_torch/start.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | nohup /home/zhaoyp/guandan_tog/actor_torch/danserver 100000 >/dev/null 2>&1 & 3 | sleep 0.5s 4 | nohup /root/miniconda3/envs/guandan/bin/python -u /home/zhaoyp/guandan_tog/actor_torch/actor.py > /home/zhaoyp/actor_out.log 2>&1 & 5 | sleep 0.5s 6 | nohup /root/miniconda3/envs/guandan/bin/python -u /home/zhaoyp/guandan_tog/actor_torch/game.py > /home/zhaoyp/game_out.log 2>&1 & 7 | sleep 0.5s 8 | nohup /root/miniconda3/envs/guandan/bin/python -u /home/zhaoyp/guandan_tog/actor_torch/restart.py > /home/zhaoyp/restart_out.log 2>&1 & 9 | -------------------------------------------------------------------------------- /actor_torch/utils/__pycache__/data_trans.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/actor_torch/utils/__pycache__/data_trans.cpython-36.pyc -------------------------------------------------------------------------------- /actor_torch/utils/__pycache__/logger.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/actor_torch/utils/__pycache__/logger.cpython-36.pyc -------------------------------------------------------------------------------- /actor_torch/utils/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/actor_torch/utils/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /actor_torch/utils/data_trans.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import io 4 | import pickle 5 | import time 6 | from itertools import count 7 | from pathlib import Path 8 | from typing import Any, Tuple 9 | import threading 10 | import torch 11 | import zmq 12 | 13 | 14 | class CPU_Unpickler(pickle.Unpickler): 15 | def find_class(self, module, name): 16 | if module == 'torch.storage' and name == '_load_from_bytes': 17 | return lambda b: torch.load(io.BytesIO(b), map_location='cpu') 18 | else: return super().find_class(module, name) 19 | 20 | 21 | def find_new_weights(current_model_id: int, ckpt_path: Path) -> Tuple[Any, int]: 22 | try: 23 | ckpt_files = sorted(os.listdir(ckpt_path), key=lambda p: int(p.split('.')[0])) 24 | latest_file = ckpt_files[-1] 25 | except IndexError: 26 | # No checkpoint file 27 | return None, -1 28 | new_model_id = int(latest_file.split('.')[0]) 29 | 30 | if int(new_model_id) > current_model_id: 31 | loaded = False 32 | while not loaded: 33 | try: 34 | with open(ckpt_path / latest_file, 'rb') as f: 35 | new_weights = CPU_Unpickler(f).load() 36 | loaded = True 37 | except (EOFError, pickle.UnpicklingError): 38 | # The file of weights does not finish writing 39 | pass 40 | return new_weights, new_model_id 41 | else: 42 | return None, current_model_id 43 | 44 | 45 | def create_experiment_dir(args, prefix: str) -> None: 46 | if args.exp_path is None: 47 | args.exp_path = prefix + datetime.datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d-%H-%M-%S') 48 | args.exp_path = Path(args.exp_path) 49 | 50 | if args.exp_path.exists(): 51 | os.system(f'rm -rf {args.exp_path}') 52 | # raise FileExistsError(f'Experiment directory {str(args.exp_path)!r} already exists') 53 | 54 | args.exp_path.mkdir() 55 | 56 | 57 | def run_weights_subscriber(args, unknown_args): 58 | """Subscribe weights from Learner and save them locally""" 59 | context = zmq.Context() 60 | socket = context.socket(zmq.SUB) 61 | socket.connect(f'tcp://{args.ip}:{args.param_port}') 62 | socket.setsockopt_string(zmq.SUBSCRIBE, '') # Subscribe everything 63 | def recv_weight(): 64 | try: 65 | weights = socket.recv(flags=zmq.NOBLOCK) 66 | # Weights received 67 | with open(args.ckpt_path / f'{model_id}.pth', 'wb') as f: 68 | f.write(weights) 69 | del weights 70 | 71 | if model_id > args.num_saved_ckpt: 72 | os.remove(args.ckpt_path / f'{model_id - args.num_saved_ckpt}.pth') 73 | except zmq.Again: 74 | pass 75 | for model_id in count(1): # Starts from 1 76 | t = 3 77 | recv_weight_thread = threading.Timer(t, recv_weight) 78 | while True: 79 | recv_weight_thread.run() 80 | recv_weight_thread.finished.clear() 81 | # try: 82 | # weights = socket.recv(flags=zmq.NOBLOCK) 83 | # # Weights received 84 | # with open(args.ckpt_path / f'{model_id}.ckpt', 'wb') as f: 85 | # f.write(weights) 86 | 87 | # if model_id > args.num_saved_ckpt: 88 | # os.remove(args.ckpt_path / f'{model_id - args.num_saved_ckpt}.ckpt') 89 | # break 90 | # except zmq.Again: 91 | # pass 92 | 93 | # # For not cpu-intensive 94 | # time.sleep(1) 95 | 96 | -------------------------------------------------------------------------------- /create_container.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # docker network create --driver bridge --subnet=172.15.15.0/24 --gateway=172.15.15.1 guandanNet 3 | #docker run -itd --gpus all --network=guandanNet --ip 172.15.15.2 --name guandan_learner -v /home/luyd/guandan_tog/:/home/luyd/guandan_tog -w /home/luyd/guandan_tog/ nvcr.io/nvidia/tensorflow:22.02-tf1-py3 4 | for i in {3..13} 5 | do 6 | docker run -itd --network=guandanNet --ip 172.15.15.$i --name guandan_actor_$i -v /home/zhaoyp/log/log$i:/home/root/log -v /home/zhaoyp/guandan_tog:/home/zhaoyp/guandan_tog -w /home/zhaoyp/guandan_tog guandan_actor:v5 /bin/bash 7 | done 8 | for i in {14..43} 9 | do 10 | docker run -itd --network=guandanNet --ip 172.15.15.$i --name guandan_actor_$i -v /home/zhaoyp/guandan_tog:/home/zhaoyp/guandan_tog -w /home/zhaoyp/guandan_tog guandan_actor:v5 /bin/bash 11 | done 12 | 13 | -------------------------------------------------------------------------------- /learner_n/__pycache__/common.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_n/__pycache__/common.cpython-36.pyc -------------------------------------------------------------------------------- /learner_n/__pycache__/common.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_n/__pycache__/common.cpython-38.pyc -------------------------------------------------------------------------------- /learner_n/agents/__init__.py: -------------------------------------------------------------------------------- 1 | from core.registry import Registry 2 | 3 | agent_registry = Registry('Agent') 4 | 5 | from agents.dqn import * 6 | -------------------------------------------------------------------------------- /learner_n/agents/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_n/agents/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /learner_n/agents/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_n/agents/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /learner_n/agents/dqn/__init__.py: -------------------------------------------------------------------------------- 1 | from .guandan_agent import MCAgent 2 | 3 | __all__ = ['MCAgent'] 4 | -------------------------------------------------------------------------------- /learner_n/agents/dqn/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_n/agents/dqn/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /learner_n/agents/dqn/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_n/agents/dqn/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /learner_n/agents/dqn/__pycache__/dqn_agent.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_n/agents/dqn/__pycache__/dqn_agent.cpython-36.pyc -------------------------------------------------------------------------------- /learner_n/agents/dqn/__pycache__/dqn_agent.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_n/agents/dqn/__pycache__/dqn_agent.cpython-38.pyc -------------------------------------------------------------------------------- /learner_n/agents/dqn/__pycache__/guandan_agent.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_n/agents/dqn/__pycache__/guandan_agent.cpython-36.pyc -------------------------------------------------------------------------------- /learner_n/agents/dqn/__pycache__/guandan_agent.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_n/agents/dqn/__pycache__/guandan_agent.cpython-38.pyc -------------------------------------------------------------------------------- /learner_n/agents/dqn/guandan_agent.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | import models.utils as utils 4 | import numpy as np 5 | import tensorflow as tf 6 | from agents import agent_registry 7 | from core import Agent 8 | from tensorflow.train import AdamOptimizer 9 | 10 | 11 | @agent_registry.register('MC') 12 | class MCAgent(Agent): 13 | def __init__(self, model_cls, observation_space, action_space, config=None, lr=0.001, 14 | *args, **kwargs): 15 | # Define parameters 16 | self.lr = lr 17 | self.lamda = 0.65 18 | 19 | self.policy_model = None 20 | self.loss = None 21 | self.train_q = None 22 | 23 | self.target_ph = utils.placeholder(shape=(1)) 24 | self.old_q = utils.placeholder(shape=(1)) 25 | 26 | super(MCAgent, self).__init__(model_cls, observation_space, action_space, config, *args, **kwargs) 27 | 28 | def build(self) -> None: 29 | self.policy_model = self.model_instances[0] 30 | # cliped_q = tf.clip_by_value(self.old_q / self.policy_model.values, 1-self.lamda, 1+self.lamda) 31 | # self.loss = tf.reduce_mean((cliped_q - self.target_ph) ** 2) 32 | self.loss = tf.reduce_mean((self.policy_model.values - self.target_ph) ** 2) 33 | self.train_q = tf.train.RMSPropOptimizer(learning_rate=self.lr, epsilon=1e-5).minimize(self.loss) 34 | self.policy_model.sess.run(tf.global_variables_initializer()) 35 | 36 | 37 | def learn(self, training_data: Dict[str, np.ndarray], *args, **kwargs) -> None: 38 | x_no_action, action, q, reward = [training_data[key] for key in ['x_no_action', 'action', 'q', 'reward']] 39 | x_batch = np.concatenate([x_no_action, action], axis=-1) 40 | 41 | _, loss, values = self.policy_model.sess.run([self.train_q, self.loss, self.policy_model.values], 42 | feed_dict={ 43 | self.policy_model.x_ph: x_batch, 44 | self.old_q: q, 45 | self.target_ph: reward}) 46 | return { 47 | 'loss': loss, 48 | 'values': values 49 | } 50 | 51 | def set_weights(self, weights, *args, **kwargs) -> None: 52 | self.policy_model.set_weights(weights) 53 | 54 | def get_weights(self, *args, **kwargs) -> Any: 55 | return self.policy_model.get_weights() 56 | 57 | def save(self, path, *args, **kwargs) -> None: 58 | self.policy_model.save(path) 59 | 60 | def load(self, path, *args, **kwargs) -> None: 61 | self.policy_model.load(path) 62 | -------------------------------------------------------------------------------- /learner_n/build/conda/env_linux.yaml: -------------------------------------------------------------------------------- 1 | name: framework 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=4.5=1_gnu 7 | - _tflow_select=2.1.0=gpu 8 | - absl-py=0.13.0=py36h06a4308_0 9 | - astor=0.8.1=py36h06a4308_0 10 | - blas=1.0=mkl 11 | - c-ares=1.17.1=h27cfd23_0 12 | - ca-certificates=2021.7.5=h06a4308_1 13 | - certifi=2021.5.30=py36h06a4308_0 14 | - coverage=5.5=py36h27cfd23_2 15 | - cudatoolkit=10.0.130=0 16 | - cudnn=7.6.5=cuda10.0_0 17 | - cupti=10.0.130=0 18 | - cython=0.29.24=py36h295c915_0 19 | - gast=0.2.2=py36_0 20 | - google-pasta=0.2.0=py_0 21 | - grpcio=1.36.1=py36h2157cd5_1 22 | - h5py=2.10.0=py36hd6299e0_1 23 | - hdf5=1.10.6=hb1b8bf9_0 24 | - importlib-metadata=3.10.0=py36h06a4308_0 25 | - intel-openmp=2021.3.0=h06a4308_3350 26 | - keras-applications=1.0.8=py_1 27 | - keras-preprocessing=1.1.2=pyhd3eb1b0_0 28 | - ld_impl_linux-64=2.35.1=h7274673_9 29 | - libffi=3.3=he6710b0_2 30 | - libgcc-ng=9.3.0=h5101ec6_17 31 | - libgfortran-ng=7.5.0=ha8ba4b0_17 32 | - libgfortran4=7.5.0=ha8ba4b0_17 33 | - libgomp=9.3.0=h5101ec6_17 34 | - libprotobuf=3.17.2=h4ff587b_1 35 | - libstdcxx-ng=9.3.0=hd4cf53a_17 36 | - markdown=3.3.4=py36h06a4308_0 37 | - mkl=2020.2=256 38 | - mkl-service=2.3.0=py36he8ac12f_0 39 | - mkl_fft=1.3.0=py36h54f3939_0 40 | - mkl_random=1.1.1=py36h0573a6f_0 41 | - ncurses=6.2=he6710b0_1 42 | - numpy=1.19.2=py36h54aff64_0 43 | - numpy-base=1.19.2=py36hfa32c7d_0 44 | - openssl=1.1.1k=h27cfd23_0 45 | - opt_einsum=3.3.0=pyhd3eb1b0_1 46 | - pip=21.0.1=py36h06a4308_0 47 | - protobuf=3.17.2=py36h295c915_0 48 | - python=3.6.13=h12debd9_1 49 | - readline=8.1=h27cfd23_0 50 | - scipy=1.5.2=py36h0b6359f_0 51 | - setuptools=52.0.0=py36h06a4308_0 52 | - six=1.16.0=pyhd3eb1b0_0 53 | - sqlite=3.36.0=hc218d9a_0 54 | - tensorboard=1.15.0=pyhb230dea_0 55 | - tensorflow=1.15.0=gpu_py36h5a509aa_0 56 | - tensorflow-base=1.15.0=gpu_py36h9dcbed7_0 57 | - tensorflow-estimator=1.15.1=pyh2649769_0 58 | - tensorflow-gpu=1.15.0=h0d30ee6_0 59 | - termcolor=1.1.0=py36h06a4308_1 60 | - tk=8.6.10=hbc83047_0 61 | - typing_extensions=3.10.0.0=pyh06a4308_0 62 | - webencodings=0.5.1=py36_1 63 | - werkzeug=0.16.1=py_0 64 | - wheel=0.37.0=pyhd3eb1b0_0 65 | - wrapt=1.12.1=py36h7b6447c_1 66 | - xz=5.2.5=h7b6447c_0 67 | - zipp=3.5.0=pyhd3eb1b0_0 68 | - zlib=1.2.11=h7b6447c_3 69 | - pip: 70 | - atari-py==0.2.6 71 | - cffi==1.14.6 72 | - cloudpickle==1.6.0 73 | - cmake==3.21.1.post1 74 | - dataclasses==0.8 75 | - gym==0.18.3 76 | - horovod==0.22.1 77 | - opencv-python==4.5.3.56 78 | - pillow==8.2.0 79 | - psutil==5.8.0 80 | - pyarrow==5.0.0 81 | - pycparser==2.20 82 | - pyglet==1.5.15 83 | - pyyaml==5.4.1 84 | - pyzmq==22.2.1 85 | - zmq==0.0.0 86 | -------------------------------------------------------------------------------- /learner_n/build/docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # ================================================================== 2 | # module list 3 | # ------------------------------------------------------------------ 4 | # python 3.6 (apt) 5 | # tensorflow 1.15 (pip) 6 | # ================================================================== 7 | 8 | FROM tensorflow/tensorflow:1.15.4-gpu 9 | ENV LANG C.UTF-8 10 | RUN APT_INSTALL="apt-get install -y --no-install-recommends" && \ 11 | PIP_INSTALL="pip --no-cache-dir install --upgrade" && \ 12 | GIT_CLONE="git clone --depth 10" && \ 13 | 14 | apt-get update && \ 15 | 16 | # ================================================================== 17 | # tools 18 | # ------------------------------------------------------------------ 19 | 20 | DEBIAN_FRONTEND=noninteractive $APT_INSTALL \ 21 | apt-utils \ 22 | build-essential \ 23 | ca-certificates \ 24 | openssh-server \ 25 | net-tools \ 26 | iputils-ping \ 27 | cmake \ 28 | wget \ 29 | git \ 30 | vim \ 31 | libssl-dev \ 32 | libxss1 \ 33 | libgl1-mesa-glx \ 34 | htop \ 35 | curl \ 36 | unzip \ 37 | unrar \ 38 | && \ 39 | 40 | # ================================================================== 41 | # python 42 | # ------------------------------------------------------------------ 43 | 44 | DEBIAN_FRONTEND=noninteractive $APT_INSTALL \ 45 | software-properties-common \ 46 | && \ 47 | add-apt-repository ppa:deadsnakes/ppa && \ 48 | apt-get update && \ 49 | DEBIAN_FRONTEND=noninteractive $APT_INSTALL \ 50 | python3.6 \ 51 | python3.6-dev \ 52 | python3-distutils-extra \ 53 | && \ 54 | wget -O ~/get-pip.py \ 55 | https://bootstrap.pypa.io/get-pip.py && \ 56 | python3.6 ~/get-pip.py && \ 57 | rm -r -f /usr/local/bin/python* && \ 58 | ln -s /usr/bin/python3.6 /usr/local/bin/python3 && \ 59 | ln -s /usr/bin/python3.6 /usr/local/bin/python && \ 60 | $PIP_INSTALL \ 61 | psutil \ 62 | numpy \ 63 | scipy \ 64 | pandas \ 65 | cloudpickle \ 66 | Cython \ 67 | tqdm \ 68 | && \ 69 | 70 | # ================================================================== 71 | # tensorflow 72 | # ------------------------------------------------------------------ 73 | 74 | $PIP_INSTALL \ 75 | tensorflow-gpu==1.15.4 \ 76 | tensorflow-probability~=0.7 \ 77 | && \ 78 | 79 | # ================================================================== 80 | # framework 81 | # ------------------------------------------------------------------ 82 | 83 | $PIP_INSTALL \ 84 | pyarrow \ 85 | pyzmq \ 86 | opencv-python \ 87 | atari-py \ 88 | gym \ 89 | horovod \ 90 | psutil \ 91 | influxdb \ 92 | && \ 93 | 94 | # ================================================================== 95 | # config & cleanup 96 | # ------------------------------------------------------------------ 97 | 98 | ldconfig && \ 99 | apt-get clean && \ 100 | apt-get autoremove && \ 101 | rm -rf /var/lib/apt/lists/* /tmp/* ~/* 102 | 103 | EXPOSE 6006 -------------------------------------------------------------------------------- /learner_n/common.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import time 3 | import warnings 4 | from pathlib import Path 5 | 6 | import yaml 7 | 8 | from agents import agent_registry 9 | from core import Agent 10 | from models import model_registry 11 | 12 | 13 | def get_agent(args, unknown_args): 14 | model_cls = model_registry.get(args.model) 15 | agent_cls = agent_registry.get(args.alg) 16 | agent = agent_cls(model_cls, args.observation_space, args.action_space, args.agent_config, **unknown_args) 17 | return agent 18 | 19 | 20 | def load_yaml_config(args, role_type: str) -> None: 21 | if role_type not in {'actor', 'learner'}: 22 | raise ValueError('Invalid role type') 23 | 24 | # Load config file 25 | if args.config is not None: 26 | with open(args.config) as f: 27 | config = yaml.load(f, Loader=yaml.FullLoader) 28 | else: 29 | config = None 30 | 31 | if config is not None and isinstance(config, dict): 32 | if role_type in config: 33 | for k, v in config[role_type].items(): 34 | if k in args: 35 | setattr(args, k, v) 36 | else: 37 | warnings.warn(f"Invalid config item '{k}' ignored", RuntimeWarning) 38 | args.agent_config = config['agent'] if 'agent' in config else None 39 | else: 40 | args.agent_config = None 41 | 42 | 43 | def save_yaml_config(config_path: Path, args, role_type: str, agent: Agent) -> None: 44 | class Dumper(yaml.Dumper): 45 | def increase_indent(self, flow=False, *_, **__): 46 | return super().increase_indent(flow=flow, indentless=False) 47 | 48 | if role_type not in {'actor', 'learner'}: 49 | raise ValueError('Invalid role type') 50 | 51 | with open(config_path, 'w') as f: 52 | args_config = {k: v for k, v in vars(args).items() if 53 | not k.endswith('path') and k != 'agent_config' and k != 'config'} 54 | yaml.dump({role_type: args_config}, f, sort_keys=False, Dumper=Dumper) 55 | f.write('\n') 56 | yaml.dump({'agent': agent.export_config()}, f, sort_keys=False, Dumper=Dumper) 57 | 58 | 59 | def create_experiment_dir(args, prefix: str) -> None: 60 | if args.exp_path is None: 61 | args.exp_path = prefix + datetime.datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d-%H-%M-%S') 62 | args.exp_path = Path(args.exp_path) 63 | 64 | if args.exp_path.exists(): 65 | raise FileExistsError(f'Experiment directory {str(args.exp_path)!r} already exists') 66 | 67 | args.exp_path.mkdir() 68 | -------------------------------------------------------------------------------- /learner_n/core/__init__.py: -------------------------------------------------------------------------------- 1 | from .agent import Agent 2 | from .env import Env 3 | from .mem_pool import MemPool 4 | from .model import Model 5 | from .registry import Registry 6 | -------------------------------------------------------------------------------- /learner_n/core/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_n/core/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /learner_n/core/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_n/core/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /learner_n/core/__pycache__/agent.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_n/core/__pycache__/agent.cpython-36.pyc -------------------------------------------------------------------------------- /learner_n/core/__pycache__/agent.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_n/core/__pycache__/agent.cpython-38.pyc -------------------------------------------------------------------------------- /learner_n/core/__pycache__/env.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_n/core/__pycache__/env.cpython-36.pyc -------------------------------------------------------------------------------- /learner_n/core/__pycache__/env.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_n/core/__pycache__/env.cpython-38.pyc -------------------------------------------------------------------------------- /learner_n/core/__pycache__/mem_pool.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_n/core/__pycache__/mem_pool.cpython-36.pyc -------------------------------------------------------------------------------- /learner_n/core/__pycache__/mem_pool.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_n/core/__pycache__/mem_pool.cpython-38.pyc -------------------------------------------------------------------------------- /learner_n/core/__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_n/core/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /learner_n/core/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_n/core/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /learner_n/core/__pycache__/registry.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_n/core/__pycache__/registry.cpython-36.pyc -------------------------------------------------------------------------------- /learner_n/core/__pycache__/registry.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_n/core/__pycache__/registry.cpython-38.pyc -------------------------------------------------------------------------------- /learner_n/core/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_n/core/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /learner_n/core/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_n/core/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /learner_n/core/agent.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from abc import ABC, abstractmethod 3 | from pathlib import Path 4 | from typing import Any, Type, Union, Dict, List, Tuple 5 | 6 | import numpy as np 7 | 8 | from .model import Model 9 | from .utils import get_config_params 10 | 11 | 12 | class Agent(ABC): 13 | def __init__(self, model_cls: Type[Model], observation_space: Any, action_space: Any, config: dict = None, 14 | *args, **kwargs) -> None: 15 | """ 16 | This method MUST be called between (0.) and (4.) in subclasses for doing initialization works 17 | 18 | 0. [IN '__init__' of SUBCLASSES] Define parameters, model configurations and other related variables 19 | 1. If 'config' is not 'None', set specified configuration parameters (which appear after 'config') for agent or 20 | specified configurations for model 21 | 2. Initialize model instances 22 | 3. Build training part of computational graph 23 | 4. [IN '__init__' of SUBCLASSES] Do other operations if necessary 24 | 25 | :param model_cls: Model class that agent adopts 26 | :param observation_space: Env observation space 27 | :param action_space: Env action space 28 | :param config: Configurations for agent and models 29 | :param args: Positional configurations for agent only (ignored if specified in 'config') 30 | :param kwargs: Keyword configurations for agent only (ignored if specified in 'config') 31 | """ 32 | self.model_cls = model_cls 33 | self.observation_space = observation_space 34 | self.action_space = action_space 35 | 36 | # 1. Set configurations 37 | if config is not None: 38 | self.load_config(config) 39 | 40 | # 2. Initialize instances of 'model_cls' 41 | self.model_instances = None 42 | self._init_model_instances(config) 43 | 44 | # 3. Build training part of computational graph 45 | self.build() 46 | 47 | @abstractmethod 48 | def build(self) -> None: 49 | """Build computational graph for training""" 50 | pass 51 | 52 | @abstractmethod 53 | def set_weights(self, *args, **kwargs) -> None: 54 | pass 55 | 56 | @abstractmethod 57 | def get_weights(self, *args, **kwargs) -> Any: 58 | pass 59 | 60 | @abstractmethod 61 | def save(self, path: Path, *args, **kwargs) -> None: 62 | """Save the checkpoint file""" 63 | pass 64 | 65 | @abstractmethod 66 | def load(self, path: Path, *args, **kwargs) -> None: 67 | """Load the checkpoint file""" 68 | pass 69 | 70 | @abstractmethod 71 | def learn(self, training_data: Dict[str, np.ndarray], *args, **kwargs) -> Union[Dict[str, float], None]: 72 | """ 73 | Train the agent with data generated by 'prepare_update' 74 | :param training_data: A dictionary of lists of training_data, such as: 75 | {'state': [[1, 2], [3, 4]], 'action': [1, 0], 'value': [0.1, 0.3]} 76 | :param args: Optional positional arguments 77 | :param kwargs: Optional keyword arguments 78 | :return: Training statistics 79 | """ 80 | pass 81 | 82 | def export_config(self) -> dict: 83 | """Export dictionary as configurations""" 84 | param_dict = {p: getattr(self, p) for p in get_config_params(self)} 85 | 86 | if len(self.model_instances) == 1: 87 | model_config = self.model_instances[0].export_config() 88 | else: 89 | model_config = [x.export_config() for x in self.model_instances] 90 | param_dict.update({'model': model_config}) 91 | 92 | return param_dict 93 | 94 | def load_config(self, config: dict) -> None: 95 | """Load dictionary as configurations and initialize model instances""" 96 | for key, val in config.items(): 97 | if key in get_config_params(self): 98 | self.__dict__[key] = val 99 | elif key != 'model': 100 | warnings.warn(f"Invalid config item '{key}' ignored", RuntimeWarning) 101 | 102 | def predict(self, state: Any, *args, **kwargs) -> Any: 103 | """Get the action distribution at specific state""" 104 | return self.model_instances[0].forward(state, *args, **kwargs) 105 | 106 | def policy(self, state: Any, *args, **kwargs) -> Any: 107 | """Choose action during exploitation""" 108 | return np.argmax(self.predict(state, *args, **kwargs)[0]) 109 | 110 | def sample(self, state: Any, *args, **kwargs) -> Tuple[Any, Dict]: 111 | """Return action and other information (value, distribution et al) during exploration/sampling""" 112 | p = self.predict(state, *args, **kwargs)[0] 113 | return np.random.choice(len(p), p=p), {} 114 | 115 | def _init_model_instances(self, config: Union[dict, None]) -> None: 116 | """Initialize model instances""" 117 | self.model_instances = [] 118 | 119 | def create_model_instance(_c: dict): 120 | ret = {} 121 | for k, v in _c.items(): 122 | if k in valid_config: 123 | ret[k] = v 124 | else: 125 | warnings.warn(f"Invalid config item '{k}' ignored", RuntimeWarning) 126 | self.model_instances.append(self.model_cls(self.observation_space, self.action_space, **ret)) 127 | 128 | if config is not None and 'model' in config: 129 | model_config = config['model'] 130 | valid_config = get_config_params(self.model_cls) 131 | 132 | if isinstance(model_config, list): 133 | for _, c in enumerate(model_config): 134 | create_model_instance(c) 135 | elif isinstance(model_config, dict): 136 | create_model_instance(model_config) 137 | else: 138 | self.model_instances.append(self.model_cls(self.observation_space, self.action_space)) 139 | -------------------------------------------------------------------------------- /learner_n/core/env.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any, Tuple 3 | 4 | 5 | class Env(ABC): 6 | def __init__(self, *args, **kwargs) -> None: 7 | pass 8 | 9 | @abstractmethod 10 | def step(self, action: Any, *args, **kwargs) -> Tuple[Any, Any, Any, Any]: 11 | pass 12 | 13 | @abstractmethod 14 | def reset(self, *args, **kwargs) -> Any: 15 | pass 16 | 17 | @abstractmethod 18 | def get_action_space(self) -> Any: 19 | pass 20 | 21 | @abstractmethod 22 | def get_observation_space(self) -> Any: 23 | pass 24 | 25 | @abstractmethod 26 | def calc_reward(self, *args, **kwargs) -> Any: 27 | """Reshape rewards""" 28 | pass 29 | 30 | @abstractmethod 31 | def render(self, *args, **kwargs) -> None: 32 | pass 33 | 34 | @abstractmethod 35 | def close(self) -> None: 36 | pass 37 | -------------------------------------------------------------------------------- /learner_n/core/mem_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import time 3 | from collections import defaultdict, deque 4 | from multiprocessing.managers import BaseManager 5 | from typing import Dict, List 6 | 7 | import numpy as np 8 | 9 | __all__ = ['MemPool', 'MultiprocessingMemPool', 'MemPoolManager'] 10 | 11 | 12 | class MemPool: 13 | def __init__(self, capacity: int = None, keys: List[str] = None) -> None: 14 | self._keys = keys 15 | if keys is None: 16 | self.data = defaultdict(lambda: deque(maxlen=capacity)) 17 | else: 18 | self.data = {key: deque(maxlen=capacity) for key in keys} 19 | 20 | def push(self, data: Dict[str, np.ndarray]) -> None: 21 | """Push data into memory pool""" 22 | for key, value in data.items(): 23 | self.data[key].extend(value) 24 | 25 | if self._keys is None: 26 | self._keys = list(self.data.keys()) 27 | 28 | def sample(self, size: int = -1) -> Dict[str, np.ndarray]: 29 | """ 30 | Sample training data from memory pool 31 | :param size: The number of sample data, default '-1' that indicates all data 32 | :return: The sampled and concatenated training data 33 | """ 34 | 35 | num = len(self) 36 | indices = list(range(num)) 37 | if 0 < size < num: 38 | indices = random.sample(indices, size) 39 | 40 | result = {} 41 | for key in self._keys: 42 | result[key] = np.stack([self.data[key][idx] for idx in indices]) 43 | return result 44 | 45 | def clear(self) -> None: 46 | """Clear all data""" 47 | for key in self._keys: 48 | self.data[key].clear() 49 | 50 | def __len__(self): 51 | if self._keys is None: 52 | return 0 53 | return len(self.data[self._keys[0]]) 54 | 55 | 56 | class MultiprocessingMemPool(MemPool): 57 | def __init__(self, capacity: int = None, keys: List[str] = None) -> None: 58 | super().__init__(capacity, keys) 59 | 60 | self._receiving_data_throughput = None 61 | self._consuming_data_throughput = None 62 | 63 | def push(self, data: Dict[str, np.ndarray]) -> None: 64 | super().push(data) 65 | 66 | if self._receiving_data_throughput is not None: 67 | self._receiving_data_throughput += len(data[self._keys[0]]) 68 | 69 | def sample(self, size: int = -1) -> Dict[str, np.ndarray]: 70 | data = super().sample(size) 71 | 72 | if self._consuming_data_throughput is not None: 73 | self._consuming_data_throughput += len(data[self._keys[0]]) 74 | 75 | # super().clear() 76 | 77 | return data 78 | 79 | def clear(self) -> None: 80 | super().clear() 81 | 82 | self._receiving_data_throughput = None 83 | self._consuming_data_throughput = None 84 | 85 | def _get_receiving_data_throughput(self): 86 | return self._receiving_data_throughput 87 | 88 | def _get_consuming_data_throughput(self): 89 | return self._consuming_data_throughput 90 | 91 | def _reset_receiving_data_throughput(self): 92 | self._receiving_data_throughput = 0 93 | 94 | def _reset_consuming_data_throughput(self): 95 | self._consuming_data_throughput = 0 96 | 97 | @classmethod 98 | def record_throughput(cls, obj, interval=10): 99 | """Print receiving and consuming periodically""" 100 | 101 | while True: 102 | obj._reset_receiving_data_throughput() 103 | obj._reset_consuming_data_throughput() 104 | 105 | time.sleep(interval) 106 | 107 | print(f'Receiving FPS: {obj._get_receiving_data_throughput() / interval:.2f}, ' 108 | f'Consuming FPS: {obj._get_consuming_data_throughput() / interval:.2f}') 109 | 110 | 111 | class MemPoolManager(BaseManager): 112 | pass 113 | 114 | 115 | MemPoolManager.register('MemPool', MultiprocessingMemPool, 116 | exposed=['__len__', 'push', 'sample', 'clear', '_get_receiving_data_throughput', 117 | '_get_consuming_data_throughput', '_reset_receiving_data_throughput', 118 | '_reset_consuming_data_throughput']) 119 | -------------------------------------------------------------------------------- /learner_n/core/model.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from pathlib import Path 3 | from typing import Any 4 | 5 | from .utils import get_config_params 6 | 7 | 8 | class Model(ABC): 9 | def __init__(self, observation_space: Any, action_space: Any, config: dict = None, model_id: str = '0', 10 | *args, **kwargs) -> None: 11 | """ 12 | This method MUST be called after (0.) in subclasses 13 | 14 | 0. [IN '__init__' of SUBCLASSES] Define parameters, layers, tensors and other related variables 15 | 1. If 'config' is not 'None', set specified configuration parameters (which appear after 'config') 16 | 2. Build model 17 | 18 | :param model_id: The identifier of the model 19 | :param config: Configurations of hyper-parameters 20 | :param args: Positional configurations (ignored if specified in 'config') 21 | :param kwargs: Keyword configurations (ignored if specified in 'config') 22 | """ 23 | self.observation_space = observation_space 24 | self.action_space = action_space 25 | self.model_id = model_id 26 | self.config = config 27 | 28 | # 1. Set configurations 29 | if config is not None: 30 | self.load_config(config) 31 | 32 | # 2. Build up model 33 | self.build() 34 | 35 | @abstractmethod 36 | def build(self, *args, **kwargs) -> None: 37 | """Build the computational graph""" 38 | pass 39 | 40 | @abstractmethod 41 | def set_weights(self, weights: Any, *args, **kwargs) -> None: 42 | pass 43 | 44 | @abstractmethod 45 | def get_weights(self, *args, **kwargs) -> Any: 46 | pass 47 | 48 | @abstractmethod 49 | def forward(self, states: Any, *args, **kwargs) -> Any: 50 | pass 51 | 52 | @abstractmethod 53 | def save(self, path: Path, *args, **kwargs) -> None: 54 | pass 55 | 56 | @abstractmethod 57 | def load(self, path: Path, *args, **kwargs) -> None: 58 | pass 59 | 60 | def export_config(self) -> dict: 61 | """Export dictionary as configurations""" 62 | config_params = get_config_params(self) 63 | 64 | return {p: getattr(self, p) for p in config_params} 65 | 66 | def load_config(self, config: dict) -> None: 67 | """Load dictionary as configurations and build model""" 68 | for key, val in config.items(): 69 | if key in get_config_params(Model.__init__): 70 | self.__dict__[key] = val 71 | -------------------------------------------------------------------------------- /learner_n/core/registry.py: -------------------------------------------------------------------------------- 1 | class Registry: 2 | """A registry to map strings to classes""" 3 | 4 | def __init__(self, name: str) -> None: 5 | self._name = name 6 | self._obj_map = {} 7 | 8 | def do_register(self, name, cls): 9 | assert name not in self._obj_map, f'An object named {name!r} was already registered in {self._name!r} registry!' 10 | self._obj_map[name] = cls 11 | 12 | def register(self, name): 13 | def _register(cls): 14 | self.do_register(name, cls) 15 | return cls 16 | 17 | return _register 18 | 19 | def get(self, name): 20 | ret = self._obj_map.get(name) 21 | if ret is None: 22 | raise KeyError(f'No object named {name!r} found in {self._name!r} registry!') 23 | return ret 24 | -------------------------------------------------------------------------------- /learner_n/core/utils.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from typing import List 3 | 4 | 5 | def get_config_params(obj_or_cls) -> List[str]: 6 | """ 7 | Return configurable parameters in 'Agent.__init__' and 'Model.__init__' which appear after 'config' 8 | :param obj_or_cls: An instance of 'Agent' / 'Model' OR their corresponding classes (NOT base classes) 9 | :return: A list of configurable parameters 10 | """ 11 | import core # Import inside function to avoid cyclic import 12 | 13 | if inspect.isclass(obj_or_cls): 14 | if not issubclass(obj_or_cls, core.Agent) and not issubclass(obj_or_cls, core.Model): 15 | raise ValueError("Only accepts subclasses of 'Agent' or 'Model'") 16 | else: 17 | if not isinstance(obj_or_cls, core.Agent) and not isinstance(obj_or_cls, core.Model): 18 | raise ValueError("Only accepts instances 'Agent' or 'Model'") 19 | 20 | sig = list(inspect.signature(obj_or_cls.__init__).parameters.keys()) 21 | 22 | config_params = [] 23 | config_part = False 24 | for param in sig: 25 | if param == 'config': 26 | # Following parameters should be what we want 27 | config_part = True 28 | elif param in {'args', 'kwargs'}: 29 | pass 30 | elif config_part: 31 | config_params.append(param) 32 | 33 | return config_params 34 | -------------------------------------------------------------------------------- /learner_n/kill_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | for i in {3..83} 3 | do 4 | sshpass ssh root@172.15.15.$i "bash /home/zhaoyp/guandan/actor_n/kill.sh" 5 | done 6 | ps aux|grep python|grep -v grep|cut -c 9-15|xargs kill -9 -------------------------------------------------------------------------------- /learner_n/kill_learner.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # ps -ef | grep learner.py | awk '{print $2}' | xargs kill -9 3 | ps aux|grep python|grep -v grep|cut -c 9-15|xargs kill -9 4 | -------------------------------------------------------------------------------- /learner_n/learner.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import os 3 | import pickle 4 | import warnings 5 | from argparse import ArgumentParser 6 | from collections import defaultdict 7 | from multiprocessing import Process 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | import zmq 12 | from pyarrow import deserialize 13 | from tensorflow.keras.backend import set_session 14 | 15 | from common import (create_experiment_dir, get_agent, load_yaml_config, 16 | save_yaml_config) 17 | from core.mem_pool import MemPoolManager, MultiprocessingMemPool 18 | from utils import logger 19 | from utils.cmdline import parse_cmdline_kwargs 20 | 21 | warnings.filterwarnings("ignore") 22 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = '3' 23 | tf.logging.set_verbosity(tf.logging.ERROR) 24 | config = tf.ConfigProto() 25 | config.gpu_options.allow_growth = True 26 | set_session(tf.Session(config=config)) 27 | 28 | parser = ArgumentParser() 29 | parser.add_argument('--alg', type=str, default='MC', help='The RL algorithm') 30 | parser.add_argument('--env', type=str, default='GuanDan', help='The game environment') 31 | parser.add_argument('--data_port', type=int, default=5000, help='Learner server port to receive training data') 32 | parser.add_argument('--param_port', type=int, default=5001, help='Learner server to publish model parameters') 33 | parser.add_argument('--model', type=str, default='guandan_model', help='Training model') 34 | parser.add_argument('--pool_size', type=int, default=1024, help='The max length of data pool') 35 | parser.add_argument('--batch_size', type=int, default=32, help='The batch size for training') 36 | parser.add_argument('--training_freq', type=int, default=2, 37 | help='How many receptions of new data are between each training, ' 38 | 'which can be fractional to represent more than one training per reception') 39 | parser.add_argument('--keep_training', type=bool, default=False, 40 | help="No matter whether new data is received recently, keep training as long as the data is enough " 41 | "and ignore `--training_freq`") 42 | parser.add_argument('--config', type=str, default=None, help='Directory to config file') 43 | parser.add_argument('--exp_path', type=str, default=None, help='Directory to save logging data and config file') 44 | parser.add_argument('--record_throughput_interval', type=int, default=5, 45 | help='The time interval between each throughput record') 46 | parser.add_argument('--num_envs', type=int, default=1, help='The number of environment copies') 47 | parser.add_argument('--ckpt_save_freq', type=int, default=3000, help='The number of updates between each weights saving') 48 | parser.add_argument('--ckpt_save_type', type=str, default='weight', help='Type of checkpoint file will be recorded : weight(smaller) or checkpoint(bigger') 49 | parser.add_argument('--observation_space', type=int, default=(567,), 50 | help='The YAML configuration file') 51 | parser.add_argument('--action_space', type=int, default=(5, 216), 52 | help='The YAML configuration file') 53 | parser.add_argument('--epsilon', type=float, default=0.01, 54 | help='Epsilon') 55 | 56 | def main(): 57 | # Parse input parameters 58 | args, unknown_args = parser.parse_known_args() 59 | unknown_args = parse_cmdline_kwargs(unknown_args) 60 | 61 | # Load config file 62 | load_yaml_config(args, 'learner') 63 | 64 | # Expose socket to actor(s) 65 | context = zmq.Context() 66 | weights_socket = context.socket(zmq.PUB) 67 | weights_socket.bind(f'tcp://*:{args.param_port}') 68 | 69 | agent = get_agent(args, unknown_args) 70 | # with open('./last.ckpt', 'rb') as f: 71 | # weight = pickle.load(f) 72 | # agent.set_weights(weight) 73 | # print('Fineturn Success') 74 | 75 | # Configure experiment directory 76 | create_experiment_dir(args, 'LEARNER-') 77 | save_yaml_config(args.exp_path / 'config.yaml', args, 'learner', agent) 78 | args.log_path = args.exp_path / 'log' 79 | args.ckpt_path = args.exp_path / 'ckpt' 80 | args.ckpt_path.mkdir() 81 | args.log_path.mkdir() 82 | 83 | logger.configure(str(args.log_path)) 84 | 85 | receiving_condition = multiprocessing.Condition() 86 | num_receptions = multiprocessing.Value('i', 0) 87 | 88 | # Start memory pool in another process 89 | manager = MemPoolManager() 90 | manager.start() 91 | mem_pool = manager.MemPool(capacity=args.pool_size) 92 | Process(target=recv_data, 93 | args=(args.data_port, mem_pool, receiving_condition, num_receptions, args.keep_training)).start() 94 | 95 | # Print throughput statistics 96 | Process(target=MultiprocessingMemPool.record_throughput, args=(mem_pool, args.record_throughput_interval)).start() 97 | 98 | model_save_freq = 0 99 | model_init_flag = 0 100 | log_times = 0 101 | while True: 102 | if model_init_flag == 0: 103 | weights_socket.send(pickle.dumps(agent.get_weights())) 104 | model_init_flag = 1 105 | 106 | if len(mem_pool) >= args.batch_size: 107 | # Sync weights to actor 108 | weights = agent.get_weights() 109 | weights_socket.send(pickle.dumps(weights)) 110 | 111 | if model_save_freq%args.ckpt_save_freq == 0: 112 | if args.ckpt_save_type == 'checkpoint': 113 | agent.save(args.ckpt_path / 'ckpt') 114 | elif args.ckpt_save_type == 'weight': 115 | with open(args.ckpt_path / f'adduniversal{model_save_freq}.ckpt', 'wb') as f: 116 | pickle.dump(weights, f) 117 | 118 | if args.keep_training: 119 | agent.learn(mem_pool.sample(size=args.batch_size)) 120 | else: 121 | with receiving_condition: 122 | while num_receptions.value < args.training_freq: 123 | receiving_condition.wait() 124 | # data = mem_pool.sample(size=args.batch_size) 125 | data = mem_pool.sample() 126 | num_receptions.value -= args.training_freq 127 | # Training 128 | stat = agent.learn(data) 129 | if log_times%1000 == 0: 130 | stats = defaultdict(list) 131 | for k, v in stat.items(): 132 | stats[k].append(v) 133 | stat = {k: np.array(v).mean() for k, v in stats.items()} 134 | if stat is not None: 135 | for k, v in stat.items(): 136 | logger.record_tabular(k, v) 137 | logger.dump_tabular() 138 | else: 139 | log_times += 1 140 | 141 | model_save_freq += 1 142 | 143 | 144 | def recv_data(data_port, mem_pool, receiving_condition, num_receptions, keep_training): 145 | context = zmq.Context() 146 | data_socket = context.socket(zmq.REP) 147 | data_socket.bind(f'tcp://*:{data_port}') 148 | 149 | while True: 150 | data: dict = deserialize(data_socket.recv()) 151 | data_socket.send(b'200') 152 | 153 | if keep_training: 154 | mem_pool.push(data) 155 | else: 156 | with receiving_condition: 157 | mem_pool.push(data) 158 | num_receptions.value += 1 159 | receiving_condition.notify() 160 | 161 | 162 | if __name__ == '__main__': 163 | main() 164 | -------------------------------------------------------------------------------- /learner_n/models/__init__.py: -------------------------------------------------------------------------------- 1 | from core.registry import Registry 2 | 3 | from models.tf_v1_model import TFV1Model 4 | 5 | model_registry = Registry('Model') 6 | 7 | from models.q_model import * 8 | -------------------------------------------------------------------------------- /learner_n/models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_n/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /learner_n/models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_n/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /learner_n/models/__pycache__/ac_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_n/models/__pycache__/ac_model.cpython-36.pyc -------------------------------------------------------------------------------- /learner_n/models/__pycache__/ac_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_n/models/__pycache__/ac_model.cpython-38.pyc -------------------------------------------------------------------------------- /learner_n/models/__pycache__/custom_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_n/models/__pycache__/custom_model.cpython-36.pyc -------------------------------------------------------------------------------- /learner_n/models/__pycache__/custom_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_n/models/__pycache__/custom_model.cpython-38.pyc -------------------------------------------------------------------------------- /learner_n/models/__pycache__/distributions.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_n/models/__pycache__/distributions.cpython-36.pyc -------------------------------------------------------------------------------- /learner_n/models/__pycache__/distributions.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_n/models/__pycache__/distributions.cpython-38.pyc -------------------------------------------------------------------------------- /learner_n/models/__pycache__/q_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_n/models/__pycache__/q_model.cpython-36.pyc -------------------------------------------------------------------------------- /learner_n/models/__pycache__/q_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_n/models/__pycache__/q_model.cpython-38.pyc -------------------------------------------------------------------------------- /learner_n/models/__pycache__/tf_v1_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_n/models/__pycache__/tf_v1_model.cpython-36.pyc -------------------------------------------------------------------------------- /learner_n/models/__pycache__/tf_v1_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_n/models/__pycache__/tf_v1_model.cpython-38.pyc -------------------------------------------------------------------------------- /learner_n/models/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_n/models/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /learner_n/models/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_n/models/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /learner_n/models/distributions.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class CategoricalPd: 5 | def __init__(self, logits): 6 | self.logits = logits 7 | 8 | def mode(self): 9 | return tf.argmax(self.logits, axis=-1) 10 | 11 | def logp(self, x): 12 | return -self.neglogp(x) 13 | 14 | def neglogp(self, x): 15 | # return tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=x) 16 | # Note: we can't use sparse_softmax_cross_entropy_with_logits because 17 | # the implementation does not allow second-order derivatives... 18 | if x.dtype in {tf.uint8, tf.int32, tf.int64}: 19 | # one-hot encoding 20 | x_shape_list = x.shape.as_list() 21 | logits_shape_list = self.logits.get_shape().as_list()[:-1] 22 | for xs, ls in zip(x_shape_list, logits_shape_list): 23 | if xs is not None and ls is not None: 24 | assert xs == ls, 'shape mismatch: {} in x vs {} in logits'.format(xs, ls) 25 | 26 | x = tf.one_hot(x, self.logits.get_shape().as_list()[-1]) 27 | else: 28 | # already encoded 29 | assert x.shape.as_list() == self.logits.shape.as_list() 30 | 31 | return tf.nn.softmax_cross_entropy_with_logits_v2(logits=self.logits, labels=x) 32 | 33 | def kl(self, other): 34 | a0 = self.logits - tf.reduce_max(self.logits, axis=-1, keepdims=True) 35 | a1 = other.logits - tf.reduce_max(other.logits, axis=-1, keepdims=True) 36 | ea0 = tf.exp(a0) 37 | ea1 = tf.exp(a1) 38 | z0 = tf.reduce_sum(ea0, axis=-1, keepdims=True) 39 | z1 = tf.reduce_sum(ea1, axis=-1, keepdims=True) 40 | p0 = ea0 / z0 41 | return tf.reduce_sum(p0 * (a0 - tf.log(z0) - a1 + tf.log(z1)), axis=-1) 42 | 43 | def entropy(self): 44 | a0 = self.logits - tf.reduce_max(self.logits, axis=-1, keepdims=True) 45 | ea0 = tf.exp(a0) 46 | z0 = tf.reduce_sum(ea0, axis=-1, keepdims=True) 47 | p0 = ea0 / z0 48 | return tf.reduce_sum(p0 * (tf.log(z0) - a0), axis=-1) 49 | 50 | def sample(self): 51 | u = tf.random_uniform(tf.shape(self.logits), dtype=self.logits.dtype) 52 | return tf.argmax(self.logits - tf.log(-tf.log(u)), axis=-1) 53 | -------------------------------------------------------------------------------- /learner_n/models/q_model.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any 3 | 4 | import tensorflow as tf 5 | 6 | import models.utils as utils 7 | from models import model_registry 8 | from models.tf_v1_model import TFV1Model 9 | 10 | __all__ = ['QModel', 'QMLPModel', 'QCNNModel', 'GDModel'] 11 | 12 | 13 | class QModel(TFV1Model, ABC): 14 | def __init__(self, observation_space, action_space, config=None, model_id='0', *args, **kwargs): 15 | with tf.variable_scope(model_id): 16 | self.x_ph = utils.placeholder(shape=observation_space) 17 | 18 | # 输出张量 19 | self.values = None 20 | 21 | # init中调用了build函数 22 | super(QModel, self).__init__(observation_space, action_space, config, model_id, scope=model_id, 23 | *args, **kwargs) 24 | 25 | # 参数初始化 26 | self.sess.run(tf.global_variables_initializer()) 27 | 28 | def forward(self, x_batch: Any, z: Any, *args, **kwargs) -> Any: 29 | return self.sess.run(self.values, feed_dict={self.x_ph: x_batch}) 30 | 31 | @abstractmethod 32 | def build(self, *args, **kwargs) -> None: 33 | pass 34 | 35 | 36 | @model_registry.register('guandan_model') 37 | class GDModel(QModel): 38 | def build(self) -> None: 39 | with tf.variable_scope(self.scope): 40 | with tf.variable_scope('v'): 41 | self.values = utils.mlp(self.x_ph, [512, 512, 512, 512, 512, 1], activation='tanh', 42 | output_activation=None) 43 | print('model build success') 44 | 45 | 46 | @model_registry.register('qmlp') 47 | class QMLPModel(QModel): 48 | def build(self) -> None: 49 | with tf.variable_scope(self.scope): 50 | with tf.variable_scope('q'): 51 | self.values = utils.mlp(self.x_ph, [24, 24, self.action_space], activation='relu', 52 | output_activation=None) 53 | 54 | 55 | @model_registry.register('qcnn') 56 | class QCNNModel(QModel): 57 | def build(self) -> None: 58 | with tf.variable_scope(self.scope): 59 | with tf.variable_scope('cnn_base'): 60 | layers = [{'filters': 16, 'kernel_size': 8, 'strides': 4, 'activation': 'relu'}, 61 | {'filters': 32, 'kernel_size': 4, 'strides': 2, 'activation': 'relu'}] 62 | feat = self.x_ph 63 | for layer in layers: 64 | feat = tf.layers.conv2d(feat, **layer) 65 | feat = tf.layers.flatten(feat) 66 | with tf.variable_scope('q'): 67 | self.values = utils.mlp(feat, [256, self.action_space], activation='relu', 68 | output_activation=None) 69 | -------------------------------------------------------------------------------- /learner_n/models/tf_v1_model.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from pathlib import Path 3 | from typing import Any 4 | 5 | import tensorflow as tf 6 | from core import Model 7 | from tensorflow.keras.backend import get_session 8 | 9 | 10 | class TFV1Model(Model, ABC): 11 | def __init__(self, observation_space: Any, action_space: Any, config=None, model_id='0', session=None, scope=None, 12 | *args, **kwargs): 13 | self.scope = scope 14 | 15 | # Initialize Tensorflow session 16 | if session is None: 17 | session = get_session() 18 | self.sess = session 19 | 20 | super(TFV1Model, self).__init__(observation_space, action_space, config, model_id, *args, **kwargs) 21 | 22 | # Build assignment ops 23 | self._weight_ph = None 24 | self._to_assign = None 25 | self._nodes = None 26 | self._build_assign() 27 | 28 | # Build saver 29 | self.saver = tf.train.Saver(tf.trainable_variables()) 30 | 31 | def set_weights(self, weights, *args, **kwargs) -> None: 32 | feed_dict = {self._weight_ph[var.name]: weight 33 | for (var, weight) in zip(tf.trainable_variables(self.scope), weights)} 34 | self.sess.run(self._nodes, feed_dict=feed_dict) 35 | 36 | def get_weights(self, *args, **kwargs) -> Any: 37 | return self.sess.run(tf.trainable_variables(self.scope)) 38 | 39 | def save(self, path: Path, *args, **kwargs) -> None: 40 | self.saver.save(self.sess, str(path)) 41 | 42 | def load(self, path: Path, *args, **kwargs) -> None: 43 | self.saver.restore(self.sess, str(path)) 44 | 45 | def _build_assign(self): 46 | self._weight_ph, self._to_assign = dict(), dict() 47 | variables = tf.trainable_variables(self.scope) 48 | for var in variables: 49 | self._weight_ph[var.name] = tf.placeholder(var.value().dtype, var.get_shape().as_list()) 50 | self._to_assign[var.name] = var.assign(self._weight_ph[var.name]) 51 | self._nodes = list(self._to_assign.values()) 52 | 53 | @abstractmethod 54 | def build(self, *args, **kwargs) -> None: 55 | pass 56 | 57 | @abstractmethod 58 | def forward(self, states: Any, *args, **kwargs) -> Any: 59 | pass 60 | -------------------------------------------------------------------------------- /learner_n/models/utils.py: -------------------------------------------------------------------------------- 1 | """Copied from https://github.com/openai/spinningup/blob/master/spinup/algos/tf1/ppo/core.py""" 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | EPS = 1e-8 6 | 7 | __all__ = ['conv', 'fc', 'conv_to_fc', 'placeholder', 'mlp', 'actor'] 8 | 9 | 10 | def ortho_init(scale=1.0): 11 | def _ortho_init(shape, dtype, partition_info=None): 12 | # lasagne ortho init for tf 13 | shape = tuple(shape) 14 | if len(shape) == 2: 15 | flat_shape = shape 16 | elif len(shape) == 4: # assumes NHWC 17 | flat_shape = (np.prod(shape[:-1]), shape[-1]) 18 | else: 19 | raise NotImplementedError 20 | a = np.random.normal(0.0, 1.0, flat_shape) 21 | u, _, v = np.linalg.svd(a, full_matrices=False) 22 | q = u if u.shape == flat_shape else v # pick the one with the correct shape 23 | q = q.reshape(shape) 24 | return (scale * q[:shape[0], :shape[1]]).astype(np.float32) 25 | 26 | return _ortho_init 27 | 28 | 29 | def conv(x, scope, *, nf, rf, stride, pad='VALID', init_scale=1.0, data_format='NHWC', one_dim_bias=False): 30 | if data_format == 'NHWC': 31 | channel_ax = 3 32 | strides = [1, stride, stride, 1] 33 | bshape = [1, 1, 1, nf] 34 | elif data_format == 'NCHW': 35 | channel_ax = 1 36 | strides = [1, 1, stride, stride] 37 | bshape = [1, nf, 1, 1] 38 | else: 39 | raise NotImplementedError 40 | bias_var_shape = [nf] if one_dim_bias else [1, nf, 1, 1] 41 | nin = x.get_shape()[channel_ax].value 42 | wshape = [rf, rf, nin, nf] 43 | with tf.variable_scope(scope): 44 | w = tf.get_variable("w", wshape, initializer=ortho_init(init_scale)) 45 | b = tf.get_variable("b", bias_var_shape, initializer=tf.constant_initializer(0.0)) 46 | if not one_dim_bias and data_format == 'NHWC': 47 | b = tf.reshape(b, bshape) 48 | return tf.nn.conv2d(x, w, strides=strides, padding=pad, data_format=data_format) + b 49 | 50 | 51 | def fc(x, scope, nh, *, init_scale=1.0, init_bias=0.0): 52 | with tf.variable_scope(scope): 53 | nin = x.get_shape()[1].value 54 | w = tf.get_variable("w", [nin, nh], initializer=ortho_init(init_scale)) 55 | b = tf.get_variable("b", [nh], initializer=tf.constant_initializer(init_bias)) 56 | return tf.matmul(x, w) + b 57 | 58 | 59 | def conv_to_fc(x): 60 | nh = np.prod([v.value for v in x.get_shape()[1:]]) 61 | x = tf.reshape(x, [-1, nh]) 62 | return x 63 | 64 | 65 | def combined_shape(length, shape=None): 66 | if shape is None: 67 | return (length,) 68 | return (length, shape) if np.isscalar(shape) else (length, *shape) 69 | 70 | 71 | def placeholder(dtype=tf.float32, shape=None): 72 | return tf.placeholder(dtype=dtype, shape=combined_shape(None, shape)) 73 | 74 | 75 | def gaussian_likelihood(x, mu, log_std): 76 | pre_sum = -0.5 * (((x - mu) / (tf.exp(log_std) + EPS)) ** 2 + 2 * log_std + np.log(2 * np.pi)) 77 | return tf.reduce_sum(pre_sum, axis=1) 78 | 79 | 80 | def mlp(x, hidden_sizes=(32,), activation=tf.tanh, output_activation=None): 81 | for h in hidden_sizes[:-1]: 82 | x = tf.layers.dense(x, units=h, activation=activation) 83 | return tf.layers.dense(x, units=hidden_sizes[-1], activation=output_activation) 84 | 85 | 86 | def categorical_policy(logits, action, act_dim): 87 | logp_all = tf.nn.log_softmax(logits) 88 | pi = tf.squeeze(tf.multinomial(logits, 1), axis=1) 89 | logp = tf.reduce_sum(tf.one_hot(action, depth=act_dim) * logp_all, axis=1) 90 | logp_pi = tf.reduce_sum(tf.one_hot(pi, depth=act_dim) * logp_all, axis=1) 91 | return pi, logp, logp_pi 92 | 93 | 94 | def gaussian_policy(mu, action, act_dim): 95 | log_std = tf.get_variable(name='log_std', initializer=-0.5 * np.ones(act_dim, dtype=np.float32)) 96 | std = tf.exp(log_std) 97 | pi = mu + tf.random_normal(tf.shape(mu)) * std 98 | logp = gaussian_likelihood(action, mu, log_std) 99 | logp_pi = gaussian_likelihood(pi, mu, log_std) 100 | return pi, logp, logp_pi 101 | 102 | 103 | _mode = ['categorical', 'gaussian'] 104 | 105 | 106 | def actor(logits, action, act_dim, mode='categorical'): 107 | assert mode in _mode 108 | policy = eval(mode + '_policy') 109 | 110 | pi, logp, logp_pi = policy(logits, action, act_dim) 111 | return pi, logp, logp_pi 112 | -------------------------------------------------------------------------------- /learner_n/start_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # sshpass ssh root@172.15.15.3 "bash /home/luyd/guandan/actor_n/start.sh" 3 | # nohup python -u learner.py > ./learner_out.log 2>&1 & 4 | for i in {3..23} 5 | do 6 | sshpass ssh root@172.15.15.$i "bash /home/zhaoyp/guandan/actor_n/start.sh" 7 | echo $i 8 | sleep 0.1s 9 | done 10 | 11 | nohup python -u learner.py > ./learner_out.log 2>&1 & 12 | 13 | for i in {24..83} 14 | do 15 | sshpass ssh root@172.15.15.$i "bash /home/zhaoyp/guandan/actor_n/start.sh" 16 | echo $i 17 | sleep 0.1s 18 | done 19 | -------------------------------------------------------------------------------- /learner_n/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_n/utils/__init__.py -------------------------------------------------------------------------------- /learner_n/utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_n/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /learner_n/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_n/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /learner_n/utils/__pycache__/cmdline.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_n/utils/__pycache__/cmdline.cpython-36.pyc -------------------------------------------------------------------------------- /learner_n/utils/__pycache__/cmdline.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_n/utils/__pycache__/cmdline.cpython-38.pyc -------------------------------------------------------------------------------- /learner_n/utils/__pycache__/logger.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_n/utils/__pycache__/logger.cpython-36.pyc -------------------------------------------------------------------------------- /learner_n/utils/__pycache__/logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_n/utils/__pycache__/logger.cpython-38.pyc -------------------------------------------------------------------------------- /learner_n/utils/cmdline.py: -------------------------------------------------------------------------------- 1 | def parse_cmdline_kwargs(args): 2 | """ 3 | Copied from https://github.com/openai/baselines/blob/master/baselines/run.py 4 | convert a list of '='-spaced command-line arguments to a dictionary, evaluating python objects when possible 5 | """ 6 | 7 | def parse(v): 8 | 9 | assert isinstance(v, str) 10 | try: 11 | return eval(v) 12 | except (NameError, SyntaxError): 13 | return v 14 | 15 | return {k: parse(v) for k, v in parse_unknown_args(args).items()} 16 | 17 | 18 | def parse_unknown_args(args): 19 | """ 20 | Copied from https://github.com/openai/baselines/blob/master/baselines/common/cmd_util.py 21 | Parse arguments not consumed by arg parser into a dictionary 22 | """ 23 | retval = {} 24 | preceded_by_key = False 25 | for arg in args: 26 | if arg.startswith('--'): 27 | if '=' in arg: 28 | key = arg.split('=')[0][2:] 29 | value = arg.split('=')[1] 30 | retval[key] = value 31 | else: 32 | key = arg[2:] 33 | preceded_by_key = True 34 | elif preceded_by_key: 35 | retval[key] = arg 36 | preceded_by_key = False 37 | 38 | return retval 39 | -------------------------------------------------------------------------------- /learner_n/utils/mpi_util.py: -------------------------------------------------------------------------------- 1 | """Copied from https://github.com/openai/baselines/blob/master/baselines/common/mpi_util.py""" 2 | import os 3 | import platform 4 | import shutil 5 | import subprocess 6 | import sys 7 | import warnings 8 | from collections import defaultdict 9 | 10 | import numpy as np 11 | 12 | try: 13 | from mpi4py import MPI 14 | except ImportError: 15 | MPI = None 16 | 17 | 18 | def sync_from_root(sess, variables, comm=None): 19 | """ 20 | Send the root node's parameters to every worker. 21 | Arguments: 22 | sess: the TensorFlow session. 23 | variables: all parameter variables including optimizer's 24 | """ 25 | if comm is None: comm = MPI.COMM_WORLD 26 | import tensorflow as tf 27 | values = comm.bcast(sess.run(variables)) 28 | sess.run([tf.assign(var, val) 29 | for (var, val) in zip(variables, values)]) 30 | 31 | 32 | def gpu_count(): 33 | """ 34 | Count the GPUs on this machine. 35 | """ 36 | if shutil.which('nvidia-smi') is None: 37 | return 0 38 | output = subprocess.check_output(['nvidia-smi', '--query-gpu=gpu_name', '--format=csv']) 39 | return max(0, len(output.split(b'\n')) - 2) 40 | 41 | 42 | def setup_mpi_gpus(): 43 | """ 44 | Set CUDA_VISIBLE_DEVICES to MPI rank if not already set 45 | """ 46 | if 'CUDA_VISIBLE_DEVICES' not in os.environ: 47 | if sys.platform == 'darwin': # This Assumes if you're on OSX you're just 48 | ids = [] # doing a smoke test and don't want GPUs 49 | else: 50 | lrank, _lsize = get_local_rank_size(MPI.COMM_WORLD) 51 | ids = [lrank] 52 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, ids)) 53 | 54 | 55 | def get_local_rank_size(comm): 56 | """ 57 | Returns the rank of each process on its machine 58 | The processes on a given machine will be assigned ranks 59 | 0, 1, 2, ..., N-1, 60 | where N is the number of processes on this machine. 61 | 62 | Useful if you want to assign one gpu per machine 63 | """ 64 | this_node = platform.node() 65 | ranks_nodes = comm.allgather((comm.Get_rank(), this_node)) 66 | node2rankssofar = defaultdict(int) 67 | local_rank = None 68 | for (rank, node) in ranks_nodes: 69 | if rank == comm.Get_rank(): 70 | local_rank = node2rankssofar[node] 71 | node2rankssofar[node] += 1 72 | assert local_rank is not None 73 | return local_rank, node2rankssofar[this_node] 74 | 75 | 76 | def share_file(comm, path): 77 | """ 78 | Copies the file from rank 0 to all other ranks 79 | Puts it in the same place on all machines 80 | """ 81 | localrank, _ = get_local_rank_size(comm) 82 | if comm.Get_rank() == 0: 83 | with open(path, 'rb') as fh: 84 | data = fh.read() 85 | comm.bcast(data) 86 | else: 87 | data = comm.bcast(None) 88 | if localrank == 0: 89 | os.makedirs(os.path.dirname(path), exist_ok=True) 90 | with open(path, 'wb') as fh: 91 | fh.write(data) 92 | comm.Barrier() 93 | 94 | 95 | def dict_gather(comm, d, op='mean', assert_all_have_data=True): 96 | """ 97 | Perform a reduction operation over dicts 98 | """ 99 | if comm is None: return d 100 | alldicts = comm.allgather(d) 101 | size = comm.size 102 | k2li = defaultdict(list) 103 | for d in alldicts: 104 | for (k, v) in d.items(): 105 | k2li[k].append(v) 106 | result = {} 107 | for (k, li) in k2li.items(): 108 | if assert_all_have_data: 109 | assert len(li) == size, "only %i out of %i MPI workers have sent '%s'" % (len(li), size, k) 110 | if op == 'mean': 111 | result[k] = np.mean(li, axis=0) 112 | elif op == 'sum': 113 | result[k] = np.sum(li, axis=0) 114 | else: 115 | assert 0, op 116 | return result 117 | 118 | 119 | def mpi_weighted_mean(comm, local_name2valcount): 120 | """ 121 | Perform a weighted average over dicts that are each on a different node 122 | Input: local_name2valcount: dict mapping key -> (value, count) 123 | Returns: key -> mean 124 | """ 125 | all_name2valcount = comm.gather(local_name2valcount) 126 | if comm.rank == 0: 127 | name2sum = defaultdict(float) 128 | name2count = defaultdict(float) 129 | for n2vc in all_name2valcount: 130 | for (name, (val, count)) in n2vc.items(): 131 | try: 132 | val = float(val) 133 | except ValueError: 134 | if comm.rank == 0: 135 | warnings.warn('WARNING: tried to compute mean on non-float {}={}'.format(name, val)) 136 | else: 137 | name2sum[name] += val * count 138 | name2count[name] += count 139 | return {name: name2sum[name] / name2count[name] for name in name2sum} 140 | else: 141 | return {} 142 | -------------------------------------------------------------------------------- /learner_torch/__pycache__/common.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_torch/__pycache__/common.cpython-38.pyc -------------------------------------------------------------------------------- /learner_torch/__pycache__/mem_pool.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_torch/__pycache__/mem_pool.cpython-38.pyc -------------------------------------------------------------------------------- /learner_torch/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_torch/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /learner_torch/__pycache__/ppo.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_torch/__pycache__/ppo.cpython-38.pyc -------------------------------------------------------------------------------- /learner_torch/common.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import time 3 | import warnings 4 | from pathlib import Path 5 | import inspect 6 | 7 | from typing import List 8 | import yaml 9 | 10 | def load_yaml_config(args, role_type: str) -> None: 11 | if role_type not in {'actor', 'learner'}: 12 | raise ValueError('Invalid role type') 13 | 14 | # Load config file 15 | if args.config is not None: 16 | with open(args.config) as f: 17 | config = yaml.load(f, Loader=yaml.FullLoader) 18 | else: 19 | config = None 20 | 21 | if config is not None and isinstance(config, dict): 22 | if role_type in config: 23 | for k, v in config[role_type].items(): 24 | if k in args: 25 | setattr(args, k, v) 26 | else: 27 | warnings.warn(f"Invalid config item '{k}' ignored", RuntimeWarning) 28 | args.agent_config = config['agent'] if 'agent' in config else None 29 | else: 30 | args.agent_config = None 31 | 32 | 33 | def save_yaml_config(config_path: Path, args, role_type: str, agent) -> None: 34 | class Dumper(yaml.Dumper): 35 | def increase_indent(self, flow=False, *_, **__): 36 | return super().increase_indent(flow=flow, indentless=False) 37 | 38 | if role_type not in {'actor', 'learner'}: 39 | raise ValueError('Invalid role type') 40 | 41 | with open(config_path, 'w') as f: 42 | args_config = {k: v for k, v in vars(args).items() if 43 | not k.endswith('path') and k != 'agent_config' and k != 'config'} 44 | yaml.dump({role_type: args_config}, f, sort_keys=False, Dumper=Dumper) 45 | f.write('\n') 46 | yaml.dump({'agent': agent.export_config()}, f, sort_keys=False, Dumper=Dumper) 47 | 48 | 49 | def create_experiment_dir(args, prefix: str) -> None: 50 | if args.exp_path is None: 51 | args.exp_path = prefix + datetime.datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d-%H-%M-%S') 52 | args.exp_path = Path(args.exp_path) 53 | 54 | if args.exp_path.exists(): 55 | raise FileExistsError(f'Experiment directory {str(args.exp_path)!r} already exists') 56 | 57 | args.exp_path.mkdir() 58 | 59 | 60 | def get_config_params(obj_or_cls) -> List[str]: 61 | """ 62 | Return configurable parameters in 'Agent.__init__' and 'Model.__init__' which appear after 'config' 63 | :param obj_or_cls: An instance of 'Agent' / 'Model' OR their corresponding classes (NOT base classes) 64 | :return: A list of configurable parameters 65 | """ 66 | # import core # Import inside function to avoid cyclic import 67 | 68 | # if inspect.isclass(obj_or_cls): 69 | # if not issubclass(obj_or_cls, core.Agent) and not issubclass(obj_or_cls, core.Model): 70 | # raise ValueError("Only accepts subclasses of 'Agent' or 'Model'") 71 | # else: 72 | # if not isinstance(obj_or_cls, core.Agent) and not isinstance(obj_or_cls, core.Model): 73 | # raise ValueError("Only accepts instances 'Agent' or 'Model'") 74 | 75 | sig = list(inspect.signature(obj_or_cls.__init__).parameters.keys()) 76 | 77 | config_params = [] 78 | config_part = False 79 | for param in sig: 80 | if param == 'config': 81 | # Following parameters should be what we want 82 | config_part = True 83 | elif param in {'args', 'kwargs'}: 84 | pass 85 | elif config_part: 86 | config_params.append(param) 87 | 88 | return config_params 89 | -------------------------------------------------------------------------------- /learner_torch/kill_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | for i in {3..43} 3 | do 4 | sshpass ssh root@172.15.15.$i "bash /home/zhaoyp/guandan_tog/actor_torch/kill.sh" 5 | done 6 | 7 | ps aux|grep python|grep -v grep|cut -c 9-15|xargs kill -9 8 | 9 | 10 | 11 | rm /home/zhaoyp/guandan_tog/learner_torch/ckpt_bak/*.pth 12 | -------------------------------------------------------------------------------- /learner_torch/kill_learner.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # ps -ef | grep learner.py | awk '{print $2}' | xargs kill -9 3 | ps aux|grep python|grep -v grep|cut -c 9-15|xargs kill -9 4 | -------------------------------------------------------------------------------- /learner_torch/mem_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import time 3 | from collections import defaultdict, deque 4 | from multiprocessing.managers import BaseManager 5 | from typing import Dict, List 6 | 7 | import numpy as np 8 | 9 | __all__ = ['MemPool', 'MultiprocessingMemPool', 'MemPoolManager'] 10 | 11 | 12 | class MemPool: 13 | def __init__(self, capacity: int = None, keys: List[str] = None) -> None: 14 | self._keys = keys 15 | if keys is None: 16 | self.data = defaultdict(lambda: deque(maxlen=capacity)) 17 | else: 18 | self.data = {key: deque(maxlen=capacity) for key in keys} 19 | 20 | def push(self, data: Dict[str, np.ndarray]) -> None: 21 | """Push data into memory pool""" 22 | for key, value in data.items(): 23 | self.data[key].extend(value) 24 | 25 | if self._keys is None: 26 | self._keys = list(self.data.keys()) 27 | 28 | def sample(self, size: int = -1) -> Dict[str, np.ndarray]: 29 | """ 30 | Sample training data from memory pool 31 | :param size: The number of sample data, default '-1' that indicates all data 32 | :return: The sampled and concatenated training data 33 | """ 34 | 35 | num = len(self) 36 | indices = list(range(num)) 37 | if 0 < size < num: 38 | indices = random.sample(indices, size) 39 | 40 | result = {} 41 | for key in self._keys: 42 | result[key] = np.stack([self.data[key][idx] for idx in indices]) 43 | return result 44 | 45 | def clear(self) -> None: 46 | """Clear all data""" 47 | for key in self._keys: 48 | self.data[key].clear() 49 | 50 | def __len__(self): 51 | if self._keys is None: 52 | return 0 53 | return len(self.data[self._keys[0]]) 54 | 55 | 56 | class MultiprocessingMemPool(MemPool): 57 | def __init__(self, capacity: int = None, keys: List[str] = None) -> None: 58 | super().__init__(capacity, keys) 59 | 60 | self._receiving_data_throughput = None 61 | self._consuming_data_throughput = None 62 | 63 | def push(self, data: Dict[str, np.ndarray]) -> None: 64 | super().push(data) 65 | 66 | if self._receiving_data_throughput is not None: 67 | self._receiving_data_throughput += len(data[self._keys[0]]) 68 | 69 | def sample(self, size: int = -1) -> Dict[str, np.ndarray]: 70 | data = super().sample(size) 71 | 72 | if self._consuming_data_throughput is not None: 73 | self._consuming_data_throughput += len(data[self._keys[0]]) 74 | 75 | # super().clear() 76 | 77 | return data 78 | 79 | def clear(self) -> None: 80 | super().clear() 81 | 82 | self._receiving_data_throughput = None 83 | self._consuming_data_throughput = None 84 | 85 | def _get_receiving_data_throughput(self): 86 | return self._receiving_data_throughput 87 | 88 | def _get_consuming_data_throughput(self): 89 | return self._consuming_data_throughput 90 | 91 | def _reset_receiving_data_throughput(self): 92 | self._receiving_data_throughput = 0 93 | 94 | def _reset_consuming_data_throughput(self): 95 | self._consuming_data_throughput = 0 96 | 97 | @classmethod 98 | def record_throughput(cls, obj, interval=10): 99 | """Print receiving and consuming periodically""" 100 | 101 | while True: 102 | obj._reset_receiving_data_throughput() 103 | obj._reset_consuming_data_throughput() 104 | 105 | time.sleep(interval) 106 | 107 | print(f'Receiving FPS: {obj._get_receiving_data_throughput() / interval:.2f}, ' 108 | f'Consuming FPS: {obj._get_consuming_data_throughput() / interval:.2f}') 109 | 110 | 111 | class MemPoolManager(BaseManager): 112 | pass 113 | 114 | 115 | MemPoolManager.register('MemPool', MultiprocessingMemPool, 116 | exposed=['__len__', 'push', 'sample', 'clear', '_get_receiving_data_throughput', 117 | '_get_consuming_data_throughput', '_reset_receiving_data_throughput', 118 | '_reset_consuming_data_throughput']) 119 | -------------------------------------------------------------------------------- /learner_torch/ppo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | import torch 4 | from utils.mpi_pytorch import (mpi_avg_grads, setup_pytorch_for_mpi, 5 | sync_params) 6 | from utils.mpi_tools import (mpi_avg, mpi_fork, num_procs, proc_id) 7 | from common import get_config_params 8 | from torch.optim import Adam, RMSprop 9 | 10 | 11 | class PPOAgent: 12 | def __init__(self, model, clip_ratio=0.2, lr=1e-4, train_iters=20, target_kl=0.01) -> None: 13 | self.ac = model 14 | self.clip_ratio = clip_ratio 15 | self.lr = lr 16 | self.train_iters = train_iters 17 | self.target_kl = target_kl 18 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 19 | self.optimizer = Adam(self.ac.parameters(), lr=self.lr, eps=1e-5) 20 | setup_pytorch_for_mpi() 21 | sync_params(self.ac) 22 | 23 | def update(self, data): 24 | # Set up optimizers for policy and value function 25 | 26 | # pi_l_old, pi_info_old = self.compute_loss_pi(data) 27 | # pi_l_old = pi_l_old.item() 28 | # v_l_old = self.compute_loss_v(data).item() 29 | 30 | # Train policy with multiple steps of gradient descent 31 | for _ in range(self.train_iters): 32 | self.optimizer.zero_grad() 33 | loss_pi, loss_v, loss_ent, pi_info = self.compute_loss(data) 34 | loss = loss_pi + 0.5 * loss_v + 0.05 * loss_ent 35 | kl = mpi_avg(pi_info['kl']) 36 | if kl > 1.5 * self.target_kl: 37 | break 38 | loss.backward() 39 | mpi_avg_grads(self.ac) # average grads across MPI processes 40 | torch.nn.utils.clip_grad_norm_(self.ac.parameters(), 10) 41 | self.optimizer.step() 42 | time.sleep(0.1) 43 | #time.sleep(1) 44 | 45 | return { 46 | 'pg_loss': loss_pi.cpu().detach().numpy(), 47 | 'vf_loss': loss_v.cpu().detach().numpy(), 48 | 'entropy': pi_info['ent'], 49 | 'clip_rate': pi_info['cf'], 50 | 'kl': pi_info['kl'], 51 | } 52 | 53 | # Set up function for computing PPO policy loss 54 | def compute_loss(self, data): 55 | obs, act, adv, logp_old, legalaction = torch.tensor(data['obs']).to(torch.float32).to(self.device), torch.tensor(data['act']).to(torch.float32).to(self.device), torch.tensor(data['adv']).to(torch.float32).to(self.device), torch.tensor(data['logp']).to(torch.float32).to(self.device), torch.tensor(data['legal']).to(torch.float32).to(self.device) 56 | 57 | 58 | # Policy loss 59 | pi, logp, value = self.ac.forward(obs, act, legalaction) 60 | ratio = torch.exp(logp - logp_old) 61 | clipped_ratio = torch.clamp(ratio, 0.0, 3.0) 62 | clip_adv = torch.clamp(ratio, 1-self.clip_ratio, 1+self.clip_ratio) * adv 63 | loss_pi = -(torch.min(clipped_ratio * adv, clip_adv)).mean() 64 | 65 | # value loss 66 | ret = torch.tensor(data['ret']).to(torch.float32).to(self.device) 67 | #print('reward shape', ret.shape) 68 | #print('value', value, 'ret', ret.shape, ret) 69 | loss_v = ((value - ret) ** 2).mean() * 0.5 70 | 71 | # entropy loss 72 | loss_ent = -1 * pi.entropy().mean() 73 | # Useful extra info 74 | approx_kl = (logp_old - logp).mean().item() 75 | ent = pi.entropy().mean().item() 76 | clipped = ratio.gt(1+self.clip_ratio) | ratio.lt(1-self.clip_ratio) 77 | clipfrac = torch.as_tensor(clipped, dtype=torch.float32).mean().item() 78 | pi_info = dict(kl=approx_kl, ent=ent, cf=clipfrac) 79 | 80 | return loss_pi, loss_v, loss_ent, pi_info 81 | 82 | def get_weights(self): 83 | return self.ac.get_weights() 84 | 85 | def export_config(self) -> dict: 86 | """Export dictionary as configurations""" 87 | param_dict = {p: getattr(self, p) for p in get_config_params(self)} 88 | return param_dict 89 | 90 | 91 | if __name__ == '__main__': 92 | from model import MLPActorCritic, MLPQNetwork 93 | device = 'cuda' 94 | model = MLPActorCritic((10, 567), 1).to(device) 95 | model_q = MLPQNetwork(567).to(device) 96 | #ppoagent = PPOAgent(model) 97 | b = np.load("/home/zhaoyp/guandan_tog/actor_ppo/debug128.npy", allow_pickle=True).item() 98 | state = b['x_batch'][0] 99 | print(b['actions']) 100 | index2action = model_q.get_max_10index(torch.tensor(state).to(torch.float32).to(device)) 101 | state = state[index2action] 102 | obs = {'obs': state, 'act': b['actions'][0], 'logp': b['neglogps'][0], 103 | 'adv': b['returns'][0]-b['values'][0], 'ret': b['returns'][0]} 104 | print(obs, obs['obs'].shape) 105 | #info = ppoagent.update(obs) 106 | #print(info) 107 | # print(torch.cuda.is_available()) 108 | -------------------------------------------------------------------------------- /learner_torch/start_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # sshpass ssh root@172.15.15.3 "bash /home/zhaoyp/guandan_tog/actor_n/start.sh" 3 | # nohup python -u learner.py > ./learner_out.log 2>&1 & 4 | #sshpass ssh root@172.15.15.3 "bash /home/zhaoyp/guandan_tog/actor_torch/start.sh" 5 | #sleep 0.1s 6 | for i in {3..13} 7 | do 8 | sshpass ssh root@172.15.15.$i "bash /home/zhaoyp/guandan_tog/actor_torch/start.sh" 9 | echo $i 10 | sleep 0.1s 11 | done 12 | 13 | nohup python -u learner.py > ./learner_out.log 2>&1 & 14 | 15 | for i in {14..43} 16 | do 17 | sshpass ssh root@172.15.15.$i "bash /home/zhaoyp/guandan_tog/actor_torch/start.sh" 18 | echo $i 19 | #sleep 3s 20 | done 21 | -------------------------------------------------------------------------------- /learner_torch/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_torch/utils/__init__.py -------------------------------------------------------------------------------- /learner_torch/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_torch/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /learner_torch/utils/__pycache__/cmdline.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_torch/utils/__pycache__/cmdline.cpython-38.pyc -------------------------------------------------------------------------------- /learner_torch/utils/__pycache__/logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_torch/utils/__pycache__/logger.cpython-38.pyc -------------------------------------------------------------------------------- /learner_torch/utils/__pycache__/mpi_pytorch.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_torch/utils/__pycache__/mpi_pytorch.cpython-38.pyc -------------------------------------------------------------------------------- /learner_torch/utils/__pycache__/mpi_tools.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/learner_torch/utils/__pycache__/mpi_tools.cpython-38.pyc -------------------------------------------------------------------------------- /learner_torch/utils/cmdline.py: -------------------------------------------------------------------------------- 1 | def parse_cmdline_kwargs(args): 2 | """ 3 | Copied from https://github.com/openai/baselines/blob/master/baselines/run.py 4 | convert a list of '='-spaced command-line arguments to a dictionary, evaluating python objects when possible 5 | """ 6 | 7 | def parse(v): 8 | 9 | assert isinstance(v, str) 10 | try: 11 | return eval(v) 12 | except (NameError, SyntaxError): 13 | return v 14 | 15 | return {k: parse(v) for k, v in parse_unknown_args(args).items()} 16 | 17 | 18 | def parse_unknown_args(args): 19 | """ 20 | Copied from https://github.com/openai/baselines/blob/master/baselines/common/cmd_util.py 21 | Parse arguments not consumed by arg parser into a dictionary 22 | """ 23 | retval = {} 24 | preceded_by_key = False 25 | for arg in args: 26 | if arg.startswith('--'): 27 | if '=' in arg: 28 | key = arg.split('=')[0][2:] 29 | value = arg.split('=')[1] 30 | retval[key] = value 31 | else: 32 | key = arg[2:] 33 | preceded_by_key = True 34 | elif preceded_by_key: 35 | retval[key] = arg 36 | preceded_by_key = False 37 | 38 | return retval 39 | -------------------------------------------------------------------------------- /learner_torch/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | """Copied from https://github.com/openai/spinningup/blob/master/spinup/algos/tf1/ppo/core.py""" 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | EPS = 1e-8 6 | 7 | __all__ = ['conv', 'fc', 'conv_to_fc', 'placeholder', 'mlp', 'actor'] 8 | 9 | 10 | def ortho_init(scale=1.0): 11 | def _ortho_init(shape, dtype, partition_info=None): 12 | # lasagne ortho init for tf 13 | shape = tuple(shape) 14 | if len(shape) == 2: 15 | flat_shape = shape 16 | elif len(shape) == 4: # assumes NHWC 17 | flat_shape = (np.prod(shape[:-1]), shape[-1]) 18 | else: 19 | raise NotImplementedError 20 | a = np.random.normal(0.0, 1.0, flat_shape) 21 | u, _, v = np.linalg.svd(a, full_matrices=False) 22 | q = u if u.shape == flat_shape else v # pick the one with the correct shape 23 | q = q.reshape(shape) 24 | return (scale * q[:shape[0], :shape[1]]).astype(np.float32) 25 | 26 | return _ortho_init 27 | 28 | 29 | def conv(x, scope, *, nf, rf, stride, pad='VALID', init_scale=1.0, data_format='NHWC', one_dim_bias=False): 30 | if data_format == 'NHWC': 31 | channel_ax = 3 32 | strides = [1, stride, stride, 1] 33 | bshape = [1, 1, 1, nf] 34 | elif data_format == 'NCHW': 35 | channel_ax = 1 36 | strides = [1, 1, stride, stride] 37 | bshape = [1, nf, 1, 1] 38 | else: 39 | raise NotImplementedError 40 | bias_var_shape = [nf] if one_dim_bias else [1, nf, 1, 1] 41 | nin = x.get_shape()[channel_ax].value 42 | wshape = [rf, rf, nin, nf] 43 | with tf.variable_scope(scope): 44 | w = tf.get_variable("w", wshape, initializer=ortho_init(init_scale)) 45 | b = tf.get_variable("b", bias_var_shape, initializer=tf.constant_initializer(0.0)) 46 | if not one_dim_bias and data_format == 'NHWC': 47 | b = tf.reshape(b, bshape) 48 | return tf.nn.conv2d(x, w, strides=strides, padding=pad, data_format=data_format) + b 49 | 50 | 51 | def fc(x, scope, nh, *, init_scale=1.0, init_bias=0.0): 52 | with tf.variable_scope(scope): 53 | nin = x.get_shape()[1].value 54 | w = tf.get_variable("w", [nin, nh], initializer=ortho_init(init_scale)) 55 | b = tf.get_variable("b", [nh], initializer=tf.constant_initializer(init_bias)) 56 | return tf.matmul(x, w) + b 57 | 58 | 59 | def conv_to_fc(x): 60 | nh = np.prod([v.value for v in x.get_shape()[1:]]) 61 | x = tf.reshape(x, [-1, nh]) 62 | return x 63 | 64 | 65 | def combined_shape(length, shape=None): 66 | if shape is None: 67 | return (length,) 68 | return (length, shape) if np.isscalar(shape) else (length, *shape) 69 | 70 | 71 | def placeholder(dtype=tf.float32, shape=None): 72 | return tf.placeholder(dtype=dtype, shape=combined_shape(None, shape)) 73 | 74 | 75 | def gaussian_likelihood(x, mu, log_std): 76 | pre_sum = -0.5 * (((x - mu) / (tf.exp(log_std) + EPS)) ** 2 + 2 * log_std + np.log(2 * np.pi)) 77 | return tf.reduce_sum(pre_sum, axis=1) 78 | 79 | 80 | def mlp(x, hidden_sizes=(32,), activation=tf.tanh, output_activation=None): 81 | for h in hidden_sizes[:-1]: 82 | x = tf.layers.dense(x, units=h, activation=activation) 83 | return tf.layers.dense(x, units=hidden_sizes[-1], activation=output_activation) 84 | 85 | 86 | def categorical_policy(logits, action, act_dim): 87 | logp_all = tf.nn.log_softmax(logits) 88 | pi = tf.squeeze(tf.multinomial(logits, 1), axis=1) 89 | logp = tf.reduce_sum(tf.one_hot(action, depth=act_dim) * logp_all, axis=1) 90 | logp_pi = tf.reduce_sum(tf.one_hot(pi, depth=act_dim) * logp_all, axis=1) 91 | return pi, logp, logp_pi 92 | 93 | 94 | def gaussian_policy(mu, action, act_dim): 95 | log_std = tf.get_variable(name='log_std', initializer=-0.5 * np.ones(act_dim, dtype=np.float32)) 96 | std = tf.exp(log_std) 97 | pi = mu + tf.random_normal(tf.shape(mu)) * std 98 | logp = gaussian_likelihood(action, mu, log_std) 99 | logp_pi = gaussian_likelihood(pi, mu, log_std) 100 | return pi, logp, logp_pi 101 | 102 | 103 | _mode = ['categorical', 'gaussian'] 104 | 105 | 106 | def actor(logits, action, act_dim, mode='categorical'): 107 | assert mode in _mode 108 | policy = eval(mode + '_policy') 109 | 110 | pi, logp, logp_pi = policy(logits, action, act_dim) 111 | return pi, logp, logp_pi 112 | -------------------------------------------------------------------------------- /learner_torch/utils/mpi_pytorch.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import numpy as np 3 | import os 4 | import torch 5 | from utils.mpi_tools import broadcast, mpi_avg, num_procs, proc_id 6 | 7 | def setup_pytorch_for_mpi(): 8 | """ 9 | Avoid slowdowns caused by each separate process's PyTorch using 10 | more than its fair share of CPU resources. 11 | """ 12 | #print('Proc %d: Reporting original number of Torch threads as %d.'%(proc_id(), torch.get_num_threads()), flush=True) 13 | if torch.get_num_threads()==1: 14 | return 15 | fair_num_threads = max(int(torch.get_num_threads() / num_procs()), 1) 16 | torch.set_num_threads(fair_num_threads) 17 | #print('Proc %d: Reporting new number of Torch threads as %d.'%(proc_id(), torch.get_num_threads()), flush=True) 18 | 19 | def mpi_avg_grads(module): 20 | """ Average contents of gradient buffers across MPI processes. """ 21 | if num_procs()==1: 22 | return 23 | for p in module.parameters(): 24 | p_grad_numpy = p.grad.numpy() # numpy view of tensor data 25 | avg_p_grad = mpi_avg(p.grad) 26 | p_grad_numpy[:] = avg_p_grad[:] 27 | 28 | def sync_params(module): 29 | """ Sync all parameters of module across all MPI processes. """ 30 | if num_procs()==1: 31 | return 32 | for p in module.parameters(): 33 | p_numpy = p.data.numpy() 34 | broadcast(p_numpy) -------------------------------------------------------------------------------- /learner_torch/utils/mpi_tools.py: -------------------------------------------------------------------------------- 1 | from mpi4py import MPI 2 | import os, subprocess, sys 3 | import numpy as np 4 | 5 | 6 | def mpi_fork(n, bind_to_core=False): 7 | """ 8 | Re-launches the current script with workers linked by MPI. 9 | 10 | Also, terminates the original process that launched it. 11 | 12 | Taken almost without modification from the Baselines function of the 13 | `same name`_. 14 | 15 | .. _`same name`: https://github.com/openai/baselines/blob/master/baselines/common/mpi_fork.py 16 | 17 | Args: 18 | n (int): Number of process to split into. 19 | 20 | bind_to_core (bool): Bind each MPI process to a core. 21 | """ 22 | if n<=1: 23 | return 24 | if os.getenv("IN_MPI") is None: 25 | env = os.environ.copy() 26 | env.update( 27 | MKL_NUM_THREADS="1", 28 | OMP_NUM_THREADS="1", 29 | IN_MPI="1" 30 | ) 31 | args = ["mpirun", "-np", str(n)] 32 | if bind_to_core: 33 | args += ["-bind-to", "core"] 34 | args += [sys.executable] + sys.argv 35 | subprocess.check_call(args, env=env) 36 | sys.exit() 37 | 38 | 39 | def msg(m, string=''): 40 | print(('Message from %d: %s \t '%(MPI.COMM_WORLD.Get_rank(), string))+str(m)) 41 | 42 | def proc_id(): 43 | """Get rank of calling process.""" 44 | return MPI.COMM_WORLD.Get_rank() 45 | 46 | def allreduce(*args, **kwargs): 47 | return MPI.COMM_WORLD.Allreduce(*args, **kwargs) 48 | 49 | def num_procs(): 50 | """Count active MPI processes.""" 51 | return MPI.COMM_WORLD.Get_size() 52 | 53 | def broadcast(x, root=0): 54 | MPI.COMM_WORLD.Bcast(x, root=root) 55 | 56 | def mpi_op(x, op): 57 | x, scalar = ([x], True) if np.isscalar(x) else (x, False) 58 | x = np.asarray(x, dtype=np.float32) 59 | buff = np.zeros_like(x, dtype=np.float32) 60 | allreduce(x, buff, op=op) 61 | return buff[0] if scalar else buff 62 | 63 | def mpi_sum(x): 64 | return mpi_op(x, MPI.SUM) 65 | 66 | def mpi_avg(x): 67 | """Average a scalar or vector over MPI processes.""" 68 | return mpi_sum(x) / num_procs() 69 | 70 | def mpi_statistics_scalar(x, with_min_and_max=False): 71 | """ 72 | Get mean/std and optional min/max of scalar x across MPI processes. 73 | 74 | Args: 75 | x: An array containing samples of the scalar to produce statistics 76 | for. 77 | 78 | with_min_and_max (bool): If true, return min and max of x in 79 | addition to mean and std. 80 | """ 81 | x = np.array(x, dtype=np.float32) 82 | global_sum, global_n = mpi_sum([np.sum(x), len(x)]) 83 | mean = global_sum / global_n 84 | 85 | global_sum_sq = mpi_sum(np.sum((x - mean)**2)) 86 | std = np.sqrt(global_sum_sq / global_n) # compute global std 87 | 88 | if with_min_and_max: 89 | global_min = mpi_op(np.min(x) if len(x) > 0 else np.inf, op=MPI.MIN) 90 | global_max = mpi_op(np.max(x) if len(x) > 0 else -np.inf, op=MPI.MAX) 91 | return mean, std, global_min, global_max 92 | return mean, std -------------------------------------------------------------------------------- /rm_container.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | for i in {3..43} 3 | do 4 | docker rm -f guandan_actor_$i 5 | done 6 | 7 | -------------------------------------------------------------------------------- /wintest/ai1/client0.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/10/1 16:30 3 | # @Author : Duofeng Wu 4 | # @File : client.py 5 | # @Description: 6 | 7 | 8 | import json 9 | from ws4py.client.threadedclient import WebSocketClient 10 | from state import State 11 | from action import Action 12 | 13 | 14 | class ExampleClient(WebSocketClient): 15 | 16 | def __init__(self, url): 17 | super().__init__(url) 18 | self.state = State("client0") 19 | self.action = Action("client0") 20 | 21 | def opened(self): 22 | pass 23 | 24 | def closed(self, code, reason=None): 25 | print("Closed down", code, reason) 26 | 27 | def received_message(self, message): 28 | message = json.loads(str(message)) # 先序列化收到的消息,转为Python中的字典 29 | self.state.parse(message) # 调用状态对象来解析状态 30 | if "actionList" in message: # 需要做出动作选择时调用动作对象进行解析 31 | act_index = self.action.rule_parse(message,self.state._myPos,self.state.remain_cards,self.state.history, 32 | self.state.remain_cards_classbynum,self.state.pass_num, 33 | self.state.my_pass_num,self.state.tribute_result) 34 | 35 | print(act_index) 36 | self.send(json.dumps({"actIndex": act_index})) 37 | 38 | if __name__ == '__main__': 39 | try: 40 | ws = ExampleClient('ws://127.0.0.1:23456/game/client0') 41 | 42 | #ws = ExampleClient('ws://112.124.24.226:80/game/gd/15251763326255578') 43 | # ws = ExampleClient('ws://101.37.15.53:80/game/gd/15251763326255578') 44 | ws.connect() 45 | ws.run_forever() 46 | except KeyboardInterrupt: 47 | ws.close() 48 | -------------------------------------------------------------------------------- /wintest/ai1/client1.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/10/1 16:30 3 | # @Author : Duofeng Wu 4 | # @File : client.py 5 | # @Description: 6 | 7 | 8 | import json 9 | from ws4py.client.threadedclient import WebSocketClient 10 | from state import State 11 | from action import Action 12 | 13 | 14 | class ExampleClient(WebSocketClient): 15 | 16 | def __init__(self, url): 17 | super().__init__(url) 18 | self.state = State("client0") 19 | self.action = Action("client0") 20 | 21 | def opened(self): 22 | pass 23 | 24 | def closed(self, code, reason=None): 25 | print("Closed down", code, reason) 26 | 27 | def received_message(self, message): 28 | message = json.loads(str(message)) # 先序列化收到的消息,转为Python中的字典 29 | self.state.parse(message) # 调用状态对象来解析状态 30 | if "actionList" in message: # 需要做出动作选择时调用动作对象进行解析 31 | act_index = self.action.rule_parse(message,self.state._myPos,self.state.remain_cards,self.state.history, 32 | self.state.remain_cards_classbynum,self.state.pass_num, 33 | self.state.my_pass_num,self.state.tribute_result) 34 | 35 | print(act_index) 36 | self.send(json.dumps({"actIndex": act_index})) 37 | 38 | if __name__ == '__main__': 39 | try: 40 | ws = ExampleClient('ws://127.0.0.1:23456/game/client1') 41 | 42 | #ws = ExampleClient('ws://112.124.24.226:80/game/gd/15251763326255578') 43 | # ws = ExampleClient('ws://101.37.15.53:80/game/gd/15251763326255578') 44 | ws.connect() 45 | ws.run_forever() 46 | except KeyboardInterrupt: 47 | ws.close() 48 | -------------------------------------------------------------------------------- /wintest/ai1/client2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/10/1 16:30 3 | # @Author : Duofeng Wu 4 | # @File : client.py 5 | # @Description: 6 | 7 | 8 | import json 9 | from ws4py.client.threadedclient import WebSocketClient 10 | from state import State 11 | from action import Action 12 | 13 | 14 | class ExampleClient(WebSocketClient): 15 | 16 | def __init__(self, url): 17 | super().__init__(url) 18 | self.state = State("client2") 19 | self.action = Action("client2") 20 | 21 | def opened(self): 22 | pass 23 | 24 | def closed(self, code, reason=None): 25 | print("Closed down", code, reason) 26 | 27 | def received_message(self, message): 28 | message = json.loads(str(message)) # 先序列化收到的消息,转为Python中的字典 29 | self.state.parse(message) # 调用状态对象来解析状态 30 | if "actionList" in message: # 需要做出动作选择时调用动作对象进行解析 31 | # act_index = self.action.random_parse(message) 32 | act_index = self.action.rule_parse(message,self.state._myPos,self.state.remain_cards,self.state.history, 33 | self.state.remain_cards_classbynum,self.state.pass_num, 34 | self.state.my_pass_num,self.state.tribute_result) 35 | print(act_index) 36 | self.send(json.dumps({"actIndex": act_index})) 37 | 38 | 39 | if __name__ == '__main__': 40 | try: 41 | 42 | ws = ExampleClient('ws://127.0.0.1:23456/game/client2') 43 | 44 | # ws = ExampleClient('ws://112.124.24.226:80/game/gd/15251763326776392') 45 | # ws = ExampleClient('ws://101.37.15.53:80/game/gd/15251763326776392') 46 | ws.connect() 47 | ws.run_forever() 48 | except KeyboardInterrupt: 49 | ws.close() 50 | -------------------------------------------------------------------------------- /wintest/ai1/client3.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/10/1 16:30 3 | # @Author : Duofeng Wu 4 | # @File : client.py 5 | # @Description: 6 | 7 | 8 | import json 9 | from ws4py.client.threadedclient import WebSocketClient 10 | from state import State 11 | from action import Action 12 | 13 | 14 | class ExampleClient(WebSocketClient): 15 | 16 | def __init__(self, url): 17 | super().__init__(url) 18 | self.state = State("client2") 19 | self.action = Action("client2") 20 | 21 | def opened(self): 22 | pass 23 | 24 | def closed(self, code, reason=None): 25 | print("Closed down", code, reason) 26 | 27 | def received_message(self, message): 28 | message = json.loads(str(message)) # 先序列化收到的消息,转为Python中的字典 29 | self.state.parse(message) # 调用状态对象来解析状态 30 | if "actionList" in message: # 需要做出动作选择时调用动作对象进行解析 31 | # act_index = self.action.random_parse(message) 32 | act_index = self.action.rule_parse(message,self.state._myPos,self.state.remain_cards,self.state.history, 33 | self.state.remain_cards_classbynum,self.state.pass_num, 34 | self.state.my_pass_num,self.state.tribute_result) 35 | print(act_index) 36 | self.send(json.dumps({"actIndex": act_index})) 37 | 38 | 39 | if __name__ == '__main__': 40 | try: 41 | 42 | ws = ExampleClient('ws://127.0.0.1:23456/game/client3') 43 | 44 | # ws = ExampleClient('ws://112.124.24.226:80/game/gd/15251763326776392') 45 | # ws = ExampleClient('ws://101.37.15.53:80/game/gd/15251763326776392') 46 | ws.connect() 47 | ws.run_forever() 48 | except KeyboardInterrupt: 49 | ws.close() 50 | -------------------------------------------------------------------------------- /wintest/ai2/action.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/10/1 21:32 3 | # @Author : Duofeng Wu 4 | # @File : action.py 5 | # @Description: 动作类 6 | 7 | from random import randint 8 | 9 | # 中英文对照表 10 | ENG2CH = { 11 | "Single": "单张", 12 | "Pair": "对子", 13 | "Trips": "三张", 14 | "ThreePair": "三连对", 15 | "ThreeWithTwo": "三带二", 16 | "TwoTrips": "钢板", 17 | "Straight": "顺子", 18 | "StraightFlush": "同花顺", 19 | "Bomb": "炸弹", 20 | "PASS": "过" 21 | } 22 | 23 | 24 | class Action(object): 25 | 26 | def __init__(self): 27 | self.action = [] 28 | self.act_range = -1 29 | 30 | def GetIndexFromBack(self, msg, retValue): #"actionList": [['back', 'back', ['S2']], ['back', 'back', ['H2']] 31 | retIndex = 0 32 | print("retValue:", retValue) 33 | retAction = retValue['action'] 34 | for action in msg["actionList"]: 35 | if (action[2] == retAction): 36 | retIndex = msg["actionList"].index(action) 37 | print("选择动作:", retIndex, "动作为:", msg["actionList"][retIndex]) 38 | return retIndex 39 | 40 | def GetIndexFromPlay(self, msg, retValue): 41 | #print("actionlist:",msg["actionList"]) 42 | 43 | sortedAction = retValue["action"] 44 | if retValue["type"] != "PASS": 45 | sortedAction.sort() 46 | print("retValue:",retValue) 47 | retIndex = 0 48 | for action in msg["actionList"]: 49 | if (action[2]!="PASS"): action[2].sort() 50 | #print("retvalue:",retValue["type"], retValue["rank"], sortedAction) 51 | #print("actionfromlist:",action[0], action[1], action[2]) 52 | if (action[0]==retValue["type"] and action[1]==retValue["rank"] and action[2]==sortedAction): 53 | retIndex=msg["actionList"].index(action) 54 | print("选择动作:", retIndex, "动作为:", msg["actionList"][retIndex]) 55 | return retIndex 56 | 57 | def parse(self, msg): 58 | self.action = msg["actionList"] 59 | self.act_range = msg["indexRange"] 60 | print(self.action) 61 | print("可选动作范围为:0至{}".format(self.act_range)) 62 | return randint(0, self.act_range) 63 | -------------------------------------------------------------------------------- /wintest/ai2/client0.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/10/1 16:30 3 | # @Author : Duofeng Wu 4 | # @File : client.py 5 | # @Description: 6 | 7 | 8 | import json 9 | 10 | from ws4py.client.threadedclient import WebSocketClient 11 | 12 | from action import Action 13 | from state import State 14 | 15 | 16 | class ExampleClient(WebSocketClient): 17 | 18 | def __init__(self, url): 19 | super().__init__(url) 20 | self.state = State() 21 | self.action = Action() 22 | 23 | def opened(self): 24 | pass 25 | 26 | def closed(self, code, reason=None): 27 | print("Closed down", code, reason) 28 | 29 | def received_message(self, message): 30 | message = json.loads(str(message)) # 先序列化收到的消息,转为Python中的字典 31 | self.state.parse(message) # 调用状态对象来解析状态 32 | if "actionList" in message: # 需要做出动作选择时调用动作对象进行解析 33 | try: 34 | if message["stage"]=="play": 35 | act_index = self.action.GetIndexFromPlay(message, self.state.retValue) 36 | elif message["stage"]=="back": 37 | act_index = self.action.GetIndexFromBack(message, self.state.retValue) 38 | else: 39 | act_index = self.action.parse(message) 40 | except: 41 | act_index = self.action.parse(message) 42 | self.send(json.dumps({"actIndex": act_index})) 43 | 44 | 45 | if __name__ == '__main__': 46 | try: 47 | ws = ExampleClient('ws://127.0.0.1:23456/game/client0') 48 | ws.connect() 49 | ws.run_forever() 50 | except KeyboardInterrupt: 51 | ws.close() 52 | -------------------------------------------------------------------------------- /wintest/ai2/client1.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/10/1 16:30 3 | # @Author : Duofeng Wu 4 | # @File : client.py 5 | # @Description: 6 | 7 | 8 | import json 9 | 10 | from ws4py.client.threadedclient import WebSocketClient 11 | 12 | from action import Action 13 | from state import State 14 | 15 | 16 | class ExampleClient(WebSocketClient): 17 | 18 | def __init__(self, url): 19 | super().__init__(url) 20 | self.state = State() 21 | self.action = Action() 22 | 23 | def opened(self): 24 | pass 25 | 26 | def closed(self, code, reason=None): 27 | print("Closed down", code, reason) 28 | 29 | def received_message(self, message): 30 | message = json.loads(str(message)) # 先序列化收到的消息,转为Python中的字典 31 | self.state.parse(message) # 调用状态对象来解析状态 32 | if "actionList" in message: # 需要做出动作选择时调用动作对象进行解析 33 | try: 34 | if message["stage"]=="play": 35 | act_index = self.action.GetIndexFromPlay(message, self.state.retValue) 36 | elif message["stage"]=="back": 37 | act_index = self.action.GetIndexFromBack(message, self.state.retValue) 38 | else: 39 | act_index = self.action.parse(message) 40 | except: 41 | act_index = self.action.parse(message) 42 | self.send(json.dumps({"actIndex": act_index})) 43 | 44 | 45 | if __name__ == '__main__': 46 | try: 47 | ws = ExampleClient('ws://127.0.0.1:23456/game/client1') 48 | ws.connect() 49 | ws.run_forever() 50 | except KeyboardInterrupt: 51 | ws.close() 52 | -------------------------------------------------------------------------------- /wintest/ai2/client2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/10/1 16:30 3 | # @Author : Duofeng Wu 4 | # @File : client.py 5 | # @Description: 6 | 7 | 8 | import json 9 | 10 | from ws4py.client.threadedclient import WebSocketClient 11 | 12 | from action import Action 13 | from state import State 14 | 15 | 16 | class ExampleClient(WebSocketClient): 17 | 18 | def __init__(self, url): 19 | super().__init__(url) 20 | self.state = State() 21 | self.action = Action() 22 | 23 | def opened(self): 24 | pass 25 | 26 | def closed(self, code, reason=None): 27 | print("Closed down", code, reason) 28 | 29 | def received_message(self, message): 30 | message = json.loads(str(message)) # 先序列化收到的消息,转为Python中的字典 31 | self.state.parse(message) # 调用状态对象来解析状态 32 | if "actionList" in message: # 需要做出动作选择时调用动作对象进行解析 33 | try: 34 | if message["stage"]=="play": 35 | act_index = self.action.GetIndexFromPlay(message, self.state.retValue) 36 | elif message["stage"]=="back": 37 | act_index = self.action.GetIndexFromBack(message, self.state.retValue) 38 | else: 39 | act_index = self.action.parse(message) 40 | except: 41 | act_index = self.action.parse(message) 42 | self.send(json.dumps({"actIndex": act_index})) 43 | 44 | 45 | if __name__ == '__main__': 46 | try: 47 | ws = ExampleClient('ws://127.0.0.1:23456/game/client2') 48 | ws.connect() 49 | ws.run_forever() 50 | except KeyboardInterrupt: 51 | ws.close() 52 | -------------------------------------------------------------------------------- /wintest/ai2/client3.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/10/1 16:30 3 | # @Author : Duofeng Wu 4 | # @File : client.py 5 | # @Description: 6 | 7 | 8 | import json 9 | 10 | from ws4py.client.threadedclient import WebSocketClient 11 | 12 | from action import Action 13 | from state import State 14 | 15 | 16 | class ExampleClient(WebSocketClient): 17 | 18 | def __init__(self, url): 19 | super().__init__(url) 20 | self.state = State() 21 | self.action = Action() 22 | 23 | def opened(self): 24 | pass 25 | 26 | def closed(self, code, reason=None): 27 | print("Closed down", code, reason) 28 | 29 | def received_message(self, message): 30 | message = json.loads(str(message)) # 先序列化收到的消息,转为Python中的字典 31 | self.state.parse(message) # 调用状态对象来解析状态 32 | if "actionList" in message: # 需要做出动作选择时调用动作对象进行解析 33 | try: 34 | if message["stage"]=="play": 35 | act_index = self.action.GetIndexFromPlay(message, self.state.retValue) 36 | elif message["stage"]=="back": 37 | act_index = self.action.GetIndexFromBack(message, self.state.retValue) 38 | else: 39 | act_index = self.action.parse(message) 40 | except: 41 | act_index = self.action.parse(message) 42 | self.send(json.dumps({"actIndex": act_index})) 43 | 44 | 45 | if __name__ == '__main__': 46 | try: 47 | ws = ExampleClient('ws://127.0.0.1:23456/game/client3') 48 | ws.connect() 49 | ws.run_forever() 50 | except KeyboardInterrupt: 51 | ws.close() 52 | -------------------------------------------------------------------------------- /wintest/ai2/config.py: -------------------------------------------------------------------------------- 1 | cardRanks=['2','3','4','5','6','7','8','9','T','J','Q','K','A','B','R'] 2 | cardColors=['S','H','C','D'] 3 | #cardTypes=['StraightFlush', 'Bomb', 'ThreePair', 'TwoTrips', 'ThreeWithTwo', 'Straight', 'Trips', 'Pair', 'Single'] 4 | cardTypes=['StraightFlush', 'Bomb', 'ThreePair', 'TwoTrips', 'Straight', 'ThreeWithTwo', 'Trips', 'Pair', 'Single'] 5 | 6 | class CompareRank(): 7 | def Larger(self, type, rank, card, formerAction, curRank): #('Straight','5','9',['S4','S5','H6','H7,'D8']) -> True 8 | if (rank == 'JOKER'): # all 4 Jokers 9 | return True 10 | elif (formerAction['rank'] == 'JOKER'): 11 | return False 12 | if ((type == 'StraightFlush' or type == 'Bomb') and formerAction['type']!='Bomb' and formerAction['type']!='StraightFlush'): 13 | return True 14 | if (type != 'Bomb' and type != 'StraightFlush' and (formerAction['type'] == 'Bomb' or formerAction['type'] == 'StraightFlush')): 15 | return False 16 | 17 | r1 = cardRanks.index(rank) 18 | r2 = cardRanks.index(formerAction['rank']) 19 | 20 | #print(type, r1, r2) 21 | if (type=='Bomb'): 22 | if (formerAction['type']=='Bomb'): 23 | if (card>len(formerAction['action'])): 24 | return True 25 | elif (card r2 29 | elif (formerAction['type']=='StraightFlush'): 30 | if card>5: 31 | return True 32 | else: 33 | return False 34 | else: 35 | return True 36 | elif (type=='StraightFlush'): 37 | if (formerAction['type'] == 'Bomb'): 38 | if (len(formerAction['action']) <= 5): 39 | return True 40 | else: 41 | return False 42 | elif (formerAction['type'] == 'StraightFlush'): 43 | if (r1 == cardRanks.index('A')): r1 = -1 44 | if (r2 == cardRanks.index('A')): r2 = -1 45 | return r1 > r2 46 | else: 47 | return True 48 | elif (type=='Trips' or type=='Pair' or type=='Single' or type=='ThreeWithTwo'): 49 | if rank == curRank: 50 | r1 = cardRanks.index('A') + 0.5 51 | if formerAction['rank'] == curRank: 52 | r2 = cardRanks.index('A') + 0.5 53 | return r1 > r2 54 | elif (type=='ThreePair' or type=='TripsPair' or type=='Straight'): 55 | if (r1 == cardRanks.index('A')): r1 = -1 56 | if (r2 == cardRanks.index('A')): r2 = -1 57 | return r1 > r2 58 | 59 | def Smaller(self, type, rank, card, formerAction, curRank): # ('Straight','5','9',['S4','S5','H6','H7,'D8']) -> False 60 | if (type == formerAction['type'] and rank == formerAction['rank']): 61 | if (type == 'ThreeWithTwo'): 62 | formerCard = '' 63 | for action in formerAction['action']: 64 | if action[1]!=formerAction['rank']: 65 | formerCard = action[1] 66 | r1 = cardRanks.index(card) 67 | r2 = cardRanks.index(formerCard) 68 | if card == curRank: 69 | r1 = cardRanks.index('A') + 0.5 70 | if formerCard == curRank: 71 | r2 = cardRanks.index('A') + 0.5 72 | return r1 < r2 73 | else: 74 | return False 75 | else: 76 | return not self.Larger(type, rank, card, formerAction, curRank) 77 | 78 | #print(CompareRank().Larger('Bomb','T',4,{'type':'Bomb','rank':'A','action':['SA', 'HA', 'HA', 'DA']}, 'T')) 79 | 80 | #print(not CompareRank().Larger('Pair', 'A', 'A', {'action': ['S6', 'H6', 'C6', 'C6'], 'type': 'Bomb', 'rank': '6'}, '3')) 81 | 82 | #print(CompareRank().Smaller('ThreeWithTwo','J','T',{'type':'ThreeWithTwo','rank':'J','action':['S4', 'H4', 'SJ', 'SJ', 'CJ']}, '2')) -------------------------------------------------------------------------------- /wintest/ai3/action.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/10/1 21:32 3 | # @Author : Duofeng Wu 4 | # @File : action.py 5 | # @Description: 动作类 6 | # 版本号:INDEX OS2.0.0 7 | 8 | from random import randint 9 | from message_Reyn_CUR import check_message 10 | 11 | # 中英文对照表 12 | ENG2CH = { 13 | "Single": "单张", 14 | "Pair": "对子", 15 | "Trips": "三张", 16 | "ThreePair": "三连对", 17 | "ThreeWithTwo": "三带二", 18 | "TwoTrips": "钢板", 19 | "Straight": "顺子", 20 | "StraightFlush": "同花顺", 21 | "Bomb": "炸弹", 22 | "PASS": "过" 23 | } 24 | 25 | 26 | class Action(object): 27 | 28 | def __init__(self): 29 | self.action = [] 30 | self.act_range = -1 31 | self.AI_choice = -1 32 | 33 | #该为完全随机数的行动 34 | def parse(self, msg): 35 | self.action = msg["actionList"] 36 | self.act_range = msg["indexRange"] 37 | print(self.action) 38 | print("可选动作范围为:0至{}".format(self.act_range)) 39 | return randint(0, self.act_range) 40 | 41 | #该为有AI加持的确定行动 42 | def parse_AI(self, msg, pos): 43 | self.action = msg["actionList"] 44 | self.act_range = msg["indexRange"] 45 | print(self.action) 46 | #运行AI来确定需要出的牌 47 | self.AI_choice = check_message(msg,pos) 48 | #由于没有考虑进贡,故而随机,否则bug 49 | if self.AI_choice == None: 50 | return randint(0, self.act_range) 51 | print("AI选择的出牌编号为:{}".format(self.AI_choice)) 52 | return self.AI_choice -------------------------------------------------------------------------------- /wintest/ai3/client0.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/10/1 16:30 3 | # @Author : Duofeng Wu 4 | # @File : client.py 5 | # @Description: 6 | # 对抗AI1号 7 | 8 | from random import randint 9 | 10 | import json 11 | 12 | pos = 0 13 | from ws4py.client.threadedclient import WebSocketClient 14 | 15 | from action import Action 16 | from state import State 17 | 18 | 19 | class ExampleClient(WebSocketClient): 20 | 21 | def __init__(self, url): 22 | super().__init__(url) 23 | self.state = State() 24 | self.action = Action() 25 | 26 | def opened(self): 27 | pass 28 | 29 | def closed(self, code, reason=None): 30 | print("Closed down", code, reason) 31 | 32 | def received_message(self, message): 33 | message = json.loads(str(message)) # 先序列化收到的消息,转为Python中的字典 34 | self.state.parse(message) # 调用状态对象来解析状态 35 | if 'myPos' in message: 36 | global pos 37 | pos = message['myPos'] 38 | if "actionList" in message: # 需要做出动作选择时调用动作对象进行解析 39 | #由AI进行选择,座位号随时读取 40 | try: 41 | act_index = self.action.parse_AI(message,pos) 42 | except: 43 | act_index = randint(0, message['indexRange']) 44 | self.send(json.dumps({"actIndex": act_index})) 45 | 46 | 47 | 48 | if __name__ == '__main__': 49 | try: 50 | # ws = ExampleClient('ws://127.0.0.1:9618/game/gd/client1') 51 | ws = ExampleClient('ws://127.0.0.1:23456/game/client0') 52 | ws.connect() 53 | ws.run_forever() 54 | except KeyboardInterrupt: 55 | ws.close() 56 | 57 | -------------------------------------------------------------------------------- /wintest/ai3/client1.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/10/1 16:30 3 | # @Author : Duofeng Wu 4 | # @File : client.py 5 | # @Description: 6 | # 对抗AI1号 7 | 8 | from random import randint 9 | 10 | import json 11 | 12 | pos = 0 13 | from ws4py.client.threadedclient import WebSocketClient 14 | 15 | from action import Action 16 | from state import State 17 | 18 | 19 | class ExampleClient(WebSocketClient): 20 | 21 | def __init__(self, url): 22 | super().__init__(url) 23 | self.state = State() 24 | self.action = Action() 25 | 26 | def opened(self): 27 | pass 28 | 29 | def closed(self, code, reason=None): 30 | print("Closed down", code, reason) 31 | 32 | def received_message(self, message): 33 | message = json.loads(str(message)) # 先序列化收到的消息,转为Python中的字典 34 | self.state.parse(message) # 调用状态对象来解析状态 35 | if 'myPos' in message: 36 | global pos 37 | pos = message['myPos'] 38 | if "actionList" in message: # 需要做出动作选择时调用动作对象进行解析 39 | #由AI进行选择,座位号随时读取 40 | try: 41 | act_index = self.action.parse_AI(message,pos) 42 | except: 43 | act_index = randint(0, message['indexRange']) 44 | self.send(json.dumps({"actIndex": act_index})) 45 | 46 | 47 | 48 | if __name__ == '__main__': 49 | try: 50 | # ws = ExampleClient('ws://127.0.0.1:9618/game/gd/client1') 51 | ws = ExampleClient('ws://127.0.0.1:23456/game/client1') 52 | ws.connect() 53 | ws.run_forever() 54 | except KeyboardInterrupt: 55 | ws.close() 56 | 57 | -------------------------------------------------------------------------------- /wintest/ai3/client2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/10/1 16:30 3 | # @Author : Duofeng Wu 4 | # @File : client.py 5 | # @Description: 6 | # 对抗AI2号 7 | 8 | from random import randint 9 | 10 | import json 11 | 12 | pos = 0 13 | from ws4py.client.threadedclient import WebSocketClient 14 | 15 | from action import Action 16 | from state import State 17 | 18 | 19 | class ExampleClient(WebSocketClient): 20 | 21 | def __init__(self, url): 22 | super().__init__(url) 23 | self.state = State() 24 | self.action = Action() 25 | 26 | def opened(self): 27 | pass 28 | 29 | def closed(self, code, reason=None): 30 | print("Closed down", code, reason) 31 | 32 | def received_message(self, message): 33 | message = json.loads(str(message)) # 先序列化收到的消息,转为Python中的字典 34 | self.state.parse(message) # 调用状态对象来解析状态 35 | if 'myPos' in message: 36 | global pos 37 | pos = message['myPos'] 38 | if "actionList" in message: # 需要做出动作选择时调用动作对象进行解析 39 | #由AI进行选择,座位号随时读取 40 | try: 41 | act_index = self.action.parse_AI(message,pos) 42 | except: 43 | act_index = randint(0, message['indexRange']) 44 | self.send(json.dumps({"actIndex": act_index})) 45 | 46 | 47 | if __name__ == '__main__': 48 | try: 49 | ws = ExampleClient('ws://127.0.0.1:23456/game/client2') 50 | # ws = ExampleClient('ws://127.0.0.1:9618/game/gd/client3') 51 | ws.connect() 52 | ws.run_forever() 53 | except KeyboardInterrupt: 54 | ws.close() 55 | 56 | 57 | -------------------------------------------------------------------------------- /wintest/ai3/client3.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/10/1 16:30 3 | # @Author : Duofeng Wu 4 | # @File : client.py 5 | # @Description: 6 | # 对抗AI2号 7 | 8 | from random import randint 9 | 10 | import json 11 | 12 | pos = 0 13 | from ws4py.client.threadedclient import WebSocketClient 14 | 15 | from action import Action 16 | from state import State 17 | 18 | 19 | class ExampleClient(WebSocketClient): 20 | 21 | def __init__(self, url): 22 | super().__init__(url) 23 | self.state = State() 24 | self.action = Action() 25 | 26 | def opened(self): 27 | pass 28 | 29 | def closed(self, code, reason=None): 30 | print("Closed down", code, reason) 31 | 32 | def received_message(self, message): 33 | message = json.loads(str(message)) # 先序列化收到的消息,转为Python中的字典 34 | self.state.parse(message) # 调用状态对象来解析状态 35 | if 'myPos' in message: 36 | global pos 37 | pos = message['myPos'] 38 | if "actionList" in message: # 需要做出动作选择时调用动作对象进行解析 39 | #由AI进行选择,座位号随时读取 40 | try: 41 | act_index = self.action.parse_AI(message,pos) 42 | except: 43 | act_index = randint(0, message['indexRange']) 44 | self.send(json.dumps({"actIndex": act_index})) 45 | 46 | 47 | if __name__ == '__main__': 48 | try: 49 | ws = ExampleClient('ws://127.0.0.1:23456/game/client3') 50 | # ws = ExampleClient('ws://127.0.0.1:9618/game/gd/client3') 51 | ws.connect() 52 | ws.run_forever() 53 | except KeyboardInterrupt: 54 | ws.close() 55 | 56 | 57 | -------------------------------------------------------------------------------- /wintest/ai4/client.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/10/1 16:30 3 | # @Author : Duofeng Wu 4 | # @File : client.py 5 | # @Description: 6 | 7 | 8 | import json 9 | from ws4py.client.threadedclient import WebSocketClient 10 | from state import State 11 | from action import Action 12 | 13 | 14 | class ExampleClient(WebSocketClient): 15 | 16 | def __init__(self, url): 17 | super().__init__(url) 18 | self.state = State() 19 | self.action = Action() 20 | 21 | def opened(self): 22 | pass 23 | 24 | def closed(self, code, reason=None): 25 | print("Closed down", code, reason) 26 | 27 | def received_message(self, message): 28 | message = json.loads(str(message)) # 先序列化收到的消息,转为Python中的字典 29 | self.state.parse(message) # 调用状态对象来解析状态 30 | if "actionList" in message: # 需要做出动作选择时调用动作对象进行解析 31 | act_index = self.action.parse(message) 32 | self.send(json.dumps({"actIndex": act_index})) 33 | 34 | 35 | if __name__ == '__main__': 36 | try: 37 | ws = ExampleClient('ws://127.0.0.1:23456/game/client') 38 | ws.connect() 39 | ws.run_forever() 40 | except KeyboardInterrupt: 41 | ws.close() 42 | -------------------------------------------------------------------------------- /wintest/ai4/client1.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/10/1 16:30 3 | # @Author : Duofeng Wu 4 | # @File : client.py 5 | # @Description: 6 | 7 | 8 | import json 9 | from ws4py.client.threadedclient import WebSocketClient 10 | from state import State 11 | from action import Action 12 | 13 | 14 | class ExampleClient(WebSocketClient): 15 | 16 | def __init__(self, url): 17 | super().__init__(url) 18 | self.state = State("client1") 19 | self.action = Action("client1") 20 | 21 | def opened(self): 22 | pass 23 | 24 | def closed(self, code, reason=None): 25 | print("Closed down", code, reason) 26 | 27 | def received_message(self, message): 28 | message = json.loads(str(message)) # 先序列化收到的消息,转为Python中的字典 29 | self.state.parse(message) # 调用状态对象来解析状态 30 | if "actionList" in message: # 需要做出动作选择时调用动作对象进行解析 31 | act_index = self.action.rule_parse(message,self.state._myPos,self.state.remain_cards,self.state.history, 32 | self.state.remain_cards_classbynum,self.state.pass_num, 33 | self.state.my_pass_num,self.state.tribute_result) 34 | 35 | print(act_index) 36 | self.send(json.dumps({"actIndex": act_index})) 37 | 38 | if __name__ == '__main__': 39 | try: 40 | # ws = ExampleClient('ws://127.0.0.1:9618/game/gd/client1') 41 | ws = ExampleClient('ws://127.0.0.1:23456/game/client0') 42 | ws.connect() 43 | ws.run_forever() 44 | except KeyboardInterrupt: 45 | ws.close() 46 | -------------------------------------------------------------------------------- /wintest/ai4/client2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/10/1 16:30 3 | # @Author : Duofeng Wu 4 | # @File : client.py 5 | # @Description: 6 | 7 | 8 | import json 9 | from ws4py.client.threadedclient import WebSocketClient 10 | from state import State 11 | from action import Action 12 | 13 | 14 | class ExampleClient(WebSocketClient): 15 | 16 | def __init__(self, url): 17 | super().__init__(url) 18 | self.state = State("client2") 19 | self.action = Action("client2") 20 | 21 | def opened(self): 22 | pass 23 | 24 | def closed(self, code, reason=None): 25 | print("Closed down", code, reason) 26 | 27 | def received_message(self, message): 28 | message = json.loads(str(message)) # 先序列化收到的消息,转为Python中的字典 29 | self.state.parse(message) # 调用状态对象来解析状态 30 | if "actionList" in message: # 需要做出动作选择时调用动作对象进行解析 31 | # act_index = self.action.random_parse(message) 32 | act_index = self.action.rule_parse(message,self.state._myPos,self.state.remain_cards,self.state.history, 33 | self.state.remain_cards_classbynum,self.state.pass_num, 34 | self.state.my_pass_num,self.state.tribute_result) 35 | print(act_index) 36 | self.send(json.dumps({"actIndex": act_index})) 37 | 38 | 39 | if __name__ == '__main__': 40 | try: 41 | #ws = ExampleClient('ws://114.55.107.187:9618/game/gd/client1') 42 | ws = ExampleClient('ws://127.0.0.1:23456/game/client1') 43 | # ws = ExampleClient('ws://127.0.0.1:9618/game/gd/13913359464455242') 44 | ws.connect() 45 | ws.run_forever() 46 | except KeyboardInterrupt: 47 | ws.close() 48 | -------------------------------------------------------------------------------- /wintest/ai4/client3.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/10/1 16:30 3 | # @Author : Duofeng Wu 4 | # @File : client.py 5 | # @Description: 6 | 7 | 8 | import json 9 | from ws4py.client.threadedclient import WebSocketClient 10 | from state import State 11 | from action import Action 12 | 13 | 14 | class ExampleClient(WebSocketClient): 15 | 16 | def __init__(self, url): 17 | super().__init__(url) 18 | self.state = State("client3") 19 | self.action = Action("client3") 20 | 21 | def opened(self): 22 | pass 23 | 24 | def closed(self, code, reason=None): 25 | print("Closed down", code, reason) 26 | 27 | def received_message(self, message): 28 | message = json.loads(str(message)) # 先序列化收到的消息,转为Python中的字典 29 | self.state.parse(message) # 调用状态对象来解析状态 30 | if "actionList" in message: # 需要做出动作选择时调用动作对象进行解析 31 | act_index = self.action.rule_parse(message,self.state._myPos,self.state.remain_cards,self.state.history, 32 | self.state.remain_cards_classbynum,self.state.pass_num, 33 | self.state.my_pass_num,self.state.tribute_result) 34 | print(act_index) 35 | self.send(json.dumps({"actIndex": act_index})) 36 | 37 | 38 | if __name__ == '__main__': 39 | try: 40 | #ws = ExampleClient('ws://114.55.107.187:9618/game/gd/client3') 41 | ws = ExampleClient('ws://127.0.0.1:23456/game/client2') 42 | ws.connect() 43 | ws.run_forever() 44 | except KeyboardInterrupt: 45 | ws.close() 46 | -------------------------------------------------------------------------------- /wintest/ai4/client4.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/10/1 16:30 3 | # @Author : Duofeng Wu 4 | # @File : client.py 5 | # @Description: 6 | 7 | 8 | import json 9 | from ws4py.client.threadedclient import WebSocketClient 10 | from state import State 11 | from action import Action 12 | 13 | 14 | class ExampleClient(WebSocketClient): 15 | 16 | def __init__(self, url): 17 | super().__init__(url) 18 | self.state = State("client4") 19 | self.action = Action("client4") 20 | 21 | def opened(self): 22 | pass 23 | 24 | def closed(self, code, reason=None): 25 | print("Closed down", code, reason) 26 | 27 | def received_message(self, message): 28 | message = json.loads(str(message)) # 先序列化收到的消息,转为Python中的字典 29 | self.state.parse(message) # 调用状态对象来解析状态 30 | if "actionList" in message: 31 | # act_index = self.action.random_parse(message)# 需要做出动作选择时调用动作对象进行解析 32 | act_index = self.action.rule_parse(message,self.state._myPos,self.state.remain_cards,self.state.history, 33 | self.state.remain_cards_classbynum,self.state.pass_num, 34 | self.state.my_pass_num,self.state.tribute_result) 35 | self.send(json.dumps({"actIndex": act_index})) 36 | 37 | 38 | if __name__ == '__main__': 39 | try: 40 | #ws = ExampleClient('ws://114.55.107.187:9618/game/gd/client3') 41 | ws = ExampleClient('ws://127.0.0.1:23456/game/client3') 42 | # ws = ExampleClient('ws://127.0.0.1:9618/game/gd/13913359464588075') 43 | ws.connect() 44 | ws.run_forever() 45 | except KeyboardInterrupt: 46 | ws.close() 47 | -------------------------------------------------------------------------------- /wintest/ai5/action.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 12020/10/1 21:32 3 | # @Author : Duofeng Wu 4 | # @File : action.py 5 | # @Description: 动作类 6 | 7 | import copy 8 | import logging 9 | from random import randint, random 10 | 11 | from active import * 12 | from passive import * 13 | from utils import * 14 | 15 | # 中英文对照表 16 | ENG2CH = { 17 | "Single": "单张", 18 | "Pair": "对子", 19 | "Trips": "三张", 20 | "ThreePair": "三连对", 21 | "ThreeWithTwo": "三带二", 22 | "TwoTrips": "钢板", 23 | "Straight": "顺子", 24 | "StraightFlush": "同花顺", 25 | "Bomb": "炸弹", 26 | "PASS": "过" 27 | } 28 | from back_tribute import * 29 | 30 | 31 | class Action(object): 32 | 33 | def __init__(self): 34 | self.action = [] 35 | self.act_range = -1 36 | 37 | def rule_parse(self,msg,mypos,remaincards,history,remain_cards_classbynum,pass_num,my_pass_num,tribute_result): 38 | self.action = msg["actionList"] 39 | if len(self.action) == 1: 40 | return 0 41 | if msg["stage"] == "play" and msg["greaterPos"] != mypos and msg["curPos"] != -1: 42 | try: 43 | 44 | numofplayers = [history['0']["remain"],history['1']["remain"],history['2']["remain"],history['3']["remain"]] 45 | numofnext = numofplayers[(mypos + 1) % 4] 46 | if numofnext != 0: 47 | print("下家还有{}张牌".format(numofnext)) 48 | else: 49 | numofpre = numofplayers[(mypos - 1) % 4] 50 | print("下家已完牌,上家还有{}张牌".format(numofpre)) 51 | self.act = passive(self.action, msg["handCards"], msg["curRank"], msg['curAction'],msg["greaterAction"],mypos, 52 | msg["greaterPos"],remaincards, numofplayers,pass_num,my_pass_num,remain_cards_classbynum) 53 | except Exception as e: 54 | print(str(e)) 55 | self.act = 1 56 | 57 | elif msg["stage"] == "play" and (msg["greaterPos"] == -1 or msg["curPos"] == -1): 58 | print("主动出牌") 59 | try: 60 | numofplayers = [history['0']["remain"], history['1']["remain"], history['2']["remain"], 61 | history['3']["remain"]] 62 | numofnext = numofplayers[(mypos + 1) % 4] 63 | if numofnext != 0: 64 | print("下家还有{}张牌".format(numofnext)) 65 | else: 66 | numofpre = numofplayers[(mypos - 1) % 4] 67 | print("下家已完牌,上家还有{}张牌".format(numofpre)) 68 | self.act = active(self.action, msg["handCards"], msg["curRank"],numofplayers,mypos,remaincards) 69 | except Exception as e: 70 | print(e) 71 | self.act = 0 72 | elif msg["stage"] == "back": 73 | try: 74 | self.act = back_action(msg,mypos,tribute_result) 75 | except Exception as e: 76 | print(e) 77 | self.act = 0 78 | elif msg["stage"] == "tribute": 79 | try: 80 | self.act = tribute(self.action,msg["curRank"]) 81 | except Exception as e: 82 | print(e) 83 | self.act = 0 84 | else: 85 | self.act_range = msg["indexRange"] 86 | self.act = randint(0, self.act_range) 87 | try: 88 | if self.action[self.act][0]=="PASS": 89 | print("当我过的时候我可以选择哪些牌:") 90 | print(self.action) 91 | except Exception as e: 92 | print("日志打印失败") 93 | print(e) 94 | return self.act 95 | 96 | def random_parse(self,msg): 97 | self.action = msg["actionList"] 98 | self.act_range = msg["indexRange"] 99 | return randint(0,self.act_range) 100 | 101 | -------------------------------------------------------------------------------- /wintest/ai5/client.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/10/1 16:30 3 | # @Author : Duofeng Wu 4 | # @File : client.py 5 | # @Description: 6 | 7 | 8 | import json 9 | from ws4py.client.threadedclient import WebSocketClient 10 | from state import State 11 | from action import Action 12 | 13 | 14 | class ExampleClient(WebSocketClient): 15 | 16 | def __init__(self, url): 17 | super().__init__(url) 18 | self.state = State() 19 | self.action = Action() 20 | 21 | def opened(self): 22 | pass 23 | 24 | def closed(self, code, reason=None): 25 | print("Closed down", code, reason) 26 | 27 | def received_message(self, message): 28 | message = json.loads(str(message)) # 先序列化收到的消息,转为Python中的字典 29 | self.state.parse(message) # 调用状态对象来解析状态 30 | if "actionList" in message: # 需要做出动作选择时调用动作对象进行解析 31 | act_index = self.action.parse(message) 32 | self.send(json.dumps({"actIndex": act_index})) 33 | 34 | 35 | if __name__ == '__main__': 36 | try: 37 | ws = ExampleClient('ws://127.0.0.1:23456/game/client') 38 | ws.connect() 39 | ws.run_forever() 40 | except KeyboardInterrupt: 41 | ws.close() 42 | -------------------------------------------------------------------------------- /wintest/ai5/client1.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/10/1 16:30 3 | # @Author : Duofeng Wu 4 | # @File : client.py 5 | # @Description: 6 | 7 | 8 | import json 9 | from ws4py.client.threadedclient import WebSocketClient 10 | from state import State 11 | from action import Action 12 | 13 | 14 | class ExampleClient(WebSocketClient): 15 | 16 | def __init__(self, url): 17 | super().__init__(url) 18 | self.state = State() 19 | self.action = Action() 20 | 21 | def opened(self): 22 | pass 23 | 24 | def closed(self, code, reason=None): 25 | print("Closed down", code, reason) 26 | 27 | def received_message(self, message): 28 | message = json.loads(str(message)) # 先序列化收到的消息,转为Python中的字典 29 | self.state.parse(message) # 调用状态对象来解析状态 30 | if "actionList" in message: # 需要做出动作选择时调用动作对象进行解析 31 | act_index = self.action.rule_parse(message,self.state._myPos,self.state.remain_cards,self.state.history, 32 | self.state.remain_cards_classbynum,self.state.pass_num, 33 | self.state.my_pass_num,self.state.tribute_result) 34 | 35 | print(act_index) 36 | self.send(json.dumps({"actIndex": act_index})) 37 | 38 | if __name__ == '__main__': 39 | try: 40 | ws = ExampleClient('ws://127.0.0.1:23456/game/client0') 41 | ws.connect() 42 | ws.run_forever() 43 | except KeyboardInterrupt: 44 | ws.close() 45 | -------------------------------------------------------------------------------- /wintest/ai5/client2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/10/1 16:30 3 | # @Author : Duofeng Wu 4 | # @File : client.py 5 | # @Description: 6 | 7 | 8 | import json 9 | from ws4py.client.threadedclient import WebSocketClient 10 | from state import State 11 | from action import Action 12 | 13 | 14 | class ExampleClient(WebSocketClient): 15 | 16 | def __init__(self, url): 17 | super().__init__(url) 18 | self.state = State() 19 | self.action = Action() 20 | 21 | def opened(self): 22 | pass 23 | 24 | def closed(self, code, reason=None): 25 | print("Closed down", code, reason) 26 | 27 | def received_message(self, message): 28 | message = json.loads(str(message)) # 先序列化收到的消息,转为Python中的字典 29 | self.state.parse(message) # 调用状态对象来解析状态 30 | if "actionList" in message: # 需要做出动作选择时调用动作对象进行解析 31 | act_index = self.action.rule_parse(message,self.state._myPos,self.state.remain_cards,self.state.history, 32 | self.state.remain_cards_classbynum,self.state.pass_num, 33 | self.state.my_pass_num,self.state.tribute_result) 34 | print(act_index) 35 | self.send(json.dumps({"actIndex": act_index})) 36 | 37 | 38 | if __name__ == '__main__': 39 | try: 40 | # ws = ExampleClient('ws://127.0.0.1:80/game/gd/13913359464455242') 41 | ws = ExampleClient('ws://127.0.0.1:23456/game/client1') 42 | ws.connect() 43 | ws.run_forever() 44 | except KeyboardInterrupt: 45 | ws.close() 46 | -------------------------------------------------------------------------------- /wintest/ai5/client3.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/10/1 16:30 3 | # @Author : Duofeng Wu 4 | # @File : client.py 5 | # @Description: 6 | 7 | 8 | import json 9 | 10 | from ws4py.client.threadedclient import WebSocketClient 11 | 12 | from action import Action 13 | from state import State 14 | 15 | 16 | class ExampleClient(WebSocketClient): 17 | 18 | def __init__(self, url): 19 | super().__init__(url) 20 | self.state = State() 21 | self.action = Action() 22 | 23 | def opened(self): 24 | pass 25 | 26 | def closed(self, code, reason=None): 27 | print("Closed down", code, reason) 28 | 29 | def received_message(self, message): 30 | message = json.loads(str(message)) # 先序列化收到的消息,转为Python中的字典 31 | self.state.parse(message) # 调用状态对象来解析状态 32 | if "actionList" in message: # 需要做出动作选择时调用动作对象进行解析 33 | act_index = self.action.rule_parse(message,self.state._myPos,self.state.remain_cards,self.state.history, 34 | self.state.remain_cards_classbynum,self.state.pass_num, 35 | self.state.my_pass_num,self.state.tribute_result) 36 | print(act_index) 37 | self.send(json.dumps({"actIndex": act_index})) 38 | 39 | 40 | if __name__ == '__main__': 41 | try: 42 | ws = ExampleClient('ws://127.0.0.1:23456/game/client2') 43 | ws.connect() 44 | ws.run_forever() 45 | except KeyboardInterrupt: 46 | ws.close() 47 | -------------------------------------------------------------------------------- /wintest/ai5/client4.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/10/1 16:30 3 | # @Author : Duofeng Wu 4 | # @File : client.py 5 | # @Description: 6 | 7 | 8 | import json 9 | from ws4py.client.threadedclient import WebSocketClient 10 | from state import State 11 | from action import Action 12 | 13 | 14 | class ExampleClient(WebSocketClient): 15 | 16 | def __init__(self, url): 17 | super().__init__(url) 18 | self.state = State() 19 | self.action = Action() 20 | 21 | def opened(self): 22 | pass 23 | 24 | def closed(self, code, reason=None): 25 | print("Closed down", code, reason) 26 | 27 | def received_message(self, message): 28 | message = json.loads(str(message)) # 先序列化收到的消息,转为Python中的字典 29 | self.state.parse(message) # 调用状态对象来解析状态 30 | if "actionList" in message: 31 | # act_index = self.action.random_parse(message)# 需要做出动作选择时调用动作对象进行解析 32 | act_index = self.action.rule_parse(message,self.state._myPos,self.state.remain_cards,self.state.history, 33 | self.state.remain_cards_classbynum,self.state.pass_num, 34 | self.state.my_pass_num,self.state.tribute_result) 35 | self.send(json.dumps({"actIndex": act_index})) 36 | 37 | 38 | if __name__ == '__main__': 39 | try: 40 | #ws = ExampleClient('ws://114.55.107.187:9618/game/gd/client3') 41 | ws = ExampleClient('ws://127.0.0.1:23456/game/client3') 42 | # ws = ExampleClient('ws://112.124.24.226:80/game/gd/13913359464588075') 43 | ws.connect() 44 | ws.run_forever() 45 | except KeyboardInterrupt: 46 | ws.close() 47 | -------------------------------------------------------------------------------- /wintest/ai6/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/wintest/ai6/__init__.py -------------------------------------------------------------------------------- /wintest/ai6/client1.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/10/1 16:30 3 | # @Author : Duofeng Wu 4 | # @File : client.py 5 | # @Description: 6 | 7 | 8 | import json 9 | 10 | from ws4py.client.threadedclient import WebSocketClient 11 | 12 | from action import Action 13 | from state import State 14 | 15 | 16 | class ExampleClient(WebSocketClient): 17 | 18 | def __init__(self, url): 19 | super().__init__(url) 20 | self.state = State() 21 | self.action = Action() 22 | 23 | def opened(self): 24 | pass 25 | 26 | def closed(self, code, reason=None): 27 | print("Closed down", code, reason) 28 | 29 | def received_message(self, message): 30 | message = json.loads(str(message)) # 先序列化收到的消息,转为Python中的字典 31 | self.state.parse(message) # 调用状态对象来解析状态 32 | # print("我的主机号",message['myPos']) 33 | # if (message['type'] == 'notify'): 34 | # myclient = self.state.notify_play() 35 | if "actionList" in message: # 需要做出动作选择时调用动作对象进行解析 36 | act_index = self.action.parse(message) 37 | self.send(json.dumps({"actIndex": act_index})) 38 | 39 | if __name__ == '__main__': 40 | try: 41 | # ws = ExampleClient('ws://39.108.189.48:80/game/gd/18762111338605314') 42 | ws = ExampleClient('ws://127.0.0.1:23456/game/client0') 43 | # ws = ExampleClient('ws://127.0.0.1:9618/game/gd/client1') 44 | ws.connect() 45 | ws.run_forever() 46 | except KeyboardInterrupt: 47 | ws.close() 48 | -------------------------------------------------------------------------------- /wintest/ai6/client2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/10/1 16:30 3 | # @Author : Duofeng Wu 4 | # @File : client.py 5 | # @Description: 6 | 7 | 8 | import json 9 | from ws4py.client.threadedclient import WebSocketClient 10 | from state2 import State 11 | from action2 import Action 12 | 13 | 14 | class ExampleClient(WebSocketClient): 15 | 16 | def __init__(self, url): 17 | super().__init__(url) 18 | self.state = State() 19 | self.action = Action() 20 | 21 | def opened(self): 22 | pass 23 | 24 | def closed(self, code, reason=None): 25 | print("Closed down", code, reason) 26 | 27 | def received_message(self, message): 28 | message = json.loads(str(message)) # 先序列化收到的消息,转为Python中的字典 29 | self.state.parse(message) # 调用状态对象来解析状态 30 | if "actionList" in message: # 需要做出动作选择时调用动作对象进行解析 31 | act_index = self.action.parse(message) 32 | self.send(json.dumps({"actIndex": act_index})) 33 | 34 | 35 | if __name__ == '__main__': 36 | try: 37 | ws = ExampleClient('ws://127.0.0.1:23456/game/client1') 38 | ws.connect() 39 | ws.run_forever() 40 | except KeyboardInterrupt: 41 | ws.close() 42 | -------------------------------------------------------------------------------- /wintest/ai6/client3.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/10/1 16:30 3 | # @Author : Duofeng Wu 4 | # @File : client.py 5 | # @Description: 6 | 7 | 8 | import json 9 | 10 | from ws4py.client.threadedclient import WebSocketClient 11 | 12 | from action3 import Action 13 | from state3 import State 14 | 15 | 16 | class ExampleClient(WebSocketClient): 17 | 18 | def __init__(self, url): 19 | super().__init__(url) 20 | self.state = State() 21 | self.action = Action() 22 | 23 | def opened(self): 24 | pass 25 | 26 | def closed(self, code, reason=None): 27 | print("Closed down", code, reason) 28 | 29 | def received_message(self, message): 30 | message = json.loads(str(message)) # 先序列化收到的消息,转为Python中的字典 31 | self.state.parse(message) # 调用状态对象来解析状态 32 | if "actionList" in message: # 需要做出动作选择时调用动作对象进行解析 33 | act_index = self.action.parse(message) 34 | self.send(json.dumps({"actIndex": act_index})) 35 | 36 | 37 | if __name__ == '__main__': 38 | try: 39 | # ws = ExampleClient('ws://39.108.189.48:80/game/gd/18762111338284605') 40 | # # ws = ExampleClient('ws://114.55.107.187:23456/game/18762111338605314') 41 | # ws = ExampleClient('ws://112.124.24.226:80/game/gd/17550225823901100') 42 | ws = ExampleClient('ws://127.0.0.1:23456/game/client2') 43 | ws.connect() 44 | ws.run_forever() 45 | except KeyboardInterrupt: 46 | ws.close() 47 | -------------------------------------------------------------------------------- /wintest/ai6/client4.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/10/1 16:30 3 | # @Author : Duofeng Wu 4 | # @File : client.py 5 | # @Description: 6 | 7 | 8 | import json 9 | from ws4py.client.threadedclient import WebSocketClient 10 | from state4 import State 11 | from action4 import Action 12 | 13 | 14 | class ExampleClient(WebSocketClient): 15 | 16 | def __init__(self, url): 17 | super().__init__(url) 18 | self.state = State() 19 | self.action = Action() 20 | 21 | def opened(self): 22 | pass 23 | 24 | def closed(self, code, reason=None): 25 | print("Closed down", code, reason) 26 | 27 | def received_message(self, message): 28 | message = json.loads(str(message)) # 先序列化收到的消息,转为Python中的字典 29 | self.state.parse(message) # 调用状态对象来解析状态 30 | if "actionList" in message: # 需要做出动作选择时调用动作对象进行解析 31 | act_index = self.action.parse(message) 32 | self.send(json.dumps({"actIndex": act_index})) 33 | 34 | 35 | if __name__ == '__main__': 36 | try: 37 | ws = ExampleClient('ws://127.0.0.1:23456/game/client3') 38 | ws.connect() 39 | ws.run_forever() 40 | except KeyboardInterrupt: 41 | ws.close() 42 | -------------------------------------------------------------------------------- /wintest/ai6/data1.txt: -------------------------------------------------------------------------------- 1 | 0 -------------------------------------------------------------------------------- /wintest/ai6/data2.txt: -------------------------------------------------------------------------------- 1 | 1 -------------------------------------------------------------------------------- /wintest/ai6/data3.txt: -------------------------------------------------------------------------------- 1 | 2 -------------------------------------------------------------------------------- /wintest/ai6/data4.txt: -------------------------------------------------------------------------------- 1 | 3 -------------------------------------------------------------------------------- /wintest/ai7/action.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/10/19 19:30 3 | # @Author : Duofeng Wu && Zenghui Qian 4 | # @File : action.py 5 | # @Description: 动作类 6 | 7 | from random import randint 8 | 9 | from mysolve import solve 10 | 11 | # 中英文对照表 12 | ENG2CH = { 13 | "Single": "单张", 14 | "Pair": "对子", 15 | "Trips": "三张", 16 | "ThreePair": "三连对", 17 | "ThreeWithTwo": "三带二", 18 | "TwoTrips": "钢板", 19 | "Straight": "顺子", 20 | "StraightFlush": "同花顺", 21 | "Bomb": "炸弹", 22 | "PASS": "过" 23 | } 24 | 25 | 26 | class Action(object): 27 | 28 | def __init__(self): 29 | self.action = [] 30 | self.act_range = -1 31 | 32 | def parse(self, msg, mate_pos): # 增加了一个新参数mate_pos,表示队友的位置 33 | self.action = msg["actionList"] 34 | self.act_range = msg["indexRange"] 35 | print(self.action) 36 | print("可选动作范围为:0至{}".format(self.act_range)) 37 | index = solve(msg, mate_pos) 38 | return index 39 | -------------------------------------------------------------------------------- /wintest/ai7/client1.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/10/19 19:30 3 | # @Author : Duofeng Wu && Zenghui Qian 4 | # @File : client.py 5 | # @Description: 6 | 7 | 8 | import json 9 | 10 | from ws4py.client.threadedclient import WebSocketClient 11 | 12 | from action import Action 13 | from state import State 14 | 15 | 16 | class ExampleClient(WebSocketClient): 17 | 18 | def __init__(self, url): 19 | super().__init__(url) 20 | self.state = State() 21 | self.action = Action() 22 | self.my_pos = -1 # 增加了一个属性,用来记录自己的位置 23 | self.mate_pos = -1 # 增加了一个属性,用来记录队友的位置 24 | 25 | def opened(self): 26 | pass 27 | 28 | def closed(self, code, reason=None): 29 | print("Closed down", code, reason) 30 | 31 | def received_message(self, message): 32 | message = json.loads(str(message)) # 先序列化收到的消息,转为Python中的字典 33 | self.state.parse(message) 34 | if message["stage"] == "beginning": # 先从beginning阶段获取自己的位置 35 | self.my_pos = message["myPos"] # 根据自己的位置推断出队友的位置 36 | self.mate_pos = (self.my_pos+2) % 4 37 | 38 | if "actionList" in message: # 需要做出动作选择时调用动作对象进行解析 39 | act_index = self.action.parse(message, self.mate_pos) 40 | # 在action.parse中增加了一个新参数,传入队友的位置 41 | 42 | # act_index = self.action.parse(message, -1) # 传入-1时默认为“笨笨”操作,不会调用算法 43 | self.send(json.dumps({"actIndex": act_index})) 44 | 45 | 46 | if __name__ == '__main__': 47 | try: 48 | ws = ExampleClient('ws://127.0.0.1:23456/game/client0') 49 | # ws = ExampleClient('ws://112.124.24.226:80/game/gd/19852273119160234') 50 | # ws = ExampleClient('ws://112.124.24.226:80/game/gd/19852273119160234') 51 | ws.connect() 52 | ws.run_forever() 53 | except KeyboardInterrupt: 54 | ws.close() 55 | -------------------------------------------------------------------------------- /wintest/ai7/client2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/10/19 19:30 3 | # @Author : Duofeng Wu && Zenghui Qian 4 | # @File : client.py 5 | # @Description: 6 | 7 | 8 | import json 9 | 10 | from ws4py.client.threadedclient import WebSocketClient 11 | 12 | from action import Action 13 | from state import State 14 | 15 | 16 | class ExampleClient(WebSocketClient): 17 | 18 | def __init__(self, url): 19 | super().__init__(url) 20 | self.state = State() 21 | self.action = Action() 22 | self.my_pos = -1 # 增加了一个属性,用来记录自己的位置 23 | self.mate_pos = -1 # 增加了一个属性,用来记录队友的位置 24 | 25 | def opened(self): 26 | pass 27 | 28 | def closed(self, code, reason=None): 29 | print("Closed down", code, reason) 30 | 31 | def received_message(self, message): 32 | message = json.loads(str(message)) # 先序列化收到的消息,转为Python中的字典 33 | self.state.parse(message) 34 | if message["stage"] == "beginning": # 先从beginning阶段获取自己的位置 35 | self.my_pos = message["myPos"] # 根据自己的位置推断出队友的位置 36 | self.mate_pos = (self.my_pos+2) % 4 37 | 38 | if "actionList" in message: # 需要做出动作选择时调用动作对象进行解析 39 | act_index = self.action.parse(message, self.mate_pos) 40 | # 在action.parse中增加了一个新参数,传入队友的位置 41 | 42 | # act_index = self.action.parse(message, -1) # 传入-1时默认为“笨笨”操作,不会调用算法 43 | self.send(json.dumps({"actIndex": act_index})) 44 | 45 | 46 | if __name__ == '__main__': 47 | try: 48 | ws = ExampleClient('ws://127.0.0.1:23456/game/client1') 49 | # ws = ExampleClient('ws://112.124.24.226:80/game/gd/19852273119160234') 50 | # ws = ExampleClient('ws://112.124.24.226:80/game/gd/19852273119160234') 51 | ws.connect() 52 | ws.run_forever() 53 | except KeyboardInterrupt: 54 | ws.close() 55 | -------------------------------------------------------------------------------- /wintest/ai7/client3.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/10/19 19:30 3 | # @Author : Duofeng Wu && Zenghui Qian 4 | # @File : client.py 5 | # @Description: 6 | 7 | 8 | import json 9 | 10 | from ws4py.client.threadedclient import WebSocketClient 11 | 12 | from action import Action 13 | from state import State 14 | 15 | 16 | class ExampleClient(WebSocketClient): 17 | 18 | def __init__(self, url): 19 | super().__init__(url) 20 | self.state = State() 21 | self.action = Action() 22 | self.my_pos = -1 # 增加了一个属性,用来记录自己的位置 23 | self.mate_pos = -1 # 增加了一个属性,用来记录队友的位置 24 | 25 | def opened(self): 26 | pass 27 | 28 | def closed(self, code, reason=None): 29 | print("Closed down", code, reason) 30 | 31 | def received_message(self, message): 32 | message = json.loads(str(message)) # 先序列化收到的消息,转为Python中的字典 33 | self.state.parse(message) 34 | if message["stage"] == "beginning": # 先从beginning阶段获取自己的位置 35 | self.my_pos = message["myPos"] # 根据自己的位置推断出队友的位置 36 | self.mate_pos = (self.my_pos+2) % 4 37 | 38 | if "actionList" in message: # 需要做出动作选择时调用动作对象进行解析 39 | act_index = self.action.parse(message, self.mate_pos) 40 | # 在action.parse中增加了一个新参数,传入队友的位置 41 | 42 | # act_index = self.action.parse(message, -1) # 传入-1时默认为“笨笨”操作,不会调用算法 43 | self.send(json.dumps({"actIndex": act_index})) 44 | 45 | 46 | if __name__ == '__main__': 47 | try: 48 | ws = ExampleClient('ws://127.0.0.1:23456/game/client2') 49 | # ws = ExampleClient('ws://112.124.24.226:80/game/gd/19852273119160234') 50 | # ws = ExampleClient('ws://112.124.24.226:80/game/gd/19852273119160234') 51 | ws.connect() 52 | ws.run_forever() 53 | except KeyboardInterrupt: 54 | ws.close() 55 | -------------------------------------------------------------------------------- /wintest/ai7/client4.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/10/19 19:30 3 | # @Author : Duofeng Wu && Zenghui Qian 4 | # @File : client.py 5 | # @Description: 6 | 7 | 8 | import json 9 | 10 | from ws4py.client.threadedclient import WebSocketClient 11 | 12 | from action import Action 13 | from state import State 14 | 15 | 16 | class ExampleClient(WebSocketClient): 17 | 18 | def __init__(self, url): 19 | super().__init__(url) 20 | self.state = State() 21 | self.action = Action() 22 | self.my_pos = -1 # 增加了一个属性,用来记录自己的位置 23 | self.mate_pos = -1 # 增加了一个属性,用来记录队友的位置 24 | 25 | def opened(self): 26 | pass 27 | 28 | def closed(self, code, reason=None): 29 | print("Closed down", code, reason) 30 | 31 | def received_message(self, message): 32 | message = json.loads(str(message)) # 先序列化收到的消息,转为Python中的字典 33 | self.state.parse(message) 34 | if message["stage"] == "beginning": # 先从beginning阶段获取自己的位置 35 | self.my_pos = message["myPos"] # 根据自己的位置推断出队友的位置 36 | self.mate_pos = (self.my_pos+2) % 4 37 | 38 | if "actionList" in message: # 需要做出动作选择时调用动作对象进行解析 39 | act_index = self.action.parse(message, self.mate_pos) 40 | # 在action.parse中增加了一个新参数,传入队友的位置 41 | 42 | # act_index = self.action.parse(message, -1) # 传入-1时默认为“笨笨”操作,不会调用算法 43 | self.send(json.dumps({"actIndex": act_index})) 44 | 45 | 46 | if __name__ == '__main__': 47 | try: 48 | ws = ExampleClient('ws://127.0.0.1:23456/game/client3') 49 | # ws = ExampleClient('ws://112.124.24.226:80/game/gd/19852273119160234') 50 | # ws = ExampleClient('ws://112.124.24.226:80/game/gd/19852273119160234') 51 | ws.connect() 52 | ws.run_forever() 53 | except KeyboardInterrupt: 54 | ws.close() 55 | -------------------------------------------------------------------------------- /wintest/ai8/action.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/10/1 21:32 3 | # @Author : Duofeng Wu 4 | # @File : action.py 5 | # @Description: 动作类 6 | 7 | from random import randint 8 | 9 | # 中英文对照表 10 | ENG2CH = { 11 | "Single": "单张", 12 | "Pair": "对子", 13 | "Trips": "三张", 14 | "ThreePair": "三连对", 15 | "ThreeWithTwo": "三带二", 16 | "TwoTrips": "钢板", 17 | "Straight": "顺子", 18 | "StraightFlush": "同花顺", 19 | "Bomb": "炸弹", 20 | "PASS": "过" 21 | } 22 | 23 | 24 | class Action(object): 25 | 26 | def __init__(self): 27 | self.action = [] 28 | self.act_range = -1 29 | 30 | def parse(self, msg): 31 | self.action = msg["actionList"] 32 | self.act_range = msg["indexRange"] 33 | # print(self.action) 34 | print("可选动作范围为:0至{}".format(self.act_range)) 35 | return randint(0, self.act_range) 36 | -------------------------------------------------------------------------------- /wintest/create_container.sh: -------------------------------------------------------------------------------- 1 | for i in {50..52} 2 | do 3 | docker run -itd --network=guandanNet --ip 172.15.15.$i --name guandan_actor_$i -v /home/zhaoyp/guandan_tog:/home/zhaoyp/guandan_tog -w /home/zhaoyp/guandan_tog guandan_actor:v5 /bin/bash 4 | done 5 | -------------------------------------------------------------------------------- /wintest/danzero/__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/wintest/danzero/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /wintest/danzero/__pycache__/util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/wintest/danzero/__pycache__/util.cpython-36.pyc -------------------------------------------------------------------------------- /wintest/danzero/actor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import time 4 | from argparse import ArgumentParser 5 | from multiprocessing import Process 6 | 7 | import numpy as np 8 | import tensorflow as tf 9 | import zmq 10 | from pyarrow import deserialize, serialize 11 | from tensorflow.keras.backend import set_session 12 | 13 | from model import GDModel 14 | 15 | parser = ArgumentParser() 16 | parser.add_argument('--observation_space', type=int, default=(567, ), 17 | help='The YAML configuration file') 18 | parser.add_argument('--action_space', type=int, default=(5, 216), 19 | help='The YAML configuration file') 20 | parser.add_argument('--model', type=str, default='../model/', 21 | help='The YAML configuration file') 22 | 23 | class Player(): 24 | def __init__(self, args) -> None: 25 | # Set 'allow_growth' 26 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = '3' 27 | tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) 28 | config = tf.ConfigProto() 29 | config.gpu_options.allow_growth = True 30 | set_session(tf.Session(config=config)) 31 | 32 | # 数据初始化 33 | self.args = args 34 | self.init_time = time.time() 35 | 36 | # 模型初始化 37 | self.model = GDModel(args.observation_space, (5, 216)) 38 | with open('./q_network.ckpt', 'rb') as f: 39 | new_weights = pickle.load(f) 40 | self.model.set_weights(new_weights) 41 | 42 | def sample(self, state) -> int: 43 | output = self.model.forward(state['x_batch']) 44 | action_idx = np.argmax(output) 45 | return action_idx 46 | 47 | 48 | def run_one_player(index, args): 49 | player = Player(args) 50 | 51 | # 初始化zmq 52 | context = zmq.Context() 53 | socket = context.socket(zmq.REP) 54 | socket.bind(f'tcp://*:{6000+index}') 55 | 56 | action_index = 0 57 | while True: 58 | state = deserialize(socket.recv()) 59 | action_index = player.sample(state) 60 | # print(f'actor{index} do action number {action_index}') 61 | socket.send(serialize(action_index).to_buffer()) 62 | 63 | 64 | def main(): 65 | # 参数传递 66 | args, _ = parser.parse_known_args() 67 | 68 | def exit_wrapper(index, *x, **kw): 69 | """Exit all actors on KeyboardInterrupt (Ctrl-C)""" 70 | try: 71 | run_one_player(index, *x, **kw) 72 | except KeyboardInterrupt: 73 | if index == 0: 74 | for _i, _p in enumerate(players): 75 | if _i != index: 76 | _p.terminate() 77 | 78 | players = [] 79 | for i in [0, 2]: 80 | print(f'start{i}') 81 | p = Process(target=exit_wrapper, args=(i, args)) 82 | p.start() 83 | time.sleep(0.5) 84 | players.append(p) 85 | 86 | for player in players: 87 | player.join() 88 | 89 | if __name__ == '__main__': 90 | main() 91 | -------------------------------------------------------------------------------- /wintest/danzero/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from tensorflow.keras.backend import get_session 4 | 5 | 6 | def mlp(x, hidden_sizes=(32,), activation=tf.tanh, output_activation=None): 7 | for h in hidden_sizes[:-1]: 8 | x = tf.layers.dense(x, units=h, activation=activation) 9 | return tf.layers.dense(x, units=hidden_sizes[-1], activation=output_activation) 10 | 11 | 12 | def placeholder(dtype=tf.float32, shape=None): 13 | return tf.placeholder(dtype=dtype, shape=combined_shape(None, shape)) 14 | 15 | 16 | def combined_shape(length, shape=None): 17 | if shape is None: 18 | return (length,) 19 | return (length, shape) if np.isscalar(shape) else (length, *shape) 20 | 21 | 22 | class GDModel(): 23 | def __init__(self, observation_space, action_space, config=None, model_id='0', *args, **kwargs): 24 | with tf.variable_scope(model_id): 25 | self.x_ph = placeholder(shape=observation_space) 26 | 27 | # 输出张量 28 | self.values = None 29 | self.scope = model_id 30 | 31 | # Initialize Tensorflow session 32 | self.sess = get_session() 33 | self.observation_space = observation_space 34 | self.action_space = action_space 35 | self.model_id = model_id 36 | self.config = config 37 | 38 | # 2. Build up model 39 | self.build() 40 | 41 | # Build assignment ops 42 | self._weight_ph = None 43 | self._to_assign = None 44 | self._nodes = None 45 | self._build_assign() 46 | 47 | # 参数初始化 48 | self.sess.run(tf.global_variables_initializer()) 49 | 50 | 51 | def set_weights(self, weights) -> None: 52 | feed_dict = {self._weight_ph[var.name]: weight 53 | for (var, weight) in zip(tf.trainable_variables(scope=self.scope), weights)} 54 | self.sess.run(self._nodes, feed_dict=feed_dict) 55 | 56 | def _build_assign(self): 57 | self._weight_ph, self._to_assign = dict(), dict() 58 | variables = tf.trainable_variables(self.scope) 59 | for var in variables: 60 | self._weight_ph[var.name] = tf.placeholder(var.value().dtype, var.get_shape().as_list()) 61 | self._to_assign[var.name] = var.assign(self._weight_ph[var.name]) 62 | self._nodes = list(self._to_assign.values()) 63 | 64 | def forward(self, x_batch): 65 | return self.sess.run(self.values, feed_dict={self.x_ph: x_batch}) 66 | 67 | def build(self) -> None: 68 | with tf.variable_scope(self.scope): 69 | with tf.variable_scope('v'): 70 | self.values = mlp(self.x_ph, [512, 512, 512, 512, 512, 1], activation='tanh', 71 | output_activation=None) 72 | -------------------------------------------------------------------------------- /wintest/danzero/q_network.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/wintest/danzero/q_network.ckpt -------------------------------------------------------------------------------- /wintest/readme.md: -------------------------------------------------------------------------------- 1 | We provide the evaluation code in this directory. 2 | For the rule-based bots, they can be used for the four positions. 3 | For the DMC model, they need to be deployed at player0 and player2 while the PPO model is deployed at player1 and player3. 4 | 5 | In ./torch, we give the evaluation shell file, which you can follow to conduct needed evaluation. Here we also give an 6 | example to show how to evaluate the model during the training process, that you can refer to the evaluate_xxx.py. 7 | Because the interval between models to be saved is almost the same, you can adjust the model id to get the checkpoints you want. 8 | 9 | Here we give an example to conduct the evaluation. After copying the tested models to the target dir, you can just 10 | execute the command "bash testmodel.sh xx", where xx is the model_id. How it is set can be referred to in ./torch/actor.py. 11 | If you want to execute the evaluation during training, maybe you can run "nohup python -u evaluate_xx.py > xx.log &". 12 | In this way, you can just see the log file to see how many models have been tested. 13 | -------------------------------------------------------------------------------- /wintest/torch/__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/wintest/torch/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /wintest/torch/__pycache__/util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/wintest/torch/__pycache__/util.cpython-36.pyc -------------------------------------------------------------------------------- /wintest/torch/actor.py: -------------------------------------------------------------------------------- 1 | import time 2 | from argparse import ArgumentParser 3 | from multiprocessing import Process 4 | from random import randint 5 | 6 | import numpy as np 7 | import zmq 8 | import pickle 9 | import torch 10 | import io 11 | from model import MLPActorCritic, MLPQNetwork 12 | from pyarrow import deserialize, serialize 13 | 14 | ActionNumber = 2 15 | 16 | parser = ArgumentParser() 17 | parser.add_argument('--ip', type=str, default='172.15.15.2', 18 | help='IP address of learner server') 19 | parser.add_argument('--data_port', type=int, default=5000, 20 | help='Learner server port to send training data') 21 | parser.add_argument('--param_port', type=int, default=5001, 22 | help='Learner server port to subscribe model parameters') 23 | parser.add_argument('--exp_path', type=str, default='/home/root/log', 24 | help='Directory to save logging data, model parameters and config file') 25 | parser.add_argument('--num_saved_ckpt', type=int, default=4, 26 | help='Number of recent checkpoint files to be saved') 27 | parser.add_argument('--observation_space', type=int, default=(567,), 28 | help='The YAML configuration file') 29 | parser.add_argument('--action_space', type=int, default=(5, 216), 30 | help='The YAML configuration file') 31 | parser.add_argument('--epsilon', type=float, default=0.01, 32 | help='Epsilon') 33 | parser.add_argument('--iter', type=int, default=0, 34 | help='update steps for the tested model') 35 | 36 | class CPU_Unpickler(pickle.Unpickler): 37 | def find_class(self, module, name): 38 | if module == 'torch.storage' and name == '_load_from_bytes': 39 | return lambda b: torch.load(io.BytesIO(b), map_location='cpu') 40 | else: return super().find_class(module, name) 41 | 42 | class Player(): 43 | def __init__(self, args) -> None: 44 | # 模型初始化 45 | self.model_id = args.iter * 2000 + 500 46 | self.model = MLPActorCritic((ActionNumber, 516+ActionNumber * 54), ActionNumber) 47 | with open('./models/ppo{}.pth'.format(self.model_id), 'rb') as f: 48 | new_weights = CPU_Unpickler(f).load() 49 | print('load model:', self.model_id) 50 | self.model.set_weights(new_weights) 51 | self.model_q = MLPQNetwork(567) 52 | with open('./q_network.ckpt', 'rb') as f: 53 | tf_weights = pickle.load(f) 54 | self.model_q.load_tf_weights(tf_weights) 55 | 56 | def sample(self, state) -> int: 57 | states = state['x_batch'] 58 | legal_action = ActionNumber 59 | legal_index = np.ones(ActionNumber) 60 | state_no_action = state['x_no_action'] 61 | if len(states) >= ActionNumber: 62 | indexs = self.model_q.get_max_n_index(states, ActionNumber) 63 | dqn_states = np.asarray(states[indexs]) 64 | top_actions = dqn_states[:, -54:].flatten() 65 | states = np.concatenate((state_no_action, top_actions)) 66 | 67 | elif len(states) < ActionNumber: 68 | legal_action = len(states) 69 | legal_index[legal_action:] = np.zeros(ActionNumber-legal_action) 70 | top_indexs = self.model_q.get_max_n_index(states, ActionNumber) 71 | dqn_states = np.asarray(states[top_indexs]) 72 | top_actions = dqn_states[:,-54:].flatten() 73 | states = np.concatenate((state_no_action, top_actions)) # 把动作先添加进来 74 | supple = np.zeros(54 * (ActionNumber - legal_action)) 75 | states = np.concatenate((states,supple)) 76 | indexs = list(range(ActionNumber)) 77 | 78 | action = self.model.step(states, legal_index) 79 | return indexs[action] 80 | 81 | 82 | def run_one_player(index, args): 83 | player = Player(args) 84 | 85 | # 初始化zmq 86 | context = zmq.Context() 87 | socket = context.socket(zmq.REP) 88 | socket.bind(f'tcp://*:{6000+index}') 89 | 90 | action_index = 0 91 | while True: 92 | state = deserialize(socket.recv()) 93 | action_index = player.sample(state) 94 | # print(f'actor{index} do action number {action_index}') 95 | socket.send(serialize(action_index).to_buffer()) 96 | 97 | 98 | def main(): 99 | # 参数传递 100 | args, _ = parser.parse_known_args() 101 | 102 | def exit_wrapper(index, *x, **kw): 103 | """Exit all actors on KeyboardInterrupt (Ctrl-C)""" 104 | try: 105 | run_one_player(index, *x, **kw) 106 | except KeyboardInterrupt: 107 | if index == 0: 108 | for _i, _p in enumerate(players): 109 | if _i != index: 110 | _p.terminate() 111 | 112 | players = [] 113 | for i in [1, 3]: 114 | # print(f'start{i}') 115 | p = Process(target=exit_wrapper, args=(i, args)) 116 | p.start() 117 | time.sleep(0.5) 118 | players.append(p) 119 | 120 | for player in players: 121 | player.join() 122 | 123 | 124 | if __name__ == '__main__': 125 | main() 126 | -------------------------------------------------------------------------------- /wintest/torch/danserver: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/wintest/torch/danserver -------------------------------------------------------------------------------- /wintest/torch/evaluate_dqn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | 5 | def copy(): 6 | path = '/home/zhaoyp/guandan_tog/learner_torch/LEARNER-2023-11-08-15-38-36/ckpt/' 7 | files = os.listdir(path) 8 | dic = {} 9 | for name in files: 10 | form = name.split('.')[0] 11 | num = int(form[3:]) - 500 12 | if num > 0 and num % 5000 == 0: 13 | k = num // 5000 14 | dic[k] = path + name 15 | 16 | dest = '/home/zhaoyp/guandan_tog/wintest/torch/models' 17 | exist = os.listdir(dest) 18 | for k, v in dic.items(): 19 | if v.split('/')[-1] not in exist: 20 | os.system('cp ' + v + ' '+ dest) 21 | res = os.listdir(dest) 22 | return len(res) 23 | 24 | 25 | def current_log(oppo): 26 | path = '/home/zhaoyp/guandan_tog/wintest/torch/' 27 | files = os.listdir(path) 28 | tested = [] 29 | for name in files: 30 | latter = oppo + '.log' 31 | if latter in name and 'res' in name: 32 | val = int(name.split('v')[0][3:]) 33 | tested.append(val) 34 | return tested 35 | 36 | 37 | def check(num, oppo): 38 | tested = current_log(oppo) 39 | if num not in tested: 40 | return False 41 | else: 42 | return True 43 | 44 | 45 | oppo = 'dqn' 46 | while True: 47 | flag = 0 48 | nums = copy() 49 | time.sleep(10) 50 | tested = current_log(oppo) 51 | time.sleep(10) 52 | print(nums, tested) 53 | for i in range(1, nums+1): 54 | if i not in tested: 55 | flag = 1 56 | break 57 | if flag == 1: 58 | os.system('bash testvsdqn.sh ' + str(i)) 59 | print('testing {}'.format(i)) 60 | time.sleep(10) 61 | res = check(i, oppo) 62 | while not res: 63 | time.sleep(300) 64 | res = check(i, oppo) 65 | os.system('bash kill_auto.sh') 66 | print('model index {} test finish'.format(i)) 67 | time.sleep(10) 68 | else: 69 | time.sleep(600) 70 | -------------------------------------------------------------------------------- /wintest/torch/evaluate_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | 5 | def copy(): 6 | path = '/home/zhaoyp/guandan_tog/learner_torch/LEARNER-2023-11-08-15-38-36/ckpt/' 7 | files = os.listdir(path) 8 | dic = {} 9 | for name in files: 10 | form = name.split('.')[0] 11 | num = int(form[3:]) - 500 12 | if num > 0 and num % 5000 == 0: 13 | k = num // 5000 14 | dic[k] = path + name 15 | 16 | dest = '/home/zhaoyp/guandan_tog/wintest/torch/models' 17 | exist = os.listdir(dest) 18 | for k, v in dic.items(): 19 | if v.split('/')[-1] not in exist: 20 | os.system('cp ' + v + ' '+ dest) 21 | res = os.listdir(dest) 22 | return len(res) 23 | 24 | def current_log(oppo): 25 | path = '/home/zhaoyp/guandan_tog/wintest/torch/' 26 | files = os.listdir(path) 27 | tested = [] 28 | for name in files: 29 | latter = oppo + '.log' 30 | if latter in name: 31 | val = int(name.split('v')[0][3:]) 32 | tested.append(val) 33 | return tested 34 | 35 | def check(num, oppo): 36 | tested = current_log(oppo) 37 | if num not in tested: 38 | return False 39 | else: 40 | return True 41 | 42 | oppo = '4' 43 | while True: 44 | flag = 0 45 | nums = copy() 46 | time.sleep(10) 47 | tested = current_log(oppo) 48 | time.sleep(10) 49 | print(nums, tested) 50 | for i in range(1, nums+1): 51 | if i not in tested: 52 | flag = 1 53 | break 54 | if flag == 1: 55 | os.system('bash testmodel.sh ' + str(i)) 56 | print('testing {}'.format(i)) 57 | time.sleep(10) 58 | res = check(i, oppo) 59 | while not res: 60 | time.sleep(300) 61 | res = check(i, oppo) 62 | os.system('bash kill_auto.sh') 63 | print('model index {} test finish'.format(i)) 64 | time.sleep(10) 65 | else: 66 | time.sleep(600) 67 | -------------------------------------------------------------------------------- /wintest/torch/kill.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ps -ef | grep guandan | awk '{print $2}' | xargs kill -9 3 | ps -ef | grep server | awk '{print $2}' | xargs kill -9 4 | ps aux|grep python|grep -v grep|cut -c 9-15|xargs kill -9 5 | ps -ef | grep actor | awk '{print $2}' | xargs kill -9 6 | ps -ef | grep game | awk '{print $2}' | xargs kill -9 7 | ps -ef | grep python | awk '{print $2}' | xargs kill -9 8 | -------------------------------------------------------------------------------- /wintest/torch/kill_auto.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ps -ef | grep guandan | awk '{print $2}' | xargs kill -9 3 | ps -ef | grep server | awk '{print $2}' | xargs kill -9 4 | #ps aux|grep python|grep -v grep|cut -c 9-15|xargs kill -9 5 | ps -ef | grep actor | awk '{print $2}' | xargs kill -9 6 | ps -ef | grep game | awk '{print $2}' | xargs kill -9 7 | ps -ef | grep client | awk '{print $2}' | xargs kill -9 8 | #ps -ef | grep python | awk '{print $2}' | xargs kill -9 9 | -------------------------------------------------------------------------------- /wintest/torch/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | #import scipy.signal 3 | import torch 4 | import torch.nn as nn 5 | from torch.distributions.categorical import Categorical 6 | 7 | 8 | def combined_shape(length, shape=None): 9 | if shape is None: 10 | return (length,) 11 | return (length, shape) if np.isscalar(shape) else (length, *shape) 12 | 13 | 14 | def mlp(sizes, activation, output_activation=nn.Identity): 15 | layers = [] 16 | for j in range(len(sizes)-1): 17 | act = activation if j < len(sizes)-2 else output_activation 18 | layers += [nn.Linear(sizes[j], sizes[j+1]), act()] 19 | return nn.Sequential(*layers) 20 | 21 | def shared_mlp(obs_dim, sizes, activation): # 分两个叉,一个是过softmax的logits,另一个不过,就是单纯的q(s,a),这里是前面的共享层 22 | layers = [] 23 | shapes = [obs_dim] + list(sizes) 24 | for j in range(len(shapes) - 1): 25 | act = activation 26 | layers += [nn.Linear(shapes[j], shapes[j + 1]), act()] 27 | return nn.Sequential(*layers) 28 | 29 | 30 | def count_vars(module): 31 | return sum([np.prod(p.shape) for p in module.parameters()]) 32 | 33 | 34 | class Actor(nn.Module): 35 | def _distribution(self, obs): 36 | raise NotImplementedError 37 | 38 | def _log_prob_from_distribution(self, pi, act): 39 | raise NotImplementedError 40 | 41 | def forward(self, obs, act=None, legalaction=torch.tensor(list(range(10))).to(torch.float32)): 42 | # Produce action distributions for given observations, and 43 | # optionally compute the log likelihood of given actions under 44 | # those distributions. 45 | pi = self._distribution(obs, legalaction) 46 | logp_a = None 47 | if act is not None: 48 | logp_a = self._log_prob_from_distribution(pi, act) 49 | return pi, logp_a 50 | 51 | 52 | class MLPCategoricalActor(Actor): 53 | def __init__(self, obs_dim, act_dim, hidden_sizes, activation): 54 | super().__init__() 55 | self.logits_net = mlp([obs_dim] + list(hidden_sizes) + [act_dim], activation) 56 | 57 | def _distribution(self, obs, legal_action): 58 | logits = torch.squeeze(self.logits_net(obs)) - (1 - legal_action) * 1e6 59 | return Categorical(logits=logits) 60 | 61 | def _log_prob_from_distribution(self, pi, act): 62 | return pi.log_prob(act) 63 | 64 | 65 | class MLPCritic(nn.Module): 66 | def __init__(self, obs_dim, hidden_sizes, activation): 67 | super().__init__() 68 | self.v_net = mlp([obs_dim] + list(hidden_sizes) + [1], activation) 69 | 70 | def forward(self, obs): 71 | return torch.squeeze(self.v_net(obs), -1) # Critical to ensure v has right shape. 72 | 73 | 74 | class MLPQ(nn.Module): 75 | def __init__(self, obs_dim, hidden_sizes, activation): 76 | super().__init__() 77 | self.q_net = mlp([obs_dim] + list(hidden_sizes) + [1], activation) 78 | 79 | def forward(self, obs): 80 | return torch.squeeze(self.q_net(obs), -1) # Critical to ensure q has right shape. 81 | 82 | 83 | class MLPActorCritic(nn.Module): 84 | def __init__(self, observation_space, action_space, 85 | hidden_sizes=(512, 512, 512, 512, 256), activation=nn.Tanh): 86 | super().__init__() 87 | 88 | obs_dim = observation_space 89 | self.shared = shared_mlp(obs_dim[1], hidden_sizes, activation) 90 | self.pi = mlp([hidden_sizes[-1], 128, action_space], activation) # 输出logits 91 | self.v = mlp([hidden_sizes[-1], 128, 1], activation) # 输出q(s,a) 92 | 93 | 94 | def step(self, obs, legal_action): 95 | obs = torch.tensor(obs).to(torch.float32) 96 | legal_action = torch.tensor(legal_action).to(torch.float32) 97 | with torch.no_grad(): 98 | shared_feature = self.shared(obs) 99 | # print(shared_feature.shape, legal_action.shape) 100 | logits = torch.squeeze(self.pi(shared_feature)) - (1 - legal_action) * 1e8 101 | a = torch.argmax(logits) 102 | # del obs, legal_action 103 | # return a.numpy().item(), v.numpy().item(), logp_a.numpy().item() 104 | return a.numpy() 105 | 106 | def act(self, obs): 107 | return self.step(obs)[0] 108 | 109 | def set_weights(self, weights): 110 | self.load_state_dict(weights) 111 | 112 | def get_weights(self): 113 | return self.state_dict() 114 | 115 | 116 | class MLPQNetwork(nn.Module): 117 | def __init__(self, observation_space, 118 | hidden_sizes=(512, 512, 512, 512, 512), activation=nn.Tanh): 119 | super().__init__() 120 | 121 | obs_dim = observation_space 122 | 123 | # build Q function 124 | self.q = MLPQ(obs_dim, hidden_sizes, activation) 125 | 126 | def load_tf_weights(self, weights): 127 | name = ['q_net.0.weight', 'q_net.0.bias', 'q_net.2.weight', 'q_net.2.bias', 'q_net.4.weight', 'q_net.4.bias', 'q_net.6.weight', 'q_net.6.bias', 'q_net.8.weight', 'q_net.8.bias', 'q_net.10.weight', 'q_net.10.bias'] 128 | tensor_weights = [] 129 | for weight in weights: 130 | temp = torch.tensor(weight).T 131 | tensor_weights.append(temp) 132 | new_weights = dict(zip(name, tensor_weights)) 133 | self.q.load_state_dict(new_weights) 134 | print('load tf weights success') 135 | 136 | def get_max_n_index(self, data, n): 137 | #data = data[:,:-3] 138 | q_list = self.q(torch.tensor(data).to(torch.float32)) 139 | q_list = q_list.detach().numpy() 140 | return q_list.argsort()[-n:][::-1].tolist() 141 | 142 | if __name__ == '__main__': 143 | model = MLPActorCritic((10, 567), 1) 144 | model_q = MLPQNetwork(567) 145 | b = np.load("/home/zhaoyp/guandan_tog/actor_ppo/debug128.npy", allow_pickle=True).item() 146 | print(b.keys()) 147 | state = b['x_batch'][0] 148 | n = 3 149 | index2action = model_q.get_max_n_index(torch.tensor(state).to(torch.float32),n) 150 | 151 | # state = np.random.random((513, )) 152 | # action1 = np.random.random((54, )) 153 | # action2 = np.random.random((54, )) 154 | # action3 = np.random.random((54, )) 155 | # b = np.load("/home/zhaoyp/guandan_tog_tog/actor_torch/debug145.npy", allow_pickle=True).item() 156 | # print(b.keys()) 157 | # print(b['obs_cut'].shape) 158 | # print(b['obs'].shape) 159 | 160 | # print('time1') 161 | # objgraph.show_most_common_types(limit=30) 162 | # objgraph.show_growth() 163 | 164 | 165 | # print('time2') 166 | # objgraph.show_most_common_types(limit=30) 167 | # objgraph.show_growth() 168 | 169 | # a, v, p = model.step(state, legal_index) 170 | 171 | # print('time3') 172 | # objgraph.show_most_common_types(limit=30) 173 | # objgraph.show_growth() 174 | 175 | # print(a,v,p) 176 | # print(type(a),type(v),type(p)) 177 | -------------------------------------------------------------------------------- /wintest/torch/q_network.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/wintest/torch/q_network.ckpt -------------------------------------------------------------------------------- /wintest/torch/testmodel.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | nohup /home/zhaoyp/guandan_tog/wintest/torch/danserver 50 >/dev/null 2>&1 & 3 | sleep 0.5s 4 | nohup /root/miniconda3/envs/guandan/bin/python /home/zhaoyp/guandan_tog/wintest/ai4/client1.py >/dev/null 2>&1 & 5 | # /root/miniconda3/envs/guandan/bin/python /home/zhaoyp/guandan_tog/wintest/random_clien0.py 2>&1 & 6 | sleep 0.5s 7 | nohup /root/miniconda3/envs/guandan/bin/python /home/zhaoyp/guandan_tog/wintest/torch/client1.py --resfile res$1v4.log >/dev/null 2>&1 & 8 | # /root/miniconda3/envs/guandan/bin/python /home/zhaoyp/guandan_tog/wintest/newversion/my/client1.py --resfile res$1v9.log 2>&1 & 9 | sleep 0.5s 10 | nohup /root/miniconda3/envs/guandan/bin/python /home/zhaoyp/guandan_tog/wintest/ai4/client3.py >/dev/null 2>&1 & 11 | sleep 0.5s 12 | nohup /root/miniconda3/envs/guandan/bin/python /home/zhaoyp/guandan_tog/wintest/torch/client3.py >/dev/null 2>&1 & 13 | sleep 0.5s 14 | nohup /root/miniconda3/envs/guandan/bin/python /home/zhaoyp/guandan_tog/wintest/torch/actor.py --iter $1 >/dev/null 2>&1 & 15 | # /root/miniconda3/envs/guandan/bin/python /home/zhaoyp/guandan_tog/wintest/newversion/my/actor.py --model $1 2>&1 & 16 | echo $1 17 | 18 | -------------------------------------------------------------------------------- /wintest/torch/testvsdqn.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | nohup /home/zhaoyp/guandan_tog/wintest/torch/danserver 10 >/dev/null 2>&1 & 3 | sleep 0.5s 4 | nohup /root/miniconda3/envs/guandan/bin/python /home/zhaoyp/guandan_tog/wintest/danzero/client0.py >/dev/null 2>&1 & 5 | # /root/miniconda3/envs/guandan/bin/python /home/zhaoyp/guandan_tog/wintest/random_clien0.py 2>&1 & 6 | sleep 0.5s 7 | nohup /root/miniconda3/envs/guandan/bin/python /home/zhaoyp/guandan_tog/wintest/torch/client1.py --resfile res$1vdqn.log > /dev/null 2>&1 & 8 | # /root/miniconda3/envs/guandan/bin/python /home/zhaoyp/guandan_tog/wintest/newversion/my/client1.py --resfile res$1v9.log 2>&1 & 9 | sleep 0.5s 10 | nohup /root/miniconda3/envs/guandan/bin/python /home/zhaoyp/guandan_tog/wintest/danzero/client2.py >/dev/null 2>&1 & 11 | sleep 0.5s 12 | nohup /root/miniconda3/envs/guandan/bin/python /home/zhaoyp/guandan_tog/wintest/torch/client3.py >/dev/null 2>&1 & 13 | sleep 0.5s 14 | nohup /root/miniconda3/envs/guandan/bin/python /home/zhaoyp/guandan_tog/wintest/danzero/actor.py >/dev/null 2>&1 & 15 | sleep 0.5s 16 | nohup /root/miniconda3/envs/guandan/bin/python /home/zhaoyp/guandan_tog/wintest/torch/actor.py --iter $1 >/dev/null 2>&1 & 17 | 18 | echo $1 19 | -------------------------------------------------------------------------------- /离线平台使用说明.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/submit-paper/Danzero_plus/e2b900b01096e0743de945f9963df35cd544f36d/离线平台使用说明.pdf --------------------------------------------------------------------------------