├── MERA ├── Framework.png ├── configs │ ├── config_transformer_moe_attn.yaml │ └── config_transformer_moe_vote.yaml ├── main_500.py ├── poster.pdf ├── readme.md ├── run.sh ├── src │ ├── custom_moe_layer.py │ ├── dataset.py │ ├── model_moe_attn.py │ ├── model_moe_vote.py │ ├── noisy_gate.py │ └── noisy_gate_vmoe.py └── vote.sh └── poster.pdf /MERA/Framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenchen1104/MERA/b614fa6747bffa0488468c8d2f136cf0ea940ab4/MERA/Framework.png -------------------------------------------------------------------------------- /MERA/configs/config_transformer_moe_attn.yaml: -------------------------------------------------------------------------------- 1 | qlib_init: 2 | provider_uri: "~/.qlib/qlib_data/cn_data" 3 | region: cn 4 | 5 | 6 | model_config: &model_config 7 | input_size: 351 8 | hidden_size: 128 9 | num_layers: 2 10 | num_heads: 4 11 | use_attn: False 12 | dropout: 0.3 13 | num_expert: 8 14 | topk: 4 15 | gate_dim: 16 16 | 17 | num_states: &num_states 1 18 | 19 | tra_config: &tra_config 20 | num_states: *num_states 21 | hidden_size: 64 22 | tau: 1.0 23 | src_info: LR_TPE 24 | 25 | task: 26 | model: 27 | class: TRAModel 28 | module_path: src/model_moe_attn.py 29 | kwargs: 30 | lr: 0.0001 31 | n_epochs: 500 32 | max_steps_per_epoch: 200 33 | early_stop: 20 34 | seed: 1 35 | logdir: output/500/attn/seed1 36 | model_type: Transformer 37 | model_config: *model_config 38 | tra_config: *tra_config 39 | lamb: 1.0 40 | rho: 0.99 41 | freeze_model: False 42 | model_init_state: -------------------------------------------------------------------------------- /MERA/configs/config_transformer_moe_vote.yaml: -------------------------------------------------------------------------------- 1 | qlib_init: 2 | provider_uri: "~/.qlib/qlib_data/cn_data" 3 | region: cn 4 | 5 | 6 | model_config: &model_config 7 | input_size: 351 8 | hidden_size: 128 9 | num_layers: 2 10 | num_heads: 4 11 | use_attn: False 12 | dropout: 0.3 13 | num_expert: 32 14 | topk: 8 15 | gate_dim: 16 16 | 17 | num_states: &num_states 1 18 | 19 | tra_config: &tra_config 20 | num_states: *num_states 21 | hidden_size: 64 22 | tau: 1.0 23 | src_info: LR_TPE 24 | 25 | task: 26 | model: 27 | class: TRAModel 28 | module_path: src/model_moe_vote.py 29 | kwargs: 30 | lr: 0.0001 31 | n_epochs: 500 32 | max_steps_per_epoch: 200 33 | early_stop: 20 34 | seed: 1 35 | logdir: output/500/vote/seed1 36 | model_type: Transformer 37 | model_config: *model_config 38 | tra_config: *tra_config 39 | lamb: 1.0 40 | rho: 0.99 41 | freeze_model: False 42 | model_init_state: -------------------------------------------------------------------------------- /MERA/main_500.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import qlib 4 | from ruamel.yaml import YAML 5 | from qlib.utils import init_instance_by_config 6 | from src.dataset import MinDataset,collate_fn 7 | import h5py 8 | import numpy as np 9 | import torch 10 | 11 | def main(seed, config_file="./configs/config_transformer.yaml"): 12 | # set random seed 13 | with open(config_file) as f: 14 | yaml = YAML(typ='safe', pure=True) 15 | config=yaml.load(f) 16 | 17 | # seed_suffix = "/seed1000" if "init" in config_file else f"/seed{seed}" 18 | seed_suffix = "" 19 | config["task"]["model"]["kwargs"].update( 20 | {"seed": seed, "logdir": config["task"]["model"]["kwargs"]["logdir"] + seed_suffix} 21 | ) 22 | 23 | # initialize workflow 24 | qlib.init( 25 | provider_uri=config["qlib_init"]["provider_uri"], 26 | region=config["qlib_init"]["region"], 27 | ) 28 | # dataset = init_instance_by_config(config["task"]["dataset"]) 29 | model = init_instance_by_config(config["task"]["model"]) 30 | 31 | f_feature = h5py.File('./500_processed/500_features.h5') 32 | f_label_raw = h5py.File('./500_processed/500_label_raw.h5') 33 | f_label_norm = h5py.File('./500_processed/500_label_norm.h5') 34 | f_similar = h5py.File('./500_processed/similars50.h5') 35 | 36 | features = {} 37 | similars = {} 38 | yraws = {} 39 | ynorms = {} 40 | dates = [] 41 | for key in f_feature.keys(): 42 | dates.append(key) 43 | feature = np.array(f_feature[key]) 44 | features[key] = feature 45 | 46 | similar = np.array(f_similar[key]) 47 | similars[key] = similar 48 | 49 | yraw = np.array(f_label_raw[key]) 50 | yraws[key] = yraw 51 | 52 | ynorm = np.array(f_label_norm[key]) 53 | ynorms[key] = ynorm 54 | 55 | train_dataset = MinDataset(dates=dates[:970], yraw=yraws, similar=similars, ynorm=ynorms, feature=features) 56 | valid_dataset = MinDataset(dates=dates[970:1116], yraw=yraws, similar=similars, ynorm=ynorms, feature=features) 57 | test_dataset = MinDataset(dates=dates[1116:],yraw=yraws, similar=similars, ynorm=ynorms, feature=features) 58 | 59 | model.fit(train_dataset, valid_dataset, test_dataset) 60 | # model.predict(test_dataset) 61 | 62 | 63 | if __name__ == "__main__": 64 | # set params from cmd 65 | parser = argparse.ArgumentParser(allow_abbrev=False) 66 | parser.add_argument("--seed", type=int, default=1000, help="random seed") 67 | parser.add_argument("--config_file", type=str, default="./configs/config_transformer.yaml", help="config file") 68 | args = parser.parse_args() 69 | main(**vars(args)) 70 | -------------------------------------------------------------------------------- /MERA/poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenchen1104/MERA/b614fa6747bffa0488468c8d2f136cf0ea940ab4/MERA/poster.pdf -------------------------------------------------------------------------------- /MERA/readme.md: -------------------------------------------------------------------------------- 1 | # [WWW2025: Mixture of Experts with Retrieval-Augmented Representation for Modeling Diversified Stock Patterns] 2 | 3 | This repository contains the official implementation of our WWW 2025 paper: 4 | [[Mixture of Experts with Retrieval-Augmented Representation for Modeling Diversified Stock Patterns]](https://dl.acm.org/doi/10.1145/3701716.3715513) 5 | 6 | ## 📌 Abstract 7 | 8 | Successful quantitative investment relies on accurate predictions of the future stock price. Deep learning-based solutions have recently demonstrated a superior ability to capture the intricate and nonlinear interactions among various market variables. However, most existing methods use the same parameters to fit all samples, without considering that the real stock market often exhibits multiple patterns. To alleviate this issue, we propose a novel module called Mixture of Experts with Retrieval-Augmented Representation (MERA). Essentially, MERA consists of a set of independent experts for differentiated modeling as well as a GateNet that dynamically allocates data of different patterns to the most suitable experts. The model backbone is responsible for learning the coarse-grained representations for all stock patterns. Then, each expert in the MERA module focuses on the specific pattern and performs a more fine-grained analysis. However, accurate data allocation remains challenging due to the lack of explicit pattern identifiers. To overcome this, MERA retrieves relevant samples using high-level representations from self-supervised pre-training. The label information of neighbor samples is promising discriminative signals to indicate the target stock pattern. Extensive experiments on real-world stock markets show significant improvements 9 | 10 | ## 📌 Overview 11 | 12 | Framework 13 | 14 | 15 | ## 📌 Citation 16 | 17 | ```bibtex 18 | @inproceedings{10.1145/3701716.3715513, 19 | author = {Liu, YuJun and Song, Chen-Hui and Liu, Peiyuan and Li, Naiqi and Dai, Tao and Bao, Jigang and Jiang, Yong and Xia, Shu-Tao}, 20 | title = {MERA: Mixture of Experts with Retrieval-Augmented Representation for Modeling Diversified Stock Patterns}, 21 | year = {2025}, 22 | booktitle = {Companion Proceedings of the ACM on Web Conference 2025}, 23 | pages = {1148–1152}, 24 | numpages = {5}, 25 | keywords = {mixture of expert, retrieval-augmented representation, stock prediction}, 26 | location = {Sydney NSW, Australia}, 27 | series = {WWW '25} 28 | } 29 | -------------------------------------------------------------------------------- /MERA/run.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=1 python main_500.py --config_file configs/config_transformer_moe_attn.yaml --seed 1 -------------------------------------------------------------------------------- /MERA/src/custom_moe_layer.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Adaption to act as the MLP layer using an MoE MLP layer in transformer. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | from fmoe.layers import FMoE, _fmoe_general_global_forward 7 | from fmoe.linear import FMoELinear 8 | from functools import partial 9 | import tree 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | from fmoe.functions import prepare_forward, ensure_comm 15 | from fmoe.functions import MOEScatter, MOEGather 16 | from fmoe.functions import AllGather, Slice 17 | from fmoe.gates import NaiveGate 18 | 19 | from src.noisy_gate import NoisyGate 20 | from src.noisy_gate_vmoe import NoisyGate_VMoE 21 | 22 | from pdb import set_trace 23 | import numpy as np 24 | 25 | 26 | def knn(x, k): 27 | inner = -2 * torch.matmul(x.transpose(2, 1), x) 28 | xx = torch.sum(x ** 2, dim=1, keepdim=True) 29 | pairwise_distance = -xx - inner - xx.transpose(2, 1) 30 | 31 | idx = pairwise_distance.topk(k=k, dim=-1)[1] 32 | return idx 33 | 34 | 35 | def get_graph_feature(x, k=20, idx=None): 36 | batch_size = x.size(0) 37 | num_points = x.size(2) 38 | x = x.view(batch_size, -1, num_points) 39 | if idx is None: 40 | idx = knn(x, k=k) 41 | device = torch.device('cuda') 42 | 43 | idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points 44 | 45 | idx = idx + idx_base # torch.Size([1, 64, 20]) 46 | 47 | idx = idx.view(-1) 48 | 49 | _, num_dims, _ = x.size() 50 | 51 | x = x.transpose(2, 1).contiguous() 52 | feature = x.view(batch_size * num_points, -1)[idx, :] # torch.Size([1280, 384]) 类似gather 53 | feature = feature.view(batch_size, num_points, k, num_dims) # torch.Size([1, 64, 20, 384]) 54 | x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1) 55 | 56 | feature = torch.cat((feature - x, x), dim=3).permute(0, 3, 1, 2).contiguous() # torch.Size([1, 768, 64, 20]) 57 | 58 | return feature 59 | 60 | 61 | class _Expert(nn.Module): 62 | r""" 63 | An expert using 2 FMoELinear modules to speed up the computation of experts 64 | within one worker. 65 | """ 66 | 67 | def __init__(self, num_expert, d_model, d_hidden, activation, rank=0): 68 | super().__init__() 69 | self.htoh4 = FMoELinear(num_expert, d_model, d_hidden, bias=True, rank=rank) 70 | self.h4toh = FMoELinear(num_expert, d_hidden, d_model, bias=True, rank=rank) 71 | self.activation = activation 72 | 73 | def forward(self, inp, fwd_expert_count): 74 | r""" 75 | First expand input to 4h (the hidden size is variable, but is called h4 76 | for convenience). Then perform activation. Finally shirink back to h. 77 | """ 78 | x = self.htoh4(inp, fwd_expert_count) 79 | x = self.activation(x) 80 | x = self.h4toh(x, fwd_expert_count) 81 | return x 82 | 83 | 84 | class Attention(nn.Module): 85 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 86 | super().__init__() 87 | self.num_heads = num_heads 88 | head_dim = dim // num_heads 89 | 90 | self.scale = qk_scale or head_dim ** -0.5 91 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 92 | self.attn_drop = nn.Dropout(attn_drop) 93 | self.proj = nn.Linear(dim, dim) 94 | self.proj_drop = nn.Dropout(proj_drop) 95 | 96 | def forward(self, x): 97 | B, N, C = x.shape 98 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 99 | q, k, v = qkv[0], qkv[1], qkv[2] 100 | 101 | attn = (q @ k.transpose(-2, -1)) * self.scale 102 | attn = attn.softmax(dim=-1) 103 | attn = self.attn_drop(attn) 104 | 105 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 106 | x = self.proj(x) 107 | x = self.proj_drop(x) 108 | return x 109 | 110 | 111 | class Mlp(nn.Module): 112 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., norm_layer= partial(nn.LayerNorm, eps=1e-6)): 113 | super().__init__() 114 | # out_features = out_features or in_features 115 | # hidden_features = hidden_features or in_features 116 | self.fc1 = nn.Linear(in_features, hidden_features) 117 | self.act = act_layer() 118 | self.fc2 = nn.Linear(hidden_features, out_features) 119 | self.drop = nn.Dropout(drop) 120 | self.norm = norm_layer(out_features) 121 | 122 | def forward(self, x): 123 | x = self.fc1(x) 124 | x = self.act(x) 125 | x = self.drop(x) 126 | x = self.fc2(x) 127 | x = self.drop(x) 128 | x = self.norm(x) 129 | return x 130 | 131 | 132 | class SelfAttention(nn.Module): 133 | def __init__(self, input_dim, hidden_dim): 134 | super(SelfAttention, self).__init__() 135 | self.input_dim = input_dim 136 | self.hidden_dim = hidden_dim 137 | self.query = nn.Linear(input_dim, hidden_dim) 138 | self.key = nn.Linear(input_dim, hidden_dim) 139 | self.value = nn.Linear(input_dim, hidden_dim) 140 | 141 | def forward(self, x): 142 | # 计算查询、键和值 143 | query = self.query(x) 144 | key = self.key(x) 145 | value = self.value(x) 146 | 147 | # 计算注意力得分 148 | scores = torch.matmul(query, key.transpose(-2, -1)) 149 | scores = scores / torch.sqrt(torch.tensor(self.hidden_dim).float()) 150 | attention_weights = torch.softmax(scores, dim=-1) 151 | # 计算加权和 152 | weighted_values = torch.matmul(attention_weights, value) 153 | output = torch.sum(weighted_values, dim=1) # 对第二维进行求和,得到32*hidden_dim的表示 154 | 155 | return output, attention_weights 156 | 157 | 158 | 159 | class FMoETransformerMLP(FMoE): 160 | r""" 161 | A complete MoE MLP module in a Transformer block. 162 | * `activation` is the activation function to be used in MLP in each expert. 163 | * `d_hidden` is the dimension of the MLP layer. 164 | """ 165 | 166 | def __init__( 167 | self, 168 | num_expert=32, 169 | d_model=1024, 170 | d_gate=1024, 171 | d_hidden=4096, 172 | activation=torch.nn.GELU(), 173 | expert_dp_comm="none", 174 | expert_rank=0, 175 | gate=NaiveGate, 176 | world_size=1, 177 | top_k=2, 178 | vmoe_noisy_std=1, 179 | gate_return_decoupled_activation=False, 180 | gate_task_specific_dim=-1, 181 | multi_gate=False, 182 | regu_experts_fromtask = False, 183 | num_experts_pertask = -1, 184 | num_tasks = -1, 185 | regu_sem = False, 186 | sem_force = False, 187 | regu_subimage = False, 188 | expert_prune = False, 189 | prune_threshold = 0.1, 190 | **kwargs 191 | ): 192 | super().__init__(num_expert=num_expert, d_model=d_model, gate=gate, world_size=world_size, top_k=top_k, **kwargs) 193 | self.our_d_gate = d_gate 194 | self.our_d_model = d_model 195 | 196 | self.num_expert = num_expert 197 | self.regu_experts_fromtask = regu_experts_fromtask 198 | self.num_experts_pertask = num_experts_pertask 199 | self.num_tasks = num_tasks 200 | self.regu_sem = regu_sem 201 | self.sem_force = sem_force 202 | self.regu_subimage = regu_subimage 203 | self.expert_prune = expert_prune 204 | self.prune_threshold = prune_threshold 205 | if self.sem_force: 206 | self.force_id=[[0],[1,17,18,19,20],[2,12,13,14,15,16],[3,9,10,11],[4,5],[6,7,8,38],[21,22,23,24,25,26,39],[27,28,29,30,31,32,33,34,35,36,37]] 207 | if self.regu_experts_fromtask: 208 | self.start_experts_id=[] 209 | start_id = 0 210 | for i in range(self.num_tasks): 211 | start_id = start_id + int(i* (self.num_expert-self.num_experts_pertask)/(self.num_tasks-1)) 212 | self.start_experts_id.append(start_id) 213 | print('self.start_experts_id',self.start_experts_id) 214 | 215 | # self.experts = _Expert( 216 | # num_expert, d_model, d_hidden, activation, rank=expert_rank 217 | # ) 218 | self.e=nn.GRU( 219 | input_size=d_model, 220 | hidden_size=d_model, 221 | num_layers=2, 222 | batch_first=True, 223 | ) 224 | self.experts = nn.ModuleList([self.e for i in range(num_expert)]) 225 | # self.experts = nn.ModuleList([nn.GRU( 226 | # input_size=d_model, 227 | # hidden_size=d_model, 228 | # num_layers=2, 229 | # batch_first=True, 230 | # ) for i in range(num_expert)]) 231 | 232 | self.gate_task_specific_dim = gate_task_specific_dim 233 | self.multi_gate = multi_gate 234 | 235 | print('multi_gate',self.multi_gate) 236 | if gate == NoisyGate: 237 | if self.multi_gate: 238 | self.gate = nn.ModuleList([ 239 | gate(d_gate, num_expert, world_size, top_k, 240 | return_decoupled_activation=gate_return_decoupled_activation, regu_experts_fromtask = self.regu_experts_fromtask, 241 | num_experts_pertask = self.num_experts_pertask,num_tasks = self.num_tasks, regu_sem=self.regu_sem,sem_force = self.sem_force) 242 | for i in range(self.our_d_gate-self.our_d_model)]) 243 | else: 244 | self.gate = gate(d_gate, num_expert, world_size, top_k, 245 | return_decoupled_activation=gate_return_decoupled_activation, regu_experts_fromtask = self.regu_experts_fromtask, 246 | num_experts_pertask = self.num_experts_pertask,num_tasks = self.num_tasks) 247 | elif gate == NoisyGate_VMoE: 248 | if self.multi_gate: 249 | self.gate = nn.ModuleList([ 250 | gate(d_gate, num_expert, world_size, top_k, 251 | return_decoupled_activation=gate_return_decoupled_activation, 252 | noise_std=vmoe_noisy_std,regu_experts_fromtask = self.regu_experts_fromtask, 253 | num_experts_pertask=self.num_experts_pertask, num_tasks=self.num_tasks,regu_sem=self.regu_sem,sem_force = self.sem_force, regu_subimage=self.regu_subimage) 254 | for i in range(self.our_d_gate-self.our_d_model)]) 255 | else: 256 | self.gate = gate(d_gate, num_expert, world_size, top_k, 257 | return_decoupled_activation=gate_return_decoupled_activation, 258 | noise_std=vmoe_noisy_std,regu_experts_fromtask = self.regu_experts_fromtask, 259 | num_experts_pertask = self.num_experts_pertask, num_tasks = self.num_tasks,regu_sem=self.regu_sem,sem_force = self.sem_force, regu_subimage=self.regu_subimage) 260 | 261 | else: 262 | raise ValueError("No such gating type") 263 | self.mark_parallel_comm(expert_dp_comm) 264 | 265 | self.count = [0]*num_expert 266 | self.score = torch.zeros(1, num_expert).cuda() 267 | 268 | 269 | def my_expert_fn(self, inp, fwd_expert_count): 270 | 271 | if isinstance(fwd_expert_count, torch.Tensor): 272 | fwd_expert_count_cpu = fwd_expert_count.cpu().numpy() 273 | outputs = [] 274 | base_idx = 0 275 | for i in range(self.num_expert): 276 | batch_size = fwd_expert_count_cpu[i] 277 | if batch_size == 0: 278 | continue 279 | inp_slice = inp[base_idx : base_idx + batch_size] 280 | out,_=self.experts[i](inp_slice) 281 | # out=self.experts[i](inp_slice) 282 | outputs.append(out) 283 | base_idx += batch_size 284 | return torch.cat(outputs, dim=0) 285 | 286 | 287 | def forward(self, inp: torch.Tensor, src_mask=None, is_causal=None, src_key_padding_mask=None, gate_inp=None, task_id = None, task_specific_feature = None, sem=None): 288 | r""" 289 | This module wraps up the FMoE module with reshape, residual and layer 290 | normalization. 291 | """ 292 | if (task_id is not None) and (task_specific_feature is not None): 293 | assert self.multi_gate is False 294 | size = gate_inp.shape[0] 295 | gate_inp = torch.cat((gate_inp,task_specific_feature.repeat(size,1)),dim=-1) 296 | output = self.forward_moe(gate_inp=gate_inp, moe_inp=inp, task_id=task_id, sem=sem) 297 | return output 298 | 299 | 300 | def forward_moe(self, gate_inp, moe_inp, task_id=None, sem=None): 301 | r""" 302 | The FMoE module first computes gate output, and then conduct MoE forward 303 | according to the gate. The score of the selected gate given by the 304 | expert is multiplied to the experts' output tensors as a weight. 305 | """ 306 | moe_inp_batch_size = tree.flatten(tree.map_structure(lambda tensor: tensor.shape[0], moe_inp)) 307 | assert all( 308 | [batch_size == moe_inp_batch_size[0] for batch_size in moe_inp_batch_size] 309 | ), "MoE inputs must have the same batch size" 310 | 311 | if self.world_size > 1: 312 | def ensure_comm_func(tensor): 313 | ensure_comm(tensor, self.moe_group) 314 | tree.map_structure(ensure_comm_func, moe_inp) 315 | tree.map_structure(ensure_comm_func, gate_inp) 316 | if self.slice_size > 1: 317 | def slice_func(tensor): 318 | return Slice.apply(tensor, self.slice_rank, self.slice_size, self.slice_group) 319 | moe_inp = tree.map_structure(slice_func, moe_inp) 320 | 321 | if (task_id is not None) and self.multi_gate: 322 | # print('in custom moe_layer,task_id',task_id) 323 | gate_top_k_idx, gate_score = self.gate[task_id](gate_inp) 324 | else: 325 | gate_top_k_idx, gate_score = self.gate(gate_inp) 326 | # print(gate_top_k_idx) 327 | # print(gate_score) 328 | # self.score += torch.sum(gate_score, dim=0, keepdim=True) 329 | 330 | # 统计expert的使用频率 331 | # from collections import Counter 332 | # counts = Counter(gate_top_k_idx.reshape(-1,1)) 333 | # for num, count in counts.items(): 334 | # self.count[num]+=count 335 | 336 | if self.expert_prune: 337 | gate_score = torch.where(gate_score>self.prune_threshold,gate_score,0.) 338 | prune_prob = 1-torch.nonzero(gate_score).shape[0]/torch.cumprod(torch.tensor(gate_score.shape),dim=0)[-1] 339 | print('prune_prob',prune_prob) 340 | 341 | if self.sem_force and (sem is not None): 342 | batch = sem.shape[0] 343 | gate_top_k_idx = gate_top_k_idx.reshape(batch,-1,self.top_k) 344 | sem = sem.reshape(batch,-1) 345 | for k in range(batch): 346 | for i in range(sem.shape[-1]): 347 | for j in range(len(self.force_id)): 348 | if sem[k,i] in self.force_id[j]: 349 | gate_top_k_idx[k,i+1,:]=[j*2,j*2+1] 350 | gate_top_k_idx = gate_top_k_idx.reshape(-1,self.top_k) 351 | gate_score = torch.ones((gate_score.shape[0],self.top_k),device=gate_score.device)*0.5 352 | 353 | if self.regu_experts_fromtask and (task_id is not None): 354 | # print('task_id',self.start_experts_id[task_id],task_id) 355 | gate_top_k_idx = gate_top_k_idx + self.start_experts_id[task_id] 356 | 357 | if self.gate_hook is not None: 358 | self.gate_hook(gate_top_k_idx, gate_score, None) 359 | 360 | # delete masked tensors 361 | if self.mask is not None and self.mask_dict is not None: 362 | # TODO: to fix 363 | def delete_mask_func(tensor): 364 | # to: (BxL') x d_model 365 | tensor = tensor[mask == 0, :] 366 | return tensor 367 | 368 | mask = self.mask.view(-1) 369 | moe_inp = tree.map_structure(delete_mask_func, moe_inp) 370 | gate_top_k_idx = gate_top_k_idx[mask == 0, :] 371 | 372 | # fwd = _fmoe_general_global_forward(moe_inp, gate_top_k_idx, self.expert_fn, self.num_expert, self.world_size) 373 | fwd = _fmoe_general_global_forward(moe_inp, gate_top_k_idx, self.my_expert_fn, self.num_expert, self.world_size) 374 | 375 | # recover deleted tensors 376 | if self.mask is not None and self.mask_dict is not None: 377 | def recover_func(tensor): 378 | # to: (BxL') x top_k x dim 379 | dim = tensor.shape[-1] 380 | tensor = tensor.view(-1, self.top_k, dim) 381 | # to: (BxL) x top_k x d_model 382 | x = torch.zeros( 383 | mask.shape[0], 384 | self.top_k, 385 | dim, 386 | device=tensor.device, 387 | dtype=tensor.dtype, 388 | ) 389 | # recover 390 | x[mask == 0] = tensor 391 | for k, v in self.mask_dict.items(): 392 | x[mask == k] = v 393 | return x 394 | moe_outp = tree.map_structure(recover_func, fwd) 395 | else: 396 | def view_func(tensor): 397 | tensor = tensor.view(moe_inp_batch_size[0], self.top_k, tensor.shape[1], tensor.shape[2]) 398 | return tensor 399 | moe_outp = tree.map_structure(view_func, fwd) 400 | 401 | gate_score = gate_score.view(-1, self.top_k, 1, 1) 402 | 403 | moe_outp = torch.sum(moe_outp*gate_score, dim=1) 404 | 405 | if self.slice_size > 1: 406 | 407 | def all_gather_func(tensor): 408 | return AllGather.apply(tensor, self.slice_rank, self.slice_size, self.slice_group) 409 | 410 | moe_outp = tree.map_structure(all_gather_func, moe_outp) 411 | 412 | moe_outp_batch_size = tree.flatten(tree.map_structure(lambda tensor: tensor.shape[0], moe_outp)) 413 | assert all( 414 | [batch_size == moe_outp_batch_size[0] for batch_size in moe_outp_batch_size] 415 | ), "MoE outputs must have the same batch size" 416 | return moe_outp 417 | -------------------------------------------------------------------------------- /MERA/src/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset 4 | 5 | device = "cuda" if torch.cuda.is_available() else "cpu" 6 | 7 | def collate_fn(batch): 8 | date = [] 9 | feature = [] 10 | similar = [] 11 | yraw = [] 12 | ynorm = [] 13 | 14 | for sample in batch: 15 | date += sample["date"] 16 | feature.extend(sample["feature"]) 17 | similar.extend(sample["similar"]) 18 | yraw.extend(sample["yraw"]) 19 | ynorm.extend(sample["ynorm"]) 20 | 21 | return torch.FloatTensor(np.stack(feature, axis=0)), torch.FloatTensor(np.stack(similar, axis=0)), torch.FloatTensor(np.stack(yraw, axis=0)), torch.FloatTensor(np.stack(ynorm, axis=0)), date 22 | 23 | 24 | class MinDataset(Dataset): 25 | def __init__(self, dates, yraw, ynorm, feature, similar): 26 | self.dates = dates 27 | self.yraw = yraw 28 | self.ynorm = ynorm 29 | self.feature = feature 30 | self.similar = similar 31 | 32 | def __len__(self): 33 | return len(self.dates) 34 | 35 | def __getitem__(self, idx): 36 | date = self.dates[idx] 37 | feature = torch.FloatTensor(self.feature[date]) 38 | similar = torch.FloatTensor(self.similar[date]) 39 | yraw = torch.FloatTensor(self.yraw[date]) 40 | ynorm = torch.FloatTensor(self.ynorm[date]) 41 | 42 | return { 43 | "date": [date for _ in range(len(yraw))], 44 | "feature": feature, 45 | "similar": similar, 46 | "yraw": yraw, 47 | "ynorm": ynorm, 48 | } 49 | -------------------------------------------------------------------------------- /MERA/src/model_moe_attn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import os 5 | import copy 6 | import math 7 | import json 8 | import collections 9 | import numpy as np 10 | import pandas as pd 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.optim as optim 15 | import torch.nn.functional as F 16 | 17 | from tqdm import tqdm 18 | from torch.utils.data import DataLoader, random_split 19 | 20 | from qlib.utils import get_or_create_path 21 | from qlib.log import get_module_logger 22 | from qlib.model.base import Model 23 | 24 | from src.custom_moe_layer import FMoETransformerMLP 25 | from src.noisy_gate import NoisyGate 26 | from src.noisy_gate_vmoe import NoisyGate_VMoE 27 | 28 | from src.dataset import collate_fn 29 | 30 | device = "cuda" if torch.cuda.is_available() else "cpu" 31 | 32 | 33 | class TRAModel(Model): 34 | def __init__( 35 | self, 36 | model_config, 37 | tra_config, 38 | model_type="LSTM", 39 | lr=1e-3, 40 | n_epochs=500, 41 | early_stop=50, 42 | smooth_steps=5, 43 | max_steps_per_epoch=None, 44 | freeze_model=False, 45 | model_init_state=None, 46 | lamb=0.0, 47 | rho=0.99, 48 | seed=None, 49 | logdir=None, 50 | eval_train=True, 51 | eval_test=True, 52 | avg_params=True, 53 | **kwargs, 54 | ): 55 | np.random.seed(seed) 56 | torch.manual_seed(seed) 57 | 58 | self.logger = get_module_logger("TRA") 59 | self.logger.info("TRA Model...") 60 | 61 | self.model = eval(model_type)(**model_config).to(device) 62 | if model_init_state: 63 | self.model.load_state_dict(torch.load(model_init_state, map_location="cuda")["model"],strict=False) 64 | if freeze_model: 65 | # for param in self.model.parameters(): 66 | # param.requires_grad_(False) 67 | for name, param in self.model.named_parameters(): 68 | if 'experts' in name or 'gate' in name or 'bn' in name or 'norm' in name or 'input_proj' in name or 'embedding' in name: 69 | print(name) 70 | param.requires_grad_(True) 71 | else: 72 | param.requires_grad_(False) 73 | 74 | else: 75 | self.logger.info("# model params: %d" % sum([p.numel() for p in self.model.parameters()])) 76 | 77 | self.tra = TRA(self.model.output_size, **tra_config).to(device) 78 | self.logger.info("# tra params: %d" % sum([p.numel() for p in self.tra.parameters()])) 79 | 80 | self.optimizer = optim.Adam(list(self.model.parameters()) + list(self.tra.parameters()), lr=lr) 81 | 82 | self.model_config = model_config 83 | self.tra_config = tra_config 84 | self.lr = lr 85 | self.n_epochs = n_epochs 86 | self.early_stop = early_stop 87 | self.smooth_steps = smooth_steps 88 | self.max_steps_per_epoch = max_steps_per_epoch 89 | self.lamb = lamb 90 | self.rho = rho 91 | self.seed = seed 92 | self.logdir = logdir 93 | self.eval_train = eval_train 94 | self.eval_test = eval_test 95 | self.avg_params = avg_params 96 | 97 | if self.tra.num_states > 1 and not self.eval_train: 98 | self.logger.warn("`eval_train` will be ignored when using TRA") 99 | 100 | if self.logdir is not None: 101 | if os.path.exists(self.logdir): 102 | self.logger.warn(f"logdir {self.logdir} is not empty") 103 | os.makedirs(self.logdir, exist_ok=True) 104 | 105 | self.fitted = False 106 | self.global_step = -1 107 | 108 | def collect_noisy_gating_loss(self, model, weight): 109 | loss = 0 110 | for module in model.modules(): 111 | # print(module) 112 | if (isinstance(module, NoisyGate) or isinstance(module, NoisyGate_VMoE)) and module.has_loss: 113 | loss += module.get_loss() 114 | return loss * weight 115 | 116 | def train_epoch(self, data_set): 117 | self.model.train() 118 | self.tra.train() 119 | 120 | count = 0 121 | total_loss = 0 122 | total_count = 0 123 | 124 | for i, batch in enumerate(data_set): 125 | 126 | self.global_step += 1 127 | 128 | feature, similar, label_raw, label_norm, index = batch 129 | 130 | feature = feature.to(device) 131 | similar = similar.to(device) 132 | label_raw = label_raw.to(device) 133 | label_norm = label_norm.to(device) 134 | 135 | feature = feature.permute(0, 2, 1) 136 | hidden = self.model(feature, similar) 137 | pred, all_preds, prob = self.tra(hidden) 138 | 139 | mask = ~torch.isnan(label_norm) 140 | 141 | # 过滤掉NaN值,只保留非NaN值的预测值和标签值 142 | filtered_pred = pred[mask] 143 | filtered_label_norm = label_norm[mask] 144 | 145 | loss = (filtered_pred - filtered_label_norm).pow(2).mean() 146 | # loss = (pred - label_norm).pow(2).mean() 147 | 148 | gate_loss = self.collect_noisy_gating_loss(self.model, 0.01) 149 | loss += gate_loss 150 | loss.backward() 151 | self.optimizer.step() 152 | self.optimizer.zero_grad() 153 | 154 | total_loss += loss.item() 155 | total_count += len(pred) 156 | 157 | total_loss /= total_count 158 | print(total_loss) 159 | return total_loss 160 | 161 | def test_epoch(self, data_set, return_pred=False): 162 | self.model.eval() 163 | self.tra.eval() 164 | 165 | preds = [] 166 | metrics = [] 167 | for i, batch in enumerate(data_set): 168 | # print(i) 169 | feature, similar, label_raw, label_norm, index = batch 170 | 171 | feature = feature.to(device) 172 | similar=similar.to(device) 173 | label_raw = label_raw.to(device) 174 | label_norm = label_norm.to(device) 175 | 176 | feature = feature.permute(0, 2, 1) 177 | 178 | with torch.no_grad(): 179 | hidden = self.model(feature, similar) 180 | pred, all_preds, prob = self.tra(hidden) 181 | X = np.c_[ 182 | pred.cpu().numpy(), 183 | label_raw.cpu().numpy(), 184 | ] 185 | 186 | columns = ["score", "label"] 187 | pred = pd.DataFrame(X, index = index, columns = columns) 188 | 189 | metrics.append(evaluate(pred)) 190 | 191 | if return_pred: 192 | preds.append(pred) 193 | 194 | metrics = pd.DataFrame(metrics) 195 | metrics = { 196 | "MSE": metrics.MSE.mean(), 197 | "MAE": metrics.MAE.mean(), 198 | "IC": metrics.IC.mean(), 199 | "ICIR": metrics.IC.mean() / metrics.IC.std(), 200 | } 201 | 202 | if return_pred: 203 | preds = pd.concat(preds, axis=0) 204 | preds.sort_index(inplace=True) 205 | 206 | return metrics, preds 207 | 208 | def fit(self, train_dataset, val_dataset, test_dataset, evals_result=dict()): 209 | 210 | train_loader = DataLoader(train_dataset, batch_size = 2, collate_fn = collate_fn, shuffle = True) 211 | val_loader = DataLoader(val_dataset, batch_size = 1, collate_fn = collate_fn, shuffle = False) 212 | test_loader = DataLoader(test_dataset, batch_size = 1, collate_fn = collate_fn, shuffle = False) 213 | 214 | best_score = -1 215 | best_epoch = 0 216 | stop_rounds = 0 217 | best_params = { 218 | "model": copy.deepcopy(self.model.state_dict()), 219 | "tra": copy.deepcopy(self.tra.state_dict()), 220 | } 221 | params_list = { 222 | "model": collections.deque(maxlen=self.smooth_steps), 223 | "tra": collections.deque(maxlen=self.smooth_steps), 224 | } 225 | evals_result["train"] = [] 226 | evals_result["valid"] = [] 227 | evals_result["test"] = [] 228 | 229 | # train 230 | self.fitted = True 231 | self.global_step = -1 232 | 233 | for epoch in range(self.n_epochs): 234 | self.logger.info("Epoch %d:", epoch) 235 | 236 | self.logger.info("training...") 237 | self.train_epoch(train_loader) 238 | 239 | self.logger.info("evaluating...") 240 | # average params for inference 241 | params_list["model"].append(copy.deepcopy(self.model.state_dict())) 242 | params_list["tra"].append(copy.deepcopy(self.tra.state_dict())) 243 | self.model.load_state_dict(average_params(params_list["model"])) 244 | self.tra.load_state_dict(average_params(params_list["tra"])) 245 | 246 | # NOTE: during evaluating, the whole memory will be refreshed 247 | if self.tra.num_states > 1 or self.eval_train: 248 | train_metrics = self.test_epoch(train_loader)[0] 249 | evals_result["train"].append(train_metrics) 250 | self.logger.info("\ttrain metrics: %s" % train_metrics) 251 | 252 | valid_metrics = self.test_epoch(val_loader)[0] 253 | evals_result["valid"].append(valid_metrics) 254 | self.logger.info("\tvalid metrics: %s" % valid_metrics) 255 | 256 | if self.eval_test: 257 | test_metrics = self.test_epoch(test_loader)[0] 258 | evals_result["test"].append(test_metrics) 259 | self.logger.info("\ttest metrics: %s" % test_metrics) 260 | 261 | if valid_metrics["IC"] > best_score: 262 | best_score = valid_metrics["IC"] 263 | stop_rounds = 0 264 | best_epoch = epoch 265 | best_params = { 266 | "model": copy.deepcopy(self.model.state_dict()), 267 | "tra": copy.deepcopy(self.tra.state_dict()), 268 | } 269 | else: 270 | stop_rounds += 1 271 | if stop_rounds >= self.early_stop: 272 | self.logger.info("early stop @ %s" % epoch) 273 | break 274 | 275 | # restore parameters 276 | self.model.load_state_dict(params_list["model"][-1]) 277 | self.tra.load_state_dict(params_list["tra"][-1]) 278 | 279 | self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch)) 280 | self.model.load_state_dict(best_params["model"]) 281 | self.tra.load_state_dict(best_params["tra"]) 282 | 283 | metrics, preds = self.test_epoch(test_loader, return_pred=True) 284 | self.logger.info("test metrics: %s" % metrics) 285 | 286 | if self.logdir: 287 | self.logger.info("save model & pred to local directory") 288 | 289 | pd.concat({name: pd.DataFrame(evals_result[name]) for name in evals_result}, axis=1).to_csv( 290 | self.logdir + "/logs.csv", index=False 291 | ) 292 | 293 | torch.save(best_params, self.logdir + "/model.bin") 294 | 295 | preds.to_pickle(self.logdir + "/pred.pkl") 296 | 297 | info = { 298 | "config": { 299 | "model_config": self.model_config, 300 | "tra_config": self.tra_config, 301 | "lr": self.lr, 302 | "n_epochs": self.n_epochs, 303 | "early_stop": self.early_stop, 304 | "smooth_steps": self.smooth_steps, 305 | "lamb": self.lamb, 306 | "rho": self.rho, 307 | "seed": self.seed, 308 | "logdir": self.logdir, 309 | }, 310 | "best_eval_metric": -best_score, # NOTE: minux -1 for minimize 311 | "metric": metrics, 312 | } 313 | with open(self.logdir + "/info.json", "w") as f: 314 | json.dump(info, f) 315 | 316 | def predict(self, test_dataset, model_path): 317 | test_loader = DataLoader(test_dataset, batch_size = 1, collate_fn = collate_fn, shuffle = False) 318 | best_params = torch.load(model_path) 319 | self.model.load_state_dict(best_params["model"]) 320 | self.tra.load_state_dict(best_params["tra"]) 321 | metrics, preds = self.test_epoch(test_loader, return_pred=True) 322 | print(self.model.moe.score) 323 | self.logger.info("test metrics: %s" % metrics) 324 | return preds 325 | 326 | 327 | class PositionalEncoding(nn.Module): 328 | # reference: https://pytorch.org/tutorials/beginner/transformer_tutorial.html 329 | def __init__(self, d_model, dropout=0.1, max_len=5000): 330 | super(PositionalEncoding, self).__init__() 331 | self.dropout = nn.Dropout(p=dropout) 332 | 333 | pe = torch.zeros(max_len, d_model) 334 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 335 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 336 | pe[:, 0::2] = torch.sin(position * div_term) 337 | pe[:, 1::2] = torch.cos(position * div_term) 338 | pe = pe.unsqueeze(0).transpose(0, 1) 339 | self.register_buffer("pe", pe) 340 | self.pe = self.pe.to(device) 341 | 342 | def forward(self, x): 343 | x = x + self.pe[: x.size(0), :] 344 | return self.dropout(x) 345 | 346 | 347 | class Transformer(nn.Module): 348 | 349 | def __init__( 350 | self, 351 | input_size=16, 352 | hidden_size=64, 353 | num_layers=2, 354 | num_heads=2, 355 | dropout=0.0, 356 | input_drop=0.0, 357 | noise_level=0.0, 358 | topk=1, 359 | num_expert=4, 360 | gate_dim=16, 361 | moe_gate_type='noisy_vmoe', 362 | vmoe_noisy_std=1, 363 | **kwargs, 364 | ): 365 | super().__init__() 366 | 367 | self.input_size = input_size 368 | self.hidden_size = hidden_size 369 | self.num_layers = num_layers 370 | self.num_heads = num_heads 371 | self.noise_level = noise_level 372 | self.gate_dim=gate_dim 373 | 374 | # self.input_drop = nn.Dropout(input_drop) 375 | 376 | self.input_proj = nn.Linear(input_size, hidden_size) 377 | 378 | self.pe = PositionalEncoding(hidden_size, dropout) 379 | 380 | if moe_gate_type == "noisy": 381 | moe_gate_fun = NoisyGate 382 | elif moe_gate_type == "noisy_vmoe": 383 | moe_gate_fun = NoisyGate_VMoE 384 | else: 385 | raise ValueError("unknow gate type of {}".format(moe_gate_type)) 386 | 387 | act_layer=nn.GELU 388 | activation = nn.Sequential( 389 | act_layer(), 390 | nn.Dropout(dropout) 391 | ) 392 | 393 | blocks = [] 394 | for i in range(num_layers): 395 | blocks.append(nn.TransformerEncoderLayer(nhead = num_heads, dropout = dropout, d_model = hidden_size, dim_feedforward = hidden_size * 4, norm_first=True)) 396 | 397 | self.moe = FMoETransformerMLP(num_expert = num_expert, d_model = hidden_size, d_gate = gate_dim, gate = moe_gate_fun, 398 | top_k = topk, activation = activation, vmoe_noisy_std=vmoe_noisy_std, expert_prune=False) 399 | 400 | self.encoder = nn.Sequential(*blocks) 401 | 402 | self.output_size = hidden_size 403 | # self.norm = nn.LayerNorm(hidden_size) 404 | # self.moe_norm = nn.LayerNorm(hidden_size//8) 405 | self.bn = nn.BatchNorm1d(input_size) 406 | 407 | self.embedding = nn.Embedding(10, gate_dim) 408 | 409 | def forward(self, x, similars): 410 | shape = x.shape 411 | x = x.reshape(-1,self.input_size) 412 | x = self.bn(x) 413 | x = x.reshape(shape) 414 | 415 | x = x.permute(1, 0, 2).contiguous() # the first dim need to be sequence 416 | 417 | x = self.input_proj(x) 418 | x = self.pe(x) 419 | 420 | for i, layer in enumerate(self.encoder): 421 | x = layer(x) 422 | 423 | out = x 424 | x = x[-1] 425 | x = x.unsqueeze(1) 426 | 427 | similar_features = similars[:,:,:-2] 428 | attn_weights_feature = F.softmax(torch.bmm(x, similar_features.transpose(1, 2)), dim=-1) 429 | 430 | similar_label = similars[:,:,-2] 431 | 432 | similar_label = self.embedding(similar_label.int()) 433 | weighted_similar_label = torch.bmm(attn_weights_feature, similar_label) # torch.Size([bs, 1, 128]) 434 | weighted_similar_label = weighted_similar_label.squeeze(1) 435 | output = self.moe(inp = out.permute(1, 0, 2), gate_inp = weighted_similar_label) 436 | 437 | return output[:,-1] 438 | 439 | class TRA(nn.Module): 440 | 441 | """Temporal Routing Adaptor (TRA) 442 | 443 | TRA takes historical prediction errors & latent representation as inputs, 444 | then routes the input sample to a specific predictor for training & inference. 445 | 446 | Args: 447 | input_size (int): input size (RNN/Transformer's hidden size) 448 | num_states (int): number of latent states (i.e., trading patterns) 449 | If `num_states=1`, then TRA falls back to traditional methods 450 | hidden_size (int): hidden size of the router 451 | tau (float): gumbel softmax temperature 452 | """ 453 | 454 | def __init__(self, input_size, num_states=1, hidden_size=8, tau=1.0, src_info="LR_TPE"): 455 | super().__init__() 456 | 457 | self.num_states = num_states 458 | self.tau = tau 459 | self.src_info = src_info 460 | 461 | if num_states > 1: 462 | self.router = nn.LSTM( 463 | input_size=num_states, 464 | hidden_size=hidden_size, 465 | num_layers=1, 466 | batch_first=True, 467 | ) 468 | self.fc = nn.Linear(hidden_size + input_size, num_states) 469 | 470 | self.predictors1 = nn.Linear(input_size, hidden_size) 471 | self.predictors2 = nn.Linear(hidden_size, num_states) 472 | self.relu=nn.ReLU() 473 | 474 | def forward(self, hidden, hist_loss=None): 475 | preds = self.predictors2(self.relu(self.predictors1(hidden))) 476 | return preds.squeeze(-1), preds, None 477 | 478 | 479 | def evaluate(pred): 480 | # pred = pred.rank(pct=True) # transform into percentiles 481 | pred = pred.dropna(subset=['label']) 482 | score = pred.score 483 | label = pred.label 484 | diff = score - label 485 | MSE = (diff**2).mean() 486 | MAE = (diff.abs()).mean() 487 | IC = score.corr(label) 488 | # return {"MSE": MSE.astype(np.float64), "MAE": MAE.astype(np.float64), "IC": IC.astype(np.float64)} 489 | return {"MSE": MSE, "MAE": MAE, "IC": IC} 490 | 491 | 492 | def average_params(params_list): 493 | assert isinstance(params_list, (tuple, list, collections.deque)) 494 | n = len(params_list) 495 | if n == 1: 496 | return params_list[0] 497 | new_params = collections.OrderedDict() 498 | keys = None 499 | for i, params in enumerate(params_list): 500 | if keys is None: 501 | keys = params.keys() 502 | for k, v in params.items(): 503 | if k not in keys: 504 | raise ValueError("the %d-th model has different params" % i) 505 | if k not in new_params: 506 | new_params[k] = v / n 507 | else: 508 | new_params[k] += v / n 509 | return new_params 510 | 511 | 512 | def shoot_infs(inp_tensor): 513 | """Replaces inf by maximum of tensor""" 514 | mask_inf = torch.isinf(inp_tensor) 515 | ind_inf = torch.nonzero(mask_inf, as_tuple=False) 516 | if len(ind_inf) > 0: 517 | for ind in ind_inf: 518 | if len(ind) == 2: 519 | inp_tensor[ind[0], ind[1]] = 0 520 | elif len(ind) == 1: 521 | inp_tensor[ind[0]] = 0 522 | m = torch.max(inp_tensor) 523 | for ind in ind_inf: 524 | if len(ind) == 2: 525 | inp_tensor[ind[0], ind[1]] = m 526 | elif len(ind) == 1: 527 | inp_tensor[ind[0]] = m 528 | return inp_tensor 529 | 530 | 531 | def sinkhorn(Q, n_iters=3, epsilon=0.01): 532 | # epsilon should be adjusted according to logits value's scale 533 | with torch.no_grad(): 534 | Q = shoot_infs(Q) 535 | Q = torch.exp(Q / epsilon) 536 | for i in range(n_iters): 537 | Q /= Q.sum(dim=0, keepdim=True) 538 | Q /= Q.sum(dim=1, keepdim=True) 539 | return Q 540 | -------------------------------------------------------------------------------- /MERA/src/model_moe_vote.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import os 5 | import copy 6 | import math 7 | import json 8 | import collections 9 | import numpy as np 10 | import pandas as pd 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.optim as optim 15 | import torch.nn.functional as F 16 | 17 | from tqdm import tqdm 18 | from torch.utils.data import DataLoader, random_split 19 | 20 | from qlib.utils import get_or_create_path 21 | from qlib.log import get_module_logger 22 | from qlib.model.base import Model 23 | 24 | from src.custom_moe_layer import FMoETransformerMLP 25 | from src.noisy_gate import NoisyGate 26 | from src.noisy_gate_vmoe import NoisyGate_VMoE 27 | 28 | from src.dataset import collate_fn 29 | 30 | device = "cuda" if torch.cuda.is_available() else "cpu" 31 | 32 | 33 | class TRAModel(Model): 34 | def __init__( 35 | self, 36 | model_config, 37 | tra_config, 38 | model_type="LSTM", 39 | lr=1e-3, 40 | n_epochs=500, 41 | early_stop=50, 42 | smooth_steps=5, 43 | max_steps_per_epoch=None, 44 | freeze_model=False, 45 | model_init_state=None, 46 | lamb=0.0, 47 | rho=0.99, 48 | seed=None, 49 | logdir=None, 50 | eval_train=True, 51 | eval_test=True, 52 | avg_params=True, 53 | **kwargs, 54 | ): 55 | np.random.seed(seed) 56 | torch.manual_seed(seed) 57 | 58 | self.logger = get_module_logger("TRA") 59 | self.logger.info("TRA Model...") 60 | 61 | self.model = eval(model_type)(**model_config).to(device) 62 | if model_init_state: 63 | self.model.load_state_dict(torch.load(model_init_state, map_location="cuda")["model"],strict=False) 64 | if freeze_model: 65 | # for param in self.model.parameters(): 66 | # param.requires_grad_(False) 67 | for name, param in self.model.named_parameters(): 68 | if 'experts' in name or 'gate' in name or 'bn' in name or 'norm' in name or 'input_proj' in name or 'embedding' in name: 69 | print(name) 70 | param.requires_grad_(True) 71 | else: 72 | param.requires_grad_(False) 73 | 74 | else: 75 | self.logger.info("# model params: %d" % sum([p.numel() for p in self.model.parameters()])) 76 | 77 | self.tra = TRA(self.model.output_size, **tra_config).to(device) 78 | self.logger.info("# tra params: %d" % sum([p.numel() for p in self.tra.parameters()])) 79 | 80 | self.optimizer = optim.Adam(list(self.model.parameters()) + list(self.tra.parameters()), lr=lr) 81 | 82 | self.model_config = model_config 83 | self.tra_config = tra_config 84 | self.lr = lr 85 | self.n_epochs = n_epochs 86 | self.early_stop = early_stop 87 | self.smooth_steps = smooth_steps 88 | self.max_steps_per_epoch = max_steps_per_epoch 89 | self.lamb = lamb 90 | self.rho = rho 91 | self.seed = seed 92 | self.logdir = logdir 93 | self.eval_train = eval_train 94 | self.eval_test = eval_test 95 | self.avg_params = avg_params 96 | 97 | if self.tra.num_states > 1 and not self.eval_train: 98 | self.logger.warn("`eval_train` will be ignored when using TRA") 99 | 100 | if self.logdir is not None: 101 | if os.path.exists(self.logdir): 102 | self.logger.warn(f"logdir {self.logdir} is not empty") 103 | os.makedirs(self.logdir, exist_ok=True) 104 | 105 | self.fitted = False 106 | self.global_step = -1 107 | 108 | def collect_noisy_gating_loss(self, model, weight): 109 | loss = 0 110 | for module in model.modules(): 111 | # print(module) 112 | if (isinstance(module, NoisyGate) or isinstance(module, NoisyGate_VMoE)) and module.has_loss: 113 | loss += module.get_loss() 114 | return loss * weight 115 | 116 | def train_epoch(self, data_set): 117 | self.model.train() 118 | self.tra.train() 119 | 120 | count = 0 121 | total_loss = 0 122 | total_count = 0 123 | 124 | for i, batch in enumerate(data_set): 125 | 126 | self.global_step += 1 127 | 128 | feature, similar, label_raw, label_norm, index = batch 129 | 130 | feature = feature.to(device) 131 | similar = similar.to(device) 132 | label_raw = label_raw.to(device) 133 | label_norm = label_norm.to(device) 134 | 135 | feature = feature.permute(0, 2, 1) 136 | hidden = self.model(feature, similar) 137 | pred, all_preds, prob = self.tra(hidden) 138 | 139 | mask = ~torch.isnan(label_norm) 140 | 141 | # 过滤掉NaN值,只保留非NaN值的预测值和标签值 142 | filtered_pred = pred[mask] 143 | filtered_label_norm = label_norm[mask] 144 | 145 | loss = (filtered_pred - filtered_label_norm).pow(2).mean() 146 | # loss = (pred - label_norm).pow(2).mean() 147 | 148 | gate_loss = self.collect_noisy_gating_loss(self.model, 0.01) 149 | loss += gate_loss 150 | loss.backward() 151 | self.optimizer.step() 152 | self.optimizer.zero_grad() 153 | 154 | total_loss += loss.item() 155 | total_count += len(pred) 156 | 157 | total_loss /= total_count 158 | print(total_loss) 159 | return total_loss 160 | 161 | def test_epoch(self, data_set, return_pred=False): 162 | self.model.eval() 163 | self.tra.eval() 164 | 165 | preds = [] 166 | metrics = [] 167 | for i, batch in enumerate(data_set): 168 | # print(i) 169 | feature, similar, label_raw, label_norm, index = batch 170 | 171 | feature = feature.to(device) 172 | similar=similar.to(device) 173 | label_raw = label_raw.to(device) 174 | label_norm = label_norm.to(device) 175 | 176 | feature = feature.permute(0, 2, 1) 177 | 178 | with torch.no_grad(): 179 | hidden = self.model(feature, similar) 180 | pred, all_preds, prob = self.tra(hidden) 181 | X = np.c_[ 182 | pred.cpu().numpy(), 183 | label_raw.cpu().numpy(), 184 | ] 185 | 186 | columns = ["score", "label"] 187 | pred = pd.DataFrame(X, index = index, columns = columns) 188 | 189 | metrics.append(evaluate(pred)) 190 | 191 | if return_pred: 192 | preds.append(pred) 193 | 194 | metrics = pd.DataFrame(metrics) 195 | metrics = { 196 | "MSE": metrics.MSE.mean(), 197 | "MAE": metrics.MAE.mean(), 198 | "IC": metrics.IC.mean(), 199 | "ICIR": metrics.IC.mean() / metrics.IC.std(), 200 | } 201 | 202 | if return_pred: 203 | preds = pd.concat(preds, axis=0) 204 | preds.sort_index(inplace=True) 205 | 206 | return metrics, preds 207 | 208 | def fit(self, train_dataset, val_dataset, test_dataset, evals_result=dict()): 209 | 210 | train_loader = DataLoader(train_dataset, batch_size = 2, collate_fn = collate_fn, shuffle = True) 211 | val_loader = DataLoader(val_dataset, batch_size = 1, collate_fn = collate_fn, shuffle = False) 212 | test_loader = DataLoader(test_dataset, batch_size = 1, collate_fn = collate_fn, shuffle = False) 213 | 214 | best_score = -1 215 | best_epoch = 0 216 | stop_rounds = 0 217 | best_params = { 218 | "model": copy.deepcopy(self.model.state_dict()), 219 | "tra": copy.deepcopy(self.tra.state_dict()), 220 | } 221 | params_list = { 222 | "model": collections.deque(maxlen=self.smooth_steps), 223 | "tra": collections.deque(maxlen=self.smooth_steps), 224 | } 225 | evals_result["train"] = [] 226 | evals_result["valid"] = [] 227 | evals_result["test"] = [] 228 | 229 | # train 230 | self.fitted = True 231 | self.global_step = -1 232 | 233 | for epoch in range(self.n_epochs): 234 | self.logger.info("Epoch %d:", epoch) 235 | 236 | self.logger.info("training...") 237 | self.train_epoch(train_loader) 238 | 239 | self.logger.info("evaluating...") 240 | # average params for inference 241 | params_list["model"].append(copy.deepcopy(self.model.state_dict())) 242 | params_list["tra"].append(copy.deepcopy(self.tra.state_dict())) 243 | self.model.load_state_dict(average_params(params_list["model"])) 244 | self.tra.load_state_dict(average_params(params_list["tra"])) 245 | 246 | # NOTE: during evaluating, the whole memory will be refreshed 247 | if self.tra.num_states > 1 or self.eval_train: 248 | train_metrics = self.test_epoch(train_loader)[0] 249 | evals_result["train"].append(train_metrics) 250 | self.logger.info("\ttrain metrics: %s" % train_metrics) 251 | 252 | valid_metrics = self.test_epoch(val_loader)[0] 253 | evals_result["valid"].append(valid_metrics) 254 | self.logger.info("\tvalid metrics: %s" % valid_metrics) 255 | 256 | if self.eval_test: 257 | test_metrics = self.test_epoch(test_loader)[0] 258 | evals_result["test"].append(test_metrics) 259 | self.logger.info("\ttest metrics: %s" % test_metrics) 260 | 261 | if valid_metrics["IC"] > best_score: 262 | best_score = valid_metrics["IC"] 263 | stop_rounds = 0 264 | best_epoch = epoch 265 | best_params = { 266 | "model": copy.deepcopy(self.model.state_dict()), 267 | "tra": copy.deepcopy(self.tra.state_dict()), 268 | } 269 | else: 270 | stop_rounds += 1 271 | if stop_rounds >= self.early_stop: 272 | self.logger.info("early stop @ %s" % epoch) 273 | break 274 | 275 | # restore parameters 276 | self.model.load_state_dict(params_list["model"][-1]) 277 | self.tra.load_state_dict(params_list["tra"][-1]) 278 | 279 | self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch)) 280 | self.model.load_state_dict(best_params["model"]) 281 | self.tra.load_state_dict(best_params["tra"]) 282 | 283 | metrics, preds = self.test_epoch(test_loader, return_pred=True) 284 | self.logger.info("test metrics: %s" % metrics) 285 | 286 | if self.logdir: 287 | self.logger.info("save model & pred to local directory") 288 | 289 | pd.concat({name: pd.DataFrame(evals_result[name]) for name in evals_result}, axis=1).to_csv( 290 | self.logdir + "/logs.csv", index=False 291 | ) 292 | 293 | torch.save(best_params, self.logdir + "/model.bin") 294 | 295 | preds.to_pickle(self.logdir + "/pred.pkl") 296 | 297 | info = { 298 | "config": { 299 | "model_config": self.model_config, 300 | "tra_config": self.tra_config, 301 | "lr": self.lr, 302 | "n_epochs": self.n_epochs, 303 | "early_stop": self.early_stop, 304 | "smooth_steps": self.smooth_steps, 305 | "lamb": self.lamb, 306 | "rho": self.rho, 307 | "seed": self.seed, 308 | "logdir": self.logdir, 309 | }, 310 | "best_eval_metric": -best_score, # NOTE: minux -1 for minimize 311 | "metric": metrics, 312 | } 313 | with open(self.logdir + "/info.json", "w") as f: 314 | json.dump(info, f) 315 | 316 | def predict(self, test_dataset, model_path): 317 | test_loader = DataLoader(test_dataset, batch_size = 1, collate_fn = collate_fn, shuffle = False) 318 | best_params = torch.load(model_path) 319 | self.model.load_state_dict(best_params["model"]) 320 | self.tra.load_state_dict(best_params["tra"]) 321 | metrics, preds = self.test_epoch(test_loader, return_pred=True) 322 | print(self.model.moe.score) 323 | self.logger.info("test metrics: %s" % metrics) 324 | return preds 325 | 326 | 327 | class PositionalEncoding(nn.Module): 328 | # reference: https://pytorch.org/tutorials/beginner/transformer_tutorial.html 329 | def __init__(self, d_model, dropout=0.1, max_len=5000): 330 | super(PositionalEncoding, self).__init__() 331 | self.dropout = nn.Dropout(p=dropout) 332 | 333 | pe = torch.zeros(max_len, d_model) 334 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 335 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 336 | pe[:, 0::2] = torch.sin(position * div_term) 337 | pe[:, 1::2] = torch.cos(position * div_term) 338 | pe = pe.unsqueeze(0).transpose(0, 1) 339 | self.register_buffer("pe", pe) 340 | self.pe = self.pe.to(device) 341 | 342 | def forward(self, x): 343 | x = x + self.pe[: x.size(0), :] 344 | return self.dropout(x) 345 | 346 | 347 | class Transformer(nn.Module): 348 | 349 | def __init__( 350 | self, 351 | input_size=16, 352 | hidden_size=64, 353 | num_layers=2, 354 | num_heads=2, 355 | dropout=0.0, 356 | input_drop=0.0, 357 | noise_level=0.0, 358 | topk=1, 359 | num_expert=4, 360 | gate_dim=16, 361 | moe_gate_type='noisy_vmoe', 362 | vmoe_noisy_std=1, 363 | **kwargs, 364 | ): 365 | super().__init__() 366 | 367 | self.input_size = input_size 368 | self.hidden_size = hidden_size 369 | self.num_layers = num_layers 370 | self.num_heads = num_heads 371 | self.noise_level = noise_level 372 | self.gate_dim=gate_dim 373 | 374 | # self.input_drop = nn.Dropout(input_drop) 375 | 376 | self.input_proj = nn.Linear(input_size, hidden_size) 377 | 378 | self.pe = PositionalEncoding(hidden_size, dropout) 379 | 380 | if moe_gate_type == "noisy": 381 | moe_gate_fun = NoisyGate 382 | elif moe_gate_type == "noisy_vmoe": 383 | moe_gate_fun = NoisyGate_VMoE 384 | else: 385 | raise ValueError("unknow gate type of {}".format(moe_gate_type)) 386 | 387 | act_layer=nn.GELU 388 | activation = nn.Sequential( 389 | act_layer(), 390 | nn.Dropout(dropout) 391 | ) 392 | 393 | blocks = [] 394 | for i in range(num_layers): 395 | blocks.append(nn.TransformerEncoderLayer(nhead = num_heads, dropout = dropout, d_model = hidden_size, dim_feedforward = hidden_size * 4, norm_first=True)) 396 | 397 | self.moe = FMoETransformerMLP(num_expert = num_expert, d_model = hidden_size, d_gate = gate_dim, gate = moe_gate_fun, 398 | top_k = topk, activation = activation, vmoe_noisy_std=vmoe_noisy_std, expert_prune=False) 399 | 400 | self.encoder = nn.Sequential(*blocks) 401 | 402 | self.output_size = hidden_size 403 | self.bn = nn.BatchNorm1d(input_size) 404 | 405 | self.embedding = nn.Embedding(10, gate_dim) 406 | 407 | def forward(self, x, similars): 408 | shape = x.shape 409 | x = x.reshape(-1,self.input_size) 410 | x = self.bn(x) 411 | x = x.reshape(shape) 412 | 413 | x = x.permute(1, 0, 2).contiguous() # the first dim need to be sequence 414 | 415 | x = self.input_proj(x) 416 | x = self.pe(x) 417 | 418 | for i, layer in enumerate(self.encoder): 419 | x = layer(x) 420 | 421 | out = x 422 | x = x[-1] 423 | 424 | similar_label = similars[:,:,-2] 425 | 426 | mode_values = torch.mode(similar_label, dim=1).values # torch.Size([908]) 427 | similar_label = self.embedding(mode_values.int()) # torch.Size([890, 16]) 428 | 429 | output = self.moe(inp = out.permute(1, 0, 2), gate_inp = similar_label) 430 | return output[:,-1] 431 | 432 | class TRA(nn.Module): 433 | 434 | """Temporal Routing Adaptor (TRA) 435 | 436 | TRA takes historical prediction errors & latent representation as inputs, 437 | then routes the input sample to a specific predictor for training & inference. 438 | 439 | Args: 440 | input_size (int): input size (RNN/Transformer's hidden size) 441 | num_states (int): number of latent states (i.e., trading patterns) 442 | If `num_states=1`, then TRA falls back to traditional methods 443 | hidden_size (int): hidden size of the router 444 | tau (float): gumbel softmax temperature 445 | """ 446 | 447 | def __init__(self, input_size, num_states=1, hidden_size=8, tau=1.0, src_info="LR_TPE"): 448 | super().__init__() 449 | 450 | self.num_states = num_states 451 | self.tau = tau 452 | self.src_info = src_info 453 | 454 | if num_states > 1: 455 | self.router = nn.LSTM( 456 | input_size=num_states, 457 | hidden_size=hidden_size, 458 | num_layers=1, 459 | batch_first=True, 460 | ) 461 | self.fc = nn.Linear(hidden_size + input_size, num_states) 462 | 463 | self.predictors1 = nn.Linear(input_size, hidden_size) 464 | self.predictors2 = nn.Linear(hidden_size, num_states) 465 | self.relu=nn.ReLU() 466 | 467 | def forward(self, hidden, hist_loss=None): 468 | preds = self.predictors2(self.relu(self.predictors1(hidden))) 469 | return preds.squeeze(-1), preds, None 470 | 471 | 472 | def evaluate(pred): 473 | # pred = pred.rank(pct=True) # transform into percentiles 474 | pred = pred.dropna(subset=['label']) 475 | score = pred.score 476 | label = pred.label 477 | diff = score - label 478 | MSE = (diff**2).mean() 479 | MAE = (diff.abs()).mean() 480 | IC = score.corr(label) 481 | # return {"MSE": MSE.astype(np.float64), "MAE": MAE.astype(np.float64), "IC": IC.astype(np.float64)} 482 | return {"MSE": MSE, "MAE": MAE, "IC": IC} 483 | 484 | 485 | def average_params(params_list): 486 | assert isinstance(params_list, (tuple, list, collections.deque)) 487 | n = len(params_list) 488 | if n == 1: 489 | return params_list[0] 490 | new_params = collections.OrderedDict() 491 | keys = None 492 | for i, params in enumerate(params_list): 493 | if keys is None: 494 | keys = params.keys() 495 | for k, v in params.items(): 496 | if k not in keys: 497 | raise ValueError("the %d-th model has different params" % i) 498 | if k not in new_params: 499 | new_params[k] = v / n 500 | else: 501 | new_params[k] += v / n 502 | return new_params 503 | 504 | 505 | def shoot_infs(inp_tensor): 506 | """Replaces inf by maximum of tensor""" 507 | mask_inf = torch.isinf(inp_tensor) 508 | ind_inf = torch.nonzero(mask_inf, as_tuple=False) 509 | if len(ind_inf) > 0: 510 | for ind in ind_inf: 511 | if len(ind) == 2: 512 | inp_tensor[ind[0], ind[1]] = 0 513 | elif len(ind) == 1: 514 | inp_tensor[ind[0]] = 0 515 | m = torch.max(inp_tensor) 516 | for ind in ind_inf: 517 | if len(ind) == 2: 518 | inp_tensor[ind[0], ind[1]] = m 519 | elif len(ind) == 1: 520 | inp_tensor[ind[0]] = m 521 | return inp_tensor 522 | 523 | 524 | def sinkhorn(Q, n_iters=3, epsilon=0.01): 525 | # epsilon should be adjusted according to logits value's scale 526 | with torch.no_grad(): 527 | Q = shoot_infs(Q) 528 | Q = torch.exp(Q / epsilon) 529 | for i in range(n_iters): 530 | Q /= Q.sum(dim=0, keepdim=True) 531 | Q /= Q.sum(dim=1, keepdim=True) 532 | return Q 533 | -------------------------------------------------------------------------------- /MERA/src/noisy_gate.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Noisy gate for gshard and switch 3 | """ 4 | from fmoe.gates.base_gate import BaseGate 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.distributions.normal import Normal 10 | import math 11 | 12 | from pdb import set_trace 13 | 14 | class NoisyGate(BaseGate): 15 | def __init__(self, d_model, num_expert, world_size, top_k=2, no_noise=False, return_decoupled_activation=False,regu_experts_fromtask=False,\ 16 | num_experts_pertask = -1,num_tasks = -1): 17 | super().__init__(num_expert, world_size) 18 | self.w_gate = nn.Parameter( 19 | torch.zeros(d_model, self.tot_expert), requires_grad=True 20 | ) 21 | self.w_noise = nn.Parameter( 22 | torch.zeros(d_model, self.tot_expert), requires_grad=True 23 | ) 24 | 25 | self.return_decoupled_activation = return_decoupled_activation 26 | if self.return_decoupled_activation: 27 | self.w_gate_aux = nn.Parameter( 28 | torch.zeros(d_model, self.tot_expert), requires_grad=True 29 | ) 30 | self.w_noise_aux = nn.Parameter( 31 | torch.zeros(d_model, self.tot_expert), requires_grad=True 32 | ) 33 | 34 | self.top_k = top_k 35 | self.no_noise = no_noise 36 | self.softplus = nn.Softplus() 37 | self.softmax = nn.Softmax(1) 38 | 39 | self.noise_epsilon = 1e-2 40 | 41 | self.activation = None 42 | self.select_idx = None 43 | self.regu_experts_fromtask= regu_experts_fromtask 44 | self.num_experts_pertask = num_experts_pertask 45 | self.num_tasks = num_tasks 46 | self.reset_parameters() 47 | 48 | def reset_parameters(self): 49 | # Approach is the same as in torch.nn.Linear 50 | # https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py#L88 51 | 52 | torch.nn.init.kaiming_uniform_(self.w_gate, a=math.sqrt(5)) 53 | torch.nn.init.kaiming_uniform_(self.w_noise, a=math.sqrt(5)) 54 | 55 | if self.return_decoupled_activation: 56 | torch.nn.init.kaiming_uniform_(self.w_gate_aux, a=math.sqrt(5)) 57 | torch.nn.init.kaiming_uniform_(self.w_noise_aux, a=math.sqrt(5)) 58 | 59 | def _gates_to_load(self, gates): 60 | """Compute the true load per expert, given the gates. 61 | The load is the number of examples for which the corresponding gate is >0. 62 | Args: 63 | gates: a `Tensor` of shape [batch_size, n] 64 | Returns: 65 | a float32 `Tensor` of shape [n] 66 | """ 67 | return (gates > 0).sum(0) 68 | 69 | def _prob_in_top_k( 70 | self, clean_values, noisy_values, noise_stddev, noisy_top_values 71 | ): 72 | """Helper function to NoisyTopKGating. 73 | Computes the probability that value is in top k, given different random noise. 74 | This gives us a way of backpropagating from a loss that balances the number 75 | of times each expert is in the top k experts per example. 76 | In the case of no noise, pass in None for noise_stddev, and the result will 77 | not be differentiable. 78 | Args: 79 | clean_values: a `Tensor` of shape [batch, n]. 80 | noisy_values: a `Tensor` of shape [batch, n]. Equal to clean values plus 81 | normally distributed noise with standard deviation noise_stddev. 82 | noise_stddev: a `Tensor` of shape [batch, n], or None 83 | noisy_top_values: a `Tensor` of shape [batch, m]. 84 | "values" Output of tf.top_k(noisy_top_values, m). m >= k+1 85 | Returns: 86 | a `Tensor` of shape [batch, n]. 87 | """ 88 | 89 | batch = clean_values.size(0) 90 | m = noisy_top_values.size(1) 91 | top_values_flat = noisy_top_values.flatten() 92 | threshold_positions_if_in = ( 93 | torch.arange(batch, device=clean_values.device) * m + self.top_k 94 | ) 95 | threshold_if_in = torch.unsqueeze( 96 | torch.gather(top_values_flat, 0, threshold_positions_if_in), 1 97 | ) 98 | is_in = torch.gt(noisy_values, threshold_if_in) 99 | threshold_positions_if_out = threshold_positions_if_in - 1 100 | threshold_if_out = torch.unsqueeze( 101 | torch.gather(top_values_flat, 0, threshold_positions_if_out), 1 102 | ) 103 | # is each value currently in the top k. 104 | normal = Normal( 105 | torch.tensor([0.0], device=clean_values.device), 106 | torch.tensor([1.0], device=clean_values.device), 107 | ) 108 | 109 | prob_if_in = normal.cdf((clean_values - threshold_if_in) / noise_stddev) 110 | prob_if_out = normal.cdf((clean_values - threshold_if_out) / noise_stddev) 111 | prob = torch.where(is_in, prob_if_in, prob_if_out) 112 | return prob 113 | 114 | def cv_squared(self, x): 115 | """The squared coefficient of variation of a sample. 116 | Useful as a loss to encourage a positive distribution to be more uniform. 117 | Epsilons added for numerical stability. 118 | Returns 0 for an empty Tensor. 119 | Args: 120 | x: a `Tensor`. 121 | Returns: 122 | a `Scalar`. 123 | """ 124 | eps = 1e-10 125 | # if only num_expert = 1 126 | if x.shape[0] == 1: 127 | return torch.Tensor([0]) 128 | return x.float().var() / (x.float().mean() ** 2 + eps) 129 | 130 | def set_loss(self, loss): 131 | if self.loss is None: 132 | self.loss = loss 133 | else: 134 | self.loss += loss 135 | 136 | def forward(self, inp): 137 | shape_input = list(inp.shape) 138 | channel = shape_input[-1] 139 | other_dim = shape_input[:-1] 140 | inp = inp.reshape(-1, channel) 141 | 142 | clean_logits = inp @ self.w_gate 143 | raw_noise_stddev = inp @ self.w_noise 144 | noise_stddev = (self.softplus(raw_noise_stddev) + self.noise_epsilon) * self.training 145 | 146 | if self.no_noise: 147 | noise_stddev *= 0 148 | 149 | noisy_logits = clean_logits + (torch.randn_like(clean_logits) * noise_stddev) 150 | 151 | if self.select_idx is not None: 152 | assert len(self.select_idx) >= self.top_k 153 | noisy_logits = noisy_logits[:, self.select_idx] 154 | 155 | logits = noisy_logits 156 | 157 | if self.return_decoupled_activation: 158 | clean_logits_aux = inp @ self.w_gate_aux 159 | raw_noise_stddev_aux = inp @ self.w_noise_aux 160 | noise_stddev_aux = (self.softplus(raw_noise_stddev_aux) + self.noise_epsilon) * self.training 161 | 162 | if self.no_noise: 163 | noise_stddev_aux *= 0 164 | 165 | noisy_logits_aux = clean_logits_aux + (torch.randn_like(clean_logits_aux) * noise_stddev_aux) 166 | 167 | if self.select_idx is not None and len(self.select_idx) == self.top_k: 168 | top_k_gates, top_k_indices = logits.topk( 169 | min(self.top_k, self.tot_expert), dim=1 170 | ) 171 | 172 | return ( 173 | top_k_indices, 174 | top_k_gates, 175 | ) 176 | 177 | # calculate topk + 1 that will be needed for the noisy gates 178 | top_logits, top_indices = logits.topk( 179 | min(self.top_k + 1, self.tot_expert), dim=1 180 | ) 181 | 182 | top_k_logits = top_logits[:, : self.top_k] 183 | top_k_indices = top_indices[:, : self.top_k] 184 | top_k_gates = self.softmax(top_k_logits) 185 | 186 | zeros = torch.zeros_like(logits, requires_grad=True) 187 | gates = zeros.scatter(1, top_k_indices, top_k_gates) 188 | 189 | if self.training: 190 | if self.top_k < self.tot_expert: 191 | load = ( 192 | self._prob_in_top_k( 193 | clean_logits, noisy_logits, noise_stddev, top_logits 194 | ) 195 | ).sum(0) 196 | else: 197 | load = self._gates_to_load(gates) 198 | 199 | importance = gates.sum(0) 200 | loss = self.cv_squared(importance) + self.cv_squared(load) 201 | else: 202 | loss = 0 203 | 204 | self.set_loss(loss) 205 | self.activation = logits.reshape(other_dim + [-1,]).contiguous() 206 | 207 | # print("top_k_indices are {}".format(top_k_indices)) 208 | if self.return_decoupled_activation: 209 | # print("set activation as noisy_logits_aux") 210 | self.activation = noisy_logits_aux.reshape(other_dim + [-1, ]).contiguous() 211 | 212 | top_k_indices = top_k_indices.reshape(other_dim + [self.top_k]).contiguous() 213 | top_k_gates = top_k_gates.reshape(other_dim + [self.top_k]).contiguous() 214 | 215 | return ( 216 | top_k_indices, 217 | top_k_gates, 218 | ) 219 | 220 | def get_activation(self, clear=True): 221 | activation = self.activation 222 | if clear: 223 | self.activation = None 224 | return activation 225 | 226 | @property 227 | def has_activation(self): 228 | return self.activation is not None 229 | -------------------------------------------------------------------------------- /MERA/src/noisy_gate_vmoe.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Noisy gate for gshard and switch 3 | """ 4 | from fmoe.gates.base_gate import BaseGate 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.distributions.normal import Normal 10 | import math 11 | import numpy as np 12 | from collections import Counter 13 | from pdb import set_trace 14 | 15 | class NoisyGate_VMoE(BaseGate): 16 | def __init__(self, d_model, num_expert, world_size, top_k=2, noise_std=1, no_noise=False, 17 | return_decoupled_activation=False,regu_experts_fromtask=False,num_experts_pertask=-1,num_tasks=-1, 18 | regu_sem=False,sem_force = False,regu_subimage=False): 19 | super().__init__(num_expert, world_size) 20 | self.w_gate = nn.Parameter( 21 | torch.zeros(d_model, self.tot_expert), requires_grad=True 22 | ) 23 | 24 | self.return_decoupled_activation = return_decoupled_activation 25 | if self.return_decoupled_activation: 26 | self.w_gate_aux = nn.Parameter( 27 | torch.zeros(d_model, self.tot_expert), requires_grad=True 28 | ) 29 | 30 | self.top_k = top_k 31 | self.no_noise = no_noise 32 | self.noise_std = noise_std 33 | 34 | self.softmax = nn.Softmax(1) 35 | 36 | self.activation = None 37 | self.select_idx = None 38 | self.regu_experts_fromtask= regu_experts_fromtask 39 | self.num_experts_pertask = num_experts_pertask 40 | self.num_tasks = num_tasks 41 | self.regu_sem = regu_sem 42 | self.regu_subimage = regu_subimage 43 | self.patch_size = 16 44 | 45 | if self.regu_sem: 46 | from losses.loss_functions import SoftMaxwithLoss 47 | self.criterion = SoftMaxwithLoss() 48 | self.num_class = 40 49 | self.head = nn.Linear(num_expert, self.num_class) 50 | self.semregu_loss = 0.0 51 | if self.regu_subimage: 52 | self.regu_subimage_loss = 0.0 53 | self.subimage_tokens = 5 54 | if self.regu_experts_fromtask: 55 | self.start_experts_id=[] 56 | start_id = 0 57 | for i in range(self.num_tasks): 58 | start_id = start_id + int(i* (self.tot_expert-self.num_experts_pertask)/(self.num_tasks-1)) 59 | self.start_experts_id.append(start_id) 60 | print('self.start_experts_id',self.start_experts_id) 61 | self.reset_parameters() 62 | 63 | def reset_parameters(self): 64 | # Approach is the same as in torch.nn.Linear 65 | # https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py#L88 66 | 67 | torch.nn.init.kaiming_uniform_(self.w_gate, a=math.sqrt(5)) 68 | 69 | if self.return_decoupled_activation: 70 | torch.nn.init.kaiming_uniform_(self.w_gate_aux, a=math.sqrt(5)) 71 | 72 | def _gates_to_load(self, gates): 73 | """Compute the true load per expert, given the gates. 74 | The load is the number of examples for which the corresponding gate is >0. 75 | Args: 76 | gates: a `Tensor` of shape [batch_size, n] 77 | Returns: 78 | a float32 `Tensor` of shape [n] 79 | """ 80 | return (gates > 0).sum(0) 81 | 82 | def _prob_in_top_k( 83 | self, clean_values, noisy_values, noise_stddev, noisy_top_values 84 | ): 85 | """Helper function to NoisyTopKGating. 86 | Computes the probability that value is in top k, given different random noise. 87 | This gives us a way of backpropagating from a loss that balances the number 88 | of times each expert is in the top k experts per example. 89 | In the case of no noise, pass in None for noise_stddev, and the result will 90 | not be differentiable. 91 | Args: 92 | clean_values: a `Tensor` of shape [batch, n]. 93 | noisy_values: a `Tensor` of shape [batch, n]. Equal to clean values plus 94 | normally distributed noise with standard deviation noise_stddev. 95 | noise_stddev: a `Tensor` of shape [batch, n], or None 96 | noisy_top_values: a `Tensor` of shape [batch, m]. 97 | "values" Output of tf.top_k(noisy_top_values, m). m >= k+1 98 | Returns: 99 | a `Tensor` of shape [batch, n]. 100 | """ 101 | 102 | batch = clean_values.size(0) 103 | m = noisy_top_values.size(1) 104 | top_values_flat = noisy_top_values.flatten() 105 | threshold_positions_if_in = ( 106 | torch.arange(batch, device=clean_values.device) * m + self.top_k 107 | ) 108 | threshold_if_in = torch.unsqueeze( 109 | torch.gather(top_values_flat, 0, threshold_positions_if_in), 1 110 | ) 111 | is_in = torch.gt(noisy_values, threshold_if_in) 112 | threshold_positions_if_out = threshold_positions_if_in - 1 113 | threshold_if_out = torch.unsqueeze( 114 | torch.gather(top_values_flat, 0, threshold_positions_if_out), 1 115 | ) 116 | # is each value currently in the top k. 117 | normal = Normal( 118 | torch.tensor([0.0], device=clean_values.device), 119 | torch.tensor([1.0], device=clean_values.device), 120 | ) 121 | 122 | prob_if_in = normal.cdf((clean_values - threshold_if_in) / noise_stddev) 123 | prob_if_out = normal.cdf((clean_values - threshold_if_out) / noise_stddev) 124 | prob = torch.where(is_in, prob_if_in, prob_if_out) 125 | return prob 126 | 127 | def cv_squared(self, x): 128 | """The squared coefficient of variation of a sample. 129 | Useful as a loss to encourage a positive distribution to be more uniform. 130 | Epsilons added for numerical stability. 131 | Returns 0 for an empty Tensor. 132 | Args: 133 | x: a `Tensor`. 134 | Returns: 135 | a `Scalar`. 136 | """ 137 | eps = 1e-10 138 | # if only num_expert = 1 139 | if x.shape[0] == 1: 140 | return torch.Tensor([0]) 141 | return x.float().var() / (x.float().mean() ** 2 + eps) 142 | 143 | def set_loss(self, loss): 144 | if self.loss is None: 145 | self.loss = loss 146 | else: 147 | self.loss += loss 148 | def get_semregu_loss(self): 149 | return self.semregu_loss 150 | 151 | def get_regu_subimage_loss(self): 152 | return self.regu_subimage_loss 153 | 154 | 155 | # def get_groundtruth_sem(self, sem): 156 | # batch = sem.shape[0] 157 | # hint = np.ones(batch,1,int(sem.shape[2]/self.patch_size),int(sem.shape[3]/self.patch_size))*255 158 | # idx = 0 159 | # for k in range(batch): 160 | # for i in range(int(sem.shape[2]/self.patch_size)): 161 | # for j in range(int(sem.shape[3]/self.patch_size)): 162 | # patch = sem[k][:,self.patch_size*i:self.patch_size*(i+1),self.patch_size*j:self.patch_size*(j+1)].cpu().numpy().flatten() 163 | # index , num=Counter(patch).most_common(1)[0] 164 | # if num>0.4*(self.patch_size*self.patch_size): 165 | # hint[k,:,i,j]=index 166 | # if index != 255: 167 | # idx = idx+1 168 | # print(idx/(batch*int(sem.shape[2]/self.patch_size)*int(sem.shape[3]/self.patch_size)),'percent token will be used') 169 | # return torch.tensor(hint, device=sem.device) 170 | 171 | def forward(self, inp, task_id=None,sem=None): 172 | shape_input = list(inp.shape) 173 | # print(shape_input) 174 | channel = shape_input[-1] 175 | other_dim = shape_input[:-1] 176 | inp = inp.reshape(-1, channel) 177 | 178 | if self.regu_experts_fromtask and (task_id is not None): 179 | clean_logits = inp @ self.w_gate[:,self.start_experts_id[task_id]:self.start_experts_id[task_id]+self.num_experts_pertask] 180 | raw_noise_stddev = self.noise_std / self.num_experts_pertask 181 | else: 182 | clean_logits = inp @ self.w_gate 183 | raw_noise_stddev = self.noise_std / self.tot_expert 184 | noise_stddev = raw_noise_stddev * self.training 185 | 186 | if self.regu_sem and (sem is not None): 187 | batch = sem.shape[0] 188 | prior_selection = clean_logits.reshape(batch,-1,self.num_expert)[:,1:,:] 189 | prior_selection = prior_selection.reshape(-1,self.num_expert) 190 | prior_out = self.head(prior_selection) 191 | prior_out = prior_out.reshape(batch,sem.shape[2],sem.shape[3],self.num_class) 192 | prior_out = prior_out.permute(0,3,1,2) 193 | # hint = self.get_groundtruth_sem(sem) 194 | semregu_loss = self.criterion(prior_out,sem) 195 | # print('during forward regu loss',semregu_loss) 196 | self.semregu_loss = semregu_loss 197 | # print('clean_logits',clean_logits.shape,sem.shape) 198 | 199 | if self.regu_subimage and (sem is not None): 200 | self.regu_subimage_loss = 0 201 | batch_size = sem.shape[0] 202 | prior_selection = clean_logits.reshape(batch_size,-1,self.num_expert)[:,1:,:] 203 | prior_selection = prior_selection.reshape(batch_size,30,40,self.num_expert) 204 | for k in range(batch_size): 205 | for i in range(int(30/self.subimage_tokens)): 206 | for j in range(int(40/self.subimage_tokens)): 207 | subimage_selection = prior_selection[k,self.subimage_tokens*i:self.subimage_tokens*(i+1),self.subimage_tokens*j:self.subimage_tokens*(j+1),:] 208 | # print(subimage_selection.shape) 209 | subimage_selection = subimage_selection.reshape(-1,self.num_expert) 210 | # print(torch.sum(subimage_selection, dim=0)) 211 | top_subimage_values,top_subimage_index = torch.topk(torch.sum(subimage_selection, dim=0),2) 212 | gt_logit = torch.zeros(self.num_expert,device=clean_logits.device) 213 | gt_logit[top_subimage_index[0]]=top_subimage_values[0] 214 | gt_logit[top_subimage_index[1]]=top_subimage_values[1] 215 | # print(top_subimage_values,top_subimage_index,gt_logit) 216 | gt_logit = gt_logit.repeat(subimage_selection.shape[0],1) 217 | print('gt_logit',gt_logit.shape) 218 | # gt_logit = torch.softmax(gt_logit) 219 | kl1 = F.kl_div(subimage_selection.softmax(dim=-1).log(), gt_logit.softmax(dim=-1), reduction='batchmean') 220 | # kl2 = F.kl_div(gt_logit.softmax(dim=-1).log(), subimage_selection.softmax(dim=-1), reduction='batchmean') 221 | self.regu_subimage_loss=self.regu_subimage_loss+kl1 #(kl1+kl2)/2 222 | self.regu_subimage_loss = self.regu_subimage_loss/(batch_size*30*40/self.subimage_tokens/self.subimage_tokens) 223 | 224 | 225 | 226 | if self.no_noise: 227 | noise_stddev *= 0 228 | 229 | noisy_logits = clean_logits + (torch.randn_like(clean_logits) * noise_stddev) 230 | 231 | if self.select_idx is not None: 232 | assert len(self.select_idx) >= self.top_k 233 | noisy_logits = noisy_logits[:, self.select_idx] 234 | 235 | logits = noisy_logits 236 | 237 | if self.return_decoupled_activation: 238 | clean_logits_aux = inp @ self.w_gate_aux 239 | raw_noise_stddev = self.noise_std / self.tot_expert 240 | noise_stddev_aux = (torch.randn_like(clean_logits) * raw_noise_stddev) * self.training 241 | 242 | if self.no_noise: 243 | noise_stddev_aux *= 0 244 | 245 | noisy_logits_aux = clean_logits_aux + (torch.randn_like(clean_logits_aux) * noise_stddev_aux) 246 | 247 | if self.select_idx is not None and len(self.select_idx) == self.top_k: 248 | top_k_gates, top_k_indices = logits.topk( 249 | min(self.top_k, self.tot_expert), dim=1 250 | ) 251 | 252 | return ( 253 | top_k_indices, 254 | top_k_gates, 255 | ) 256 | 257 | # calculate topk + 1 that will be needed for the noisy gates 258 | logits = self.softmax(logits) 259 | top_logits, top_indices = logits.topk( 260 | min(self.top_k + 1, self.tot_expert), dim=1 261 | ) 262 | 263 | top_k_logits = top_logits[:, : self.top_k] 264 | top_k_indices = top_indices[:, : self.top_k] 265 | top_k_gates = top_k_logits 266 | 267 | zeros = torch.zeros_like(logits, requires_grad=True) 268 | gates = zeros.scatter(1, top_k_indices, top_k_logits) 269 | 270 | if self.training: 271 | if self.top_k < self.tot_expert and (not self.no_noise) and abs(noise_stddev) > 1e-6: 272 | # print("calculate load loss") 273 | load = ( 274 | self._prob_in_top_k( 275 | clean_logits, noisy_logits, noise_stddev, top_logits 276 | ) 277 | ).sum(0) 278 | else: 279 | load = self._gates_to_load(gates) 280 | 281 | importance = gates.sum(0) 282 | loss = self.cv_squared(importance) + self.cv_squared(load) 283 | else: 284 | loss = 0 285 | 286 | self.set_loss(loss) 287 | self.activation = logits.reshape(other_dim + [-1,]).contiguous() 288 | 289 | # print("top_k_indices are {}".format(top_k_indices)) 290 | if self.return_decoupled_activation: 291 | # print("set activation as noisy_logits_aux") 292 | self.activation = noisy_logits_aux.reshape(other_dim + [-1, ]).contiguous() 293 | 294 | top_k_indices = top_k_indices.reshape(other_dim + [self.top_k]).contiguous() 295 | top_k_gates = top_k_gates.reshape(other_dim + [self.top_k]).contiguous() 296 | # print('top_k_indices',top_k_indices.shape,top_k_gates.shape) 297 | return ( 298 | top_k_indices, 299 | top_k_gates, 300 | ) 301 | 302 | def get_activation(self, clear=True): 303 | activation = self.activation 304 | if clear: 305 | self.activation = None 306 | return activation 307 | 308 | @property 309 | def has_activation(self): 310 | return self.activation is not None 311 | -------------------------------------------------------------------------------- /MERA/vote.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=3 python main_500.py --config_file configs/config_transformer_moe_vote.yaml --seed 1 -------------------------------------------------------------------------------- /poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenchen1104/MERA/b614fa6747bffa0488468c8d2f136cf0ea940ab4/poster.pdf --------------------------------------------------------------------------------