├── .gitignore ├── .idea ├── .gitignore ├── code.iml ├── deployment.xml ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml └── modules.xml ├── README.md ├── __init__.py ├── codes ├── dataset │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── dataloader.cpython-38.pyc │ │ ├── dataloader.cpython-39.pyc │ │ ├── dataset.cpython-38.pyc │ │ └── dataset.cpython-39.pyc │ ├── dataloader.py │ └── dataset.py ├── env │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── environment.cpython-38.pyc │ │ └── environment.cpython-39.pyc │ ├── env_wrapper.py │ └── environment.py ├── mcts │ ├── MCTS.py │ ├── __init__.py │ └── __pycache__ │ │ ├── MCTS.cpython-38.pyc │ │ ├── MCTS.cpython-39.pyc │ │ ├── __init__.cpython-38.pyc │ │ └── __init__.cpython-39.pyc ├── multi_runner │ ├── MCTSF.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── MCTSF.cpython-39.pyc │ │ ├── __init__.cpython-39.pyc │ │ └── envs.cpython-39.pyc │ └── envs.py ├── net │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── network.cpython-38.pyc │ │ └── network.cpython-39.pyc │ └── network.py ├── scripts │ └── check_redundancy.py ├── trainer │ ├── Player.py │ ├── Trainer.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── Player.cpython-39.pyc │ │ ├── Trainer.cpython-38.pyc │ │ ├── Trainer.cpython-39.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── loss.cpython-38.pyc │ │ └── loss.cpython-39.pyc │ └── loss.py └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── util_functions.cpython-38.pyc │ └── util_functions.cpython-39.pyc │ └── util_functions.py ├── config ├── S_4.yaml ├── S_4_remote.yaml ├── S_9.yaml └── my_conf.yaml ├── main.py ├── record.txt └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | /exp/ 2 | /data/ -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Datasource local storage ignored files 5 | /dataSources/ 6 | /dataSources.local.xml 7 | # Editor-based HTTP Client requests 8 | /httpRequests/ 9 | -------------------------------------------------------------------------------- /.idea/code.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OpenTensor 2 | 3 | This is the code implement of our paper "OpenTensor: Reproducing Faster Matrix Multiplication Discovering Algorithms" in the 37th Conference on Neural Information Processing Systems Workshop (NeurIPS 2023). 4 | 5 | We provide the codes to generate synthetic tensors, train OpenTensor and perform tensor decomposition. 6 | 7 | ## Config 8 | 9 | All configs should be contained in a yaml file. We provide some config templates in the `./config` folder. For example, `./config/S_4.yaml` is the config file for decomposing $4 \times 4 \times 4$ matrix multiplication tensor, which is equivalent to discovering the $2 \times 2$ matrix multiplication algorithm. 10 | 11 | ## Generating synthetic data 12 | 13 | ``` 14 | mkdir data 15 | python main.py --config ./config/S_4.yaml --mode generate_data 16 | ``` 17 | 18 | This command generates 100000 synthetic tensors and saves it to the `./data` folder. 19 | 20 | ## Training OpenTensor 21 | 22 | ``` 23 | mkdir exp 24 | python main.py --config ./config/S_4.yaml --mode train 25 | ``` 26 | 27 | The model parameters and the tensorboard log files are all saved in the subfolders of `./exp`. 28 | 29 | ## Testing OpenTensor 30 | 31 | ``` 32 | python main.py --config ./config/S_4.yaml --mode infer --run_dir $run_dir 33 | ``` 34 | 35 | where `$run_dir` is the subfolders of `./exp`, which contains the model parameters of OpenTensor. This command discovers descomposition of the matrix multiplication tensor with the OpenTensor model. 36 | 37 | 38 | ## Citing us 39 | 40 | If our work has been helpful to you, please feel free to cite us: 41 | 42 | ```latex 43 | @article{sun2024opentensor, 44 | title={OpenTensor: Reproducing Faster Matrix Multiplication Discovering Algorithms}, 45 | author={Sun, Yiwen and Li, Wenye}, 46 | journal={arXiv preprint arXiv:2405.20748}, 47 | year={2024} 48 | } 49 | ``` -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/__init__.py -------------------------------------------------------------------------------- /codes/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from codes.dataset.dataset import * 2 | from codes.dataset.dataloader import * -------------------------------------------------------------------------------- /codes/dataset/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/dataset/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /codes/dataset/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/dataset/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /codes/dataset/__pycache__/dataloader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/dataset/__pycache__/dataloader.cpython-38.pyc -------------------------------------------------------------------------------- /codes/dataset/__pycache__/dataloader.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/dataset/__pycache__/dataloader.cpython-39.pyc -------------------------------------------------------------------------------- /codes/dataset/__pycache__/dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/dataset/__pycache__/dataset.cpython-38.pyc -------------------------------------------------------------------------------- /codes/dataset/__pycache__/dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/dataset/__pycache__/dataset.cpython-39.pyc -------------------------------------------------------------------------------- /codes/dataset/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | 4 | 5 | class MultiEpochsDataLoader(torch.utils.data.DataLoader): 6 | 7 | def __init__(self, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | self._DataLoader__initialized = False 10 | self.batch_sampler = _RepeatSampler(self.batch_sampler) 11 | self._DataLoader__initialized = True 12 | self.iterator = super().__iter__() 13 | 14 | def __len__(self): 15 | return len(self.batch_sampler.sampler) 16 | 17 | def __iter__(self): 18 | for i in range(len(self)): 19 | yield next(self.iterator) 20 | 21 | 22 | class _RepeatSampler(object): 23 | """ Sampler that repeats forever. 24 | Args: 25 | sampler (Sampler) 26 | """ 27 | 28 | def __init__(self, sampler): 29 | self.sampler = sampler 30 | 31 | def __iter__(self): 32 | while True: 33 | # for i in range(len(self.sampler)): 34 | yield from iter(self.sampler) -------------------------------------------------------------------------------- /codes/dataset/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | from tqdm import tqdm 5 | 6 | import sys 7 | import os 8 | import copy 9 | import itertools 10 | sys.path.append(os.path.abspath(os.path.join(".."))) 11 | sys.path.append(os.path.abspath(os.path.join("."))) 12 | from codes.utils import * 13 | 14 | class TupleDataset(Dataset): 15 | 16 | def __init__(self, 17 | T, 18 | S_size, 19 | N_steps, 20 | coefficients, 21 | self_data=[], 22 | synthetic_data=[], 23 | debug=False, 24 | save_type="traj"): 25 | # examples: A list of episodes, including: 26 | # 1. state (Network input) 27 | # 1.1 tensors (np) 28 | # 1.2 scalars (np) 29 | # 2. action (np) 30 | # 3. reward 31 | 32 | self.T = T 33 | self.S_size = S_size 34 | self.N_steps = N_steps 35 | self.coefficients = coefficients 36 | token_len = 3 * S_size // N_steps 37 | self.N_logits = len(coefficients) ** token_len 38 | self.ct = 0 39 | self.save_type = save_type 40 | 41 | print("Preprocessing dataset...") 42 | self.self_data = self_data 43 | self.synthetic_data = synthetic_data 44 | self.data = self_data + synthetic_data 45 | self.data_iterer = itertools.cycle(self.data) 46 | self.self_examples = [] 47 | self.synthetic_examples = [] 48 | 49 | #TODO: Randomize sign permutation. 50 | #TODO: Reformualte data format. 51 | # Canonicalize actions & to logits. 52 | if save_type == "tuple": 53 | for episode in tqdm(synthetic_data): 54 | state, action, reward = episode 55 | action = self.action_to_logits(canonicalize_action(action)) 56 | self.synthetic_examples.append([state, action, reward]) 57 | for episode in tqdm(self_data): 58 | state, action, reward = episode 59 | action = self.action_to_logits(canonicalize_action(action)) 60 | self.self_examples.append([state, action, reward]) 61 | self.examples = self.self_examples + self.synthetic_examples 62 | 63 | else: # Traj format data. 64 | self._prepare_examples_from_trajs() 65 | 66 | def _prepare_examples_from_trajs(self): 67 | ''' 68 | This function will permutate self.xxx_data (but not change), 69 | and get the corresponding examples. 70 | ''' 71 | S_size = self.S_size 72 | T = self.T 73 | 74 | self_examples, synthetic_examples = [], [] 75 | 76 | for traj in tqdm(self.self_data): 77 | # Reverse the order. From decompose order to synthesis order. 78 | _traj = copy.deepcopy(traj) 79 | _states, _actions, _rewards = _traj 80 | _states.reverse(), _actions.reverse(), _rewards.reverse() 81 | _traj = [_states, _actions, _rewards] 82 | 83 | new_traj = self.permutate_traj(_traj) 84 | self_examples.extend(self.traj_to_episode(new_traj)) 85 | 86 | for traj in tqdm(self.synthetic_data): 87 | new_traj = self.permutate_traj(traj) 88 | synthetic_examples.extend(self.traj_to_episode(new_traj)) 89 | 90 | self.self_examples, self.synthetic_examples = self_examples, synthetic_examples 91 | self.examples = self_examples + synthetic_examples 92 | 93 | def _permutate_traj(self, trajs_n=5000): 94 | assert self.save_type == "traj" 95 | print("Permutate!") 96 | for _ in range(trajs_n): 97 | self_traj = next(self.data_iterer) 98 | new_traj = self.permutate_traj(self_traj) 99 | new_episodes = self.traj_to_episode(new_traj) 100 | n = len(new_episodes) 101 | self.examples = self.examples[n:] + new_episodes 102 | 103 | def __len__(self): 104 | return len(self.examples) 105 | 106 | def __getitem__(self, idx): 107 | state, action, reward = self.examples[idx] 108 | tensor, scalar = state 109 | action = self.logits_to_action(action) 110 | tensor, action = self.random_sign_permutation(tensor, action) # Data aug. 111 | action = canonicalize_action(action) #FIXME: Is it needed? 112 | action = self.action_to_logits(action) 113 | # self._permutate_traj() # Permutate traj. 114 | return [tensor, scalar], action, reward 115 | 116 | def traj_to_episode(self, traj): 117 | results = [] 118 | T, S_size = self.T, self.S_size 119 | states, actions, rewards = traj 120 | states = list(reversed(states)); actions = list(reversed(actions)); rewards = list(reversed(rewards)) 121 | actions_tensor = [action2tensor(action) for action in actions] 122 | for idx, state in enumerate(states): 123 | tensors = np.zeros((T, S_size, S_size, S_size), dtype=np.int32) 124 | tensors[0] = state # state. 125 | if idx != 0: 126 | # History actions. 127 | tensors[1:(idx+1)] = np.stack(reversed(actions_tensor[max(idx-(T-1), 0):idx]), axis=0) 128 | scalars = np.array([idx, idx, idx]) #FIXME: Havn't decided the scalars. 129 | 130 | cur_state = [tensors, scalars] 131 | action = self.action_to_logits(canonicalize_action(actions[idx])) 132 | reward = rewards[idx] 133 | results.append([cur_state, action, reward]) 134 | return results 135 | 136 | def permutate_traj(self, traj): 137 | S_size = self.S_size 138 | states, actions, rewards = traj # [T, S, S, S], [T, 3, S], [T], synthesis order. 139 | final_state = states[0] - action2tensor(actions[0]) # If synthesis data, final_state = 0, rewards[0] = -1. 140 | 141 | # Shuffle the traj. 142 | new_actions = actions.copy() 143 | np.random.shuffle(new_actions) 144 | new_states = [] 145 | new_rewards = copy.deepcopy(rewards) 146 | # sample = np.zeros((S_size, S_size, S_size), dtype=np.int32) 147 | sample = final_state 148 | for action in new_actions: 149 | sample = sample + action2tensor(action) 150 | new_states.append(sample.copy()) 151 | new_traj = [new_states, new_actions, new_rewards] # synthesis order. 152 | return new_traj 153 | 154 | def action_to_logits(self, 155 | action): 156 | ''' 157 | action: A [3, S_size] array. 158 | ''' 159 | 160 | # Break action into tokens. 161 | token_len = 3 * self.S_size // self.N_steps 162 | coefficients = self.coefficients 163 | action = action.reshape((-1, token_len)) # [N_steps, token_len] 164 | 165 | # Get logits. 166 | logits = [] # Start sign. 167 | for token in action: # Get one logit. 168 | # token = token.to_list() 169 | logit = 0 170 | if torch.is_tensor(token): 171 | token = torch.flip(token, dims=(0,)) 172 | else: 173 | token = token[::-1] 174 | for idx, v in enumerate(token): 175 | logit += coefficients.index(v) * (len(coefficients) ** idx) 176 | logits.append(logit) 177 | 178 | return np.array(logits, dtype=np.int32) 179 | 180 | def logits_to_action(self, logits): 181 | ''' 182 | logit: N_steps values of {0, 1, ..., N_logits - 1}. 183 | e.g.: 184 | If: 185 | token_len = 2 186 | coefficients = [0, 1, -1] 187 | N_steps = 6 188 | Then: 189 | [0, 1, 2, 3, 4, 5] -> [0 0 | 0 1 | 0 -1 | 1 0 | 1 1 | 1 -1 ] 190 | ''' 191 | token_len = 3 * self.S_size // self.N_steps 192 | coefficients = self.coefficients 193 | action = [] 194 | for logit in logits: # Get one action 195 | token = [] 196 | if logit == self.N_logits: 197 | raise # Mean that there is a start sign in the middle of action. 198 | for _ in range(token_len): # Get one token 199 | idx = logit % len(coefficients) 200 | token.append(coefficients[idx]) 201 | logit = logit // len(coefficients) 202 | token.reverse() 203 | action.extend(token) 204 | 205 | action = np.array(action, dtype=np.int32).reshape((3, -1)) 206 | return action 207 | 208 | def random_sign_permutation(self, 209 | tensor, action): 210 | trans_1, trans_2, trans_3 = \ 211 | (np.random.binomial(1, .5, self.S_size) * 2 - 1).astype(np.int32), \ 212 | (np.random.binomial(1, .5, self.S_size) * 2 - 1).astype(np.int32), \ 213 | (np.random.binomial(1, .5, self.S_size) * 2 - 1).astype(np.int32) 214 | tensor = np.einsum('i, j, k, bijk -> bijk', trans_1, trans_2, trans_3, tensor, 215 | dtype=np.int32) 216 | action = np.stack([action[0]*trans_1, action[1]*trans_2, action[2]*trans_3], axis=0) 217 | return tensor, action 218 | 219 | 220 | if __name__ == '__main__': 221 | dataset = TupleDataset(T=7, 222 | S_size=4, 223 | N_steps=6, 224 | coefficients=[0, 1, -1], 225 | synthetic_data=np.load("data/traj_data/100000_S4T7_scalar3.npy", allow_pickle=True).tolist(), 226 | debug=True) 227 | # from torch.utils.data import DataLoader 228 | # dataloader = DataLoader(dataset, batch_size=64, shuffle=True) 229 | # res = next(iter(dataloader)) 230 | import pdb; pdb.set_trace() -------------------------------------------------------------------------------- /codes/env/__init__.py: -------------------------------------------------------------------------------- 1 | from codes.env.environment import * -------------------------------------------------------------------------------- /codes/env/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/env/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /codes/env/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/env/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /codes/env/__pycache__/environment.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/env/__pycache__/environment.cpython-38.pyc -------------------------------------------------------------------------------- /codes/env/__pycache__/environment.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/env/__pycache__/environment.cpython-39.pyc -------------------------------------------------------------------------------- /codes/env/env_wrapper.py: -------------------------------------------------------------------------------- 1 | # vector env -------------------------------------------------------------------------------- /codes/env/environment.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from math import sqrt 3 | 4 | import sys 5 | import os 6 | sys.path.append(os.path.abspath(os.path.join(".."))) 7 | sys.path.append(os.path.abspath(os.path.join("."))) 8 | from codes.utils import * 9 | 10 | # todo: 建议写成gym的标准形式 init reset step 11 | # todo: 可以少写函数 多写变量 12 | 13 | class Environment(): 14 | ''' 15 | 负责定义游戏的动作、状态以及回报 16 | state: np.darray, [4, 4, 4] 17 | action: np.darray, [3, 4] (表示u, v, w) 18 | results: [[s_1, a_1, pi_1, r_1], ...] 19 | 20 | 包括的功能有: 21 | play: 进行一次游戏 22 | 生成人工Tensor 23 | 以及其它和state, action及reward相关的操作 24 | ''' 25 | 26 | def __init__(self, 27 | S_size, 28 | R_limit, 29 | init_state=None, 30 | T=7, 31 | **kwargs): 32 | ''' 33 | S_size: u, v, w的维度 34 | R_limit: 游戏的步数上限 35 | ''' 36 | # 参数 37 | self.S_size = S_size 38 | self.R_limit = R_limit 39 | self.T = T 40 | # 环境变量 41 | if init_state is None: 42 | init_state = self.get_init_state(S_size) 43 | self.cur_state = init_state 44 | self.accumulate_reward = 0 45 | self.step_ct = 0 46 | 47 | # 历史变量 48 | self.hist_actions = [np.zeros_like(self.cur_state) for _ in range(self.T-1)] 49 | 50 | 51 | def get_init_state(self, 52 | S_size, 53 | no_base_change=False): 54 | ''' 55 | 得到一个初始化状态: state 56 | S_size: u, v, w的维度 57 | 返回: state. 58 | ''' 59 | ##### 60 | # 注意,这里我们可以添加基变换的数据增强 61 | ##### 62 | 63 | def one_hot(idx): 64 | temp = np.zeros((S_size, ), dtype=np.int32) 65 | temp[idx] = 1 66 | return temp 67 | 68 | # 1. Get the original Matmul-Tensor. 69 | init_state = np.zeros((S_size, S_size, S_size), dtype=np.int32) 70 | n = round(sqrt(S_size)) 71 | 72 | for i in range(n): # 用自然基的方式构建向量 73 | for j in range(n): 74 | z_idx = i * n + j 75 | z = one_hot(z_idx) # C_{i,j} = c_{i*n + j} 76 | for k in range(n): 77 | x_idx = i * n + k # A_{i,k} = a_{i*n + k} 78 | y_idx = k * n + j # B_{k,j} = b_{k*n + j} 79 | x, y = one_hot(x_idx), one_hot(y_idx) 80 | init_state += outer(x, y, z) 81 | 82 | # 2. Change of Basis. 83 | #FIXME: We haven't applied "basis change" operation. 84 | # Randomly get a transform matrix. 85 | if not no_base_change: 86 | p0 = .985 87 | P, L = np.random.choice([0, 1, -1], size=(S_size, S_size), p=[p0, (1-p0)/2, (1-p0)/2], replace=True),\ 88 | np.random.choice([0, 1, -1], size=(S_size, S_size), p=[p0, (1-p0)/2, (1-p0)/2], replace=True) 89 | for i in range(S_size): 90 | P[i, i] = np.random.choice([1, -1], size=(1,), p=[.5, .5]) 91 | L[i, i] = np.random.choice([1, -1], size=(1,), p=[.5, .5]) 92 | P, L = np.triu(P), np.tril(L) 93 | trans_mat = np.matmul(P, L) 94 | 95 | init_state = change_basis_tensor(tensor=init_state, 96 | trans_mat=trans_mat) 97 | # import pdb; pdb.set_trace() 98 | 99 | return init_state 100 | 101 | def step(self, 102 | action): 103 | ''' 104 | 状态转移并改动reward, 并返回是否游戏结束 105 | ''' 106 | u, v, w = action 107 | self.cur_state -= outer(u, v, w) 108 | self.accumulate_reward -= 1 109 | self.step_ct += 1 110 | self.hist_actions.append(action2tensor(action)) 111 | # 判断是否终止 112 | if self.is_terminate(): 113 | return True 114 | if self.step_ct >= self.R_limit: 115 | self.accumulate_reward += self.terminate_reward() 116 | return True 117 | return False 118 | 119 | def terminate_reward(self): 120 | ''' 121 | 截断时得到的惩罚。 122 | 返回: reward 123 | ''' 124 | state = self.cur_state 125 | # terminate_reward = 0 126 | # for z_idx in range(self.S_size): 127 | # terminate_reward -= np.linalg.matrix_rank(np.mat(state[..., z_idx], dtype=np.int32)) 128 | terminate_reward = -terminate_rank_approx(state) 129 | return terminate_reward 130 | 131 | def is_terminate(self): 132 | ''' 133 | 判断cur_state是否为0 134 | 返回: bool 135 | ''' 136 | return is_zero_tensor(self.cur_state) 137 | 138 | def reset(self, 139 | init_state=None, 140 | no_base_change=False): 141 | ''' 142 | 重置环境 143 | ''' 144 | if init_state is None: 145 | init_state = self.get_init_state(self.S_size, no_base_change) 146 | self.cur_state = init_state 147 | self.accumulate_reward = 0 148 | self.step_ct = 0 149 | self.hist_actions = [np.zeros_like(self.cur_state) for _ in range(self.T-1)] 150 | 151 | def get_network_input(self): 152 | ''' 153 | 将变量组织成网络输入的格式 154 | ''' 155 | T = self.T 156 | S_size = self.S_size 157 | hist_actions = self.hist_actions[-(T-1):] 158 | hist_actions.reverse() 159 | tensors = np.zeros((T, S_size, S_size, S_size), dtype=np.int32) 160 | tensors[0] = self.cur_state 161 | tensors[1:] = np.stack(hist_actions, axis=0) 162 | scalars = np.array([self.step_ct, self.step_ct, self.step_ct]) #FIXME: Havn't decided the scalars. 163 | 164 | return tensors, scalars 165 | 166 | 167 | if __name__ == '__main__': 168 | test_env = Environment(S_size=4, 169 | R_limit=8) 170 | test_action = np.array([ 171 | [0, 0, 1, 0], 172 | [1, 1, 0, 0], 173 | [0, 1, 0, 0] 174 | ]) 175 | for _ in range(8): 176 | print(test_env.step(test_action)) 177 | print(test_env.accumulate_reward) 178 | import pdb; pdb.set_trace() -------------------------------------------------------------------------------- /codes/mcts/MCTS.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import math 4 | import networkx as nx 5 | import matplotlib.pyplot as plt 6 | from tqdm import tqdm 7 | 8 | import sys 9 | import os 10 | import copy 11 | from typing import Tuple 12 | sys.path.append(os.path.abspath(os.path.join(".."))) 13 | sys.path.append(os.path.abspath(os.path.join("."))) 14 | from codes.env import Environment 15 | from codes.net import Net 16 | from codes.utils import * 17 | 18 | 19 | class Node(): 20 | ''' 21 | 一个MCTS的节点 22 | ''' 23 | 24 | def __init__(self, 25 | state, 26 | parent, 27 | pre_action, 28 | pre_action_idx, 29 | is_terminal=False): 30 | ##### 31 | # 这里应该初始化一个节点, 32 | # 包括Q, N以及女儿父母 33 | # e.g.: 34 | # self.Q = 0 35 | # ...... 36 | ##### 37 | 38 | self.parent = parent # parent: A Node instance (or None). 39 | self.pre_action = pre_action # pre_action: Action (or None). 40 | self.pre_action_idx = pre_action_idx 41 | self.is_leaf = True 42 | self.is_terminal = is_terminal 43 | self.state = state # state: Tensor. 44 | 45 | self.actions = [] # A list for actions. 46 | self.children = [] # A list for nodes. 47 | self.N = [] # A list for visit counts. 48 | self.Q = [] # A list for action value. 49 | self.pi = [] # A list for empirical policy probability. 50 | self.children_n = 0 51 | 52 | node = self 53 | depth = 0 54 | while node.parent: 55 | depth += 1 56 | node = node.parent 57 | self.depth = depth 58 | 59 | 60 | def expand(self, 61 | net: Net, 62 | noise=False, 63 | network_output=None, 64 | R_limit=12): 65 | ''' 66 | Expand this node. 67 | Return the value of this state. 68 | ''' 69 | # 1. Check terminal situation. 70 | if not self.is_leaf: 71 | raise Exception("This node has been expanded.") 72 | self.is_leaf = False 73 | 74 | if self.is_terminal: # Mean the state is terminal. Only propagate. 75 | node = self 76 | node.is_leaf = True 77 | if is_zero_tensor(node.state): 78 | value = 0 79 | else: 80 | value = -1 * terminate_rank_approx(node.state) 81 | while node.parent is not None: 82 | action_idx = node.pre_action_idx 83 | node = node.parent 84 | node.N[action_idx] += 1 85 | v = (value + -1 * (self.depth - node.depth)) 86 | node.Q[action_idx] = v / node.N[action_idx] +\ 87 | node.Q[action_idx] * (node.N[action_idx] - 1) / node.N[action_idx] 88 | 89 | return 90 | 91 | #FIXME: Here we can apply a transposition table. 92 | # 2. Get network output. 93 | # 2.1. If use network to infer: 94 | if network_output is None: 95 | tensors, scalars = self.get_network_input(net) 96 | 97 | net.set_mode("infer") 98 | with torch.no_grad(): 99 | output = net([tensors, scalars]) 100 | _, value, policy, prob = *net.value(output), *net.policy(output) # policy: [1, N_samples, 3, S_size] 101 | del output, tensors, scalars 102 | value, policy = value[0], policy[0] 103 | policy = [canonicalize_action(action) for action in policy] 104 | 105 | # 2.2. If we already have network output: 106 | else: 107 | value, policy = network_output 108 | 109 | # 2.3. Add noise for root node's expand. 110 | if noise: 111 | noise_actions = [canonicalize_action(random_action()) for _ in range(len(policy) // 4)] 112 | policy = policy + noise_actions 113 | 114 | # 3. Get empirical policy probability. 115 | N_samples = net.N_samples 116 | rec = [False for _ in range(N_samples)] # "True" represents having been recorded. 117 | actions = [] 118 | pi = [] 119 | for pos in range(N_samples): # Naive loop. 120 | action = policy[pos] 121 | if not rec[pos]: 122 | # Count. 123 | actions.append(action) 124 | rec[pos] = True 125 | ct = 1 126 | for i in range(pos+1, N_samples): 127 | if rec[i]: # Have been counted. 128 | continue 129 | if is_equal(policy[i], action): 130 | ct += 1 131 | rec[i] = True 132 | pi.append(ct / N_samples) 133 | # import pdb; pdb.set_trace() 134 | self.actions = actions 135 | self.pi = pi 136 | self.children_n = len(actions) 137 | 138 | # 4. Init records. 139 | self.N = [0 for _ in range(len(actions))] 140 | self.Q = [0 for _ in range(len(actions))] 141 | 142 | # 5. Expand the children nodes. 143 | for idx, action in enumerate(actions): 144 | child_state = self.state - action2tensor(action) 145 | child_depth = self.depth + 1 146 | child_node = Node(state=child_state, 147 | parent=self, 148 | pre_action=action, 149 | pre_action_idx=idx, 150 | is_terminal=(is_zero_tensor(child_state) or child_depth >= R_limit)) 151 | self.children.append(child_node) 152 | 153 | # 6. Backward propagate. 154 | node = self 155 | while node.parent is not None: 156 | action_idx = node.pre_action_idx 157 | node = node.parent 158 | node.N[action_idx] += 1 159 | v = (value + -1 * (self.depth - node.depth)) 160 | node.Q[action_idx] = v / node.N[action_idx] +\ 161 | node.Q[action_idx] * (node.N[action_idx] - 1) / node.N[action_idx] 162 | 163 | 164 | def select(self, c=None): 165 | ''' 166 | Choose the best child. 167 | Return the chosen node. 168 | ''' 169 | if self.is_leaf: 170 | raise Exception("Cannot choose a leaf node.") 171 | 172 | if c is None: 173 | c = 1.25 + math.log((1+19652+sum(self.N)) / 19652) 174 | 175 | scores = [self.Q[i] + c * self.pi[i] * math.sqrt(sum(self.N)) / (1 + self.N[i]) 176 | for i in range(self.children_n)] 177 | 178 | return self.children[np.argmax(scores)], scores 179 | 180 | 181 | def get_network_input(self, net): 182 | # Get state for net evaluation. 183 | # State: 184 | # Tensors: [cur_state, last t=1 action, last t=2 action, ... last t=T-1 action] 185 | # Scalars: [depth(step_ct)] 186 | T = net.T 187 | tensors = np.zeros([T, *self.state.shape], dtype=np.int32) 188 | tensors[0] = self.state # Current state. 189 | node = self 190 | for t in range(1, T): 191 | if node.parent is None: 192 | break 193 | tensors[t] = action2tensor(node.pre_action) 194 | node = node.parent 195 | scalars = np.array([self.depth, self.depth, self.depth]) #FIXME: Havn't decided the scalars. 196 | 197 | return tensors, scalars 198 | 199 | 200 | 201 | class MCTS(): 202 | ''' 203 | 蒙特卡洛树搜索 204 | ''' 205 | 206 | def __init__(self, 207 | init_state, 208 | simulate_times=400, 209 | R_limit=12, 210 | **kwargs): 211 | ''' 212 | 超参数传递 213 | ''' 214 | 215 | self.simulate_times = simulate_times 216 | self.R_limit = R_limit 217 | if init_state is not None: 218 | self.root_node = Node(state=init_state, 219 | parent=None, 220 | pre_action=None, 221 | pre_action_idx=None) 222 | 223 | 224 | def __call__(self, 225 | state, 226 | net: Net, 227 | log=False, 228 | verbose=False, 229 | noise=False): 230 | ''' 231 | 进行一次MCTS 232 | 返回: action, actions, visit_pi 233 | ''' 234 | 235 | assert is_equal(state, self.root_node.state), "State is inconsistent." 236 | iter_item = range(self.simulate_times) if verbose else tqdm(range(self.simulate_times)) 237 | R_limit = self.R_limit 238 | for simu in iter_item: 239 | # Select a leaf node. 240 | node = self.root_node 241 | while not node.is_leaf: 242 | node, scores = node.select() #FIXME: Need to control the factor c. 243 | node.expand(net, noise=noise, R_limit=R_limit) 244 | 245 | actions = self.root_node.actions 246 | N = self.root_node.N 247 | visit_ratio = (np.array(N) / sum(N)).tolist() 248 | action = actions[np.argmax(visit_ratio)] 249 | 250 | if log: 251 | log_txt = self.log() 252 | return action, actions, visit_ratio, log_txt 253 | 254 | return action, actions, visit_ratio 255 | 256 | 257 | def move(self, 258 | action): 259 | ''' 260 | MCTS向前一步 261 | ''' 262 | assert not self.root_node.is_leaf, "Cannot move a leaf node." 263 | 264 | # Get the action idx. 265 | action_idx = None 266 | for idx, child_action in enumerate(self.root_node.actions): 267 | if is_equal(child_action, action): 268 | action_idx = idx 269 | 270 | # Delete other children and move. 271 | self.root_node.children = [self.root_node.children[action_idx]] 272 | self.root_node = self.root_node.children[0] 273 | 274 | 275 | def reset(self, 276 | state, 277 | simulate_times=None, 278 | R_limit=None): 279 | ''' 280 | Reset MCTS. 281 | ''' 282 | if simulate_times is not None: 283 | self.simulate_times = simulate_times 284 | if R_limit is not None: 285 | self.R_limit = R_limit 286 | self.root_node = Node(state=state, 287 | parent=None, 288 | pre_action=None, 289 | pre_action_idx=None) 290 | 291 | 292 | def visualize(self): 293 | ''' 294 | visualize the tree. 295 | ''' 296 | # Create a graph. 297 | graph = nx.DiGraph() 298 | close_set = [self.root_node] 299 | 300 | while close_set != []: 301 | node = close_set.pop() 302 | if not node.is_leaf: 303 | [graph.add_edge( 304 | node, 305 | child 306 | ) for child in node.children] 307 | [close_set.append(child) for child in node.children] 308 | 309 | nx.draw(graph, with_labels=True, font_weight='bold') 310 | raise NotImplementedError 311 | plt.show() 312 | 313 | 314 | def log(self): 315 | ''' 316 | Get the log text. 317 | ''' 318 | node = self.root_node 319 | state_txt = str(node.state) # state. 320 | _, scores = node.select() 321 | N, Q, scores, children = np.array(node.N), np.array(node.Q), np.array(scores), np.array(node.actions) 322 | top_k_idx = np.argsort(N)[-5:] 323 | N, Q, scores, children = N[top_k_idx], Q[top_k_idx], scores[top_k_idx], children[top_k_idx] 324 | 325 | N_txt, Q_txt, scores_txt, children_txt = str(N), str(Q), str(scores), str(children) 326 | 327 | log_txt = "\n".join( 328 | ["\nCur state: \n", state_txt, 329 | "\nDepth: \n", str(node.depth), 330 | "\nchildren: \n", children_txt, 331 | "\nscores: \n", scores_txt, 332 | "\nQ: \n", Q_txt, 333 | "\nN: \n", N_txt,] 334 | ) 335 | 336 | return log_txt 337 | 338 | 339 | if __name__ == '__main__': 340 | 341 | init_state = np.array( 342 | [[[1, 0, 0, 0], 343 | [0, 1, 0, 0], 344 | [0, 0, 0, 0], 345 | [0, 0, 0, 0]], 346 | 347 | [[0, 0, 0, 0], 348 | [0, 0, 0, 0], 349 | [1, 0, 0, 0], 350 | [0, 1, 0, 0]], 351 | 352 | [[0, 0, 1, 0], 353 | [0, 0, 0, 1], 354 | [0, 0, 0, 0], 355 | [0, 0, 0, 0]], 356 | 357 | [[0, 0, 0, 0], 358 | [0, 0, 0, 0], 359 | [0, 0, 1, 0], 360 | [0, 0, 0, 1]]]) 361 | root_node = Node(state=init_state, 362 | parent=None, 363 | pre_action=None, 364 | pre_action_idx=None) 365 | 366 | from net import Net 367 | net = Net(N_samples=4) # For debugging. 368 | 369 | ############ Debug for Node ############ 370 | # import pdb; pdb.set_trace() 371 | # root_node.expand(net) 372 | # children_node = root_node.select() 373 | # children_node.expand(net) 374 | # import pdb; pdb.set_trace() 375 | 376 | ############ Debug for MCYS ############ 377 | mcts = MCTS(init_state=init_state, 378 | simulate_times=20) 379 | import pdb; pdb.set_trace() 380 | action, actions, pi = mcts(init_state, net) 381 | import pdb; pdb.set_trace() 382 | mcts.move(action) 383 | state = init_state - action2tensor(action) 384 | import pdb; pdb.set_trace() 385 | action, actions, pi = mcts(state, net) 386 | import pdb; pdb.set_trace() -------------------------------------------------------------------------------- /codes/mcts/__init__.py: -------------------------------------------------------------------------------- 1 | from codes.mcts.MCTS import * -------------------------------------------------------------------------------- /codes/mcts/__pycache__/MCTS.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/mcts/__pycache__/MCTS.cpython-38.pyc -------------------------------------------------------------------------------- /codes/mcts/__pycache__/MCTS.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/mcts/__pycache__/MCTS.cpython-39.pyc -------------------------------------------------------------------------------- /codes/mcts/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/mcts/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /codes/mcts/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/mcts/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /codes/multi_runner/MCTSF.py: -------------------------------------------------------------------------------- 1 | # MCTSF: Monto Colo Tree Search Forest. 2 | from codes.mcts import * 3 | from typing import List 4 | 5 | class MCTSF(): 6 | ''' 7 | 多个蒙特卡洛树的模拟 8 | ''' 9 | 10 | def __init__(self, 11 | mcts_list: List[MCTS], 12 | simulate_times=400): 13 | 14 | self.mcts_list = mcts_list 15 | self.mcts_n = len(mcts_list) 16 | self.simulate_times = simulate_times 17 | self.R_limit = mcts_list[0].R_limit 18 | 19 | 20 | def __call__(self, 21 | state_list, 22 | net: Net, 23 | noise=False): 24 | 25 | R_limit = self.R_limit 26 | 27 | # Check the states. 28 | for mcts, state in zip(self.mcts_list, state_list): 29 | assert is_equal(state, mcts.root_node.state), "State is inconsistent." 30 | 31 | # Simu. 32 | for simu in tqdm(range(self.simulate_times)): 33 | # Select the leaf nodes. 34 | node_list = [] 35 | for mcts in self.mcts_list: 36 | node = mcts.root_node 37 | while not node.is_leaf: 38 | node, scores = node.select() 39 | node_list.append(node) 40 | 41 | # Get network input to expand... 42 | node: Node 43 | batch_tensors, batch_scalars = [], [] 44 | for node in node_list: 45 | tensors, scalars = node.get_network_input(net) 46 | batch_tensors.append(tensors); batch_scalars.append(scalars) 47 | batch_tensors, batch_scalars = np.array(batch_tensors), np.array(batch_scalars) 48 | # Infer... 49 | net.set_mode("infer") 50 | with torch.no_grad(): 51 | batch_output = net([batch_tensors, batch_scalars]) 52 | _, batch_value, batch_policy, prob = *net.value(batch_output), *net.policy(batch_output) 53 | # batch_value: [B,] batch_policy: [B, N_samples, 3, S_size] 54 | del batch_output, batch_tensors, batch_scalars 55 | batch_policy = [[canonicalize_action(action) for action in policy] for policy in batch_policy] 56 | 57 | # Expand... 58 | for node, value, policy in zip(node_list, batch_value, batch_policy): 59 | node.expand(net, network_output=(value, policy), R_limit=R_limit, noise=noise) 60 | 61 | # Get results. 62 | actions_list = [mcts.root_node.actions for mcts in self.mcts_list] 63 | N_list = [mcts.root_node.N for mcts in self.mcts_list] 64 | visit_ratio_list = [(np.array(N) / sum(N)).tolist() for N in N_list] 65 | action_list = [actions_list[idx][np.argmax(visit_ratio_list[idx])] for idx in range(self.mcts_n)] 66 | 67 | return action_list, actions_list, visit_ratio_list 68 | 69 | 70 | def reset(self, 71 | state_list): 72 | 73 | for mcts, state in zip(self.mcts_list, state_list): 74 | mcts.reset(state, simulate_times=self.simulate_times) 75 | 76 | 77 | def move(self, 78 | action_list): 79 | 80 | for mcts, action in zip(self.mcts_list, action_list): 81 | mcts.move(action) -------------------------------------------------------------------------------- /codes/multi_runner/__init__.py: -------------------------------------------------------------------------------- 1 | from codes.multi_runner.envs import * 2 | from codes.multi_runner.MCTSF import * -------------------------------------------------------------------------------- /codes/multi_runner/__pycache__/MCTSF.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/multi_runner/__pycache__/MCTSF.cpython-39.pyc -------------------------------------------------------------------------------- /codes/multi_runner/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/multi_runner/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /codes/multi_runner/__pycache__/envs.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/multi_runner/__pycache__/envs.cpython-39.pyc -------------------------------------------------------------------------------- /codes/multi_runner/envs.py: -------------------------------------------------------------------------------- 1 | from codes.env import * 2 | from typing import List 3 | 4 | class ENVS(): 5 | 6 | def __init__(self, 7 | env_list: List[Environment]): 8 | 9 | self.env_list = env_list 10 | self.env_n = len(env_list) 11 | self.terminate_list = [False] * self.env_n 12 | 13 | 14 | def reset(self): 15 | 16 | for env in self.env_list: 17 | env.reset() 18 | self.terminate_list = [False] * self.env_n 19 | 20 | 21 | def step(self, 22 | action_list): 23 | 24 | for idx, (env, action) in enumerate(zip(self.env_list, action_list)): 25 | if not self.terminate_list[idx]: 26 | self.terminate_list[idx] = env.step(action) 27 | 28 | 29 | def get_curstates(self): 30 | 31 | state_list = [] 32 | for env in self.env_list: 33 | state_list.append(env.cur_state.copy()) 34 | 35 | return state_list 36 | 37 | 38 | def is_all_terminated(self): 39 | 40 | flag = True 41 | for sub_flag in self.terminate_list: 42 | flag = flag & sub_flag 43 | 44 | return flag 45 | 46 | 47 | def get_rewards(self): 48 | 49 | reward_list = [] 50 | for env in self.env_list: 51 | reward_list.append(env.accumulate_reward) 52 | 53 | return reward_list 54 | 55 | 56 | def get_stepcts(self): 57 | 58 | step_ct_list = [] 59 | for env in self.env_list: 60 | step_ct_list.append(env.step_ct) 61 | 62 | return step_ct_list -------------------------------------------------------------------------------- /codes/net/__init__.py: -------------------------------------------------------------------------------- 1 | from codes.net.network import * -------------------------------------------------------------------------------- /codes/net/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/net/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /codes/net/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/net/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /codes/net/__pycache__/network.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/net/__pycache__/network.cpython-38.pyc -------------------------------------------------------------------------------- /codes/net/__pycache__/network.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/net/__pycache__/network.cpython-39.pyc -------------------------------------------------------------------------------- /codes/net/network.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import math 6 | 7 | import sys 8 | import os 9 | sys.path.append(os.path.abspath(os.path.join(".."))) 10 | sys.path.append(os.path.abspath(os.path.join("."))) 11 | from codes.env import * 12 | from codes.mcts import * 13 | from codes.utils import * 14 | 15 | # Input: 16 | # Tensors of shape of [B,T,S,S,S]. First one is current tensor. 17 | # Scalars of shape of [B,s]. 18 | 19 | 20 | class Torso(nn.Module): 21 | ''' 22 | 网络躯干(Encoder). 23 | ''' 24 | 25 | def __init__(self, 26 | S_size=4, 27 | channel=3, 28 | T=7, 29 | scalar_size=3, 30 | n_attentive=8, 31 | mode="train", 32 | device='cuda', 33 | **kwargs): 34 | super(Torso, self).__init__() 35 | self.S_size = S_size 36 | self.channel = channel 37 | self.scalar_size = scalar_size 38 | self.T = T 39 | self.n_attentive = n_attentive 40 | self.mode = mode 41 | self.device = device 42 | 43 | self.attentive_modes = nn.ModuleList([AttentiveModes(S_size, channel) for _ in range(n_attentive)]) 44 | self.scalar2grid = nn.ModuleList([nn.Linear(scalar_size, S_size**2) for _ in range(3)]) # s -> S*S 45 | self.grid2grid = nn.ModuleList([nn.Linear(S_size**2*(T*S_size+1), S_size**2*channel) for _ in range(3)]) # S*S*TS+1 -> S*S*c 46 | #WARNING: grid2grid should be T*S_size+1 -> channel! 47 | 48 | 49 | def forward(self, x): 50 | 51 | # Input: 52 | # Tensors of shape of [B,T,S,S,S]. First one is current tensor. (numpy) 53 | # Scalars of shape of [B,s]. (numpy) 54 | 55 | S_size = self.S_size 56 | T = self.T 57 | channel = self.channel 58 | n_attentive = self.n_attentive 59 | 60 | input_t, input_s = x # Tensor input and Scalar input. 61 | if not torch.is_tensor(input_t): 62 | input_t, input_s = torch.from_numpy(input_t).float().to(self.device), torch.from_numpy(input_s).float().to(self.device) 63 | else: 64 | input_t, input_s = input_t.float().to(self.device), input_s.float().to(self.device) 65 | if self.mode == "infer": 66 | # assert len(input_t.shape) == 4 and len(input_s.shape) == 1, \ 67 | # "Infer mode does not support batch." 68 | if len(input_t.shape) == 4 and len(input_s.shape) == 1: 69 | input_t = input_t[None]; input_s = input_s[None] # Add a batch dim. 70 | batch_size = input_t.shape[0] 71 | 72 | # 1. Project to grids. 73 | x1 = torch.reshape(torch.permute(input_t, (0,2,3,4,1)), (batch_size, S_size, S_size, T*S_size)) # [B,S,S,TS] 74 | x2 = torch.reshape(torch.permute(input_t, (0,4,2,3,1)), (batch_size, S_size, S_size, T*S_size)) 75 | x3 = torch.reshape(torch.permute(input_t, (0,3,4,2,1)), (batch_size, S_size, S_size, T*S_size)) 76 | g = [x1, x2, x3] 77 | 78 | # 2. To grids. 79 | for idx in range(3): 80 | p = torch.reshape(self.scalar2grid[idx](input_s), (batch_size, S_size, S_size, 1)) 81 | g[idx] = torch.concat([g[idx], p], dim=-1) 82 | g[idx] = torch.reshape(self.grid2grid[idx](torch.reshape(g[idx], (batch_size, -1))), (batch_size, S_size, S_size, channel)) # [B,S,S,c] 83 | 84 | # 3. Attentive modes. 85 | x1, x2, x3 = g 86 | for idx in range(n_attentive): 87 | x1, x2, x3 = self.attentive_modes[idx]([x1, x2, x3]) 88 | 89 | # 4. Final stack. 90 | e = torch.reshape(torch.stack([x1, x2, x3], axis=1), (batch_size, 3*S_size**2, channel)) # [B, 3*S**2, c] 91 | 92 | return e 93 | 94 | 95 | def set_mode(self, mode): 96 | assert mode in ["train", "infer"] 97 | self.mode = mode 98 | 99 | 100 | class AttentiveModes(nn.Module): 101 | 102 | ''' 103 | 问题: 104 | 前向时, Attention模型是否共享参数? 105 | ''' 106 | 107 | def __init__(self, 108 | S_size=4, 109 | channel=3, 110 | device='cuda'): 111 | super(AttentiveModes, self).__init__() 112 | self.channel = channel 113 | self.S_size = S_size 114 | 115 | # self.attentions = nn.ModuleList(*[Attention(channel, 116 | # channel, 117 | # 2*S_size, 118 | # 2*S_size, 119 | # False, 120 | # device=device) for _ in range(S_size)]) 121 | 122 | self.attention = Attention(channel, 123 | channel, 124 | 2*S_size, 125 | 2*S_size, 126 | False, device=device) 127 | 128 | def forward(self, x): 129 | 130 | # Input: 131 | # [x1, x2, x3]. Each of them is shaped of [B, S, S, c] 132 | 133 | S_size = self.S_size 134 | channel = self.channel 135 | 136 | for m1, m2 in [(0,1), (2,0), (1,2)]: 137 | a = torch.cat([x[m1], x[m2].transpose(1,2)], axis=2) # a: [B, S, 2S, c] 138 | # for idx in range(S_size): #FIXME: Parallel loop? 139 | # c = self.attentions[idx]([a[:,idx],]) 140 | # x[m1][:,idx] = c[:, :S_size, :] 141 | # x[m2][:,idx] = c[:, S_size:, :] 142 | 143 | a = a.reshape((-1, 2*S_size, channel)) 144 | c = self.attention([a, ]) # c: [B*S, 2S, c] 145 | c = c.reshape((-1, S_size, 2*S_size, channel)) # c: [B, S, 2S, c] 146 | x[m1] = c[:, :, :S_size, :] 147 | x[m2] = c[:, :, S_size:, :] 148 | 149 | return x 150 | 151 | 152 | class Attention(nn.Module): 153 | 154 | def __init__(self, 155 | x_channel=3, 156 | y_channel=3, 157 | N_x=8, # 2S 158 | N_y=8, # 2S 159 | causal_mask=False, 160 | N_heads=16, 161 | d=32, 162 | w=4, 163 | device='cuda'): 164 | 165 | super(Attention, self).__init__() 166 | self.x_channel, self.y_channel = x_channel, y_channel 167 | self.N_x, self.N_y = N_x, N_y 168 | self.causal_mask = causal_mask 169 | self.N_heads = N_heads 170 | self.d, self.w = d, w 171 | self.device = device 172 | 173 | self.x_layer_norm = nn.LayerNorm(x_channel) 174 | self.y_layer_norm = nn.LayerNorm(y_channel) 175 | self.final_layer_norm = nn.LayerNorm(x_channel) 176 | self.W_Q = nn.Linear(x_channel, d * N_heads, bias=False) 177 | self.W_K = nn.Linear(y_channel, d * N_heads, bias=False) 178 | self.W_V = nn.Linear(y_channel, d * N_heads, bias=False) 179 | self.linear_1 = nn.Linear(d * N_heads, x_channel) 180 | self.linear_2 = nn.Linear(x_channel, x_channel * w) 181 | self.linear_3 = nn.Linear(x_channel * w, x_channel) 182 | self.gelu = nn.GELU() 183 | 184 | 185 | def forward(self, x): 186 | 187 | # Input: 188 | # [x, (y)]. If y is missed, y=x. 189 | 190 | N_heads = self.N_heads 191 | N_x, N_y = self.N_x, self.N_y 192 | d, w = self.d, self.w 193 | 194 | if len(x) == 1: 195 | x = x[0] 196 | y = x.clone() 197 | else: 198 | x, y = x 199 | 200 | batch_size = x.shape[0] 201 | 202 | x_norm = self.x_layer_norm(x) # [B, N_x, c_x] 203 | y_norm = self.y_layer_norm(y) # [B, N_y, c_y] 204 | 205 | q_s = self.W_Q(x_norm).view(batch_size, -1, N_heads, d).transpose(1, 2) # [batch_size, N_heads, N_x, d] 206 | k_s = self.W_K(y_norm).view(batch_size, -1, N_heads, d).transpose(1, 2) # [batch_size, N_heads, N_y, d] 207 | v_s = self.W_V(y_norm).view(batch_size, -1, N_heads, d).transpose(1, 2) # [batch_size, N_heads, N_y, d] 208 | 209 | scores = torch.matmul(q_s, k_s.transpose(-1, -2)) / np.sqrt(d) # [batch_size, N_heads, N_x, N_y] 210 | # scores = F.softmax(scores, dim=-1) 211 | if self.causal_mask: 212 | mask = (torch.from_numpy(np.triu(np.ones([batch_size, N_heads, N_x, N_y]), k=1)) == 0).to(self.device) 213 | scores = scores.masked_fill(mask == 0, -1e9) 214 | # scores = scores * torch.from_numpy(np.triu(np.ones([batch_size, N_heads, N_x, N_y]), k=1)).float().to(self.device) 215 | scores = F.softmax(scores, dim=-1) 216 | 217 | o_s = torch.matmul(scores, v_s) # [batch_size, N_heads, N_x, d] 218 | o_s = o_s.transpose(1, 2).contiguous().view(batch_size, -1, N_heads*d) # [batch_size, N_x, N_heads*d] 219 | x = x.reshape(batch_size*N_x, -1) + self.linear_1(o_s.reshape(-1, N_heads*d)) # [batch_size*N_x, c_x] 220 | 221 | x = (x + self.linear_3(self.gelu(self.linear_2(self.final_layer_norm(x))))).reshape(batch_size, N_x, -1) # [batch_size, N_x, c_x] 222 | 223 | return x 224 | 225 | 226 | class PolicyHead(nn.Module): 227 | 228 | def __init__(self, 229 | N_steps=6, 230 | N_logits=9, 231 | N_features=64, 232 | N_heads=32, 233 | N_layers=2, 234 | N_samples=32, 235 | torso_feature_shape=(3*4**2, 3), 236 | mode='train', 237 | device='cuda'): 238 | 239 | super(PolicyHead, self).__init__() 240 | self.N_steps = N_steps 241 | self.N_logits = N_logits 242 | self.N_features = N_features 243 | self.N_heads = N_heads 244 | self.N_layers = N_layers 245 | self.torso_feature_shape = torso_feature_shape 246 | self.N_samples = N_samples 247 | self.mode = mode 248 | self.device = device 249 | 250 | self.linear_1 = nn.Linear(N_logits+1, N_features * N_heads, bias=False) 251 | # self.pos_embed = nn.Linear(1, N_features * N_heads) 252 | self.pos_embed = nn.Sequential( 253 | nn.Linear(1, 512), 254 | nn.ReLU(), 255 | nn.Linear(512, N_features * N_heads) 256 | ) 257 | self.self_layer_norms = nn.ModuleList([nn.LayerNorm(N_features * N_heads) for _ in range(N_layers)]) 258 | self.cross_layer_norms = nn.ModuleList([nn.LayerNorm(N_features * N_heads) for _ in range(N_layers)]) #FIXME: How to choose layer norm's channel? 259 | self.self_attentions = nn.ModuleList([Attention(x_channel=N_features * N_heads, 260 | y_channel=N_features * N_heads, 261 | N_x=N_steps+1, 262 | N_y=N_steps+1, 263 | causal_mask=True, 264 | N_heads=N_heads, 265 | device=device) for _ in range(N_layers)]) 266 | self.cross_attentions = nn.ModuleList([Attention(x_channel=N_features * N_heads, 267 | y_channel=torso_feature_shape[1], 268 | N_x=N_steps+1, 269 | N_y=torso_feature_shape[0], 270 | causal_mask=False, 271 | N_heads=N_heads, 272 | device=device) for _ in range(N_layers)]) 273 | self.self_dropouts = nn.ModuleList([nn.Dropout() for _ in range(N_layers)]) 274 | self.cross_dropouts = nn.ModuleList([nn.Dropout() for _ in range(N_layers)]) 275 | self.relu = nn.ReLU() 276 | self.linear_2 = nn.Linear(N_features * N_heads, N_logits+1) 277 | 278 | 279 | def forward(self, x): 280 | 281 | # Input: 282 | # [e, (g)]. e is the features extracted by torso, g is groundtruth (available in train mode) 283 | # e: [B, m, c] 284 | # g: {0,1,... N_logits-1} ^ [B, N_steps], [B, N_steps] 285 | 286 | N_steps = self.N_steps 287 | N_logits = self.N_logits 288 | N_samples = self.N_samples 289 | N_features = self.N_features 290 | N_heads = self.N_heads 291 | device = self.device 292 | assert self.mode in ['train', 'infer'] 293 | 294 | if self.mode == 'train': 295 | e, g = x # g: {0,1,... N_logits-1} ^ [B, N_steps+1], [B, N_steps+1] 296 | if not torch.is_tensor(g): 297 | g = torch.tensor(g).long() 298 | g_onehot = one_hot(g, num_classes=N_logits, shift=True).float().to(device) # [B, N_steps+1, N_logits+1] 299 | #FIXME: We haven't applied "shift" operation. 300 | # # Apply "start" sign. 301 | # g_onehot = torch.cat([torch.zeros_like(g_onehot[:, :1, :]), g_onehot], dim=1) # [B, N_steps+1, N_logits+1] 302 | o, z = self.predict_action_logits(g_onehot, e) # o: [B, N_steps+1, N_logits+1]; z: [B, N_steps+1, N_features*N_heads] 303 | return o[:, :-1, :], z[:, 0] # o: [B, N_steps+1, N_logits+1]; z[:, 0]: [B, N_features*N_heads] 304 | 305 | elif self.mode == 'infer': 306 | e = x[0] # e: [B, m, c] 307 | batch_size = e.shape[0] 308 | a = -2 * torch.ones((batch_size * N_samples, N_steps+1)).long().to(device) # a: {-2,-1,0,1, ... N_logits-1} ^ [N_samples, N_steps+1] 309 | a[:, 0] = -1 # Start sign. 310 | p = torch.ones((batch_size * N_samples,)).float().to(device) 311 | 312 | # We sample N_samples batchly. 313 | e = e.repeat_interleave(N_samples, dim=0) # e: [B*N_samples, m, c] 314 | for i in range(1, N_steps+1): 315 | o, z = self.predict_action_logits(one_hot(a, num_classes=N_logits).float().to(device), e) # z: [B*N_samples, N_steps+1, N_features*N_heads] 316 | prob = F.softmax(o[:, i-1, :-1], dim=1) # o[:, i-1, :-1]: [B*N_samples, N_logits] 317 | sampled_a = torch.multinomial(prob, 1).view(-1) # sampled_a = [B*N_samples,] 318 | # for s in range(batch_size * N_samples): 319 | # p[s] = p[s] * prob[s, sampled_a[s]] 320 | a[:, i] = sampled_a 321 | 322 | if i == 1: 323 | # z1 = z[0, 0].clone() # [N_features*N_heads] 324 | z1 = z.reshape((-1, N_samples, N_steps+1, N_features*N_heads)) 325 | z1 = z1[:, 0, 0, :].clone() # [B, N_features*N_heads] 326 | return a[:, 1:].reshape((-1, N_samples, N_steps)),\ 327 | p.reshape((-1, N_samples)),\ 328 | z1 # [B, N_samples, N_steps], [B, N_samples], [B, N_features*N_heads] 329 | 330 | 331 | def set_mode(self, mode): 332 | 333 | assert mode in ["train", "infer"] 334 | self.mode = mode 335 | 336 | 337 | def set_samples_n(self, 338 | N_samples): 339 | self.N_samples = N_samples 340 | 341 | 342 | def predict_action_logits(self, 343 | a, e): 344 | 345 | N_steps = self.N_steps 346 | N_logits = self.N_logits 347 | N_features = self.N_features 348 | N_heads = self.N_heads 349 | N_layers = self.N_layers 350 | torso_feature_shape = self.torso_feature_shape 351 | device = self.device 352 | 353 | batch_size = a.shape[0] 354 | 355 | x = self.linear_1(a.reshape(batch_size*(N_steps+1), -1)) # [batch_size*N_steps, N_features*N_heads] 356 | x = self.pos_embed(torch.arange(0, (N_steps+1)).repeat(batch_size).float().view((-1,1)).to(device)) + x # [batch_size*N_steps, N_features*N_heads] 357 | x = x.reshape(batch_size, (N_steps+1), -1) # [batch_size, N_steps, N_features*N_heads] 358 | 359 | for layer in range(N_layers): 360 | x = self.self_layer_norms[layer](x) # [batch_size, N_steps, N_features*N_heads] 361 | c = self.self_attentions[layer]([x,]) # [batch_size, N_steps, N_features*N_heads] 362 | # if self.mode == 'train': 363 | # c = self.self_dropouts[layer](c) 364 | # else: 365 | # c = self.self_dropouts[layer].eval()(c) 366 | c = self.self_dropouts[layer](c) 367 | x = x + c # [batch_size, N_steps, N_features*N_heads] 368 | 369 | x = self.cross_layer_norms[layer](x) # [batch_size, N_steps, N_features*N_heads] 370 | c = self.cross_attentions[layer]([x, e]) # [batch_size, N_steps, N_features*N_heads] 371 | # if self.mode == 'train': 372 | # c = self.cross_dropouts[layer](c) 373 | # else: 374 | # c = self.cross_dropouts[layer].eval()(c) 375 | c = self.cross_dropouts[layer](c) 376 | x = x + c # [batch_size, N_steps, N_features*N_heads] 377 | 378 | o = self.linear_2(self.relu(x.reshape(-1, N_features*N_heads))).reshape(batch_size, (N_steps+1), N_logits+1) # [batch_size, N_steps, N_logits] 379 | 380 | return o, x 381 | 382 | 383 | class ValueHead(nn.Module): 384 | 385 | def __init__(self, 386 | N_layers=3, 387 | in_channel=2048, 388 | inter_channel=512, 389 | out_channel=8): 390 | 391 | super(ValueHead, self).__init__() 392 | self.out_channel = out_channel 393 | self.in_channel = in_channel 394 | self.inter_channel = inter_channel 395 | self.N_layers = N_layers 396 | self.mode = "train" 397 | 398 | self.in_linear = nn.Linear(in_channel, inter_channel) 399 | self.linaers = nn.ModuleList([nn.Linear(inter_channel, inter_channel) for _ in range(N_layers-1)]) 400 | self.out_linear = nn.Linear(inter_channel, out_channel) 401 | self.relus = nn.ModuleList([nn.ReLU() for _ in range(N_layers)]) 402 | 403 | 404 | def forward(self, x): 405 | 406 | # Input: 407 | # [z]. z: The feature in policy head. 408 | 409 | N_layers = self.N_layers 410 | 411 | z = x[0] 412 | # if self.mode == "infer": 413 | # z = z[None] 414 | z = self.relus[0](self.in_linear(z)) 415 | for layer in range(N_layers-1): 416 | z = self.relus[layer+1](self.linaers[layer](z)) 417 | 418 | q = self.out_linear(z) 419 | # if self.mode == "infer": 420 | # q = q[0] 421 | return q 422 | 423 | 424 | def set_mode(self, mode): 425 | 426 | assert mode in ["train", "infer"] 427 | self.mode = mode 428 | 429 | 430 | class Net(nn.Module): 431 | ''' 432 | 网络部分 433 | ''' 434 | 435 | def __init__(self, 436 | T=7, 437 | S_size=4, 438 | N_steps=6, 439 | coefficients=[0, 1, -1], 440 | N_samples=32, 441 | n_attentive=8, 442 | N_heads=32, 443 | N_features=64, 444 | policy_layers=2, 445 | device='cuda', 446 | channel=3, 447 | scalar_size=3, 448 | value_layers=3, 449 | inter_channel=512, 450 | out_channel=8, 451 | **kwargs): 452 | ''' 453 | 初始化部分 454 | ''' 455 | 456 | super(Net, self).__init__() 457 | # Parameters. 458 | self.T = T 459 | self.S_size = S_size 460 | self.N_steps = N_steps 461 | self.coefficients = coefficients 462 | token_len = 3 * S_size // N_steps 463 | N_logits = len(coefficients) ** token_len # len(F) ^ len(token) 464 | self.N_logits = N_logits 465 | self.token_len = token_len 466 | self.N_samples = N_samples 467 | 468 | # Network. 469 | self.torso = Torso(S_size=S_size, T=T, 470 | n_attentive=n_attentive, 471 | device=device, 472 | channel=channel, 473 | scalar_size=scalar_size) 474 | self.policy_head = PolicyHead(N_steps=N_steps, 475 | N_logits=N_logits, 476 | N_samples=N_samples, 477 | N_heads=N_heads, 478 | N_features=N_features, 479 | device=device, 480 | N_layers=policy_layers, 481 | torso_feature_shape=(3*S_size**2, channel)) 482 | self.value_head = ValueHead(in_channel=N_features*N_heads, 483 | N_layers=value_layers, 484 | inter_channel=inter_channel, 485 | out_channel=out_channel) 486 | self.mode = "train" 487 | 488 | 489 | def forward(self, x): 490 | 491 | # Input: 492 | # If train mode: 493 | # Tensors of shape of [B,T,S,S,S]. First one is current tensor. (numpy) 494 | # Scalars of shape of [B,s]. (numpy) 495 | # (Groundtruth) of shape of [B, N_steps]. 496 | # Elif infer mode: 497 | # Tensors of shape of [T,S,S,S] 498 | # Scalars of shape of [s,] 499 | 500 | if self.mode == 'train': 501 | states, scalars, g = x 502 | self.policy_head.set_mode("train") 503 | self.value_head.set_mode("train") 504 | self.torso.set_mode("train") 505 | 506 | e = self.torso([states, scalars]) 507 | o, z1 = self.policy_head([e, g]) 508 | q = self.value_head([z1]) 509 | 510 | return o, q # o: [B, N_steps, N_logits]; q: [B, out_channels] 511 | 512 | elif self.mode == 'infer': 513 | states, scalars = x 514 | self.policy_head.set_mode("infer") 515 | self.value_head.set_mode("infer") 516 | self.torso.set_mode("infer") 517 | 518 | e = self.torso([states, scalars]) 519 | a, p, z1 = self.policy_head([e]) 520 | q = self.value_head([z1]) 521 | 522 | return a, p, q #FIXME: Neet to process q. 523 | # a: {0,1,..., N_logits-1} ^ [B, N_samples, N_steps]; p: [B, N_samples,]; q: [B, out_channels,] 524 | 525 | 526 | def set_mode(self, mode): 527 | 528 | assert mode in ["train", "infer"] 529 | self.mode = mode 530 | 531 | 532 | def set_samples_n(self, 533 | N_samples): 534 | self.N_samples = N_samples 535 | self.policy_head.set_samples_n(N_samples) 536 | 537 | 538 | def logits_to_action(self, logits): 539 | ''' 540 | logit: N_steps values of {0, 1, ..., N_logits - 1}. 541 | e.g.: 542 | If: 543 | token_len = 2 544 | coefficients = [0, 1, -1] 545 | N_steps = 6 546 | Then: 547 | [0, 1, 2, 3, 4, 5] -> [0 0 | 0 1 | 0 -1 | 1 0 | 1 1 | 1 -1 ] 548 | ''' 549 | token_len = self.token_len 550 | coefficients = self.coefficients 551 | action = [] 552 | for logit in logits: # Get one action 553 | token = [] 554 | if logit == self.N_logits: 555 | raise 556 | for _ in range(token_len): # Get one token 557 | idx = logit % len(coefficients) 558 | token.append(coefficients[idx]) 559 | logit = logit // len(coefficients) 560 | token.reverse() 561 | action.extend(token) 562 | 563 | action = np.array(action, dtype=np.int32).reshape((3, -1)) 564 | return action 565 | 566 | 567 | def action_to_logits(self, 568 | action): 569 | ''' 570 | action: A [3, S_size] array. 571 | ''' 572 | 573 | # Break action into tokens. 574 | token_len = self.token_len 575 | coefficients = self.coefficients 576 | action = action.reshape((-1, token_len)) # [N_steps, token_len] 577 | 578 | # Get logits. 579 | logits = [] 580 | for token in action: # Get one logit. 581 | # token = token.to_list() 582 | logit = 0 583 | if torch.is_tensor(token): 584 | token = torch.flip(token, dims=(0,)) 585 | else: 586 | token = token[::-1] 587 | for idx, v in enumerate(token): 588 | logit += coefficients.index(v) * (len(coefficients) ** idx) 589 | logits.append(logit) 590 | 591 | return np.array(logits) 592 | 593 | 594 | def value(self, output, u_q=.75): 595 | ''' 596 | 根据网络的输出, 得到效用值 597 | output: output = net(x) 598 | ''' 599 | q = output[-1] 600 | q = q.detach().cpu().numpy() 601 | batch_size, out_channels = q.shape[0], q.shape[1] 602 | 603 | j = math.ceil(u_q * out_channels) 604 | return q, q[:, (j-1):].mean(axis=1) 605 | 606 | 607 | def policy(self, output): 608 | ''' 609 | 根据网络的输出, 得到采样的策略 610 | output: output = net(x) 611 | ''' 612 | assert len(output) == 3, "We need the output from infer mode." 613 | 614 | a, p, _ = output 615 | a, p = a.detach().cpu().numpy(), p.detach().cpu().numpy() 616 | 617 | batch_actions = [] 618 | for batch in a: 619 | actions = [] 620 | for logits in batch: 621 | actions.append(self.logits_to_action(logits)) 622 | actions = np.stack(actions, axis=0) 623 | batch_actions.append(actions) 624 | batch_actions = np.stack(batch_actions, axis=0) 625 | 626 | return batch_actions, p 627 | 628 | 629 | 630 | if __name__ == '__main__': 631 | # torso = Torso().to('cuda') 632 | # test_input = [np.random.randint(-1, 1, (64, 7, 4, 4, 4)), np.random.randint(-1, 1, (64, 3))] 633 | # e = torso(test_input) 634 | # import pdb; pdb.set_trace() 635 | 636 | # policy_head = PolicyHead().to('cuda') 637 | # test_g = torch.tensor([0,1,2,3,4,5]).repeat(64).reshape(64, 6).to('cuda') 638 | # train_output = policy_head([e, test_g]) 639 | # import pdb; pdb.set_trace() 640 | 641 | # value_head = ValueHead().to('cuda') 642 | # value_output = value_head([train_output[1]]) 643 | # import pdb; pdb.set_trace() 644 | 645 | # net = Net().to('cuda') 646 | # test_tensor, test_scalar = np.random.randint(-1, 1, (7, 4, 4, 4)), np.random.randint(-1, 1, (3)) 647 | # test_input = [test_tensor[None], test_scalar[None]] 648 | # test_g = torch.tensor([-1,0,1,2,3,4,5]).repeat(1).reshape(1, 7).to('cuda') 649 | # train_output = net([*test_input, test_g]) 650 | # # import pdb; pdb.set_trace() 651 | 652 | # net.set_mode("infer") 653 | # test_input = [test_tensor, test_scalar] 654 | # infer_output = net(test_input) 655 | # import pdb; pdb.set_trace() 656 | 657 | # _, value = net.value(infer_output) 658 | # policy = net.policy(infer_output) 659 | # import pdb; pdb.set_trace() 660 | 661 | net = Net() 662 | logits = np.array([-1,0,1,2,2,1,1]) 663 | action = net.logits_to_action(logits) 664 | logits = net.action_to_logits(action) 665 | import pdb; pdb.set_trace() -------------------------------------------------------------------------------- /codes/scripts/check_redundancy.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import numpy as np 4 | 5 | S_size = 4 6 | 7 | data_name = "100000_S4T7_scalar3.npy" 8 | data_path = os.path.join(".", "data", "traj_data", data_name) 9 | 10 | data = np.load(data_path, allow_pickle=True) 11 | filtered_data = []; ct = 0 12 | 13 | for traj_idx, traj in tqdm(enumerate(data)): 14 | _, actions, _ = traj 15 | raw_r = len(actions) 16 | flag = False 17 | for (i, j) in [[0,1], [1,2], [2,0]]: 18 | _mat = np.zeros((S_size ** 2, raw_r), dtype=np.int32) 19 | for idx, action in enumerate(actions): 20 | _mat[:, idx] = np.outer(action[i], action[j]).reshape((-1,)) 21 | if np.linalg.matrix_rank(_mat) < raw_r: 22 | ct += 1 23 | print("Found redundancy...") 24 | flag = True 25 | break 26 | 27 | if not flag: 28 | filtered_data.append(traj) 29 | 30 | import pdb; pdb.set_trace() -------------------------------------------------------------------------------- /codes/trainer/Player.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from torch.utils.tensorboard import SummaryWriter 3 | import sys 4 | import os 5 | sys.path.append(os.path.abspath(os.path.join(".."))) 6 | sys.path.append(os.path.abspath(os.path.join("."))) 7 | from codes.trainer.loss import QuantileLoss 8 | from codes.env import * 9 | from codes.mcts import * 10 | from codes.utils import * 11 | from codes.dataset import * 12 | from codes.multi_runner import * 13 | 14 | 15 | class Player(): 16 | 17 | def __init__(self, 18 | net, env, mcts, 19 | exp_dir, 20 | simu_times=400, 21 | play_times=10, 22 | num_workers=256, 23 | device="cuda:1", 24 | noise=False): 25 | 26 | self.net = net 27 | self.env = env 28 | self.mcts = mcts 29 | 30 | net.to(device) 31 | 32 | self.exp_dir = exp_dir 33 | self.trainer_logger = SummaryWriter(log_dir=os.path.join(exp_dir, "log")) 34 | 35 | self.simu_times = simu_times 36 | self.play_times = play_times 37 | self.num_workers = num_workers 38 | self.device = device 39 | self.noise = noise 40 | 41 | self.call_ct = 0 42 | 43 | 44 | def play(self, warm_up=False) -> list: 45 | ''' 46 | 进行一次Tensor Game, 得到游玩记录 47 | 返回: results 48 | ''' 49 | 50 | num_workers = self.num_workers 51 | simu_times = self.simu_times 52 | play_times = self.play_times 53 | noise = self.noise 54 | 55 | if warm_up: 56 | simu_times = 40 57 | play_times = 1 58 | num_workers = 10 59 | 60 | results = [] 61 | avg_steps = 0 62 | net = self.net 63 | env = self.env 64 | mcts = self.mcts 65 | net.set_mode("infer") 66 | net.eval() 67 | 68 | # wkdir. 69 | save_dir = os.path.join(self.exp_dir, "data") 70 | os.makedirs(save_dir, exist_ok=True) 71 | save_path = os.path.join(save_dir, "self_data.npy") 72 | 73 | # Load ckpt. 74 | latest_path = os.path.join(self.exp_dir, "ckpt", "latest.pth") 75 | ct = 0 76 | while True: 77 | if ct > 10000000: 78 | raise Exception 79 | try: 80 | self.load_model(ckpt_path=latest_path, to_device=self.device) 81 | break 82 | except: 83 | ct += 1 84 | continue 85 | 86 | # Multi runner. 87 | env_list = [copy.deepcopy(env) for _ in range(num_workers)] 88 | mcts_list = [copy.deepcopy(mcts) for _ in range(num_workers)] 89 | envs = ENVS(env_list) 90 | mctsf = MCTSF(mcts_list, simulate_times=simu_times) 91 | 92 | for game in (range(play_times)): 93 | 94 | envs.reset() 95 | state_list = envs.get_curstates() 96 | mctsf.reset(state_list) 97 | trajs = [] 98 | 99 | while True: 100 | state_list = envs.get_curstates() 101 | action_list, _, _ = mctsf(state_list, net, noise=noise) 102 | envs.step(action_list) 103 | terminate_flag = envs.is_all_terminated() 104 | mctsf.move(action_list) # Move MCTS forward. 105 | # one_result.append([state_list, action_list]) # Record. (s, a). 106 | trajs.append([[state, action] for state, action in zip(state_list, action_list)]) 107 | 108 | if terminate_flag: 109 | reward_list = envs.get_rewards() 110 | # for step in range(env.step_ct): 111 | # one_result[step] += [final_reward + step] # Final results. (s, a, r(s)). 112 | # # Note: 113 | # # a is not included in the history actions of s. 114 | for step, _trajs in enumerate(trajs): 115 | for idx, traj_state in enumerate(_trajs): # traj_state: [state, action] 116 | traj_state += [reward_list[idx] + (step)] 117 | 118 | # Prepare per traj. 119 | for idx in range(num_workers): 120 | one_traj = [_trajs[idx] for _trajs in trajs] # [ [s1, a1, r1], [s2, a2, r2] ... ] 121 | one_traj = [episode for episode in one_traj if not is_zero_tensor(episode[0])] # Filter the invalid state. 122 | 123 | states, actions, rewards = \ 124 | [episode[0] for episode in one_traj], \ 125 | [episode[1] for episode in one_traj], \ 126 | [episode[2] for episode in one_traj] 127 | 128 | one_traj = [states, actions, rewards] 129 | results.append(one_traj) 130 | 131 | step_ct_list = envs.get_stepcts() 132 | batch_avg_step_ct = np.array(step_ct_list).mean() 133 | 134 | avg_steps = (game / (game+1)) * avg_steps + batch_avg_step_ct / (game+1) 135 | break 136 | 137 | # Insert a bubble. 138 | # results: (n_trajs, 3, traj_length*) 139 | one_state, one_action, one_reward = results[-1][0][0], results[-1][1][0], results[-1][2][0] 140 | bubble_traj = [[one_state,], [one_action,], [one_reward,]] 141 | results.append(bubble_traj) 142 | if save_path is not None: 143 | results = np.array(results, dtype=object) # (n_trajs, 3, traj_length*). *traj_length is not fixed. Decompose order... 144 | np.save(save_path, results) 145 | print("Avg step is: %f" % avg_steps) 146 | 147 | self.trainer_logger.add_text("Self-play", str(one_traj), global_step=self.call_ct) 148 | self.call_ct += 1 149 | 150 | return results, one_traj 151 | 152 | 153 | def run(self): 154 | 155 | self.play(warm_up=True) 156 | while True: 157 | self.play() 158 | print("Finish playing!") 159 | 160 | 161 | def load_model(self, ckpt_path, only_weight=False, to_device="cuda:0"): 162 | ckpt = torch.load(ckpt_path) 163 | self.net.load_state_dict(ckpt['model']) 164 | return ckpt['iter'] -------------------------------------------------------------------------------- /codes/trainer/Trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import yaml 3 | import copy 4 | import random 5 | import shutil 6 | from tqdm import tqdm 7 | import torch 8 | from torch.utils.data import DataLoader, Subset 9 | from torch.utils.tensorboard import SummaryWriter 10 | import numpy as np 11 | 12 | import sys 13 | import os 14 | sys.path.append(os.path.abspath(os.path.join(".."))) 15 | sys.path.append(os.path.abspath(os.path.join("."))) 16 | from codes.trainer.loss import * 17 | from codes.trainer.Player import * 18 | from codes.env import * 19 | from codes.mcts import * 20 | from codes.utils import * 21 | from codes.dataset import * 22 | from codes.multi_runner import * 23 | 24 | class Trainer(): 25 | ''' 26 | 用于训练网络 27 | ''' 28 | 29 | def __init__(self, 30 | net: Net, 31 | env: Environment, 32 | mcts: MCTS, 33 | S_size=4, 34 | T=7, 35 | coefficients=[0, 1, -1], 36 | batch_size=1024, 37 | iters_n=50000, 38 | exp_dir="exp", 39 | exp_name="debug", 40 | device="cuda:0", 41 | self_play_device="cuda:1", 42 | lr=5e-3, 43 | weight_decay=1e-5, 44 | step_size=40000, 45 | gamma=.1, 46 | a_weight=.5, 47 | v_weight=.5, 48 | save_freq=10000, 49 | temp_save_freq=2500, 50 | self_play_freq=10, 51 | self_play_buffer=100000, 52 | grad_clip=4.0, 53 | val_freq=2000, 54 | all_kwargs=None): 55 | ''' 56 | 初始化一个Trainer. 57 | 包含net, env和MCTS 58 | ''' 59 | self.env = env 60 | self.net = net 61 | self.mcts = mcts 62 | self.S_size = S_size 63 | self.T = T 64 | self.coefficients = coefficients 65 | 66 | self.self_examples = [] 67 | self.synthetic_examples = [] 68 | 69 | self.entropy_loss = torch.nn.CrossEntropyLoss() 70 | self.quantile_loss = QuantileLoss() 71 | self.a_weight = a_weight 72 | self.v_weight = v_weight 73 | 74 | self.optimizer_a = torch.optim.AdamW(net.parameters(), 75 | weight_decay=weight_decay, 76 | lr=lr) 77 | self.scheduler_a = torch.optim.lr_scheduler.StepLR(self.optimizer_a, 78 | step_size=step_size, 79 | gamma=gamma) 80 | self.optimizer_v = torch.optim.AdamW(net.parameters(), 81 | weight_decay=weight_decay, 82 | lr=lr) 83 | self.scheduler_v = torch.optim.lr_scheduler.StepLR(self.optimizer_v, 84 | step_size=step_size, 85 | gamma=gamma) 86 | 87 | self.batch_size = batch_size 88 | self.iters_n = iters_n 89 | self.grad_clip = grad_clip 90 | self.save_freq = save_freq 91 | self.temp_save_freq = temp_save_freq 92 | self.self_play_freq = self_play_freq 93 | self.self_play_buffer = self_play_buffer 94 | self.val_freq = val_freq 95 | 96 | self.exp_dir = exp_dir 97 | self.save_dir = os.path.join(exp_dir, exp_name, str(int(time.time()))) 98 | self.log_dir = os.path.join(self.save_dir, "log") 99 | self.data_dir = os.path.join(self.save_dir, "data") 100 | 101 | self.device = device 102 | self.self_play_device = self_play_device 103 | self.net.to(device) 104 | self.all_kwargs = all_kwargs 105 | 106 | 107 | def generate_synthetic_examples(self, 108 | prob=[.8, .1, .1], 109 | samples_n=10000, 110 | R_limit=12, 111 | save_path=None, 112 | save_type="traj") -> list: 113 | ''' 114 | 生成人工合成的Tensor examples 115 | 返回: results 116 | ''' 117 | assert save_type in ["traj", "tuple"] 118 | 119 | S_size = self.S_size 120 | coefficients = self.coefficients 121 | T = self.T 122 | 123 | total_results = [] 124 | for _ in tqdm(range(samples_n)): 125 | R = random.randint(1, R_limit) 126 | for _ in range(10000): 127 | sample = np.zeros((S_size, S_size, S_size), dtype=np.int32) 128 | states = [] 129 | actions = [] 130 | rewards = [] 131 | for r in range(1, (R+1)): 132 | ct = 0 133 | while True: 134 | u = np.random.choice(coefficients, size=(S_size,), p=prob, replace=True) 135 | v = np.random.choice(coefficients, size=(S_size,), p=prob, replace=True) 136 | w = np.random.choice(coefficients, size=(S_size,), p=prob, replace=True) 137 | ct += 1 138 | if not is_zero_tensor(outer(u, v, w)): 139 | break 140 | if ct > 100000: 141 | raise Exception("Oh my god...") 142 | sample = sample + outer(u, v, w) 143 | action = np.stack([u, v, w], axis=0) 144 | actions.append(canonicalize_action(action)) 145 | states.append(sample.copy()) 146 | rewards.append(-r) 147 | 148 | # Check redundancy. 149 | red_flag = False 150 | for (i, j) in [[0,1], [1,2], [2,0]]: 151 | _mat = np.zeros((S_size ** 2, R), dtype=np.int32) 152 | for idx, action in enumerate(actions): 153 | _mat[:, idx] = np.outer(action[i], action[j]).reshape((-1,)) 154 | if np.linalg.matrix_rank(_mat) < R: 155 | red_flag = True 156 | break 157 | 158 | if red_flag: 159 | continue 160 | break 161 | 162 | # Reformulate the results. 163 | if save_type == "tuple": 164 | states.reverse(); actions.reverse(); rewards.reverse() 165 | actions_tensor = [action2tensor(action) for action in actions] 166 | for idx, state in enumerate(states): 167 | tensors = np.zeros((T, S_size, S_size, S_size), dtype=np.int32) 168 | tensors[0] = state # state. 169 | if idx != 0: 170 | # History actions. 171 | tensors[1:(idx+1)] = np.stack(reversed(actions_tensor[max(idx-(T-1), 0):idx]), axis=0) 172 | scalars = np.array([idx, idx, idx]) #FIXME: Havn't decided the scalars. 173 | 174 | cur_state = [tensors, scalars] 175 | action = actions[idx] 176 | reward = rewards[idx] 177 | total_results.append([cur_state, action, reward]) 178 | 179 | else: 180 | traj = [states, actions, rewards] # Note: Synthesis order... 181 | total_results.append(traj) 182 | 183 | if save_path is not None: 184 | np.save(save_path, np.array(total_results, dtype=object)) 185 | 186 | return total_results 187 | 188 | 189 | def learn_one_batch(self, 190 | batch_example) -> torch.autograd.Variable: 191 | ''' 192 | 对一个元组进行学习 193 | ''' 194 | 195 | # Groundtruth. 196 | s, a_gt, v_gt = batch_example # s: [tensor, scalar] 197 | a_gt = a_gt.long().to(self.device) 198 | v_gt = v_gt.float().to(self.device) 199 | 200 | # Network infer. 201 | self.net.set_mode("train") 202 | output = self.net([*s, a_gt]) 203 | o, q = output # o: [batch_size, N_steps, N_logits], q: [batch_size, N_quantiles] 204 | 205 | # Losses. 206 | v_loss = self.quantile_loss(q, v_gt) # v_gt: [batch_size,] 207 | o = o.transpose(1,2) # o: [batch_size, N_logits, N_steps] 208 | a_loss = self.entropy_loss(o, a_gt) # a_gt: [batch_size, N_steps], o: [batch_size, N_logits, N_steps] 209 | loss = self.v_weight * v_loss + self.a_weight * a_loss 210 | 211 | del a_gt, v_gt 212 | 213 | return loss, v_loss, a_loss 214 | 215 | 216 | def val_one_episode(self, 217 | episode): 218 | ''' 219 | 对一个元组进行验证, 打印输出 220 | ''' 221 | 222 | state, action, reward = episode 223 | tensor, scalar = state 224 | 225 | self.net.set_mode("infer") 226 | self.net.set_samples_n(4) 227 | output = self.net(state) 228 | a, p, _ = output 229 | a = a.detach().cpu().numpy() 230 | q, v = self.net.value(output) 231 | policy, p = self.net.policy(output) 232 | 233 | policy, p, q, v = policy[0], p[0], q[0], v[0] 234 | 235 | self.net.set_mode("train") 236 | tensor, scalar, action = tensor[None], scalar[None], action[None] 237 | state = [tensor, scalar] 238 | o, _ = self.net([*state, action]) 239 | o = o.detach().cpu().numpy()[0] # o: [N_steps, N_logits] 240 | 241 | log_txt = "\n".join( 242 | ["\nState: \n", str(state[0][0, 0]), 243 | "\nGt action: \n", str(self.net.logits_to_action(action[0])), 244 | "\nGt logit: \n", str(action[0]), 245 | "\nInfer actions: \n", str(policy), 246 | "\nInfer logits: \n", str(a), 247 | "\nprob: \n", str(p), 248 | "\nGt value: \n", str(reward), 249 | "\nInfer value: \n", str(v), 250 | "\nquantile: \n", str(q), 251 | *["\nTop 5 logit for step %d\n: " % step + str(np.argsort(o[step])[-5:]) for step in range(self.net.N_steps)]] 252 | ) 253 | 254 | del a, o, p, _, output, q, v 255 | 256 | return log_txt 257 | 258 | 259 | def learn(self, 260 | resume=None, 261 | only_weight=False, 262 | example_path=None, 263 | self_example_path=None, 264 | save_type="traj", 265 | self_play=False): 266 | ''' 267 | 训练的主函数 268 | ''' 269 | optimizer_a = self.optimizer_a 270 | scheduler_a = self.scheduler_a 271 | optimizer_v = self.optimizer_v 272 | scheduler_v = self.scheduler_v 273 | batch_size = self.batch_size 274 | self_play_freq = self.self_play_freq 275 | self_play_buffer = self.self_play_buffer 276 | 277 | # Tensorboard. 278 | os.makedirs(self.save_dir) 279 | os.makedirs(self.log_dir) 280 | self.log_writer = SummaryWriter(self.log_dir) 281 | 282 | # Save config. 283 | all_kwargs = self.all_kwargs 284 | cfg_path = os.path.join(self.save_dir, "config.yaml") 285 | with open(cfg_path, 'w') as f: 286 | yaml.dump(all_kwargs, f) 287 | 288 | if resume is not None: 289 | # Load model. 290 | old_iter = self.load_model(resume, only_weight) 291 | # Copy log file. 292 | old_exp_dir = os.path.join(os.path.dirname(resume), '..') 293 | # os.system("cp -r %s %s" % (os.path.join(old_exp_dir, 'log', '*'), self.log_dir)) 294 | for log_f in os.listdir(os.path.join(old_exp_dir, "log")): 295 | shutil.copy(os.path.join(old_exp_dir, "log", log_f), self.log_dir) 296 | else: 297 | old_iter = 0 298 | 299 | # Save ckpt. 300 | ckpt_name = "latest.pth" 301 | self.save_model(ckpt_name, old_iter) 302 | 303 | # 1. Get synthetic examples. 304 | if example_path is not None: 305 | self.synthetic_examples.extend(self.load_examples(example_path)) 306 | else: 307 | self.synthetic_examples.extend(self.generate_synthetic_examples(samples_n=3000)) 308 | 309 | if self_example_path is not None: 310 | self.self_examples.extend(self.load_examples(self_example_path)) 311 | 312 | # Dataloader. 313 | dataset = TupleDataset(T=self.T, 314 | S_size=self.S_size, 315 | N_steps=self.net.N_steps, 316 | coefficients=self.coefficients, 317 | self_data=self.self_examples, 318 | synthetic_data=self.synthetic_examples) 319 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0) 320 | loader = iter(dataloader) 321 | epoch_ct = 0 322 | 323 | for i in tqdm(range(old_iter, self.iters_n)): 324 | 325 | # 2. self-play for data. 326 | # if i % self_play_freq == 0: 327 | # self.self_examples.extend(self.play(200 if i < 50000 else 800)) 328 | 329 | try: 330 | batch_example = next(loader) 331 | except StopIteration: 332 | dataloader.dataset._permutate_traj() 333 | if self_play: 334 | self_examples = self.get_self_examples() # New self-play data. 335 | if self_examples is not None: 336 | print("Detect new self-data!") 337 | self.self_examples.extend(self_examples) 338 | self.self_examples = self.self_examples[-self_play_buffer:] 339 | np.save(os.path.join(self.data_dir, "total_self_data.npy"), np.array(self.self_examples, dtype=object)) # Whole buffer. 340 | synthetic_examples_n = 2000 if i > 50000 else 100000 341 | dataset = TupleDataset(T=self.T, 342 | S_size=self.S_size, 343 | N_steps=self.net.N_steps, 344 | coefficients=self.coefficients, 345 | self_data=self.self_examples, 346 | synthetic_data=random.sample(self.synthetic_examples, synthetic_examples_n)) 347 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0) 348 | else: 349 | print("No detect new self-data...") 350 | 351 | loader = iter(dataloader) 352 | batch_example = next(loader) 353 | print("Epoch: %d finish." % epoch_ct) 354 | epoch_ct += 1 355 | 356 | # 此处进行多进程优化 357 | # todo: 什么时候更新网络参数 358 | optimizer_a.zero_grad() 359 | optimizer_v.zero_grad() 360 | loss, v_loss, a_loss = self.learn_one_batch(batch_example) 361 | loss.backward() 362 | torch.nn.utils.clip_grad_norm(self.net.parameters(), 363 | max_norm=self.grad_clip) 364 | optimizer_a.step() 365 | optimizer_v.step() 366 | scheduler_a.step() 367 | scheduler_v.step() 368 | 369 | # 添加logger部分 370 | if i % 20 == 0: 371 | # print("Loss: %f, v_loss: %f, a_loss: %f" % 372 | # (loss.detach().cpu().item(), v_loss.detach().cpu().item(), a_loss.detach().cpu().item())) 373 | self.log_writer.add_scalar("loss", loss.detach().cpu().item(), global_step=i) 374 | self.log_writer.add_scalar("v_loss", v_loss.detach().cpu().item(), global_step=i) 375 | self.log_writer.add_scalar("a_loss", a_loss.detach().cpu().item(), global_step=i) 376 | 377 | if i % self.save_freq == 0: 378 | ckpt_name = "it%07d.pth" % i 379 | self.save_model(ckpt_name, i) 380 | 381 | if i % self.temp_save_freq == 0: 382 | ckpt_name = "latest.pth" 383 | self.save_model(ckpt_name, i) 384 | 385 | if i % self.val_freq == 0: 386 | val_episode = dataset[random.randint(0, len(dataset)-1)] 387 | log_txt = self.val_one_episode(val_episode) 388 | self.log_writer.add_text("Infer", log_txt, global_step=i) 389 | 390 | self.save_model("final.pth", i) 391 | 392 | 393 | def infer(self, 394 | init_state=None, 395 | no_base_change=True, 396 | mcts_simu_times=10000, 397 | mcts_samples_n=16, 398 | step_limit=12, 399 | resume=None, 400 | vis=False, 401 | noise=False, 402 | log=True): 403 | 404 | log_actions = [] 405 | 406 | assert resume is not None, "No meaning for random init infer." 407 | self.load_model(resume) 408 | if log: 409 | exp_dir = os.path.join(os.path.dirname(resume), '..') 410 | infer_log_dir = os.path.join(exp_dir, "infer") 411 | os.makedirs(infer_log_dir, exist_ok=True) 412 | infer_log_f = os.path.join(infer_log_dir, str(int(time.time()))+'.txt') 413 | 414 | net = self.net 415 | env = self.env 416 | mcts = self.mcts 417 | 418 | env.reset(init_state, no_base_change) 419 | net.set_mode("infer") 420 | net.set_samples_n(mcts_samples_n) 421 | net.eval() 422 | mcts.reset(env.cur_state, simulate_times=mcts_simu_times, R_limit=step_limit) 423 | env.R_limit = step_limit + 1 424 | 425 | step_ct = 0 426 | for step in tqdm(range(step_limit)): 427 | print("Current state is (step%d):" % step) 428 | print(env.cur_state) 429 | 430 | action, actions, pi, log_txt = mcts(env.cur_state, net, log=True, noise=noise) 431 | if vis: 432 | mcts.visualize() 433 | print("We choose action(step%d):" % step) 434 | print(action) 435 | terminate_flag = env.step(action) # Will change self.cur_state. 436 | mcts.move(action) # Move MCTS forward. 437 | log_actions.append(action) 438 | 439 | if log: 440 | with open(infer_log_f, "a") as f: 441 | f.write(log_txt) 442 | f.write("\n\n\n") 443 | 444 | if terminate_flag: 445 | step_ct = step + 1 446 | print("We get to the end!") 447 | break 448 | 449 | 450 | print("Final result:") 451 | print(env.cur_state) 452 | 453 | print("Actions are:") 454 | print(np.stack(log_actions, axis=0)) 455 | 456 | if log: 457 | with open(infer_log_f, "a") as f: 458 | f.write("\n\n\n") 459 | f.write("\nFinal result:\n") 460 | f.write("\n" + str(env.cur_state) + "\n") 461 | f.write("\nActions are:\n") 462 | f.write("\n" + str(np.stack(log_actions, axis=0)) + "\n") 463 | f.write("\n\n\n") 464 | f.write("\nStep ct: %d\n" % step_ct) 465 | 466 | return step_ct 467 | 468 | 469 | def filter_train_data(self, 470 | n=100, 471 | example_path=None, 472 | mcts_simu_times=10000, 473 | mcts_samples_n=16, 474 | step_limit=12, 475 | resume=None): 476 | 477 | assert resume is not None, "No meaning for random init infer." 478 | assert example_path is not None 479 | 480 | synthetic_examples = self.load_examples(example_path) 481 | test_data = random.sample(list(synthetic_examples), n) 482 | 483 | exp_dir = os.path.join(os.path.dirname(resume), '..') 484 | infer_log_dir = os.path.join(exp_dir, "infer") 485 | os.makedirs(infer_log_dir, exist_ok=True) 486 | infer_log_f = os.path.join(infer_log_dir, str(int(time.time()))+'.txt') 487 | 488 | for traj in (test_data): 489 | # import pdb; pdb.set_trace() 490 | states, _, _ = traj 491 | raw_r = len(states) 492 | init_state = states[-1] 493 | 494 | result_r = self.infer(init_state=init_state, 495 | no_base_change=True, 496 | mcts_samples_n=mcts_samples_n, 497 | mcts_simu_times=mcts_simu_times, 498 | step_limit=step_limit, 499 | resume=resume, 500 | log=False) 501 | 502 | with open(infer_log_f, "a") as f: 503 | f.write("%d %d\n" % (raw_r, result_r)) 504 | 505 | 506 | 507 | def save_model(self, ckpt_name, iter): 508 | save_dir = os.path.join(self.save_dir, "ckpt") 509 | os.makedirs(save_dir, exist_ok=True) 510 | save_path = os.path.join(save_dir, ckpt_name) 511 | torch.save({'model': self.net.state_dict(), 512 | 'iter': iter, 513 | 'optimizer_a': self.optimizer_a.state_dict(), 514 | 'optimizer_v': self.optimizer_v.state_dict(), 515 | 'scheduler_a': self.scheduler_a.state_dict(), 516 | 'scheduler_v': self.scheduler_v.state_dict()}, save_path) 517 | 518 | 519 | def load_model(self, ckpt_path, only_weight=False, to_device="cuda:0"): 520 | ckpt = torch.load(ckpt_path) 521 | self.net.load_state_dict(ckpt['model']) 522 | if not only_weight: 523 | self.optimizer_a.load_state_dict(ckpt['optimizer_a']) 524 | self.optimizer_v.load_state_dict(ckpt['optimizer_a']) 525 | self.scheduler_a.load_state_dict(ckpt['scheduler_v']) 526 | self.scheduler_v.load_state_dict(ckpt['scheduler_v']) 527 | return ckpt['iter'] 528 | 529 | 530 | def load_examples(self, example_path): 531 | return np.load(example_path, allow_pickle=True) 532 | 533 | 534 | def get_self_examples(self): 535 | newest_data_path = os.path.join(self.data_dir, "self_data.npy") # New data from player. 536 | old_data_path = os.path.join(self.data_dir, "self_data_old.npy") 537 | if os.path.exists(newest_data_path): 538 | ct = 0 539 | while True: 540 | if ct > 10000000: 541 | raise Exception 542 | try: 543 | self_examples = self.load_examples(newest_data_path) 544 | os.system("mv %s %s" % (newest_data_path, old_data_path)) 545 | return self_examples 546 | except: 547 | ct += 1 548 | continue 549 | return None 550 | 551 | 552 | 553 | if __name__ == '__main__': 554 | 555 | conf_path = "./config/my_conf.yaml" 556 | with open(conf_path, 'r', encoding="utf-8") as f: 557 | kwargs = yaml.load(f.read(), Loader=yaml.FullLoader) 558 | 559 | net = Net(**kwargs["net"]) 560 | mcts = MCTS(**kwargs["mcts"], 561 | init_state=None) 562 | env = Environment(**kwargs["env"], 563 | init_state=None) 564 | trainer = Trainer(**kwargs["trainer"], 565 | net=net, env=env, mcts=mcts, 566 | all_kwargs=kwargs) 567 | 568 | 569 | # import pdb; pdb.set_trace() 570 | # res = trainer.play() 571 | # res = trainer.generate_synthetic_examples() 572 | # trainer.learn(example_path="./data/100000_T5_scalar3.npy") 573 | # import pdb; pdb.set_trace() 574 | # trainer.load_model("./exp/debug/1680630182/ckpt/it0020000.pth") 575 | # trainer.infer() 576 | # trainer.infer(resume="./exp/debug/1680764978/ckpt/it0002000.pth") 577 | # import pdb; pdb.set_trace() 578 | # trainer.generate_synthetic_examples(samples_n=3000, save_path="./data/3000_T5_scalar3.npy") 579 | # trainer.learn(resume=None, 580 | # example_path="./data/100000_T5_scalar3.npy") 581 | # trainer.learn(resume=None) 582 | trainer.learn(resume=None, 583 | example_path="./data/3000_T5_scalar3.npy") -------------------------------------------------------------------------------- /codes/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from codes.trainer.Trainer import * 2 | from codes.trainer.Player import * -------------------------------------------------------------------------------- /codes/trainer/__pycache__/Player.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/trainer/__pycache__/Player.cpython-39.pyc -------------------------------------------------------------------------------- /codes/trainer/__pycache__/Trainer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/trainer/__pycache__/Trainer.cpython-38.pyc -------------------------------------------------------------------------------- /codes/trainer/__pycache__/Trainer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/trainer/__pycache__/Trainer.cpython-39.pyc -------------------------------------------------------------------------------- /codes/trainer/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/trainer/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /codes/trainer/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/trainer/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /codes/trainer/__pycache__/loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/trainer/__pycache__/loss.cpython-38.pyc -------------------------------------------------------------------------------- /codes/trainer/__pycache__/loss.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/trainer/__pycache__/loss.cpython-39.pyc -------------------------------------------------------------------------------- /codes/trainer/loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | 7 | class CrossEntropyLoss(nn.Module): 8 | pass 9 | 10 | 11 | class QuantileLoss(nn.Module): 12 | 13 | def __init__(self, 14 | delta=1, 15 | device='cuda'): 16 | super(QuantileLoss, self).__init__() 17 | self.delta = delta 18 | self.device = device 19 | 20 | def forward(self, out, label) -> Variable: 21 | ''' 22 | out: q -> [batch_size, N_quantiles] 23 | label: g -> [batch_size, ]. 24 | ''' 25 | n = out.shape[1] 26 | batch_size = out.shape[0] 27 | # label = torch.ones_like(out) * label 28 | label = label.reshape((batch_size, 1)).repeat(1, n).to(self.device) 29 | tau = ((torch.arange(n) + .5 )/ n).repeat(batch_size, 1).to(self.device) 30 | d = label - out 31 | h = F.huber_loss(out, label, reduction='none', delta=self.delta) 32 | k = torch.abs(tau - (d < 0).float()) 33 | loss = torch.mean(k * h) 34 | return loss -------------------------------------------------------------------------------- /codes/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from codes.utils.util_functions import * -------------------------------------------------------------------------------- /codes/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /codes/utils/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/utils/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /codes/utils/__pycache__/util_functions.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/utils/__pycache__/util_functions.cpython-38.pyc -------------------------------------------------------------------------------- /codes/utils/__pycache__/util_functions.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/utils/__pycache__/util_functions.cpython-39.pyc -------------------------------------------------------------------------------- /codes/utils/util_functions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | ''' 5 | Note that: 6 | All function use np as input and output. (Except one_hot) 7 | ''' 8 | 9 | def numpy_cvt(a): 10 | if torch.is_tensor(a): 11 | return a.numpy() 12 | return a 13 | 14 | def outer(x, y, z): 15 | # 得到三维张量,三维分别表示xyz 16 | return np.einsum('i,j,k->ijk', x, y, z, dtype=np.int32, casting="same_kind") 17 | 18 | def action2tensor(action): 19 | u, v, w = action 20 | return np.einsum('i,j,k->ijk', u, v, w, dtype=np.int32, casting="same_kind") 21 | 22 | def is_zero_tensor(tensor): 23 | return np.all(tensor == 0) 24 | 25 | def is_equal(a, b): 26 | assert a.shape == b.shape 27 | return np.all((a - b) == 0) 28 | 29 | def canonicalize_action(action): 30 | u, v, w = action 31 | flag_u, flag_v = 1, 1 32 | for e in u: 33 | if e != 0: 34 | flag_u = ((e > 0) * 2 - 1) 35 | u = (u * flag_u).astype(np.int32) 36 | break 37 | for e in v: 38 | if e != 0: 39 | flag_v = ((e > 0) * 2 - 1) 40 | v = (v * flag_v).astype(np.int32) 41 | break 42 | w = (w * flag_v * flag_u).astype(np.int32) 43 | return np.stack([u, v, w]) 44 | 45 | def one_hot(a_s, num_classes, shift=False): 46 | ''' 47 | Note: We return a size of num_classes+1 array. 48 | ''' 49 | if len(a_s.shape) == 1: 50 | result = torch.zeros((a_s.shape[0], num_classes+1)).long() 51 | for idx, a in enumerate(a_s): 52 | if a == -2: 53 | continue 54 | result[idx, a] = 1 55 | if shift: # Append SOS. 56 | result = torch.cat([torch.zeros((1, num_classes+1)).long(), result], dim=0) 57 | result[0, -1] = 1 58 | return result 59 | elif len(a_s.shape) == 2: 60 | result = torch.zeros((a_s.shape[0], a_s.shape[1], num_classes+1)).long() 61 | for batch, a_batch in enumerate(a_s): 62 | for idx, a in enumerate(a_batch): 63 | if a == -2: 64 | continue 65 | result[batch, idx, a] = 1 66 | if shift: # Append SOS. 67 | result = torch.cat([torch.zeros((a_s.shape[0], 1, num_classes+1)).long(), result], dim=1) 68 | result[:, 0, -1] = 1 69 | return result 70 | 71 | def change_basis_tensor(tensor, 72 | trans_mat): 73 | return np.einsum('ij, kl, mn, jln -> ikm', 74 | trans_mat, trans_mat, trans_mat, tensor, casting="same_kind", dtype=np.int32) 75 | 76 | def random_action(coefficients=[0,1,-1], 77 | prob=[.8,.1,.1], 78 | S_size=4): 79 | ct = 0 80 | while True: 81 | u = np.random.choice(coefficients, size=(S_size,), p=prob, replace=True) 82 | v = np.random.choice(coefficients, size=(S_size,), p=prob, replace=True) 83 | w = np.random.choice(coefficients, size=(S_size,), p=prob, replace=True) 84 | ct += 1 85 | if not is_zero_tensor(outer(u, v, w)): 86 | break 87 | if ct > 100000: 88 | raise Exception("Oh my god...") 89 | return np.stack([u, v, w], axis=0) 90 | 91 | def terminate_rank_approx(tensor): 92 | assert len(tensor.shape) == 3 93 | rank_approx = 0 94 | for z_idx in range(tensor.shape[-1]): 95 | rank_approx += np.linalg.matrix_rank(np.mat(tensor[..., z_idx], dtype=np.int32)) 96 | 97 | return rank_approx -------------------------------------------------------------------------------- /config/S_4.yaml: -------------------------------------------------------------------------------- 1 | # base: 2 | # S_size: 4 3 | # T: 5 4 | # coefficients: 5 | # - 0 6 | # - 1 7 | # - -1 8 | # exp_dir: "exp" 9 | # exp_name: "debug" 10 | # device: "cuda" 11 | 12 | trainer: 13 | T: 7 14 | S_size: 4 15 | coefficients: 16 | - 0 17 | - 1 18 | - -1 19 | exp_dir: "exp" 20 | exp_name: "S4T7_selfplay" 21 | device: "cuda:0" 22 | self_play_device: "cuda:1" 23 | lr: 0.0001 24 | weight_decay: 0.00001 25 | gamma: .1 26 | step_size: 10000000 27 | iters_n: 3000000 28 | batch_size: 1024 29 | a_weight: .1 30 | v_weight: .9 31 | save_freq: 25000 32 | temp_save_freq: 2500 33 | self_play_buffer: 10000 34 | self_play_freq: 25 35 | grad_clip: 4.0 36 | val_freq: 1000 37 | 38 | net: 39 | T: 7 40 | S_size: 4 41 | coefficients: 42 | - 0 43 | - 1 44 | - -1 45 | device: 'cuda:0' 46 | 47 | N_steps: 3 48 | N_samples: 32 49 | n_attentive: 6 50 | N_heads: 32 51 | N_features: 64 52 | channel: 3 53 | scalar_size: 3 54 | policy_layers: 2 55 | value_layers: 3 56 | inter_channel: 512 57 | out_channel: 8 58 | 59 | mcts: 60 | simulate_times: 400 61 | R_limit: 8 62 | 63 | env: 64 | R_limit: 8 65 | T: 7 66 | S_size: 4 67 | -------------------------------------------------------------------------------- /config/S_4_remote.yaml: -------------------------------------------------------------------------------- 1 | base: &base 2 | S_size: 4 3 | T: 7 4 | coefficients: 5 | - 0 6 | - 1 7 | - -1 8 | exp_dir: "exp" 9 | exp_name: "S4T7_selfplay" 10 | device: "cuda:0" 11 | self_play_device: "cuda:1" 12 | 13 | trainer: 14 | <<: *base 15 | lr: 0.0001 16 | weight_decay: 0.00001 17 | gamma: .1 18 | step_size: 10000000 19 | iters_n: 3000000 20 | batch_size: 1024 21 | a_weight: .1 22 | v_weight: .9 23 | save_freq: 25000 24 | temp_save_freq: 2500 25 | self_play_buffer: 10000 26 | self_play_freq: 25 27 | grad_clip: 4.0 28 | val_freq: 1000 29 | 30 | net: 31 | <<: *base 32 | N_steps: 3 33 | N_samples: 32 34 | n_attentive: 6 35 | N_heads: 32 36 | N_features: 64 37 | channel: 3 38 | scalar_size: 3 39 | policy_layers: 2 40 | value_layers: 3 41 | inter_channel: 512 42 | out_channel: 8 43 | 44 | mcts: 45 | <<: *base 46 | simulate_times: 400 47 | 48 | env: 49 | <<: *base 50 | R_limit: 8 51 | -------------------------------------------------------------------------------- /config/S_9.yaml: -------------------------------------------------------------------------------- 1 | # base: 2 | # S_size: 4 3 | # T: 5 4 | # coefficients: 5 | # - 0 6 | # - 1 7 | # - -1 8 | # exp_dir: "exp" 9 | # exp_name: "debug" 10 | # device: "cuda" 11 | 12 | trainer: 13 | T: 7 14 | S_size: 9 15 | coefficients: 16 | - 0 17 | - 1 18 | - -1 19 | exp_dir: "exp" 20 | exp_name: "first_exp" 21 | device: "cuda" 22 | 23 | lr: 0.0001 24 | weight_decay: 0.00001 25 | gamma: .1 26 | step_size: 10000000 27 | iters_n: 3000000 28 | batch_size: 2048 29 | a_weight: .5 30 | v_weight: .5 31 | save_freq: 50000 32 | self_play_buffer: 100000 33 | self_play_freq: 10 34 | grad_clip: 4.0 35 | val_freq: 1000 36 | 37 | net: 38 | T: 7 39 | S_size: 9 40 | coefficients: 41 | - 0 42 | - 1 43 | - -1 44 | device: 'cuda' 45 | 46 | N_steps: 12 47 | N_samples: 32 48 | n_attentive: 4 49 | N_heads: 16 50 | N_features: 32 51 | channel: 3 52 | scalar_size: 3 53 | policy_layers: 2 54 | value_layers: 3 55 | inter_channel: 256 56 | out_channel: 8 57 | 58 | mcts: 59 | simulate_times: 400 60 | 61 | env: 62 | R_limit: 30 63 | T: 7 64 | S_size: 9 65 | -------------------------------------------------------------------------------- /config/my_conf.yaml: -------------------------------------------------------------------------------- 1 | # base: 2 | # S_size: 4 3 | # T: 5 4 | # coefficients: 5 | # - 0 6 | # - 1 7 | # - -1 8 | # exp_dir: "exp" 9 | # exp_name: "debug" 10 | # device: "cuda" 11 | 12 | trainer: 13 | T: 5 14 | S_size: 4 15 | coefficients: 16 | - 0 17 | - 1 18 | - -1 19 | exp_dir: "exp" 20 | exp_name: "first_exp" 21 | device: "cuda" 22 | 23 | lr: 0.0001 24 | weight_decay: 0.00001 25 | gamma: .1 26 | step_size: 10000000 27 | iters_n: 3000000 28 | batch_size: 2048 29 | a_weight: .5 30 | v_weight: .5 31 | save_freq: 50000 32 | self_play_buffer: 100000 33 | self_play_freq: 10 34 | grad_clip: 4.0 35 | val_freq: 1000 36 | 37 | net: 38 | T: 5 39 | S_size: 4 40 | coefficients: 41 | - 0 42 | - 1 43 | - -1 44 | device: 'cuda' 45 | 46 | N_steps: 3 47 | N_samples: 32 48 | n_attentive: 4 49 | N_heads: 16 50 | N_features: 16 51 | channel: 3 52 | scalar_size: 3 53 | policy_layers: 2 54 | value_layers: 3 55 | inter_channel: 256 56 | out_channel: 8 57 | 58 | mcts: 59 | simulate_times: 400 60 | 61 | env: 62 | R_limit: 12 63 | T: 5 64 | S_size: 4 65 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import argparse 3 | 4 | from codes.env import Environment 5 | from codes.mcts import MCTS 6 | from codes.net import Net 7 | from codes.trainer import Trainer, Player 8 | 9 | def parse(): 10 | parser = argparse.ArgumentParser(description="OpenTensor") 11 | parser.add_argument('--config', type=str, default="./config/S_4.yaml") 12 | parser.add_argument('--mode', type=str, default="train", help="three modes: [generate_data, train, infer]") 13 | parser.add_argument('--resume', default=None, help="resume ckpt path") 14 | parser.add_argument('--run_dir', default="./exp/S4T7_selfplay/1685590684", help="The run dir to infer") 15 | args = parser.parse_args() 16 | return args 17 | 18 | 19 | if __name__ == '__main__': 20 | 21 | args = parse() 22 | conf_path = args.config 23 | mode = args.mode 24 | resume = args.resume 25 | 26 | with open(conf_path, 'r', encoding="utf-8") as f: 27 | kwargs = yaml.load(f.read(), Loader=yaml.FullLoader) 28 | 29 | net = Net(**kwargs["net"]) 30 | mcts = MCTS(**kwargs["mcts"], 31 | init_state=None) 32 | env = Environment(**kwargs["env"], 33 | init_state=None) 34 | trainer = Trainer(**kwargs["trainer"], 35 | net=net, env=env, mcts=mcts, 36 | all_kwargs=kwargs) 37 | 38 | S_size = kwargs["env"]["S_size"] 39 | T = kwargs["env"]["T"] 40 | if mode == "generate_data": 41 | trainer.generate_synthetic_examples(samples_n=100000, 42 | save_path="./data/100000_S%dT%d_scalar3_filtered.npy" % (S_size, T)) 43 | 44 | elif mode == "train": 45 | trainer.learn(resume=resume, 46 | example_path="./data/100000_S%dT%d_scalar3_filtered.npy" % (S_size, T), 47 | self_example_path=None) 48 | 49 | elif mode == "infer": 50 | self_play_net = Net(**kwargs["net"]) 51 | player = Player(net=self_play_net, 52 | env=env, 53 | mcts=mcts, 54 | exp_dir=args.run_dir, 55 | simu_times=800, 56 | play_times=1, 57 | num_workers=64, 58 | device="cuda:0", 59 | noise=True) 60 | player.run() # Running forever... -------------------------------------------------------------------------------- /record.txt: -------------------------------------------------------------------------------- 1 | 1. 网络预测动作的公式化表示 2 | 2. 预测时候输入全空张量 3 | 3. MCTS公式化表示 4 | 4. 网络输入时选择的scalars 5 | 5. MCTS时是否应该没有上限? 6 | 7 | need to check: 8 | 1. Causal mask? 9 | 2. 标准化的bug(已解决) 10 | 3. 释放MCTS内存?(越运行GPU内存越少)、(通过删除其他子节点?)(似乎也已解决) 11 | 4. examples reformulate 12 | 5. 基变换和随机符号置换 13 | 14 | To do: 15 | 1. 基变换 16 | 2. random permutation 17 | 3. data aug 18 | 4. self-play train (数据格式) 19 | 5. dataloader (暂时忽略) 20 | 6. config文件 (简单的已解决) 21 | *7. infer时采样的batch操作?(似乎已解决) 22 | 8. 输入scalar时应该如何输入? 23 | 9. 同时进行n棵树并行?(多线程) 24 | 10. transition table 25 | 11. 探索因子温度 26 | 27 | 28 | 4/7: 29 | 1. 修改了attentive module(in parallel) 30 | 2. 修改了文件排版格式 31 | 3. 添加了dataloader(速度未有明显提升) 32 | 跟数据formulate有关的都放在dataset里面! 33 | 4. 完成了随机符号置换(需要debug) 34 | 4/9: 35 | 1. 添加了val 36 | 2. 推测:是infer出了问题 -- Causal mask是否需要检查?? 采样是否要检查??position embedding?? 37 | 3. 对照transformer检查代码 38 | 4/10: 39 | 1. 添加了启动子 40 | (N_steps+1, N_logits+1,用最后一位Logit来表示启动子;训练时对整个N_steps+1进行loss计算;infer时只对非启动子进行采样) 41 | 2. 修改了一个network bug (dropout) 42 | 3. dropout在train和infer之间的区别? 43 | 4/11: 44 | 1. 修改了LayerNorm 45 | 2. 在infer时,第一个是启动子,然后其他用0来补充 46 | 3. causal mask更改 47 | 4. Loss下降得太快,,,还是出现了g_action的泄露? 48 | 5. 去除了一些线性层的bias -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://pypi.ngc.nvidia.com 2 | 3 | nvidia-cuda-runtime-cu11 4 | nvidia-cuda-cupti-cu11 5 | nvidia-cuda-nvcc-cu11 6 | nvidia-nvml-dev-cu11 7 | nvidia-cuda-nvrtc-cu11 8 | nvidia-nvtx-cu11 9 | nvidia-cuda-sanitizer-api-cu11 10 | nvidia-cublas-cu11 11 | nvidia-cufft-cu11 12 | nvidia-curand-cu11 13 | nvidia-cusolver-cu11 14 | nvidia-cusparse-cu11 15 | nvidia-npp-cu11 16 | nvidia-nvjpeg-cu11 --------------------------------------------------------------------------------