├── .gitignore
├── .idea
├── .gitignore
├── code.iml
├── deployment.xml
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
└── modules.xml
├── README.md
├── __init__.py
├── codes
├── dataset
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ ├── __init__.cpython-39.pyc
│ │ ├── dataloader.cpython-38.pyc
│ │ ├── dataloader.cpython-39.pyc
│ │ ├── dataset.cpython-38.pyc
│ │ └── dataset.cpython-39.pyc
│ ├── dataloader.py
│ └── dataset.py
├── env
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ ├── __init__.cpython-39.pyc
│ │ ├── environment.cpython-38.pyc
│ │ └── environment.cpython-39.pyc
│ ├── env_wrapper.py
│ └── environment.py
├── mcts
│ ├── MCTS.py
│ ├── __init__.py
│ └── __pycache__
│ │ ├── MCTS.cpython-38.pyc
│ │ ├── MCTS.cpython-39.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ └── __init__.cpython-39.pyc
├── multi_runner
│ ├── MCTSF.py
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── MCTSF.cpython-39.pyc
│ │ ├── __init__.cpython-39.pyc
│ │ └── envs.cpython-39.pyc
│ └── envs.py
├── net
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ ├── __init__.cpython-39.pyc
│ │ ├── network.cpython-38.pyc
│ │ └── network.cpython-39.pyc
│ └── network.py
├── scripts
│ └── check_redundancy.py
├── trainer
│ ├── Player.py
│ ├── Trainer.py
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── Player.cpython-39.pyc
│ │ ├── Trainer.cpython-38.pyc
│ │ ├── Trainer.cpython-39.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── __init__.cpython-39.pyc
│ │ ├── loss.cpython-38.pyc
│ │ └── loss.cpython-39.pyc
│ └── loss.py
└── utils
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-38.pyc
│ ├── __init__.cpython-39.pyc
│ ├── util_functions.cpython-38.pyc
│ └── util_functions.cpython-39.pyc
│ └── util_functions.py
├── config
├── S_4.yaml
├── S_4_remote.yaml
├── S_9.yaml
└── my_conf.yaml
├── main.py
├── record.txt
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | /exp/
2 | /data/
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Datasource local storage ignored files
5 | /dataSources/
6 | /dataSources.local.xml
7 | # Editor-based HTTP Client requests
8 | /httpRequests/
9 |
--------------------------------------------------------------------------------
/.idea/code.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # OpenTensor
2 |
3 | This is the code implement of our paper "OpenTensor: Reproducing Faster Matrix Multiplication Discovering Algorithms" in the 37th Conference on Neural Information Processing Systems Workshop (NeurIPS 2023).
4 |
5 | We provide the codes to generate synthetic tensors, train OpenTensor and perform tensor decomposition.
6 |
7 | ## Config
8 |
9 | All configs should be contained in a yaml file. We provide some config templates in the `./config` folder. For example, `./config/S_4.yaml` is the config file for decomposing $4 \times 4 \times 4$ matrix multiplication tensor, which is equivalent to discovering the $2 \times 2$ matrix multiplication algorithm.
10 |
11 | ## Generating synthetic data
12 |
13 | ```
14 | mkdir data
15 | python main.py --config ./config/S_4.yaml --mode generate_data
16 | ```
17 |
18 | This command generates 100000 synthetic tensors and saves it to the `./data` folder.
19 |
20 | ## Training OpenTensor
21 |
22 | ```
23 | mkdir exp
24 | python main.py --config ./config/S_4.yaml --mode train
25 | ```
26 |
27 | The model parameters and the tensorboard log files are all saved in the subfolders of `./exp`.
28 |
29 | ## Testing OpenTensor
30 |
31 | ```
32 | python main.py --config ./config/S_4.yaml --mode infer --run_dir $run_dir
33 | ```
34 |
35 | where `$run_dir` is the subfolders of `./exp`, which contains the model parameters of OpenTensor. This command discovers descomposition of the matrix multiplication tensor with the OpenTensor model.
36 |
37 |
38 | ## Citing us
39 |
40 | If our work has been helpful to you, please feel free to cite us:
41 |
42 | ```latex
43 | @article{sun2024opentensor,
44 | title={OpenTensor: Reproducing Faster Matrix Multiplication Discovering Algorithms},
45 | author={Sun, Yiwen and Li, Wenye},
46 | journal={arXiv preprint arXiv:2405.20748},
47 | year={2024}
48 | }
49 | ```
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/__init__.py
--------------------------------------------------------------------------------
/codes/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | from codes.dataset.dataset import *
2 | from codes.dataset.dataloader import *
--------------------------------------------------------------------------------
/codes/dataset/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/dataset/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/dataset/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/dataset/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/codes/dataset/__pycache__/dataloader.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/dataset/__pycache__/dataloader.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/dataset/__pycache__/dataloader.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/dataset/__pycache__/dataloader.cpython-39.pyc
--------------------------------------------------------------------------------
/codes/dataset/__pycache__/dataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/dataset/__pycache__/dataset.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/dataset/__pycache__/dataset.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/dataset/__pycache__/dataset.cpython-39.pyc
--------------------------------------------------------------------------------
/codes/dataset/dataloader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data
3 |
4 |
5 | class MultiEpochsDataLoader(torch.utils.data.DataLoader):
6 |
7 | def __init__(self, *args, **kwargs):
8 | super().__init__(*args, **kwargs)
9 | self._DataLoader__initialized = False
10 | self.batch_sampler = _RepeatSampler(self.batch_sampler)
11 | self._DataLoader__initialized = True
12 | self.iterator = super().__iter__()
13 |
14 | def __len__(self):
15 | return len(self.batch_sampler.sampler)
16 |
17 | def __iter__(self):
18 | for i in range(len(self)):
19 | yield next(self.iterator)
20 |
21 |
22 | class _RepeatSampler(object):
23 | """ Sampler that repeats forever.
24 | Args:
25 | sampler (Sampler)
26 | """
27 |
28 | def __init__(self, sampler):
29 | self.sampler = sampler
30 |
31 | def __iter__(self):
32 | while True:
33 | # for i in range(len(self.sampler)):
34 | yield from iter(self.sampler)
--------------------------------------------------------------------------------
/codes/dataset/dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torch.utils.data import Dataset
4 | from tqdm import tqdm
5 |
6 | import sys
7 | import os
8 | import copy
9 | import itertools
10 | sys.path.append(os.path.abspath(os.path.join("..")))
11 | sys.path.append(os.path.abspath(os.path.join(".")))
12 | from codes.utils import *
13 |
14 | class TupleDataset(Dataset):
15 |
16 | def __init__(self,
17 | T,
18 | S_size,
19 | N_steps,
20 | coefficients,
21 | self_data=[],
22 | synthetic_data=[],
23 | debug=False,
24 | save_type="traj"):
25 | # examples: A list of episodes, including:
26 | # 1. state (Network input)
27 | # 1.1 tensors (np)
28 | # 1.2 scalars (np)
29 | # 2. action (np)
30 | # 3. reward
31 |
32 | self.T = T
33 | self.S_size = S_size
34 | self.N_steps = N_steps
35 | self.coefficients = coefficients
36 | token_len = 3 * S_size // N_steps
37 | self.N_logits = len(coefficients) ** token_len
38 | self.ct = 0
39 | self.save_type = save_type
40 |
41 | print("Preprocessing dataset...")
42 | self.self_data = self_data
43 | self.synthetic_data = synthetic_data
44 | self.data = self_data + synthetic_data
45 | self.data_iterer = itertools.cycle(self.data)
46 | self.self_examples = []
47 | self.synthetic_examples = []
48 |
49 | #TODO: Randomize sign permutation.
50 | #TODO: Reformualte data format.
51 | # Canonicalize actions & to logits.
52 | if save_type == "tuple":
53 | for episode in tqdm(synthetic_data):
54 | state, action, reward = episode
55 | action = self.action_to_logits(canonicalize_action(action))
56 | self.synthetic_examples.append([state, action, reward])
57 | for episode in tqdm(self_data):
58 | state, action, reward = episode
59 | action = self.action_to_logits(canonicalize_action(action))
60 | self.self_examples.append([state, action, reward])
61 | self.examples = self.self_examples + self.synthetic_examples
62 |
63 | else: # Traj format data.
64 | self._prepare_examples_from_trajs()
65 |
66 | def _prepare_examples_from_trajs(self):
67 | '''
68 | This function will permutate self.xxx_data (but not change),
69 | and get the corresponding examples.
70 | '''
71 | S_size = self.S_size
72 | T = self.T
73 |
74 | self_examples, synthetic_examples = [], []
75 |
76 | for traj in tqdm(self.self_data):
77 | # Reverse the order. From decompose order to synthesis order.
78 | _traj = copy.deepcopy(traj)
79 | _states, _actions, _rewards = _traj
80 | _states.reverse(), _actions.reverse(), _rewards.reverse()
81 | _traj = [_states, _actions, _rewards]
82 |
83 | new_traj = self.permutate_traj(_traj)
84 | self_examples.extend(self.traj_to_episode(new_traj))
85 |
86 | for traj in tqdm(self.synthetic_data):
87 | new_traj = self.permutate_traj(traj)
88 | synthetic_examples.extend(self.traj_to_episode(new_traj))
89 |
90 | self.self_examples, self.synthetic_examples = self_examples, synthetic_examples
91 | self.examples = self_examples + synthetic_examples
92 |
93 | def _permutate_traj(self, trajs_n=5000):
94 | assert self.save_type == "traj"
95 | print("Permutate!")
96 | for _ in range(trajs_n):
97 | self_traj = next(self.data_iterer)
98 | new_traj = self.permutate_traj(self_traj)
99 | new_episodes = self.traj_to_episode(new_traj)
100 | n = len(new_episodes)
101 | self.examples = self.examples[n:] + new_episodes
102 |
103 | def __len__(self):
104 | return len(self.examples)
105 |
106 | def __getitem__(self, idx):
107 | state, action, reward = self.examples[idx]
108 | tensor, scalar = state
109 | action = self.logits_to_action(action)
110 | tensor, action = self.random_sign_permutation(tensor, action) # Data aug.
111 | action = canonicalize_action(action) #FIXME: Is it needed?
112 | action = self.action_to_logits(action)
113 | # self._permutate_traj() # Permutate traj.
114 | return [tensor, scalar], action, reward
115 |
116 | def traj_to_episode(self, traj):
117 | results = []
118 | T, S_size = self.T, self.S_size
119 | states, actions, rewards = traj
120 | states = list(reversed(states)); actions = list(reversed(actions)); rewards = list(reversed(rewards))
121 | actions_tensor = [action2tensor(action) for action in actions]
122 | for idx, state in enumerate(states):
123 | tensors = np.zeros((T, S_size, S_size, S_size), dtype=np.int32)
124 | tensors[0] = state # state.
125 | if idx != 0:
126 | # History actions.
127 | tensors[1:(idx+1)] = np.stack(reversed(actions_tensor[max(idx-(T-1), 0):idx]), axis=0)
128 | scalars = np.array([idx, idx, idx]) #FIXME: Havn't decided the scalars.
129 |
130 | cur_state = [tensors, scalars]
131 | action = self.action_to_logits(canonicalize_action(actions[idx]))
132 | reward = rewards[idx]
133 | results.append([cur_state, action, reward])
134 | return results
135 |
136 | def permutate_traj(self, traj):
137 | S_size = self.S_size
138 | states, actions, rewards = traj # [T, S, S, S], [T, 3, S], [T], synthesis order.
139 | final_state = states[0] - action2tensor(actions[0]) # If synthesis data, final_state = 0, rewards[0] = -1.
140 |
141 | # Shuffle the traj.
142 | new_actions = actions.copy()
143 | np.random.shuffle(new_actions)
144 | new_states = []
145 | new_rewards = copy.deepcopy(rewards)
146 | # sample = np.zeros((S_size, S_size, S_size), dtype=np.int32)
147 | sample = final_state
148 | for action in new_actions:
149 | sample = sample + action2tensor(action)
150 | new_states.append(sample.copy())
151 | new_traj = [new_states, new_actions, new_rewards] # synthesis order.
152 | return new_traj
153 |
154 | def action_to_logits(self,
155 | action):
156 | '''
157 | action: A [3, S_size] array.
158 | '''
159 |
160 | # Break action into tokens.
161 | token_len = 3 * self.S_size // self.N_steps
162 | coefficients = self.coefficients
163 | action = action.reshape((-1, token_len)) # [N_steps, token_len]
164 |
165 | # Get logits.
166 | logits = [] # Start sign.
167 | for token in action: # Get one logit.
168 | # token = token.to_list()
169 | logit = 0
170 | if torch.is_tensor(token):
171 | token = torch.flip(token, dims=(0,))
172 | else:
173 | token = token[::-1]
174 | for idx, v in enumerate(token):
175 | logit += coefficients.index(v) * (len(coefficients) ** idx)
176 | logits.append(logit)
177 |
178 | return np.array(logits, dtype=np.int32)
179 |
180 | def logits_to_action(self, logits):
181 | '''
182 | logit: N_steps values of {0, 1, ..., N_logits - 1}.
183 | e.g.:
184 | If:
185 | token_len = 2
186 | coefficients = [0, 1, -1]
187 | N_steps = 6
188 | Then:
189 | [0, 1, 2, 3, 4, 5] -> [0 0 | 0 1 | 0 -1 | 1 0 | 1 1 | 1 -1 ]
190 | '''
191 | token_len = 3 * self.S_size // self.N_steps
192 | coefficients = self.coefficients
193 | action = []
194 | for logit in logits: # Get one action
195 | token = []
196 | if logit == self.N_logits:
197 | raise # Mean that there is a start sign in the middle of action.
198 | for _ in range(token_len): # Get one token
199 | idx = logit % len(coefficients)
200 | token.append(coefficients[idx])
201 | logit = logit // len(coefficients)
202 | token.reverse()
203 | action.extend(token)
204 |
205 | action = np.array(action, dtype=np.int32).reshape((3, -1))
206 | return action
207 |
208 | def random_sign_permutation(self,
209 | tensor, action):
210 | trans_1, trans_2, trans_3 = \
211 | (np.random.binomial(1, .5, self.S_size) * 2 - 1).astype(np.int32), \
212 | (np.random.binomial(1, .5, self.S_size) * 2 - 1).astype(np.int32), \
213 | (np.random.binomial(1, .5, self.S_size) * 2 - 1).astype(np.int32)
214 | tensor = np.einsum('i, j, k, bijk -> bijk', trans_1, trans_2, trans_3, tensor,
215 | dtype=np.int32)
216 | action = np.stack([action[0]*trans_1, action[1]*trans_2, action[2]*trans_3], axis=0)
217 | return tensor, action
218 |
219 |
220 | if __name__ == '__main__':
221 | dataset = TupleDataset(T=7,
222 | S_size=4,
223 | N_steps=6,
224 | coefficients=[0, 1, -1],
225 | synthetic_data=np.load("data/traj_data/100000_S4T7_scalar3.npy", allow_pickle=True).tolist(),
226 | debug=True)
227 | # from torch.utils.data import DataLoader
228 | # dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
229 | # res = next(iter(dataloader))
230 | import pdb; pdb.set_trace()
--------------------------------------------------------------------------------
/codes/env/__init__.py:
--------------------------------------------------------------------------------
1 | from codes.env.environment import *
--------------------------------------------------------------------------------
/codes/env/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/env/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/env/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/env/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/codes/env/__pycache__/environment.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/env/__pycache__/environment.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/env/__pycache__/environment.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/env/__pycache__/environment.cpython-39.pyc
--------------------------------------------------------------------------------
/codes/env/env_wrapper.py:
--------------------------------------------------------------------------------
1 | # vector env
--------------------------------------------------------------------------------
/codes/env/environment.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from math import sqrt
3 |
4 | import sys
5 | import os
6 | sys.path.append(os.path.abspath(os.path.join("..")))
7 | sys.path.append(os.path.abspath(os.path.join(".")))
8 | from codes.utils import *
9 |
10 | # todo: 建议写成gym的标准形式 init reset step
11 | # todo: 可以少写函数 多写变量
12 |
13 | class Environment():
14 | '''
15 | 负责定义游戏的动作、状态以及回报
16 | state: np.darray, [4, 4, 4]
17 | action: np.darray, [3, 4] (表示u, v, w)
18 | results: [[s_1, a_1, pi_1, r_1], ...]
19 |
20 | 包括的功能有:
21 | play: 进行一次游戏
22 | 生成人工Tensor
23 | 以及其它和state, action及reward相关的操作
24 | '''
25 |
26 | def __init__(self,
27 | S_size,
28 | R_limit,
29 | init_state=None,
30 | T=7,
31 | **kwargs):
32 | '''
33 | S_size: u, v, w的维度
34 | R_limit: 游戏的步数上限
35 | '''
36 | # 参数
37 | self.S_size = S_size
38 | self.R_limit = R_limit
39 | self.T = T
40 | # 环境变量
41 | if init_state is None:
42 | init_state = self.get_init_state(S_size)
43 | self.cur_state = init_state
44 | self.accumulate_reward = 0
45 | self.step_ct = 0
46 |
47 | # 历史变量
48 | self.hist_actions = [np.zeros_like(self.cur_state) for _ in range(self.T-1)]
49 |
50 |
51 | def get_init_state(self,
52 | S_size,
53 | no_base_change=False):
54 | '''
55 | 得到一个初始化状态: state
56 | S_size: u, v, w的维度
57 | 返回: state.
58 | '''
59 | #####
60 | # 注意,这里我们可以添加基变换的数据增强
61 | #####
62 |
63 | def one_hot(idx):
64 | temp = np.zeros((S_size, ), dtype=np.int32)
65 | temp[idx] = 1
66 | return temp
67 |
68 | # 1. Get the original Matmul-Tensor.
69 | init_state = np.zeros((S_size, S_size, S_size), dtype=np.int32)
70 | n = round(sqrt(S_size))
71 |
72 | for i in range(n): # 用自然基的方式构建向量
73 | for j in range(n):
74 | z_idx = i * n + j
75 | z = one_hot(z_idx) # C_{i,j} = c_{i*n + j}
76 | for k in range(n):
77 | x_idx = i * n + k # A_{i,k} = a_{i*n + k}
78 | y_idx = k * n + j # B_{k,j} = b_{k*n + j}
79 | x, y = one_hot(x_idx), one_hot(y_idx)
80 | init_state += outer(x, y, z)
81 |
82 | # 2. Change of Basis.
83 | #FIXME: We haven't applied "basis change" operation.
84 | # Randomly get a transform matrix.
85 | if not no_base_change:
86 | p0 = .985
87 | P, L = np.random.choice([0, 1, -1], size=(S_size, S_size), p=[p0, (1-p0)/2, (1-p0)/2], replace=True),\
88 | np.random.choice([0, 1, -1], size=(S_size, S_size), p=[p0, (1-p0)/2, (1-p0)/2], replace=True)
89 | for i in range(S_size):
90 | P[i, i] = np.random.choice([1, -1], size=(1,), p=[.5, .5])
91 | L[i, i] = np.random.choice([1, -1], size=(1,), p=[.5, .5])
92 | P, L = np.triu(P), np.tril(L)
93 | trans_mat = np.matmul(P, L)
94 |
95 | init_state = change_basis_tensor(tensor=init_state,
96 | trans_mat=trans_mat)
97 | # import pdb; pdb.set_trace()
98 |
99 | return init_state
100 |
101 | def step(self,
102 | action):
103 | '''
104 | 状态转移并改动reward, 并返回是否游戏结束
105 | '''
106 | u, v, w = action
107 | self.cur_state -= outer(u, v, w)
108 | self.accumulate_reward -= 1
109 | self.step_ct += 1
110 | self.hist_actions.append(action2tensor(action))
111 | # 判断是否终止
112 | if self.is_terminate():
113 | return True
114 | if self.step_ct >= self.R_limit:
115 | self.accumulate_reward += self.terminate_reward()
116 | return True
117 | return False
118 |
119 | def terminate_reward(self):
120 | '''
121 | 截断时得到的惩罚。
122 | 返回: reward
123 | '''
124 | state = self.cur_state
125 | # terminate_reward = 0
126 | # for z_idx in range(self.S_size):
127 | # terminate_reward -= np.linalg.matrix_rank(np.mat(state[..., z_idx], dtype=np.int32))
128 | terminate_reward = -terminate_rank_approx(state)
129 | return terminate_reward
130 |
131 | def is_terminate(self):
132 | '''
133 | 判断cur_state是否为0
134 | 返回: bool
135 | '''
136 | return is_zero_tensor(self.cur_state)
137 |
138 | def reset(self,
139 | init_state=None,
140 | no_base_change=False):
141 | '''
142 | 重置环境
143 | '''
144 | if init_state is None:
145 | init_state = self.get_init_state(self.S_size, no_base_change)
146 | self.cur_state = init_state
147 | self.accumulate_reward = 0
148 | self.step_ct = 0
149 | self.hist_actions = [np.zeros_like(self.cur_state) for _ in range(self.T-1)]
150 |
151 | def get_network_input(self):
152 | '''
153 | 将变量组织成网络输入的格式
154 | '''
155 | T = self.T
156 | S_size = self.S_size
157 | hist_actions = self.hist_actions[-(T-1):]
158 | hist_actions.reverse()
159 | tensors = np.zeros((T, S_size, S_size, S_size), dtype=np.int32)
160 | tensors[0] = self.cur_state
161 | tensors[1:] = np.stack(hist_actions, axis=0)
162 | scalars = np.array([self.step_ct, self.step_ct, self.step_ct]) #FIXME: Havn't decided the scalars.
163 |
164 | return tensors, scalars
165 |
166 |
167 | if __name__ == '__main__':
168 | test_env = Environment(S_size=4,
169 | R_limit=8)
170 | test_action = np.array([
171 | [0, 0, 1, 0],
172 | [1, 1, 0, 0],
173 | [0, 1, 0, 0]
174 | ])
175 | for _ in range(8):
176 | print(test_env.step(test_action))
177 | print(test_env.accumulate_reward)
178 | import pdb; pdb.set_trace()
--------------------------------------------------------------------------------
/codes/mcts/MCTS.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import math
4 | import networkx as nx
5 | import matplotlib.pyplot as plt
6 | from tqdm import tqdm
7 |
8 | import sys
9 | import os
10 | import copy
11 | from typing import Tuple
12 | sys.path.append(os.path.abspath(os.path.join("..")))
13 | sys.path.append(os.path.abspath(os.path.join(".")))
14 | from codes.env import Environment
15 | from codes.net import Net
16 | from codes.utils import *
17 |
18 |
19 | class Node():
20 | '''
21 | 一个MCTS的节点
22 | '''
23 |
24 | def __init__(self,
25 | state,
26 | parent,
27 | pre_action,
28 | pre_action_idx,
29 | is_terminal=False):
30 | #####
31 | # 这里应该初始化一个节点,
32 | # 包括Q, N以及女儿父母
33 | # e.g.:
34 | # self.Q = 0
35 | # ......
36 | #####
37 |
38 | self.parent = parent # parent: A Node instance (or None).
39 | self.pre_action = pre_action # pre_action: Action (or None).
40 | self.pre_action_idx = pre_action_idx
41 | self.is_leaf = True
42 | self.is_terminal = is_terminal
43 | self.state = state # state: Tensor.
44 |
45 | self.actions = [] # A list for actions.
46 | self.children = [] # A list for nodes.
47 | self.N = [] # A list for visit counts.
48 | self.Q = [] # A list for action value.
49 | self.pi = [] # A list for empirical policy probability.
50 | self.children_n = 0
51 |
52 | node = self
53 | depth = 0
54 | while node.parent:
55 | depth += 1
56 | node = node.parent
57 | self.depth = depth
58 |
59 |
60 | def expand(self,
61 | net: Net,
62 | noise=False,
63 | network_output=None,
64 | R_limit=12):
65 | '''
66 | Expand this node.
67 | Return the value of this state.
68 | '''
69 | # 1. Check terminal situation.
70 | if not self.is_leaf:
71 | raise Exception("This node has been expanded.")
72 | self.is_leaf = False
73 |
74 | if self.is_terminal: # Mean the state is terminal. Only propagate.
75 | node = self
76 | node.is_leaf = True
77 | if is_zero_tensor(node.state):
78 | value = 0
79 | else:
80 | value = -1 * terminate_rank_approx(node.state)
81 | while node.parent is not None:
82 | action_idx = node.pre_action_idx
83 | node = node.parent
84 | node.N[action_idx] += 1
85 | v = (value + -1 * (self.depth - node.depth))
86 | node.Q[action_idx] = v / node.N[action_idx] +\
87 | node.Q[action_idx] * (node.N[action_idx] - 1) / node.N[action_idx]
88 |
89 | return
90 |
91 | #FIXME: Here we can apply a transposition table.
92 | # 2. Get network output.
93 | # 2.1. If use network to infer:
94 | if network_output is None:
95 | tensors, scalars = self.get_network_input(net)
96 |
97 | net.set_mode("infer")
98 | with torch.no_grad():
99 | output = net([tensors, scalars])
100 | _, value, policy, prob = *net.value(output), *net.policy(output) # policy: [1, N_samples, 3, S_size]
101 | del output, tensors, scalars
102 | value, policy = value[0], policy[0]
103 | policy = [canonicalize_action(action) for action in policy]
104 |
105 | # 2.2. If we already have network output:
106 | else:
107 | value, policy = network_output
108 |
109 | # 2.3. Add noise for root node's expand.
110 | if noise:
111 | noise_actions = [canonicalize_action(random_action()) for _ in range(len(policy) // 4)]
112 | policy = policy + noise_actions
113 |
114 | # 3. Get empirical policy probability.
115 | N_samples = net.N_samples
116 | rec = [False for _ in range(N_samples)] # "True" represents having been recorded.
117 | actions = []
118 | pi = []
119 | for pos in range(N_samples): # Naive loop.
120 | action = policy[pos]
121 | if not rec[pos]:
122 | # Count.
123 | actions.append(action)
124 | rec[pos] = True
125 | ct = 1
126 | for i in range(pos+1, N_samples):
127 | if rec[i]: # Have been counted.
128 | continue
129 | if is_equal(policy[i], action):
130 | ct += 1
131 | rec[i] = True
132 | pi.append(ct / N_samples)
133 | # import pdb; pdb.set_trace()
134 | self.actions = actions
135 | self.pi = pi
136 | self.children_n = len(actions)
137 |
138 | # 4. Init records.
139 | self.N = [0 for _ in range(len(actions))]
140 | self.Q = [0 for _ in range(len(actions))]
141 |
142 | # 5. Expand the children nodes.
143 | for idx, action in enumerate(actions):
144 | child_state = self.state - action2tensor(action)
145 | child_depth = self.depth + 1
146 | child_node = Node(state=child_state,
147 | parent=self,
148 | pre_action=action,
149 | pre_action_idx=idx,
150 | is_terminal=(is_zero_tensor(child_state) or child_depth >= R_limit))
151 | self.children.append(child_node)
152 |
153 | # 6. Backward propagate.
154 | node = self
155 | while node.parent is not None:
156 | action_idx = node.pre_action_idx
157 | node = node.parent
158 | node.N[action_idx] += 1
159 | v = (value + -1 * (self.depth - node.depth))
160 | node.Q[action_idx] = v / node.N[action_idx] +\
161 | node.Q[action_idx] * (node.N[action_idx] - 1) / node.N[action_idx]
162 |
163 |
164 | def select(self, c=None):
165 | '''
166 | Choose the best child.
167 | Return the chosen node.
168 | '''
169 | if self.is_leaf:
170 | raise Exception("Cannot choose a leaf node.")
171 |
172 | if c is None:
173 | c = 1.25 + math.log((1+19652+sum(self.N)) / 19652)
174 |
175 | scores = [self.Q[i] + c * self.pi[i] * math.sqrt(sum(self.N)) / (1 + self.N[i])
176 | for i in range(self.children_n)]
177 |
178 | return self.children[np.argmax(scores)], scores
179 |
180 |
181 | def get_network_input(self, net):
182 | # Get state for net evaluation.
183 | # State:
184 | # Tensors: [cur_state, last t=1 action, last t=2 action, ... last t=T-1 action]
185 | # Scalars: [depth(step_ct)]
186 | T = net.T
187 | tensors = np.zeros([T, *self.state.shape], dtype=np.int32)
188 | tensors[0] = self.state # Current state.
189 | node = self
190 | for t in range(1, T):
191 | if node.parent is None:
192 | break
193 | tensors[t] = action2tensor(node.pre_action)
194 | node = node.parent
195 | scalars = np.array([self.depth, self.depth, self.depth]) #FIXME: Havn't decided the scalars.
196 |
197 | return tensors, scalars
198 |
199 |
200 |
201 | class MCTS():
202 | '''
203 | 蒙特卡洛树搜索
204 | '''
205 |
206 | def __init__(self,
207 | init_state,
208 | simulate_times=400,
209 | R_limit=12,
210 | **kwargs):
211 | '''
212 | 超参数传递
213 | '''
214 |
215 | self.simulate_times = simulate_times
216 | self.R_limit = R_limit
217 | if init_state is not None:
218 | self.root_node = Node(state=init_state,
219 | parent=None,
220 | pre_action=None,
221 | pre_action_idx=None)
222 |
223 |
224 | def __call__(self,
225 | state,
226 | net: Net,
227 | log=False,
228 | verbose=False,
229 | noise=False):
230 | '''
231 | 进行一次MCTS
232 | 返回: action, actions, visit_pi
233 | '''
234 |
235 | assert is_equal(state, self.root_node.state), "State is inconsistent."
236 | iter_item = range(self.simulate_times) if verbose else tqdm(range(self.simulate_times))
237 | R_limit = self.R_limit
238 | for simu in iter_item:
239 | # Select a leaf node.
240 | node = self.root_node
241 | while not node.is_leaf:
242 | node, scores = node.select() #FIXME: Need to control the factor c.
243 | node.expand(net, noise=noise, R_limit=R_limit)
244 |
245 | actions = self.root_node.actions
246 | N = self.root_node.N
247 | visit_ratio = (np.array(N) / sum(N)).tolist()
248 | action = actions[np.argmax(visit_ratio)]
249 |
250 | if log:
251 | log_txt = self.log()
252 | return action, actions, visit_ratio, log_txt
253 |
254 | return action, actions, visit_ratio
255 |
256 |
257 | def move(self,
258 | action):
259 | '''
260 | MCTS向前一步
261 | '''
262 | assert not self.root_node.is_leaf, "Cannot move a leaf node."
263 |
264 | # Get the action idx.
265 | action_idx = None
266 | for idx, child_action in enumerate(self.root_node.actions):
267 | if is_equal(child_action, action):
268 | action_idx = idx
269 |
270 | # Delete other children and move.
271 | self.root_node.children = [self.root_node.children[action_idx]]
272 | self.root_node = self.root_node.children[0]
273 |
274 |
275 | def reset(self,
276 | state,
277 | simulate_times=None,
278 | R_limit=None):
279 | '''
280 | Reset MCTS.
281 | '''
282 | if simulate_times is not None:
283 | self.simulate_times = simulate_times
284 | if R_limit is not None:
285 | self.R_limit = R_limit
286 | self.root_node = Node(state=state,
287 | parent=None,
288 | pre_action=None,
289 | pre_action_idx=None)
290 |
291 |
292 | def visualize(self):
293 | '''
294 | visualize the tree.
295 | '''
296 | # Create a graph.
297 | graph = nx.DiGraph()
298 | close_set = [self.root_node]
299 |
300 | while close_set != []:
301 | node = close_set.pop()
302 | if not node.is_leaf:
303 | [graph.add_edge(
304 | node,
305 | child
306 | ) for child in node.children]
307 | [close_set.append(child) for child in node.children]
308 |
309 | nx.draw(graph, with_labels=True, font_weight='bold')
310 | raise NotImplementedError
311 | plt.show()
312 |
313 |
314 | def log(self):
315 | '''
316 | Get the log text.
317 | '''
318 | node = self.root_node
319 | state_txt = str(node.state) # state.
320 | _, scores = node.select()
321 | N, Q, scores, children = np.array(node.N), np.array(node.Q), np.array(scores), np.array(node.actions)
322 | top_k_idx = np.argsort(N)[-5:]
323 | N, Q, scores, children = N[top_k_idx], Q[top_k_idx], scores[top_k_idx], children[top_k_idx]
324 |
325 | N_txt, Q_txt, scores_txt, children_txt = str(N), str(Q), str(scores), str(children)
326 |
327 | log_txt = "\n".join(
328 | ["\nCur state: \n", state_txt,
329 | "\nDepth: \n", str(node.depth),
330 | "\nchildren: \n", children_txt,
331 | "\nscores: \n", scores_txt,
332 | "\nQ: \n", Q_txt,
333 | "\nN: \n", N_txt,]
334 | )
335 |
336 | return log_txt
337 |
338 |
339 | if __name__ == '__main__':
340 |
341 | init_state = np.array(
342 | [[[1, 0, 0, 0],
343 | [0, 1, 0, 0],
344 | [0, 0, 0, 0],
345 | [0, 0, 0, 0]],
346 |
347 | [[0, 0, 0, 0],
348 | [0, 0, 0, 0],
349 | [1, 0, 0, 0],
350 | [0, 1, 0, 0]],
351 |
352 | [[0, 0, 1, 0],
353 | [0, 0, 0, 1],
354 | [0, 0, 0, 0],
355 | [0, 0, 0, 0]],
356 |
357 | [[0, 0, 0, 0],
358 | [0, 0, 0, 0],
359 | [0, 0, 1, 0],
360 | [0, 0, 0, 1]]])
361 | root_node = Node(state=init_state,
362 | parent=None,
363 | pre_action=None,
364 | pre_action_idx=None)
365 |
366 | from net import Net
367 | net = Net(N_samples=4) # For debugging.
368 |
369 | ############ Debug for Node ############
370 | # import pdb; pdb.set_trace()
371 | # root_node.expand(net)
372 | # children_node = root_node.select()
373 | # children_node.expand(net)
374 | # import pdb; pdb.set_trace()
375 |
376 | ############ Debug for MCYS ############
377 | mcts = MCTS(init_state=init_state,
378 | simulate_times=20)
379 | import pdb; pdb.set_trace()
380 | action, actions, pi = mcts(init_state, net)
381 | import pdb; pdb.set_trace()
382 | mcts.move(action)
383 | state = init_state - action2tensor(action)
384 | import pdb; pdb.set_trace()
385 | action, actions, pi = mcts(state, net)
386 | import pdb; pdb.set_trace()
--------------------------------------------------------------------------------
/codes/mcts/__init__.py:
--------------------------------------------------------------------------------
1 | from codes.mcts.MCTS import *
--------------------------------------------------------------------------------
/codes/mcts/__pycache__/MCTS.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/mcts/__pycache__/MCTS.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/mcts/__pycache__/MCTS.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/mcts/__pycache__/MCTS.cpython-39.pyc
--------------------------------------------------------------------------------
/codes/mcts/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/mcts/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/mcts/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/mcts/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/codes/multi_runner/MCTSF.py:
--------------------------------------------------------------------------------
1 | # MCTSF: Monto Colo Tree Search Forest.
2 | from codes.mcts import *
3 | from typing import List
4 |
5 | class MCTSF():
6 | '''
7 | 多个蒙特卡洛树的模拟
8 | '''
9 |
10 | def __init__(self,
11 | mcts_list: List[MCTS],
12 | simulate_times=400):
13 |
14 | self.mcts_list = mcts_list
15 | self.mcts_n = len(mcts_list)
16 | self.simulate_times = simulate_times
17 | self.R_limit = mcts_list[0].R_limit
18 |
19 |
20 | def __call__(self,
21 | state_list,
22 | net: Net,
23 | noise=False):
24 |
25 | R_limit = self.R_limit
26 |
27 | # Check the states.
28 | for mcts, state in zip(self.mcts_list, state_list):
29 | assert is_equal(state, mcts.root_node.state), "State is inconsistent."
30 |
31 | # Simu.
32 | for simu in tqdm(range(self.simulate_times)):
33 | # Select the leaf nodes.
34 | node_list = []
35 | for mcts in self.mcts_list:
36 | node = mcts.root_node
37 | while not node.is_leaf:
38 | node, scores = node.select()
39 | node_list.append(node)
40 |
41 | # Get network input to expand...
42 | node: Node
43 | batch_tensors, batch_scalars = [], []
44 | for node in node_list:
45 | tensors, scalars = node.get_network_input(net)
46 | batch_tensors.append(tensors); batch_scalars.append(scalars)
47 | batch_tensors, batch_scalars = np.array(batch_tensors), np.array(batch_scalars)
48 | # Infer...
49 | net.set_mode("infer")
50 | with torch.no_grad():
51 | batch_output = net([batch_tensors, batch_scalars])
52 | _, batch_value, batch_policy, prob = *net.value(batch_output), *net.policy(batch_output)
53 | # batch_value: [B,] batch_policy: [B, N_samples, 3, S_size]
54 | del batch_output, batch_tensors, batch_scalars
55 | batch_policy = [[canonicalize_action(action) for action in policy] for policy in batch_policy]
56 |
57 | # Expand...
58 | for node, value, policy in zip(node_list, batch_value, batch_policy):
59 | node.expand(net, network_output=(value, policy), R_limit=R_limit, noise=noise)
60 |
61 | # Get results.
62 | actions_list = [mcts.root_node.actions for mcts in self.mcts_list]
63 | N_list = [mcts.root_node.N for mcts in self.mcts_list]
64 | visit_ratio_list = [(np.array(N) / sum(N)).tolist() for N in N_list]
65 | action_list = [actions_list[idx][np.argmax(visit_ratio_list[idx])] for idx in range(self.mcts_n)]
66 |
67 | return action_list, actions_list, visit_ratio_list
68 |
69 |
70 | def reset(self,
71 | state_list):
72 |
73 | for mcts, state in zip(self.mcts_list, state_list):
74 | mcts.reset(state, simulate_times=self.simulate_times)
75 |
76 |
77 | def move(self,
78 | action_list):
79 |
80 | for mcts, action in zip(self.mcts_list, action_list):
81 | mcts.move(action)
--------------------------------------------------------------------------------
/codes/multi_runner/__init__.py:
--------------------------------------------------------------------------------
1 | from codes.multi_runner.envs import *
2 | from codes.multi_runner.MCTSF import *
--------------------------------------------------------------------------------
/codes/multi_runner/__pycache__/MCTSF.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/multi_runner/__pycache__/MCTSF.cpython-39.pyc
--------------------------------------------------------------------------------
/codes/multi_runner/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/multi_runner/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/codes/multi_runner/__pycache__/envs.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/multi_runner/__pycache__/envs.cpython-39.pyc
--------------------------------------------------------------------------------
/codes/multi_runner/envs.py:
--------------------------------------------------------------------------------
1 | from codes.env import *
2 | from typing import List
3 |
4 | class ENVS():
5 |
6 | def __init__(self,
7 | env_list: List[Environment]):
8 |
9 | self.env_list = env_list
10 | self.env_n = len(env_list)
11 | self.terminate_list = [False] * self.env_n
12 |
13 |
14 | def reset(self):
15 |
16 | for env in self.env_list:
17 | env.reset()
18 | self.terminate_list = [False] * self.env_n
19 |
20 |
21 | def step(self,
22 | action_list):
23 |
24 | for idx, (env, action) in enumerate(zip(self.env_list, action_list)):
25 | if not self.terminate_list[idx]:
26 | self.terminate_list[idx] = env.step(action)
27 |
28 |
29 | def get_curstates(self):
30 |
31 | state_list = []
32 | for env in self.env_list:
33 | state_list.append(env.cur_state.copy())
34 |
35 | return state_list
36 |
37 |
38 | def is_all_terminated(self):
39 |
40 | flag = True
41 | for sub_flag in self.terminate_list:
42 | flag = flag & sub_flag
43 |
44 | return flag
45 |
46 |
47 | def get_rewards(self):
48 |
49 | reward_list = []
50 | for env in self.env_list:
51 | reward_list.append(env.accumulate_reward)
52 |
53 | return reward_list
54 |
55 |
56 | def get_stepcts(self):
57 |
58 | step_ct_list = []
59 | for env in self.env_list:
60 | step_ct_list.append(env.step_ct)
61 |
62 | return step_ct_list
--------------------------------------------------------------------------------
/codes/net/__init__.py:
--------------------------------------------------------------------------------
1 | from codes.net.network import *
--------------------------------------------------------------------------------
/codes/net/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/net/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/net/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/net/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/codes/net/__pycache__/network.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/net/__pycache__/network.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/net/__pycache__/network.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/net/__pycache__/network.cpython-39.pyc
--------------------------------------------------------------------------------
/codes/net/network.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import math
6 |
7 | import sys
8 | import os
9 | sys.path.append(os.path.abspath(os.path.join("..")))
10 | sys.path.append(os.path.abspath(os.path.join(".")))
11 | from codes.env import *
12 | from codes.mcts import *
13 | from codes.utils import *
14 |
15 | # Input:
16 | # Tensors of shape of [B,T,S,S,S]. First one is current tensor.
17 | # Scalars of shape of [B,s].
18 |
19 |
20 | class Torso(nn.Module):
21 | '''
22 | 网络躯干(Encoder).
23 | '''
24 |
25 | def __init__(self,
26 | S_size=4,
27 | channel=3,
28 | T=7,
29 | scalar_size=3,
30 | n_attentive=8,
31 | mode="train",
32 | device='cuda',
33 | **kwargs):
34 | super(Torso, self).__init__()
35 | self.S_size = S_size
36 | self.channel = channel
37 | self.scalar_size = scalar_size
38 | self.T = T
39 | self.n_attentive = n_attentive
40 | self.mode = mode
41 | self.device = device
42 |
43 | self.attentive_modes = nn.ModuleList([AttentiveModes(S_size, channel) for _ in range(n_attentive)])
44 | self.scalar2grid = nn.ModuleList([nn.Linear(scalar_size, S_size**2) for _ in range(3)]) # s -> S*S
45 | self.grid2grid = nn.ModuleList([nn.Linear(S_size**2*(T*S_size+1), S_size**2*channel) for _ in range(3)]) # S*S*TS+1 -> S*S*c
46 | #WARNING: grid2grid should be T*S_size+1 -> channel!
47 |
48 |
49 | def forward(self, x):
50 |
51 | # Input:
52 | # Tensors of shape of [B,T,S,S,S]. First one is current tensor. (numpy)
53 | # Scalars of shape of [B,s]. (numpy)
54 |
55 | S_size = self.S_size
56 | T = self.T
57 | channel = self.channel
58 | n_attentive = self.n_attentive
59 |
60 | input_t, input_s = x # Tensor input and Scalar input.
61 | if not torch.is_tensor(input_t):
62 | input_t, input_s = torch.from_numpy(input_t).float().to(self.device), torch.from_numpy(input_s).float().to(self.device)
63 | else:
64 | input_t, input_s = input_t.float().to(self.device), input_s.float().to(self.device)
65 | if self.mode == "infer":
66 | # assert len(input_t.shape) == 4 and len(input_s.shape) == 1, \
67 | # "Infer mode does not support batch."
68 | if len(input_t.shape) == 4 and len(input_s.shape) == 1:
69 | input_t = input_t[None]; input_s = input_s[None] # Add a batch dim.
70 | batch_size = input_t.shape[0]
71 |
72 | # 1. Project to grids.
73 | x1 = torch.reshape(torch.permute(input_t, (0,2,3,4,1)), (batch_size, S_size, S_size, T*S_size)) # [B,S,S,TS]
74 | x2 = torch.reshape(torch.permute(input_t, (0,4,2,3,1)), (batch_size, S_size, S_size, T*S_size))
75 | x3 = torch.reshape(torch.permute(input_t, (0,3,4,2,1)), (batch_size, S_size, S_size, T*S_size))
76 | g = [x1, x2, x3]
77 |
78 | # 2. To grids.
79 | for idx in range(3):
80 | p = torch.reshape(self.scalar2grid[idx](input_s), (batch_size, S_size, S_size, 1))
81 | g[idx] = torch.concat([g[idx], p], dim=-1)
82 | g[idx] = torch.reshape(self.grid2grid[idx](torch.reshape(g[idx], (batch_size, -1))), (batch_size, S_size, S_size, channel)) # [B,S,S,c]
83 |
84 | # 3. Attentive modes.
85 | x1, x2, x3 = g
86 | for idx in range(n_attentive):
87 | x1, x2, x3 = self.attentive_modes[idx]([x1, x2, x3])
88 |
89 | # 4. Final stack.
90 | e = torch.reshape(torch.stack([x1, x2, x3], axis=1), (batch_size, 3*S_size**2, channel)) # [B, 3*S**2, c]
91 |
92 | return e
93 |
94 |
95 | def set_mode(self, mode):
96 | assert mode in ["train", "infer"]
97 | self.mode = mode
98 |
99 |
100 | class AttentiveModes(nn.Module):
101 |
102 | '''
103 | 问题:
104 | 前向时, Attention模型是否共享参数?
105 | '''
106 |
107 | def __init__(self,
108 | S_size=4,
109 | channel=3,
110 | device='cuda'):
111 | super(AttentiveModes, self).__init__()
112 | self.channel = channel
113 | self.S_size = S_size
114 |
115 | # self.attentions = nn.ModuleList(*[Attention(channel,
116 | # channel,
117 | # 2*S_size,
118 | # 2*S_size,
119 | # False,
120 | # device=device) for _ in range(S_size)])
121 |
122 | self.attention = Attention(channel,
123 | channel,
124 | 2*S_size,
125 | 2*S_size,
126 | False, device=device)
127 |
128 | def forward(self, x):
129 |
130 | # Input:
131 | # [x1, x2, x3]. Each of them is shaped of [B, S, S, c]
132 |
133 | S_size = self.S_size
134 | channel = self.channel
135 |
136 | for m1, m2 in [(0,1), (2,0), (1,2)]:
137 | a = torch.cat([x[m1], x[m2].transpose(1,2)], axis=2) # a: [B, S, 2S, c]
138 | # for idx in range(S_size): #FIXME: Parallel loop?
139 | # c = self.attentions[idx]([a[:,idx],])
140 | # x[m1][:,idx] = c[:, :S_size, :]
141 | # x[m2][:,idx] = c[:, S_size:, :]
142 |
143 | a = a.reshape((-1, 2*S_size, channel))
144 | c = self.attention([a, ]) # c: [B*S, 2S, c]
145 | c = c.reshape((-1, S_size, 2*S_size, channel)) # c: [B, S, 2S, c]
146 | x[m1] = c[:, :, :S_size, :]
147 | x[m2] = c[:, :, S_size:, :]
148 |
149 | return x
150 |
151 |
152 | class Attention(nn.Module):
153 |
154 | def __init__(self,
155 | x_channel=3,
156 | y_channel=3,
157 | N_x=8, # 2S
158 | N_y=8, # 2S
159 | causal_mask=False,
160 | N_heads=16,
161 | d=32,
162 | w=4,
163 | device='cuda'):
164 |
165 | super(Attention, self).__init__()
166 | self.x_channel, self.y_channel = x_channel, y_channel
167 | self.N_x, self.N_y = N_x, N_y
168 | self.causal_mask = causal_mask
169 | self.N_heads = N_heads
170 | self.d, self.w = d, w
171 | self.device = device
172 |
173 | self.x_layer_norm = nn.LayerNorm(x_channel)
174 | self.y_layer_norm = nn.LayerNorm(y_channel)
175 | self.final_layer_norm = nn.LayerNorm(x_channel)
176 | self.W_Q = nn.Linear(x_channel, d * N_heads, bias=False)
177 | self.W_K = nn.Linear(y_channel, d * N_heads, bias=False)
178 | self.W_V = nn.Linear(y_channel, d * N_heads, bias=False)
179 | self.linear_1 = nn.Linear(d * N_heads, x_channel)
180 | self.linear_2 = nn.Linear(x_channel, x_channel * w)
181 | self.linear_3 = nn.Linear(x_channel * w, x_channel)
182 | self.gelu = nn.GELU()
183 |
184 |
185 | def forward(self, x):
186 |
187 | # Input:
188 | # [x, (y)]. If y is missed, y=x.
189 |
190 | N_heads = self.N_heads
191 | N_x, N_y = self.N_x, self.N_y
192 | d, w = self.d, self.w
193 |
194 | if len(x) == 1:
195 | x = x[0]
196 | y = x.clone()
197 | else:
198 | x, y = x
199 |
200 | batch_size = x.shape[0]
201 |
202 | x_norm = self.x_layer_norm(x) # [B, N_x, c_x]
203 | y_norm = self.y_layer_norm(y) # [B, N_y, c_y]
204 |
205 | q_s = self.W_Q(x_norm).view(batch_size, -1, N_heads, d).transpose(1, 2) # [batch_size, N_heads, N_x, d]
206 | k_s = self.W_K(y_norm).view(batch_size, -1, N_heads, d).transpose(1, 2) # [batch_size, N_heads, N_y, d]
207 | v_s = self.W_V(y_norm).view(batch_size, -1, N_heads, d).transpose(1, 2) # [batch_size, N_heads, N_y, d]
208 |
209 | scores = torch.matmul(q_s, k_s.transpose(-1, -2)) / np.sqrt(d) # [batch_size, N_heads, N_x, N_y]
210 | # scores = F.softmax(scores, dim=-1)
211 | if self.causal_mask:
212 | mask = (torch.from_numpy(np.triu(np.ones([batch_size, N_heads, N_x, N_y]), k=1)) == 0).to(self.device)
213 | scores = scores.masked_fill(mask == 0, -1e9)
214 | # scores = scores * torch.from_numpy(np.triu(np.ones([batch_size, N_heads, N_x, N_y]), k=1)).float().to(self.device)
215 | scores = F.softmax(scores, dim=-1)
216 |
217 | o_s = torch.matmul(scores, v_s) # [batch_size, N_heads, N_x, d]
218 | o_s = o_s.transpose(1, 2).contiguous().view(batch_size, -1, N_heads*d) # [batch_size, N_x, N_heads*d]
219 | x = x.reshape(batch_size*N_x, -1) + self.linear_1(o_s.reshape(-1, N_heads*d)) # [batch_size*N_x, c_x]
220 |
221 | x = (x + self.linear_3(self.gelu(self.linear_2(self.final_layer_norm(x))))).reshape(batch_size, N_x, -1) # [batch_size, N_x, c_x]
222 |
223 | return x
224 |
225 |
226 | class PolicyHead(nn.Module):
227 |
228 | def __init__(self,
229 | N_steps=6,
230 | N_logits=9,
231 | N_features=64,
232 | N_heads=32,
233 | N_layers=2,
234 | N_samples=32,
235 | torso_feature_shape=(3*4**2, 3),
236 | mode='train',
237 | device='cuda'):
238 |
239 | super(PolicyHead, self).__init__()
240 | self.N_steps = N_steps
241 | self.N_logits = N_logits
242 | self.N_features = N_features
243 | self.N_heads = N_heads
244 | self.N_layers = N_layers
245 | self.torso_feature_shape = torso_feature_shape
246 | self.N_samples = N_samples
247 | self.mode = mode
248 | self.device = device
249 |
250 | self.linear_1 = nn.Linear(N_logits+1, N_features * N_heads, bias=False)
251 | # self.pos_embed = nn.Linear(1, N_features * N_heads)
252 | self.pos_embed = nn.Sequential(
253 | nn.Linear(1, 512),
254 | nn.ReLU(),
255 | nn.Linear(512, N_features * N_heads)
256 | )
257 | self.self_layer_norms = nn.ModuleList([nn.LayerNorm(N_features * N_heads) for _ in range(N_layers)])
258 | self.cross_layer_norms = nn.ModuleList([nn.LayerNorm(N_features * N_heads) for _ in range(N_layers)]) #FIXME: How to choose layer norm's channel?
259 | self.self_attentions = nn.ModuleList([Attention(x_channel=N_features * N_heads,
260 | y_channel=N_features * N_heads,
261 | N_x=N_steps+1,
262 | N_y=N_steps+1,
263 | causal_mask=True,
264 | N_heads=N_heads,
265 | device=device) for _ in range(N_layers)])
266 | self.cross_attentions = nn.ModuleList([Attention(x_channel=N_features * N_heads,
267 | y_channel=torso_feature_shape[1],
268 | N_x=N_steps+1,
269 | N_y=torso_feature_shape[0],
270 | causal_mask=False,
271 | N_heads=N_heads,
272 | device=device) for _ in range(N_layers)])
273 | self.self_dropouts = nn.ModuleList([nn.Dropout() for _ in range(N_layers)])
274 | self.cross_dropouts = nn.ModuleList([nn.Dropout() for _ in range(N_layers)])
275 | self.relu = nn.ReLU()
276 | self.linear_2 = nn.Linear(N_features * N_heads, N_logits+1)
277 |
278 |
279 | def forward(self, x):
280 |
281 | # Input:
282 | # [e, (g)]. e is the features extracted by torso, g is groundtruth (available in train mode)
283 | # e: [B, m, c]
284 | # g: {0,1,... N_logits-1} ^ [B, N_steps], [B, N_steps]
285 |
286 | N_steps = self.N_steps
287 | N_logits = self.N_logits
288 | N_samples = self.N_samples
289 | N_features = self.N_features
290 | N_heads = self.N_heads
291 | device = self.device
292 | assert self.mode in ['train', 'infer']
293 |
294 | if self.mode == 'train':
295 | e, g = x # g: {0,1,... N_logits-1} ^ [B, N_steps+1], [B, N_steps+1]
296 | if not torch.is_tensor(g):
297 | g = torch.tensor(g).long()
298 | g_onehot = one_hot(g, num_classes=N_logits, shift=True).float().to(device) # [B, N_steps+1, N_logits+1]
299 | #FIXME: We haven't applied "shift" operation.
300 | # # Apply "start" sign.
301 | # g_onehot = torch.cat([torch.zeros_like(g_onehot[:, :1, :]), g_onehot], dim=1) # [B, N_steps+1, N_logits+1]
302 | o, z = self.predict_action_logits(g_onehot, e) # o: [B, N_steps+1, N_logits+1]; z: [B, N_steps+1, N_features*N_heads]
303 | return o[:, :-1, :], z[:, 0] # o: [B, N_steps+1, N_logits+1]; z[:, 0]: [B, N_features*N_heads]
304 |
305 | elif self.mode == 'infer':
306 | e = x[0] # e: [B, m, c]
307 | batch_size = e.shape[0]
308 | a = -2 * torch.ones((batch_size * N_samples, N_steps+1)).long().to(device) # a: {-2,-1,0,1, ... N_logits-1} ^ [N_samples, N_steps+1]
309 | a[:, 0] = -1 # Start sign.
310 | p = torch.ones((batch_size * N_samples,)).float().to(device)
311 |
312 | # We sample N_samples batchly.
313 | e = e.repeat_interleave(N_samples, dim=0) # e: [B*N_samples, m, c]
314 | for i in range(1, N_steps+1):
315 | o, z = self.predict_action_logits(one_hot(a, num_classes=N_logits).float().to(device), e) # z: [B*N_samples, N_steps+1, N_features*N_heads]
316 | prob = F.softmax(o[:, i-1, :-1], dim=1) # o[:, i-1, :-1]: [B*N_samples, N_logits]
317 | sampled_a = torch.multinomial(prob, 1).view(-1) # sampled_a = [B*N_samples,]
318 | # for s in range(batch_size * N_samples):
319 | # p[s] = p[s] * prob[s, sampled_a[s]]
320 | a[:, i] = sampled_a
321 |
322 | if i == 1:
323 | # z1 = z[0, 0].clone() # [N_features*N_heads]
324 | z1 = z.reshape((-1, N_samples, N_steps+1, N_features*N_heads))
325 | z1 = z1[:, 0, 0, :].clone() # [B, N_features*N_heads]
326 | return a[:, 1:].reshape((-1, N_samples, N_steps)),\
327 | p.reshape((-1, N_samples)),\
328 | z1 # [B, N_samples, N_steps], [B, N_samples], [B, N_features*N_heads]
329 |
330 |
331 | def set_mode(self, mode):
332 |
333 | assert mode in ["train", "infer"]
334 | self.mode = mode
335 |
336 |
337 | def set_samples_n(self,
338 | N_samples):
339 | self.N_samples = N_samples
340 |
341 |
342 | def predict_action_logits(self,
343 | a, e):
344 |
345 | N_steps = self.N_steps
346 | N_logits = self.N_logits
347 | N_features = self.N_features
348 | N_heads = self.N_heads
349 | N_layers = self.N_layers
350 | torso_feature_shape = self.torso_feature_shape
351 | device = self.device
352 |
353 | batch_size = a.shape[0]
354 |
355 | x = self.linear_1(a.reshape(batch_size*(N_steps+1), -1)) # [batch_size*N_steps, N_features*N_heads]
356 | x = self.pos_embed(torch.arange(0, (N_steps+1)).repeat(batch_size).float().view((-1,1)).to(device)) + x # [batch_size*N_steps, N_features*N_heads]
357 | x = x.reshape(batch_size, (N_steps+1), -1) # [batch_size, N_steps, N_features*N_heads]
358 |
359 | for layer in range(N_layers):
360 | x = self.self_layer_norms[layer](x) # [batch_size, N_steps, N_features*N_heads]
361 | c = self.self_attentions[layer]([x,]) # [batch_size, N_steps, N_features*N_heads]
362 | # if self.mode == 'train':
363 | # c = self.self_dropouts[layer](c)
364 | # else:
365 | # c = self.self_dropouts[layer].eval()(c)
366 | c = self.self_dropouts[layer](c)
367 | x = x + c # [batch_size, N_steps, N_features*N_heads]
368 |
369 | x = self.cross_layer_norms[layer](x) # [batch_size, N_steps, N_features*N_heads]
370 | c = self.cross_attentions[layer]([x, e]) # [batch_size, N_steps, N_features*N_heads]
371 | # if self.mode == 'train':
372 | # c = self.cross_dropouts[layer](c)
373 | # else:
374 | # c = self.cross_dropouts[layer].eval()(c)
375 | c = self.cross_dropouts[layer](c)
376 | x = x + c # [batch_size, N_steps, N_features*N_heads]
377 |
378 | o = self.linear_2(self.relu(x.reshape(-1, N_features*N_heads))).reshape(batch_size, (N_steps+1), N_logits+1) # [batch_size, N_steps, N_logits]
379 |
380 | return o, x
381 |
382 |
383 | class ValueHead(nn.Module):
384 |
385 | def __init__(self,
386 | N_layers=3,
387 | in_channel=2048,
388 | inter_channel=512,
389 | out_channel=8):
390 |
391 | super(ValueHead, self).__init__()
392 | self.out_channel = out_channel
393 | self.in_channel = in_channel
394 | self.inter_channel = inter_channel
395 | self.N_layers = N_layers
396 | self.mode = "train"
397 |
398 | self.in_linear = nn.Linear(in_channel, inter_channel)
399 | self.linaers = nn.ModuleList([nn.Linear(inter_channel, inter_channel) for _ in range(N_layers-1)])
400 | self.out_linear = nn.Linear(inter_channel, out_channel)
401 | self.relus = nn.ModuleList([nn.ReLU() for _ in range(N_layers)])
402 |
403 |
404 | def forward(self, x):
405 |
406 | # Input:
407 | # [z]. z: The feature in policy head.
408 |
409 | N_layers = self.N_layers
410 |
411 | z = x[0]
412 | # if self.mode == "infer":
413 | # z = z[None]
414 | z = self.relus[0](self.in_linear(z))
415 | for layer in range(N_layers-1):
416 | z = self.relus[layer+1](self.linaers[layer](z))
417 |
418 | q = self.out_linear(z)
419 | # if self.mode == "infer":
420 | # q = q[0]
421 | return q
422 |
423 |
424 | def set_mode(self, mode):
425 |
426 | assert mode in ["train", "infer"]
427 | self.mode = mode
428 |
429 |
430 | class Net(nn.Module):
431 | '''
432 | 网络部分
433 | '''
434 |
435 | def __init__(self,
436 | T=7,
437 | S_size=4,
438 | N_steps=6,
439 | coefficients=[0, 1, -1],
440 | N_samples=32,
441 | n_attentive=8,
442 | N_heads=32,
443 | N_features=64,
444 | policy_layers=2,
445 | device='cuda',
446 | channel=3,
447 | scalar_size=3,
448 | value_layers=3,
449 | inter_channel=512,
450 | out_channel=8,
451 | **kwargs):
452 | '''
453 | 初始化部分
454 | '''
455 |
456 | super(Net, self).__init__()
457 | # Parameters.
458 | self.T = T
459 | self.S_size = S_size
460 | self.N_steps = N_steps
461 | self.coefficients = coefficients
462 | token_len = 3 * S_size // N_steps
463 | N_logits = len(coefficients) ** token_len # len(F) ^ len(token)
464 | self.N_logits = N_logits
465 | self.token_len = token_len
466 | self.N_samples = N_samples
467 |
468 | # Network.
469 | self.torso = Torso(S_size=S_size, T=T,
470 | n_attentive=n_attentive,
471 | device=device,
472 | channel=channel,
473 | scalar_size=scalar_size)
474 | self.policy_head = PolicyHead(N_steps=N_steps,
475 | N_logits=N_logits,
476 | N_samples=N_samples,
477 | N_heads=N_heads,
478 | N_features=N_features,
479 | device=device,
480 | N_layers=policy_layers,
481 | torso_feature_shape=(3*S_size**2, channel))
482 | self.value_head = ValueHead(in_channel=N_features*N_heads,
483 | N_layers=value_layers,
484 | inter_channel=inter_channel,
485 | out_channel=out_channel)
486 | self.mode = "train"
487 |
488 |
489 | def forward(self, x):
490 |
491 | # Input:
492 | # If train mode:
493 | # Tensors of shape of [B,T,S,S,S]. First one is current tensor. (numpy)
494 | # Scalars of shape of [B,s]. (numpy)
495 | # (Groundtruth) of shape of [B, N_steps].
496 | # Elif infer mode:
497 | # Tensors of shape of [T,S,S,S]
498 | # Scalars of shape of [s,]
499 |
500 | if self.mode == 'train':
501 | states, scalars, g = x
502 | self.policy_head.set_mode("train")
503 | self.value_head.set_mode("train")
504 | self.torso.set_mode("train")
505 |
506 | e = self.torso([states, scalars])
507 | o, z1 = self.policy_head([e, g])
508 | q = self.value_head([z1])
509 |
510 | return o, q # o: [B, N_steps, N_logits]; q: [B, out_channels]
511 |
512 | elif self.mode == 'infer':
513 | states, scalars = x
514 | self.policy_head.set_mode("infer")
515 | self.value_head.set_mode("infer")
516 | self.torso.set_mode("infer")
517 |
518 | e = self.torso([states, scalars])
519 | a, p, z1 = self.policy_head([e])
520 | q = self.value_head([z1])
521 |
522 | return a, p, q #FIXME: Neet to process q.
523 | # a: {0,1,..., N_logits-1} ^ [B, N_samples, N_steps]; p: [B, N_samples,]; q: [B, out_channels,]
524 |
525 |
526 | def set_mode(self, mode):
527 |
528 | assert mode in ["train", "infer"]
529 | self.mode = mode
530 |
531 |
532 | def set_samples_n(self,
533 | N_samples):
534 | self.N_samples = N_samples
535 | self.policy_head.set_samples_n(N_samples)
536 |
537 |
538 | def logits_to_action(self, logits):
539 | '''
540 | logit: N_steps values of {0, 1, ..., N_logits - 1}.
541 | e.g.:
542 | If:
543 | token_len = 2
544 | coefficients = [0, 1, -1]
545 | N_steps = 6
546 | Then:
547 | [0, 1, 2, 3, 4, 5] -> [0 0 | 0 1 | 0 -1 | 1 0 | 1 1 | 1 -1 ]
548 | '''
549 | token_len = self.token_len
550 | coefficients = self.coefficients
551 | action = []
552 | for logit in logits: # Get one action
553 | token = []
554 | if logit == self.N_logits:
555 | raise
556 | for _ in range(token_len): # Get one token
557 | idx = logit % len(coefficients)
558 | token.append(coefficients[idx])
559 | logit = logit // len(coefficients)
560 | token.reverse()
561 | action.extend(token)
562 |
563 | action = np.array(action, dtype=np.int32).reshape((3, -1))
564 | return action
565 |
566 |
567 | def action_to_logits(self,
568 | action):
569 | '''
570 | action: A [3, S_size] array.
571 | '''
572 |
573 | # Break action into tokens.
574 | token_len = self.token_len
575 | coefficients = self.coefficients
576 | action = action.reshape((-1, token_len)) # [N_steps, token_len]
577 |
578 | # Get logits.
579 | logits = []
580 | for token in action: # Get one logit.
581 | # token = token.to_list()
582 | logit = 0
583 | if torch.is_tensor(token):
584 | token = torch.flip(token, dims=(0,))
585 | else:
586 | token = token[::-1]
587 | for idx, v in enumerate(token):
588 | logit += coefficients.index(v) * (len(coefficients) ** idx)
589 | logits.append(logit)
590 |
591 | return np.array(logits)
592 |
593 |
594 | def value(self, output, u_q=.75):
595 | '''
596 | 根据网络的输出, 得到效用值
597 | output: output = net(x)
598 | '''
599 | q = output[-1]
600 | q = q.detach().cpu().numpy()
601 | batch_size, out_channels = q.shape[0], q.shape[1]
602 |
603 | j = math.ceil(u_q * out_channels)
604 | return q, q[:, (j-1):].mean(axis=1)
605 |
606 |
607 | def policy(self, output):
608 | '''
609 | 根据网络的输出, 得到采样的策略
610 | output: output = net(x)
611 | '''
612 | assert len(output) == 3, "We need the output from infer mode."
613 |
614 | a, p, _ = output
615 | a, p = a.detach().cpu().numpy(), p.detach().cpu().numpy()
616 |
617 | batch_actions = []
618 | for batch in a:
619 | actions = []
620 | for logits in batch:
621 | actions.append(self.logits_to_action(logits))
622 | actions = np.stack(actions, axis=0)
623 | batch_actions.append(actions)
624 | batch_actions = np.stack(batch_actions, axis=0)
625 |
626 | return batch_actions, p
627 |
628 |
629 |
630 | if __name__ == '__main__':
631 | # torso = Torso().to('cuda')
632 | # test_input = [np.random.randint(-1, 1, (64, 7, 4, 4, 4)), np.random.randint(-1, 1, (64, 3))]
633 | # e = torso(test_input)
634 | # import pdb; pdb.set_trace()
635 |
636 | # policy_head = PolicyHead().to('cuda')
637 | # test_g = torch.tensor([0,1,2,3,4,5]).repeat(64).reshape(64, 6).to('cuda')
638 | # train_output = policy_head([e, test_g])
639 | # import pdb; pdb.set_trace()
640 |
641 | # value_head = ValueHead().to('cuda')
642 | # value_output = value_head([train_output[1]])
643 | # import pdb; pdb.set_trace()
644 |
645 | # net = Net().to('cuda')
646 | # test_tensor, test_scalar = np.random.randint(-1, 1, (7, 4, 4, 4)), np.random.randint(-1, 1, (3))
647 | # test_input = [test_tensor[None], test_scalar[None]]
648 | # test_g = torch.tensor([-1,0,1,2,3,4,5]).repeat(1).reshape(1, 7).to('cuda')
649 | # train_output = net([*test_input, test_g])
650 | # # import pdb; pdb.set_trace()
651 |
652 | # net.set_mode("infer")
653 | # test_input = [test_tensor, test_scalar]
654 | # infer_output = net(test_input)
655 | # import pdb; pdb.set_trace()
656 |
657 | # _, value = net.value(infer_output)
658 | # policy = net.policy(infer_output)
659 | # import pdb; pdb.set_trace()
660 |
661 | net = Net()
662 | logits = np.array([-1,0,1,2,2,1,1])
663 | action = net.logits_to_action(logits)
664 | logits = net.action_to_logits(action)
665 | import pdb; pdb.set_trace()
--------------------------------------------------------------------------------
/codes/scripts/check_redundancy.py:
--------------------------------------------------------------------------------
1 | import os
2 | from tqdm import tqdm
3 | import numpy as np
4 |
5 | S_size = 4
6 |
7 | data_name = "100000_S4T7_scalar3.npy"
8 | data_path = os.path.join(".", "data", "traj_data", data_name)
9 |
10 | data = np.load(data_path, allow_pickle=True)
11 | filtered_data = []; ct = 0
12 |
13 | for traj_idx, traj in tqdm(enumerate(data)):
14 | _, actions, _ = traj
15 | raw_r = len(actions)
16 | flag = False
17 | for (i, j) in [[0,1], [1,2], [2,0]]:
18 | _mat = np.zeros((S_size ** 2, raw_r), dtype=np.int32)
19 | for idx, action in enumerate(actions):
20 | _mat[:, idx] = np.outer(action[i], action[j]).reshape((-1,))
21 | if np.linalg.matrix_rank(_mat) < raw_r:
22 | ct += 1
23 | print("Found redundancy...")
24 | flag = True
25 | break
26 |
27 | if not flag:
28 | filtered_data.append(traj)
29 |
30 | import pdb; pdb.set_trace()
--------------------------------------------------------------------------------
/codes/trainer/Player.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from torch.utils.tensorboard import SummaryWriter
3 | import sys
4 | import os
5 | sys.path.append(os.path.abspath(os.path.join("..")))
6 | sys.path.append(os.path.abspath(os.path.join(".")))
7 | from codes.trainer.loss import QuantileLoss
8 | from codes.env import *
9 | from codes.mcts import *
10 | from codes.utils import *
11 | from codes.dataset import *
12 | from codes.multi_runner import *
13 |
14 |
15 | class Player():
16 |
17 | def __init__(self,
18 | net, env, mcts,
19 | exp_dir,
20 | simu_times=400,
21 | play_times=10,
22 | num_workers=256,
23 | device="cuda:1",
24 | noise=False):
25 |
26 | self.net = net
27 | self.env = env
28 | self.mcts = mcts
29 |
30 | net.to(device)
31 |
32 | self.exp_dir = exp_dir
33 | self.trainer_logger = SummaryWriter(log_dir=os.path.join(exp_dir, "log"))
34 |
35 | self.simu_times = simu_times
36 | self.play_times = play_times
37 | self.num_workers = num_workers
38 | self.device = device
39 | self.noise = noise
40 |
41 | self.call_ct = 0
42 |
43 |
44 | def play(self, warm_up=False) -> list:
45 | '''
46 | 进行一次Tensor Game, 得到游玩记录
47 | 返回: results
48 | '''
49 |
50 | num_workers = self.num_workers
51 | simu_times = self.simu_times
52 | play_times = self.play_times
53 | noise = self.noise
54 |
55 | if warm_up:
56 | simu_times = 40
57 | play_times = 1
58 | num_workers = 10
59 |
60 | results = []
61 | avg_steps = 0
62 | net = self.net
63 | env = self.env
64 | mcts = self.mcts
65 | net.set_mode("infer")
66 | net.eval()
67 |
68 | # wkdir.
69 | save_dir = os.path.join(self.exp_dir, "data")
70 | os.makedirs(save_dir, exist_ok=True)
71 | save_path = os.path.join(save_dir, "self_data.npy")
72 |
73 | # Load ckpt.
74 | latest_path = os.path.join(self.exp_dir, "ckpt", "latest.pth")
75 | ct = 0
76 | while True:
77 | if ct > 10000000:
78 | raise Exception
79 | try:
80 | self.load_model(ckpt_path=latest_path, to_device=self.device)
81 | break
82 | except:
83 | ct += 1
84 | continue
85 |
86 | # Multi runner.
87 | env_list = [copy.deepcopy(env) for _ in range(num_workers)]
88 | mcts_list = [copy.deepcopy(mcts) for _ in range(num_workers)]
89 | envs = ENVS(env_list)
90 | mctsf = MCTSF(mcts_list, simulate_times=simu_times)
91 |
92 | for game in (range(play_times)):
93 |
94 | envs.reset()
95 | state_list = envs.get_curstates()
96 | mctsf.reset(state_list)
97 | trajs = []
98 |
99 | while True:
100 | state_list = envs.get_curstates()
101 | action_list, _, _ = mctsf(state_list, net, noise=noise)
102 | envs.step(action_list)
103 | terminate_flag = envs.is_all_terminated()
104 | mctsf.move(action_list) # Move MCTS forward.
105 | # one_result.append([state_list, action_list]) # Record. (s, a).
106 | trajs.append([[state, action] for state, action in zip(state_list, action_list)])
107 |
108 | if terminate_flag:
109 | reward_list = envs.get_rewards()
110 | # for step in range(env.step_ct):
111 | # one_result[step] += [final_reward + step] # Final results. (s, a, r(s)).
112 | # # Note:
113 | # # a is not included in the history actions of s.
114 | for step, _trajs in enumerate(trajs):
115 | for idx, traj_state in enumerate(_trajs): # traj_state: [state, action]
116 | traj_state += [reward_list[idx] + (step)]
117 |
118 | # Prepare per traj.
119 | for idx in range(num_workers):
120 | one_traj = [_trajs[idx] for _trajs in trajs] # [ [s1, a1, r1], [s2, a2, r2] ... ]
121 | one_traj = [episode for episode in one_traj if not is_zero_tensor(episode[0])] # Filter the invalid state.
122 |
123 | states, actions, rewards = \
124 | [episode[0] for episode in one_traj], \
125 | [episode[1] for episode in one_traj], \
126 | [episode[2] for episode in one_traj]
127 |
128 | one_traj = [states, actions, rewards]
129 | results.append(one_traj)
130 |
131 | step_ct_list = envs.get_stepcts()
132 | batch_avg_step_ct = np.array(step_ct_list).mean()
133 |
134 | avg_steps = (game / (game+1)) * avg_steps + batch_avg_step_ct / (game+1)
135 | break
136 |
137 | # Insert a bubble.
138 | # results: (n_trajs, 3, traj_length*)
139 | one_state, one_action, one_reward = results[-1][0][0], results[-1][1][0], results[-1][2][0]
140 | bubble_traj = [[one_state,], [one_action,], [one_reward,]]
141 | results.append(bubble_traj)
142 | if save_path is not None:
143 | results = np.array(results, dtype=object) # (n_trajs, 3, traj_length*). *traj_length is not fixed. Decompose order...
144 | np.save(save_path, results)
145 | print("Avg step is: %f" % avg_steps)
146 |
147 | self.trainer_logger.add_text("Self-play", str(one_traj), global_step=self.call_ct)
148 | self.call_ct += 1
149 |
150 | return results, one_traj
151 |
152 |
153 | def run(self):
154 |
155 | self.play(warm_up=True)
156 | while True:
157 | self.play()
158 | print("Finish playing!")
159 |
160 |
161 | def load_model(self, ckpt_path, only_weight=False, to_device="cuda:0"):
162 | ckpt = torch.load(ckpt_path)
163 | self.net.load_state_dict(ckpt['model'])
164 | return ckpt['iter']
--------------------------------------------------------------------------------
/codes/trainer/Trainer.py:
--------------------------------------------------------------------------------
1 | import time
2 | import yaml
3 | import copy
4 | import random
5 | import shutil
6 | from tqdm import tqdm
7 | import torch
8 | from torch.utils.data import DataLoader, Subset
9 | from torch.utils.tensorboard import SummaryWriter
10 | import numpy as np
11 |
12 | import sys
13 | import os
14 | sys.path.append(os.path.abspath(os.path.join("..")))
15 | sys.path.append(os.path.abspath(os.path.join(".")))
16 | from codes.trainer.loss import *
17 | from codes.trainer.Player import *
18 | from codes.env import *
19 | from codes.mcts import *
20 | from codes.utils import *
21 | from codes.dataset import *
22 | from codes.multi_runner import *
23 |
24 | class Trainer():
25 | '''
26 | 用于训练网络
27 | '''
28 |
29 | def __init__(self,
30 | net: Net,
31 | env: Environment,
32 | mcts: MCTS,
33 | S_size=4,
34 | T=7,
35 | coefficients=[0, 1, -1],
36 | batch_size=1024,
37 | iters_n=50000,
38 | exp_dir="exp",
39 | exp_name="debug",
40 | device="cuda:0",
41 | self_play_device="cuda:1",
42 | lr=5e-3,
43 | weight_decay=1e-5,
44 | step_size=40000,
45 | gamma=.1,
46 | a_weight=.5,
47 | v_weight=.5,
48 | save_freq=10000,
49 | temp_save_freq=2500,
50 | self_play_freq=10,
51 | self_play_buffer=100000,
52 | grad_clip=4.0,
53 | val_freq=2000,
54 | all_kwargs=None):
55 | '''
56 | 初始化一个Trainer.
57 | 包含net, env和MCTS
58 | '''
59 | self.env = env
60 | self.net = net
61 | self.mcts = mcts
62 | self.S_size = S_size
63 | self.T = T
64 | self.coefficients = coefficients
65 |
66 | self.self_examples = []
67 | self.synthetic_examples = []
68 |
69 | self.entropy_loss = torch.nn.CrossEntropyLoss()
70 | self.quantile_loss = QuantileLoss()
71 | self.a_weight = a_weight
72 | self.v_weight = v_weight
73 |
74 | self.optimizer_a = torch.optim.AdamW(net.parameters(),
75 | weight_decay=weight_decay,
76 | lr=lr)
77 | self.scheduler_a = torch.optim.lr_scheduler.StepLR(self.optimizer_a,
78 | step_size=step_size,
79 | gamma=gamma)
80 | self.optimizer_v = torch.optim.AdamW(net.parameters(),
81 | weight_decay=weight_decay,
82 | lr=lr)
83 | self.scheduler_v = torch.optim.lr_scheduler.StepLR(self.optimizer_v,
84 | step_size=step_size,
85 | gamma=gamma)
86 |
87 | self.batch_size = batch_size
88 | self.iters_n = iters_n
89 | self.grad_clip = grad_clip
90 | self.save_freq = save_freq
91 | self.temp_save_freq = temp_save_freq
92 | self.self_play_freq = self_play_freq
93 | self.self_play_buffer = self_play_buffer
94 | self.val_freq = val_freq
95 |
96 | self.exp_dir = exp_dir
97 | self.save_dir = os.path.join(exp_dir, exp_name, str(int(time.time())))
98 | self.log_dir = os.path.join(self.save_dir, "log")
99 | self.data_dir = os.path.join(self.save_dir, "data")
100 |
101 | self.device = device
102 | self.self_play_device = self_play_device
103 | self.net.to(device)
104 | self.all_kwargs = all_kwargs
105 |
106 |
107 | def generate_synthetic_examples(self,
108 | prob=[.8, .1, .1],
109 | samples_n=10000,
110 | R_limit=12,
111 | save_path=None,
112 | save_type="traj") -> list:
113 | '''
114 | 生成人工合成的Tensor examples
115 | 返回: results
116 | '''
117 | assert save_type in ["traj", "tuple"]
118 |
119 | S_size = self.S_size
120 | coefficients = self.coefficients
121 | T = self.T
122 |
123 | total_results = []
124 | for _ in tqdm(range(samples_n)):
125 | R = random.randint(1, R_limit)
126 | for _ in range(10000):
127 | sample = np.zeros((S_size, S_size, S_size), dtype=np.int32)
128 | states = []
129 | actions = []
130 | rewards = []
131 | for r in range(1, (R+1)):
132 | ct = 0
133 | while True:
134 | u = np.random.choice(coefficients, size=(S_size,), p=prob, replace=True)
135 | v = np.random.choice(coefficients, size=(S_size,), p=prob, replace=True)
136 | w = np.random.choice(coefficients, size=(S_size,), p=prob, replace=True)
137 | ct += 1
138 | if not is_zero_tensor(outer(u, v, w)):
139 | break
140 | if ct > 100000:
141 | raise Exception("Oh my god...")
142 | sample = sample + outer(u, v, w)
143 | action = np.stack([u, v, w], axis=0)
144 | actions.append(canonicalize_action(action))
145 | states.append(sample.copy())
146 | rewards.append(-r)
147 |
148 | # Check redundancy.
149 | red_flag = False
150 | for (i, j) in [[0,1], [1,2], [2,0]]:
151 | _mat = np.zeros((S_size ** 2, R), dtype=np.int32)
152 | for idx, action in enumerate(actions):
153 | _mat[:, idx] = np.outer(action[i], action[j]).reshape((-1,))
154 | if np.linalg.matrix_rank(_mat) < R:
155 | red_flag = True
156 | break
157 |
158 | if red_flag:
159 | continue
160 | break
161 |
162 | # Reformulate the results.
163 | if save_type == "tuple":
164 | states.reverse(); actions.reverse(); rewards.reverse()
165 | actions_tensor = [action2tensor(action) for action in actions]
166 | for idx, state in enumerate(states):
167 | tensors = np.zeros((T, S_size, S_size, S_size), dtype=np.int32)
168 | tensors[0] = state # state.
169 | if idx != 0:
170 | # History actions.
171 | tensors[1:(idx+1)] = np.stack(reversed(actions_tensor[max(idx-(T-1), 0):idx]), axis=0)
172 | scalars = np.array([idx, idx, idx]) #FIXME: Havn't decided the scalars.
173 |
174 | cur_state = [tensors, scalars]
175 | action = actions[idx]
176 | reward = rewards[idx]
177 | total_results.append([cur_state, action, reward])
178 |
179 | else:
180 | traj = [states, actions, rewards] # Note: Synthesis order...
181 | total_results.append(traj)
182 |
183 | if save_path is not None:
184 | np.save(save_path, np.array(total_results, dtype=object))
185 |
186 | return total_results
187 |
188 |
189 | def learn_one_batch(self,
190 | batch_example) -> torch.autograd.Variable:
191 | '''
192 | 对一个元组进行学习
193 | '''
194 |
195 | # Groundtruth.
196 | s, a_gt, v_gt = batch_example # s: [tensor, scalar]
197 | a_gt = a_gt.long().to(self.device)
198 | v_gt = v_gt.float().to(self.device)
199 |
200 | # Network infer.
201 | self.net.set_mode("train")
202 | output = self.net([*s, a_gt])
203 | o, q = output # o: [batch_size, N_steps, N_logits], q: [batch_size, N_quantiles]
204 |
205 | # Losses.
206 | v_loss = self.quantile_loss(q, v_gt) # v_gt: [batch_size,]
207 | o = o.transpose(1,2) # o: [batch_size, N_logits, N_steps]
208 | a_loss = self.entropy_loss(o, a_gt) # a_gt: [batch_size, N_steps], o: [batch_size, N_logits, N_steps]
209 | loss = self.v_weight * v_loss + self.a_weight * a_loss
210 |
211 | del a_gt, v_gt
212 |
213 | return loss, v_loss, a_loss
214 |
215 |
216 | def val_one_episode(self,
217 | episode):
218 | '''
219 | 对一个元组进行验证, 打印输出
220 | '''
221 |
222 | state, action, reward = episode
223 | tensor, scalar = state
224 |
225 | self.net.set_mode("infer")
226 | self.net.set_samples_n(4)
227 | output = self.net(state)
228 | a, p, _ = output
229 | a = a.detach().cpu().numpy()
230 | q, v = self.net.value(output)
231 | policy, p = self.net.policy(output)
232 |
233 | policy, p, q, v = policy[0], p[0], q[0], v[0]
234 |
235 | self.net.set_mode("train")
236 | tensor, scalar, action = tensor[None], scalar[None], action[None]
237 | state = [tensor, scalar]
238 | o, _ = self.net([*state, action])
239 | o = o.detach().cpu().numpy()[0] # o: [N_steps, N_logits]
240 |
241 | log_txt = "\n".join(
242 | ["\nState: \n", str(state[0][0, 0]),
243 | "\nGt action: \n", str(self.net.logits_to_action(action[0])),
244 | "\nGt logit: \n", str(action[0]),
245 | "\nInfer actions: \n", str(policy),
246 | "\nInfer logits: \n", str(a),
247 | "\nprob: \n", str(p),
248 | "\nGt value: \n", str(reward),
249 | "\nInfer value: \n", str(v),
250 | "\nquantile: \n", str(q),
251 | *["\nTop 5 logit for step %d\n: " % step + str(np.argsort(o[step])[-5:]) for step in range(self.net.N_steps)]]
252 | )
253 |
254 | del a, o, p, _, output, q, v
255 |
256 | return log_txt
257 |
258 |
259 | def learn(self,
260 | resume=None,
261 | only_weight=False,
262 | example_path=None,
263 | self_example_path=None,
264 | save_type="traj",
265 | self_play=False):
266 | '''
267 | 训练的主函数
268 | '''
269 | optimizer_a = self.optimizer_a
270 | scheduler_a = self.scheduler_a
271 | optimizer_v = self.optimizer_v
272 | scheduler_v = self.scheduler_v
273 | batch_size = self.batch_size
274 | self_play_freq = self.self_play_freq
275 | self_play_buffer = self.self_play_buffer
276 |
277 | # Tensorboard.
278 | os.makedirs(self.save_dir)
279 | os.makedirs(self.log_dir)
280 | self.log_writer = SummaryWriter(self.log_dir)
281 |
282 | # Save config.
283 | all_kwargs = self.all_kwargs
284 | cfg_path = os.path.join(self.save_dir, "config.yaml")
285 | with open(cfg_path, 'w') as f:
286 | yaml.dump(all_kwargs, f)
287 |
288 | if resume is not None:
289 | # Load model.
290 | old_iter = self.load_model(resume, only_weight)
291 | # Copy log file.
292 | old_exp_dir = os.path.join(os.path.dirname(resume), '..')
293 | # os.system("cp -r %s %s" % (os.path.join(old_exp_dir, 'log', '*'), self.log_dir))
294 | for log_f in os.listdir(os.path.join(old_exp_dir, "log")):
295 | shutil.copy(os.path.join(old_exp_dir, "log", log_f), self.log_dir)
296 | else:
297 | old_iter = 0
298 |
299 | # Save ckpt.
300 | ckpt_name = "latest.pth"
301 | self.save_model(ckpt_name, old_iter)
302 |
303 | # 1. Get synthetic examples.
304 | if example_path is not None:
305 | self.synthetic_examples.extend(self.load_examples(example_path))
306 | else:
307 | self.synthetic_examples.extend(self.generate_synthetic_examples(samples_n=3000))
308 |
309 | if self_example_path is not None:
310 | self.self_examples.extend(self.load_examples(self_example_path))
311 |
312 | # Dataloader.
313 | dataset = TupleDataset(T=self.T,
314 | S_size=self.S_size,
315 | N_steps=self.net.N_steps,
316 | coefficients=self.coefficients,
317 | self_data=self.self_examples,
318 | synthetic_data=self.synthetic_examples)
319 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
320 | loader = iter(dataloader)
321 | epoch_ct = 0
322 |
323 | for i in tqdm(range(old_iter, self.iters_n)):
324 |
325 | # 2. self-play for data.
326 | # if i % self_play_freq == 0:
327 | # self.self_examples.extend(self.play(200 if i < 50000 else 800))
328 |
329 | try:
330 | batch_example = next(loader)
331 | except StopIteration:
332 | dataloader.dataset._permutate_traj()
333 | if self_play:
334 | self_examples = self.get_self_examples() # New self-play data.
335 | if self_examples is not None:
336 | print("Detect new self-data!")
337 | self.self_examples.extend(self_examples)
338 | self.self_examples = self.self_examples[-self_play_buffer:]
339 | np.save(os.path.join(self.data_dir, "total_self_data.npy"), np.array(self.self_examples, dtype=object)) # Whole buffer.
340 | synthetic_examples_n = 2000 if i > 50000 else 100000
341 | dataset = TupleDataset(T=self.T,
342 | S_size=self.S_size,
343 | N_steps=self.net.N_steps,
344 | coefficients=self.coefficients,
345 | self_data=self.self_examples,
346 | synthetic_data=random.sample(self.synthetic_examples, synthetic_examples_n))
347 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
348 | else:
349 | print("No detect new self-data...")
350 |
351 | loader = iter(dataloader)
352 | batch_example = next(loader)
353 | print("Epoch: %d finish." % epoch_ct)
354 | epoch_ct += 1
355 |
356 | # 此处进行多进程优化
357 | # todo: 什么时候更新网络参数
358 | optimizer_a.zero_grad()
359 | optimizer_v.zero_grad()
360 | loss, v_loss, a_loss = self.learn_one_batch(batch_example)
361 | loss.backward()
362 | torch.nn.utils.clip_grad_norm(self.net.parameters(),
363 | max_norm=self.grad_clip)
364 | optimizer_a.step()
365 | optimizer_v.step()
366 | scheduler_a.step()
367 | scheduler_v.step()
368 |
369 | # 添加logger部分
370 | if i % 20 == 0:
371 | # print("Loss: %f, v_loss: %f, a_loss: %f" %
372 | # (loss.detach().cpu().item(), v_loss.detach().cpu().item(), a_loss.detach().cpu().item()))
373 | self.log_writer.add_scalar("loss", loss.detach().cpu().item(), global_step=i)
374 | self.log_writer.add_scalar("v_loss", v_loss.detach().cpu().item(), global_step=i)
375 | self.log_writer.add_scalar("a_loss", a_loss.detach().cpu().item(), global_step=i)
376 |
377 | if i % self.save_freq == 0:
378 | ckpt_name = "it%07d.pth" % i
379 | self.save_model(ckpt_name, i)
380 |
381 | if i % self.temp_save_freq == 0:
382 | ckpt_name = "latest.pth"
383 | self.save_model(ckpt_name, i)
384 |
385 | if i % self.val_freq == 0:
386 | val_episode = dataset[random.randint(0, len(dataset)-1)]
387 | log_txt = self.val_one_episode(val_episode)
388 | self.log_writer.add_text("Infer", log_txt, global_step=i)
389 |
390 | self.save_model("final.pth", i)
391 |
392 |
393 | def infer(self,
394 | init_state=None,
395 | no_base_change=True,
396 | mcts_simu_times=10000,
397 | mcts_samples_n=16,
398 | step_limit=12,
399 | resume=None,
400 | vis=False,
401 | noise=False,
402 | log=True):
403 |
404 | log_actions = []
405 |
406 | assert resume is not None, "No meaning for random init infer."
407 | self.load_model(resume)
408 | if log:
409 | exp_dir = os.path.join(os.path.dirname(resume), '..')
410 | infer_log_dir = os.path.join(exp_dir, "infer")
411 | os.makedirs(infer_log_dir, exist_ok=True)
412 | infer_log_f = os.path.join(infer_log_dir, str(int(time.time()))+'.txt')
413 |
414 | net = self.net
415 | env = self.env
416 | mcts = self.mcts
417 |
418 | env.reset(init_state, no_base_change)
419 | net.set_mode("infer")
420 | net.set_samples_n(mcts_samples_n)
421 | net.eval()
422 | mcts.reset(env.cur_state, simulate_times=mcts_simu_times, R_limit=step_limit)
423 | env.R_limit = step_limit + 1
424 |
425 | step_ct = 0
426 | for step in tqdm(range(step_limit)):
427 | print("Current state is (step%d):" % step)
428 | print(env.cur_state)
429 |
430 | action, actions, pi, log_txt = mcts(env.cur_state, net, log=True, noise=noise)
431 | if vis:
432 | mcts.visualize()
433 | print("We choose action(step%d):" % step)
434 | print(action)
435 | terminate_flag = env.step(action) # Will change self.cur_state.
436 | mcts.move(action) # Move MCTS forward.
437 | log_actions.append(action)
438 |
439 | if log:
440 | with open(infer_log_f, "a") as f:
441 | f.write(log_txt)
442 | f.write("\n\n\n")
443 |
444 | if terminate_flag:
445 | step_ct = step + 1
446 | print("We get to the end!")
447 | break
448 |
449 |
450 | print("Final result:")
451 | print(env.cur_state)
452 |
453 | print("Actions are:")
454 | print(np.stack(log_actions, axis=0))
455 |
456 | if log:
457 | with open(infer_log_f, "a") as f:
458 | f.write("\n\n\n")
459 | f.write("\nFinal result:\n")
460 | f.write("\n" + str(env.cur_state) + "\n")
461 | f.write("\nActions are:\n")
462 | f.write("\n" + str(np.stack(log_actions, axis=0)) + "\n")
463 | f.write("\n\n\n")
464 | f.write("\nStep ct: %d\n" % step_ct)
465 |
466 | return step_ct
467 |
468 |
469 | def filter_train_data(self,
470 | n=100,
471 | example_path=None,
472 | mcts_simu_times=10000,
473 | mcts_samples_n=16,
474 | step_limit=12,
475 | resume=None):
476 |
477 | assert resume is not None, "No meaning for random init infer."
478 | assert example_path is not None
479 |
480 | synthetic_examples = self.load_examples(example_path)
481 | test_data = random.sample(list(synthetic_examples), n)
482 |
483 | exp_dir = os.path.join(os.path.dirname(resume), '..')
484 | infer_log_dir = os.path.join(exp_dir, "infer")
485 | os.makedirs(infer_log_dir, exist_ok=True)
486 | infer_log_f = os.path.join(infer_log_dir, str(int(time.time()))+'.txt')
487 |
488 | for traj in (test_data):
489 | # import pdb; pdb.set_trace()
490 | states, _, _ = traj
491 | raw_r = len(states)
492 | init_state = states[-1]
493 |
494 | result_r = self.infer(init_state=init_state,
495 | no_base_change=True,
496 | mcts_samples_n=mcts_samples_n,
497 | mcts_simu_times=mcts_simu_times,
498 | step_limit=step_limit,
499 | resume=resume,
500 | log=False)
501 |
502 | with open(infer_log_f, "a") as f:
503 | f.write("%d %d\n" % (raw_r, result_r))
504 |
505 |
506 |
507 | def save_model(self, ckpt_name, iter):
508 | save_dir = os.path.join(self.save_dir, "ckpt")
509 | os.makedirs(save_dir, exist_ok=True)
510 | save_path = os.path.join(save_dir, ckpt_name)
511 | torch.save({'model': self.net.state_dict(),
512 | 'iter': iter,
513 | 'optimizer_a': self.optimizer_a.state_dict(),
514 | 'optimizer_v': self.optimizer_v.state_dict(),
515 | 'scheduler_a': self.scheduler_a.state_dict(),
516 | 'scheduler_v': self.scheduler_v.state_dict()}, save_path)
517 |
518 |
519 | def load_model(self, ckpt_path, only_weight=False, to_device="cuda:0"):
520 | ckpt = torch.load(ckpt_path)
521 | self.net.load_state_dict(ckpt['model'])
522 | if not only_weight:
523 | self.optimizer_a.load_state_dict(ckpt['optimizer_a'])
524 | self.optimizer_v.load_state_dict(ckpt['optimizer_a'])
525 | self.scheduler_a.load_state_dict(ckpt['scheduler_v'])
526 | self.scheduler_v.load_state_dict(ckpt['scheduler_v'])
527 | return ckpt['iter']
528 |
529 |
530 | def load_examples(self, example_path):
531 | return np.load(example_path, allow_pickle=True)
532 |
533 |
534 | def get_self_examples(self):
535 | newest_data_path = os.path.join(self.data_dir, "self_data.npy") # New data from player.
536 | old_data_path = os.path.join(self.data_dir, "self_data_old.npy")
537 | if os.path.exists(newest_data_path):
538 | ct = 0
539 | while True:
540 | if ct > 10000000:
541 | raise Exception
542 | try:
543 | self_examples = self.load_examples(newest_data_path)
544 | os.system("mv %s %s" % (newest_data_path, old_data_path))
545 | return self_examples
546 | except:
547 | ct += 1
548 | continue
549 | return None
550 |
551 |
552 |
553 | if __name__ == '__main__':
554 |
555 | conf_path = "./config/my_conf.yaml"
556 | with open(conf_path, 'r', encoding="utf-8") as f:
557 | kwargs = yaml.load(f.read(), Loader=yaml.FullLoader)
558 |
559 | net = Net(**kwargs["net"])
560 | mcts = MCTS(**kwargs["mcts"],
561 | init_state=None)
562 | env = Environment(**kwargs["env"],
563 | init_state=None)
564 | trainer = Trainer(**kwargs["trainer"],
565 | net=net, env=env, mcts=mcts,
566 | all_kwargs=kwargs)
567 |
568 |
569 | # import pdb; pdb.set_trace()
570 | # res = trainer.play()
571 | # res = trainer.generate_synthetic_examples()
572 | # trainer.learn(example_path="./data/100000_T5_scalar3.npy")
573 | # import pdb; pdb.set_trace()
574 | # trainer.load_model("./exp/debug/1680630182/ckpt/it0020000.pth")
575 | # trainer.infer()
576 | # trainer.infer(resume="./exp/debug/1680764978/ckpt/it0002000.pth")
577 | # import pdb; pdb.set_trace()
578 | # trainer.generate_synthetic_examples(samples_n=3000, save_path="./data/3000_T5_scalar3.npy")
579 | # trainer.learn(resume=None,
580 | # example_path="./data/100000_T5_scalar3.npy")
581 | # trainer.learn(resume=None)
582 | trainer.learn(resume=None,
583 | example_path="./data/3000_T5_scalar3.npy")
--------------------------------------------------------------------------------
/codes/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | from codes.trainer.Trainer import *
2 | from codes.trainer.Player import *
--------------------------------------------------------------------------------
/codes/trainer/__pycache__/Player.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/trainer/__pycache__/Player.cpython-39.pyc
--------------------------------------------------------------------------------
/codes/trainer/__pycache__/Trainer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/trainer/__pycache__/Trainer.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/trainer/__pycache__/Trainer.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/trainer/__pycache__/Trainer.cpython-39.pyc
--------------------------------------------------------------------------------
/codes/trainer/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/trainer/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/trainer/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/trainer/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/codes/trainer/__pycache__/loss.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/trainer/__pycache__/loss.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/trainer/__pycache__/loss.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/trainer/__pycache__/loss.cpython-39.pyc
--------------------------------------------------------------------------------
/codes/trainer/loss.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.autograd import Variable
6 |
7 | class CrossEntropyLoss(nn.Module):
8 | pass
9 |
10 |
11 | class QuantileLoss(nn.Module):
12 |
13 | def __init__(self,
14 | delta=1,
15 | device='cuda'):
16 | super(QuantileLoss, self).__init__()
17 | self.delta = delta
18 | self.device = device
19 |
20 | def forward(self, out, label) -> Variable:
21 | '''
22 | out: q -> [batch_size, N_quantiles]
23 | label: g -> [batch_size, ].
24 | '''
25 | n = out.shape[1]
26 | batch_size = out.shape[0]
27 | # label = torch.ones_like(out) * label
28 | label = label.reshape((batch_size, 1)).repeat(1, n).to(self.device)
29 | tau = ((torch.arange(n) + .5 )/ n).repeat(batch_size, 1).to(self.device)
30 | d = label - out
31 | h = F.huber_loss(out, label, reduction='none', delta=self.delta)
32 | k = torch.abs(tau - (d < 0).float())
33 | loss = torch.mean(k * h)
34 | return loss
--------------------------------------------------------------------------------
/codes/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from codes.utils.util_functions import *
--------------------------------------------------------------------------------
/codes/utils/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/utils/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/utils/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/utils/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/codes/utils/__pycache__/util_functions.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/utils/__pycache__/util_functions.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/utils/__pycache__/util_functions.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YiwenAI/OpenTensor/9727d20732ceed34eac108c76e23316df32c4981/codes/utils/__pycache__/util_functions.cpython-39.pyc
--------------------------------------------------------------------------------
/codes/utils/util_functions.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | '''
5 | Note that:
6 | All function use np as input and output. (Except one_hot)
7 | '''
8 |
9 | def numpy_cvt(a):
10 | if torch.is_tensor(a):
11 | return a.numpy()
12 | return a
13 |
14 | def outer(x, y, z):
15 | # 得到三维张量,三维分别表示xyz
16 | return np.einsum('i,j,k->ijk', x, y, z, dtype=np.int32, casting="same_kind")
17 |
18 | def action2tensor(action):
19 | u, v, w = action
20 | return np.einsum('i,j,k->ijk', u, v, w, dtype=np.int32, casting="same_kind")
21 |
22 | def is_zero_tensor(tensor):
23 | return np.all(tensor == 0)
24 |
25 | def is_equal(a, b):
26 | assert a.shape == b.shape
27 | return np.all((a - b) == 0)
28 |
29 | def canonicalize_action(action):
30 | u, v, w = action
31 | flag_u, flag_v = 1, 1
32 | for e in u:
33 | if e != 0:
34 | flag_u = ((e > 0) * 2 - 1)
35 | u = (u * flag_u).astype(np.int32)
36 | break
37 | for e in v:
38 | if e != 0:
39 | flag_v = ((e > 0) * 2 - 1)
40 | v = (v * flag_v).astype(np.int32)
41 | break
42 | w = (w * flag_v * flag_u).astype(np.int32)
43 | return np.stack([u, v, w])
44 |
45 | def one_hot(a_s, num_classes, shift=False):
46 | '''
47 | Note: We return a size of num_classes+1 array.
48 | '''
49 | if len(a_s.shape) == 1:
50 | result = torch.zeros((a_s.shape[0], num_classes+1)).long()
51 | for idx, a in enumerate(a_s):
52 | if a == -2:
53 | continue
54 | result[idx, a] = 1
55 | if shift: # Append SOS.
56 | result = torch.cat([torch.zeros((1, num_classes+1)).long(), result], dim=0)
57 | result[0, -1] = 1
58 | return result
59 | elif len(a_s.shape) == 2:
60 | result = torch.zeros((a_s.shape[0], a_s.shape[1], num_classes+1)).long()
61 | for batch, a_batch in enumerate(a_s):
62 | for idx, a in enumerate(a_batch):
63 | if a == -2:
64 | continue
65 | result[batch, idx, a] = 1
66 | if shift: # Append SOS.
67 | result = torch.cat([torch.zeros((a_s.shape[0], 1, num_classes+1)).long(), result], dim=1)
68 | result[:, 0, -1] = 1
69 | return result
70 |
71 | def change_basis_tensor(tensor,
72 | trans_mat):
73 | return np.einsum('ij, kl, mn, jln -> ikm',
74 | trans_mat, trans_mat, trans_mat, tensor, casting="same_kind", dtype=np.int32)
75 |
76 | def random_action(coefficients=[0,1,-1],
77 | prob=[.8,.1,.1],
78 | S_size=4):
79 | ct = 0
80 | while True:
81 | u = np.random.choice(coefficients, size=(S_size,), p=prob, replace=True)
82 | v = np.random.choice(coefficients, size=(S_size,), p=prob, replace=True)
83 | w = np.random.choice(coefficients, size=(S_size,), p=prob, replace=True)
84 | ct += 1
85 | if not is_zero_tensor(outer(u, v, w)):
86 | break
87 | if ct > 100000:
88 | raise Exception("Oh my god...")
89 | return np.stack([u, v, w], axis=0)
90 |
91 | def terminate_rank_approx(tensor):
92 | assert len(tensor.shape) == 3
93 | rank_approx = 0
94 | for z_idx in range(tensor.shape[-1]):
95 | rank_approx += np.linalg.matrix_rank(np.mat(tensor[..., z_idx], dtype=np.int32))
96 |
97 | return rank_approx
--------------------------------------------------------------------------------
/config/S_4.yaml:
--------------------------------------------------------------------------------
1 | # base:
2 | # S_size: 4
3 | # T: 5
4 | # coefficients:
5 | # - 0
6 | # - 1
7 | # - -1
8 | # exp_dir: "exp"
9 | # exp_name: "debug"
10 | # device: "cuda"
11 |
12 | trainer:
13 | T: 7
14 | S_size: 4
15 | coefficients:
16 | - 0
17 | - 1
18 | - -1
19 | exp_dir: "exp"
20 | exp_name: "S4T7_selfplay"
21 | device: "cuda:0"
22 | self_play_device: "cuda:1"
23 | lr: 0.0001
24 | weight_decay: 0.00001
25 | gamma: .1
26 | step_size: 10000000
27 | iters_n: 3000000
28 | batch_size: 1024
29 | a_weight: .1
30 | v_weight: .9
31 | save_freq: 25000
32 | temp_save_freq: 2500
33 | self_play_buffer: 10000
34 | self_play_freq: 25
35 | grad_clip: 4.0
36 | val_freq: 1000
37 |
38 | net:
39 | T: 7
40 | S_size: 4
41 | coefficients:
42 | - 0
43 | - 1
44 | - -1
45 | device: 'cuda:0'
46 |
47 | N_steps: 3
48 | N_samples: 32
49 | n_attentive: 6
50 | N_heads: 32
51 | N_features: 64
52 | channel: 3
53 | scalar_size: 3
54 | policy_layers: 2
55 | value_layers: 3
56 | inter_channel: 512
57 | out_channel: 8
58 |
59 | mcts:
60 | simulate_times: 400
61 | R_limit: 8
62 |
63 | env:
64 | R_limit: 8
65 | T: 7
66 | S_size: 4
67 |
--------------------------------------------------------------------------------
/config/S_4_remote.yaml:
--------------------------------------------------------------------------------
1 | base: &base
2 | S_size: 4
3 | T: 7
4 | coefficients:
5 | - 0
6 | - 1
7 | - -1
8 | exp_dir: "exp"
9 | exp_name: "S4T7_selfplay"
10 | device: "cuda:0"
11 | self_play_device: "cuda:1"
12 |
13 | trainer:
14 | <<: *base
15 | lr: 0.0001
16 | weight_decay: 0.00001
17 | gamma: .1
18 | step_size: 10000000
19 | iters_n: 3000000
20 | batch_size: 1024
21 | a_weight: .1
22 | v_weight: .9
23 | save_freq: 25000
24 | temp_save_freq: 2500
25 | self_play_buffer: 10000
26 | self_play_freq: 25
27 | grad_clip: 4.0
28 | val_freq: 1000
29 |
30 | net:
31 | <<: *base
32 | N_steps: 3
33 | N_samples: 32
34 | n_attentive: 6
35 | N_heads: 32
36 | N_features: 64
37 | channel: 3
38 | scalar_size: 3
39 | policy_layers: 2
40 | value_layers: 3
41 | inter_channel: 512
42 | out_channel: 8
43 |
44 | mcts:
45 | <<: *base
46 | simulate_times: 400
47 |
48 | env:
49 | <<: *base
50 | R_limit: 8
51 |
--------------------------------------------------------------------------------
/config/S_9.yaml:
--------------------------------------------------------------------------------
1 | # base:
2 | # S_size: 4
3 | # T: 5
4 | # coefficients:
5 | # - 0
6 | # - 1
7 | # - -1
8 | # exp_dir: "exp"
9 | # exp_name: "debug"
10 | # device: "cuda"
11 |
12 | trainer:
13 | T: 7
14 | S_size: 9
15 | coefficients:
16 | - 0
17 | - 1
18 | - -1
19 | exp_dir: "exp"
20 | exp_name: "first_exp"
21 | device: "cuda"
22 |
23 | lr: 0.0001
24 | weight_decay: 0.00001
25 | gamma: .1
26 | step_size: 10000000
27 | iters_n: 3000000
28 | batch_size: 2048
29 | a_weight: .5
30 | v_weight: .5
31 | save_freq: 50000
32 | self_play_buffer: 100000
33 | self_play_freq: 10
34 | grad_clip: 4.0
35 | val_freq: 1000
36 |
37 | net:
38 | T: 7
39 | S_size: 9
40 | coefficients:
41 | - 0
42 | - 1
43 | - -1
44 | device: 'cuda'
45 |
46 | N_steps: 12
47 | N_samples: 32
48 | n_attentive: 4
49 | N_heads: 16
50 | N_features: 32
51 | channel: 3
52 | scalar_size: 3
53 | policy_layers: 2
54 | value_layers: 3
55 | inter_channel: 256
56 | out_channel: 8
57 |
58 | mcts:
59 | simulate_times: 400
60 |
61 | env:
62 | R_limit: 30
63 | T: 7
64 | S_size: 9
65 |
--------------------------------------------------------------------------------
/config/my_conf.yaml:
--------------------------------------------------------------------------------
1 | # base:
2 | # S_size: 4
3 | # T: 5
4 | # coefficients:
5 | # - 0
6 | # - 1
7 | # - -1
8 | # exp_dir: "exp"
9 | # exp_name: "debug"
10 | # device: "cuda"
11 |
12 | trainer:
13 | T: 5
14 | S_size: 4
15 | coefficients:
16 | - 0
17 | - 1
18 | - -1
19 | exp_dir: "exp"
20 | exp_name: "first_exp"
21 | device: "cuda"
22 |
23 | lr: 0.0001
24 | weight_decay: 0.00001
25 | gamma: .1
26 | step_size: 10000000
27 | iters_n: 3000000
28 | batch_size: 2048
29 | a_weight: .5
30 | v_weight: .5
31 | save_freq: 50000
32 | self_play_buffer: 100000
33 | self_play_freq: 10
34 | grad_clip: 4.0
35 | val_freq: 1000
36 |
37 | net:
38 | T: 5
39 | S_size: 4
40 | coefficients:
41 | - 0
42 | - 1
43 | - -1
44 | device: 'cuda'
45 |
46 | N_steps: 3
47 | N_samples: 32
48 | n_attentive: 4
49 | N_heads: 16
50 | N_features: 16
51 | channel: 3
52 | scalar_size: 3
53 | policy_layers: 2
54 | value_layers: 3
55 | inter_channel: 256
56 | out_channel: 8
57 |
58 | mcts:
59 | simulate_times: 400
60 |
61 | env:
62 | R_limit: 12
63 | T: 5
64 | S_size: 4
65 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | import argparse
3 |
4 | from codes.env import Environment
5 | from codes.mcts import MCTS
6 | from codes.net import Net
7 | from codes.trainer import Trainer, Player
8 |
9 | def parse():
10 | parser = argparse.ArgumentParser(description="OpenTensor")
11 | parser.add_argument('--config', type=str, default="./config/S_4.yaml")
12 | parser.add_argument('--mode', type=str, default="train", help="three modes: [generate_data, train, infer]")
13 | parser.add_argument('--resume', default=None, help="resume ckpt path")
14 | parser.add_argument('--run_dir', default="./exp/S4T7_selfplay/1685590684", help="The run dir to infer")
15 | args = parser.parse_args()
16 | return args
17 |
18 |
19 | if __name__ == '__main__':
20 |
21 | args = parse()
22 | conf_path = args.config
23 | mode = args.mode
24 | resume = args.resume
25 |
26 | with open(conf_path, 'r', encoding="utf-8") as f:
27 | kwargs = yaml.load(f.read(), Loader=yaml.FullLoader)
28 |
29 | net = Net(**kwargs["net"])
30 | mcts = MCTS(**kwargs["mcts"],
31 | init_state=None)
32 | env = Environment(**kwargs["env"],
33 | init_state=None)
34 | trainer = Trainer(**kwargs["trainer"],
35 | net=net, env=env, mcts=mcts,
36 | all_kwargs=kwargs)
37 |
38 | S_size = kwargs["env"]["S_size"]
39 | T = kwargs["env"]["T"]
40 | if mode == "generate_data":
41 | trainer.generate_synthetic_examples(samples_n=100000,
42 | save_path="./data/100000_S%dT%d_scalar3_filtered.npy" % (S_size, T))
43 |
44 | elif mode == "train":
45 | trainer.learn(resume=resume,
46 | example_path="./data/100000_S%dT%d_scalar3_filtered.npy" % (S_size, T),
47 | self_example_path=None)
48 |
49 | elif mode == "infer":
50 | self_play_net = Net(**kwargs["net"])
51 | player = Player(net=self_play_net,
52 | env=env,
53 | mcts=mcts,
54 | exp_dir=args.run_dir,
55 | simu_times=800,
56 | play_times=1,
57 | num_workers=64,
58 | device="cuda:0",
59 | noise=True)
60 | player.run() # Running forever...
--------------------------------------------------------------------------------
/record.txt:
--------------------------------------------------------------------------------
1 | 1. 网络预测动作的公式化表示
2 | 2. 预测时候输入全空张量
3 | 3. MCTS公式化表示
4 | 4. 网络输入时选择的scalars
5 | 5. MCTS时是否应该没有上限?
6 |
7 | need to check:
8 | 1. Causal mask?
9 | 2. 标准化的bug(已解决)
10 | 3. 释放MCTS内存?(越运行GPU内存越少)、(通过删除其他子节点?)(似乎也已解决)
11 | 4. examples reformulate
12 | 5. 基变换和随机符号置换
13 |
14 | To do:
15 | 1. 基变换
16 | 2. random permutation
17 | 3. data aug
18 | 4. self-play train (数据格式)
19 | 5. dataloader (暂时忽略)
20 | 6. config文件 (简单的已解决)
21 | *7. infer时采样的batch操作?(似乎已解决)
22 | 8. 输入scalar时应该如何输入?
23 | 9. 同时进行n棵树并行?(多线程)
24 | 10. transition table
25 | 11. 探索因子温度
26 |
27 |
28 | 4/7:
29 | 1. 修改了attentive module(in parallel)
30 | 2. 修改了文件排版格式
31 | 3. 添加了dataloader(速度未有明显提升)
32 | 跟数据formulate有关的都放在dataset里面!
33 | 4. 完成了随机符号置换(需要debug)
34 | 4/9:
35 | 1. 添加了val
36 | 2. 推测:是infer出了问题 -- Causal mask是否需要检查?? 采样是否要检查??position embedding??
37 | 3. 对照transformer检查代码
38 | 4/10:
39 | 1. 添加了启动子
40 | (N_steps+1, N_logits+1,用最后一位Logit来表示启动子;训练时对整个N_steps+1进行loss计算;infer时只对非启动子进行采样)
41 | 2. 修改了一个network bug (dropout)
42 | 3. dropout在train和infer之间的区别?
43 | 4/11:
44 | 1. 修改了LayerNorm
45 | 2. 在infer时,第一个是启动子,然后其他用0来补充
46 | 3. causal mask更改
47 | 4. Loss下降得太快,,,还是出现了g_action的泄露?
48 | 5. 去除了一些线性层的bias
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | --extra-index-url https://pypi.ngc.nvidia.com
2 |
3 | nvidia-cuda-runtime-cu11
4 | nvidia-cuda-cupti-cu11
5 | nvidia-cuda-nvcc-cu11
6 | nvidia-nvml-dev-cu11
7 | nvidia-cuda-nvrtc-cu11
8 | nvidia-nvtx-cu11
9 | nvidia-cuda-sanitizer-api-cu11
10 | nvidia-cublas-cu11
11 | nvidia-cufft-cu11
12 | nvidia-curand-cu11
13 | nvidia-cusolver-cu11
14 | nvidia-cusparse-cu11
15 | nvidia-npp-cu11
16 | nvidia-nvjpeg-cu11
--------------------------------------------------------------------------------