├── README.md ├── agents.py ├── data.py ├── env.py ├── eval.py ├── metrics.py ├── policy.py ├── requirements.txt ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # RL Portfolio (Attention-based DRL) 2 | 3 | Reusable, production-style implementation of a daily rebalancing portfolio agent with a temporal encoder + cross-sectional attention policy trained by PPO/A2C/REINFORCE. Long-only, weights sum to 1 via a Dirichlet head. 4 | 5 | ## Structure 6 | ``` 7 | rl_portfolio/ 8 | env.py # Daily rebalancing environment (PortfolioEnv) 9 | policy.py # Temporal encoder, cross-sectional attention, Dirichlet policy 10 | agents.py # PPO/A2C/REINFORCE trainers, buffers, configs 11 | data.py # Data loading and panel construction 12 | metrics.py # Metrics and drawdown utilities 13 | utils.py # Math helpers, z-scoring, reproducibility 14 | train.py # Training entrypoints and CLI parsing 15 | eval.py # Evaluation helpers and baselines 16 | __init__.py 17 | scripts/ 18 | train_baseline.py 19 | eval_baseline.py 20 | ``` 21 | 22 | ## Quickstart 23 | ```bash 24 | pip install -r requirements.txt 25 | python -m rl_portfolio.train --csv /path/to/sp500_daily_features.csv --algo ppo 26 | ``` 27 | 28 | ## Notes 29 | - All stochastic test-time sampling is disabled (policies use the mean of the Dirichlet). 30 | - Turnover cost is set to 5 bps by default. 31 | - Reproducibility: set_seeds(42) is used across components. 32 | -------------------------------------------------------------------------------- /agents.py: -------------------------------------------------------------------------------- 1 | # agents.py 2 | from typing import Optional, Tuple 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from torch.distributions import Dirichlet 7 | class PPOConfig: 8 | gamma: float = 0.99 9 | gae_lambda: float = 0.95 10 | clip_coef: float = 0.2 11 | vf_coef: float = 0.5 12 | ent_coef: float = 0.0 13 | max_grad_norm: float = 0.5 14 | learning_rate: float = 3e-4 15 | update_epochs: int = 5 16 | minibatch_size: int = 256 17 | minibatch_micro: int = 32 # micro-batch inside PPO update to limit peak memory 18 | 19 | 20 | class RolloutBuffer: 21 | def __init__(self, T: int, B: int, obs_shape: tuple, n_assets: int, sector_ids: torch.Tensor, 22 | device: str = "cpu", has_mkt: bool = False, k_mkt: int = 0): 23 | self.T, self.B = T, B 24 | self.device = torch.device(device) 25 | W, N, Fdim = obs_shape 26 | self.obs = torch.zeros(T, B, W, N, Fdim, device=self.device) 27 | self.tradable = torch.ones(T, B, N, dtype=torch.bool, device=self.device) 28 | self.mkt = None if not has_mkt else torch.zeros(T, B, k_mkt, device=self.device) 29 | if sector_ids.dim() == 1: 30 | sector_ids = sector_ids.view(1, -1).expand(B, -1) 31 | self.sector_ids = sector_ids.to(self.device) # kept for signature compatibility 32 | self.actions: List[Any] = [None] * T # FlatAction or any 33 | self.rewards = torch.zeros(T, B, device=self.device) 34 | self.dones = torch.zeros(T, B, dtype=torch.bool, device=self.device) 35 | self.values = torch.zeros(T, B, device=self.device) 36 | self.logprobs = torch.zeros(T, B, device=self.device) 37 | self.weights = torch.zeros(T, B, n_assets + 1, device=self.device) 38 | self.advantages = torch.zeros(T, B, device=self.device) 39 | self.returns = torch.zeros(T, B, device=self.device) 40 | self._step = 0 41 | 42 | def add(self, obs_t, tradable_t, mkt_t, action: Any, weight_t, reward_t, done_t, value_t, logp_t): 43 | t = self._step 44 | self.obs[t] = obs_t 45 | self.tradable[t] = tradable_t 46 | if self.mkt is not None and mkt_t is not None: 47 | self.mkt[t] = mkt_t 48 | self.actions[t] = action 49 | self.weights[t] = weight_t 50 | self.rewards[t] = reward_t 51 | self.dones[t] = done_t 52 | self.values[t] = value_t 53 | self.logprobs[t] = logp_t 54 | self._step += 1 55 | 56 | def compute_gae(self, last_value: torch.Tensor, gamma: float, lam: float): 57 | T, B = self.T, self.B 58 | adv = torch.zeros(B, device=self.device) 59 | for t in reversed(range(T)): 60 | next_nonterminal = (~self.dones[t]).float() 61 | next_value = last_value if t == T - 1 else self.values[t + 1] 62 | delta = self.rewards[t] + gamma * next_value * next_nonterminal - self.values[t] 63 | adv = delta + gamma * lam * next_nonterminal * adv 64 | self.advantages[t] = adv 65 | self.returns = self.advantages + self.values 66 | 67 | def iter_minibatches(self, batch_size: int): 68 | T, B = self.T, self.B 69 | total = T * B 70 | obs = self.obs.view(T * B, *self.obs.shape[2:]) 71 | trad = self.tradable.view(T * B, self.tradable.shape[-1]) 72 | mkt = None if self.mkt is None else self.mkt.view(T * B, self.mkt.shape[-1]) 73 | sector_ids = self.sector_ids 74 | logprobs = self.logprobs.view(T * B) 75 | advantages = self.advantages.view(T * B) 76 | returns = self.returns.view(T * B) 77 | values = self.values.view(T * B) 78 | 79 | idx = torch.randperm(total, device=self.device) 80 | for start in range(0, total, batch_size): 81 | mb_idx = idx[start:start + batch_size] 82 | mb_obs = obs[mb_idx] 83 | mb_trad = trad[mb_idx] 84 | mb_mkt = None if mkt is None else mkt[mb_idx] 85 | b_idx = (mb_idx % B).long() 86 | mb_sector = sector_ids[b_idx] 87 | mb_actions = [] 88 | for i in mb_idx.tolist(): 89 | t = i // B 90 | mb_actions.append(self.actions[t]) 91 | yield { 92 | 'obs': mb_obs, 93 | 'tradable': mb_trad, 94 | 'mkt': mb_mkt, 95 | 'sector_ids': mb_sector, 96 | 'actions': mb_actions, 97 | 'logprobs': logprobs[mb_idx], 98 | 'advantages': advantages[mb_idx], 99 | 'returns': returns[mb_idx], 100 | 'values': values[mb_idx] 101 | } 102 | 103 | 104 | class PPOTrainer: 105 | def __init__(self, policy: nn.Module, cfg: PPOConfig, device: str = "cpu"): 106 | self.policy = policy.to(device) 107 | self.cfg = cfg 108 | self.device = torch.device(device) 109 | self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=cfg.learning_rate) 110 | 111 | def _move_action_to_device(self, action: Any) -> Any: 112 | # FlatAction 113 | if hasattr(action, "p_all"): 114 | return FlatAction( 115 | p_all=action.p_all.to(self.device), 116 | mask_full=action.mask_full.to(self.device) 117 | ) 118 | # Fallback (keep other action types compatible) 119 | return action 120 | 121 | def update(self, buffer: RolloutBuffer): 122 | clip_coef = self.cfg.clip_coef 123 | vf_coef = self.cfg.vf_coef 124 | ent_coef = self.cfg.ent_coef 125 | max_grad_norm = self.cfg.max_grad_norm 126 | total_loss = 0.0 127 | 128 | micro = max(1, int(self.cfg.minibatch_micro)) 129 | 130 | for _ in range(self.cfg.update_epochs): 131 | for mb in buffer.iter_minibatches(self.cfg.minibatch_size): 132 | MB = mb['obs'].shape[0] 133 | for start in range(0, MB, micro): 134 | end = min(start + micro, MB) 135 | 136 | obs = mb['obs'][start:end].to(self.device) 137 | trad = mb['tradable'][start:end].to(self.device) 138 | mkt = None if mb['mkt'] is None else mb['mkt'][start:end].to(self.device) 139 | sector_ids = mb['sector_ids'][start:end].to(self.device) # kept for signature 140 | old_logp = mb['logprobs'][start:end].to(self.device) 141 | adv = mb['advantages'][start:end].to(self.device) 142 | ret = mb['returns'][start:end].to(self.device) 143 | old_value = mb['values'][start:end].to(self.device) 144 | 145 | adv = (adv - adv.mean()) / (adv.std() + 1e-8) 146 | 147 | # recompute logp & value per sample (action-aware) 148 | logp_list, value_list = [], [] 149 | msize = obs.shape[0] 150 | for i in range(msize): 151 | act = self._move_action_to_device(mb['actions'][start + i]) 152 | o_i = obs[i:i+1]; tr_i = trad[i:i+1] 153 | mk_i = None if mkt is None else mkt[i:i+1] 154 | sec_i = sector_ids[i:i+1] 155 | logp_i, value_i = self.policy.evaluate_actions(o_i, sec_i, tr_i, mk_i, act) 156 | logp_list.append(logp_i); value_list.append(value_i) 157 | 158 | logp_new = torch.cat(logp_list, dim=0) 159 | value_new = torch.cat(value_list, dim=0).squeeze(-1) 160 | 161 | ratio = (logp_new - old_logp).exp() 162 | pg1 = ratio * adv 163 | pg2 = torch.clamp(ratio, 1.0 - clip_coef, 1.0 + clip_coef) * adv 164 | policy_loss = -torch.min(pg1, pg2).mean() 165 | 166 | v_pred_clipped = old_value + torch.clamp(value_new - old_value, -clip_coef, clip_coef) 167 | v_loss1 = (value_new - ret) ** 2 168 | v_loss2 = (v_pred_clipped - ret) ** 2 169 | value_loss = 0.5 * torch.max(v_loss1, v_loss2).mean() 170 | 171 | entropy_loss = torch.tensor(0.0, device=self.device) 172 | loss = policy_loss + vf_coef * value_loss - ent_coef * entropy_loss 173 | 174 | self.optimizer.zero_grad(set_to_none=True) 175 | loss.backward() 176 | nn.utils.clip_grad_norm_(self.policy.parameters(), max_grad_norm) 177 | self.optimizer.step() 178 | 179 | total_loss += float(loss.detach().cpu()) 180 | return total_loss 181 | 182 | 183 | class A2CConfig: 184 | def __init__(self, gamma=0.99, gae_lambda=0.95, vf_coef=0.5, ent_coef=0.0, 185 | lr=1e-4, update_epochs=3, minibatch_size=256, minibatch_micro=32, max_grad_norm=0.5): 186 | self.gamma=gamma; self.gae_lambda=gae_lambda; self.vf_coef=vf_coef; self.ent_coef=ent_coef 187 | self.lr=lr; self.update_epochs=update_epochs; self.minibatch_size=minibatch_size 188 | self.minibatch_micro=minibatch_micro; self.max_grad_norm=max_grad_norm 189 | 190 | 191 | class A2CTrainer: 192 | def __init__(self, policy, cfg: A2CConfig, device="cpu"): 193 | self.policy=policy.to(device); self.device=torch.device(device) 194 | self.opt=torch.optim.Adam(self.policy.parameters(), lr=cfg.lr) 195 | self.cfg=cfg 196 | def update(self, buffer: 'RolloutBuffer'): 197 | total_loss=0.0; micro=max(1,int(self.cfg.minibatch_micro)) 198 | for _ in range(self.cfg.update_epochs): 199 | for mb in buffer.iter_minibatches(self.cfg.minibatch_size): 200 | MB = mb['obs'].shape[0] 201 | for start in range(0, MB, micro): 202 | end=min(start+micro, MB) 203 | obs=mb['obs'][start:end].to(self.device) 204 | trad=mb['tradable'][start:end].to(self.device) 205 | mkt=None if mb['mkt'] is None else mb['mkt'][start:end].to(self.device) 206 | sid=mb['sector_ids'][start:end].to(self.device) 207 | adv=mb['advantages'][start:end].to(self.device) 208 | ret=mb['returns'][start:end].to(self.device) 209 | adv=(adv-adv.mean())/(adv.std()+1e-8) 210 | logp_list, value_list=[], [] 211 | for i in range(obs.shape[0]): 212 | o_i, tr_i = obs[i:i+1], trad[i:i+1] 213 | mk_i = None if mkt is None else mkt[i:i+1] 214 | sid_i = sid[i:i+1] 215 | logp_i, value_i = self.policy.evaluate_actions(o_i, sid_i, tr_i, mk_i, mb['actions'][start+i]) 216 | logp_list.append(logp_i); value_list.append(value_i) 217 | logp_new=torch.cat(logp_list, dim=0) 218 | value_new=torch.cat(value_list, dim=0).squeeze(-1) 219 | policy_loss=-(logp_new*adv).mean() 220 | value_loss =0.5*(value_new-ret).pow(2).mean() 221 | entropy =-logp_new.mean() 222 | loss=policy_loss + self.cfg.vf_coef*value_loss - self.cfg.ent_coef*entropy 223 | self.opt.zero_grad(set_to_none=True) 224 | loss.backward() 225 | nn.utils.clip_grad_norm_(self.policy.parameters(), self.cfg.max_grad_norm) 226 | self.opt.step() 227 | total_loss += float(loss.detach().cpu()) 228 | return total_loss 229 | 230 | 231 | class REINFORCETrainer: 232 | def __init__(self, policy, lr=1e-4, vf_coef=0.5, device="cpu", max_grad_norm=0.5): 233 | self.policy = policy.to(device) 234 | self.device = torch.device(device) 235 | self.opt = torch.optim.Adam(self.policy.parameters(), lr=lr) 236 | self.vf_coef = vf_coef 237 | self.max_grad_norm = max_grad_norm 238 | 239 | def _sid_single_row(self, buffer) -> torch.Tensor: 240 | if getattr(buffer, "sector_ids", None) is None: 241 | n_assets = buffer.tradable.shape[-1] 242 | return torch.zeros((1, n_assets), dtype=torch.long, device=self.device) 243 | sid = buffer.sector_ids 244 | if not torch.is_tensor(sid): sid = torch.as_tensor(sid) 245 | sid = sid.to(self.device) 246 | if sid.dim() == 3: sid = sid.squeeze(0).squeeze(0) 247 | elif sid.dim() == 2: sid = sid[:1] 248 | elif sid.dim() == 1: sid = sid.unsqueeze(0) 249 | else: sid = sid.reshape(1, -1) 250 | return sid # (1, n_assets) 251 | 252 | def update(self, buffer: 'RolloutBuffer'): 253 | obs = buffer.obs.view(buffer.T * buffer.B, *buffer.obs.shape[2:]).to(self.device) 254 | trad = buffer.tradable.view(buffer.T * buffer.B, buffer.tradable.shape[-1]).to(self.device) 255 | mkt = None if buffer.mkt is None else buffer.mkt.view(buffer.T * buffer.B, buffer.mkt.shape[-1]).to(self.device) 256 | s1 = self._sid_single_row(buffer) 257 | returns = buffer.returns.view(-1).to(self.device) 258 | actions = buffer.actions 259 | logp_list, value_list = [], [] 260 | for t in range(obs.shape[0]): 261 | o_t = obs[t:t+1]; tr_t = trad[t:t+1]; mk_t = None if mkt is None else mkt[t:t+1] 262 | a_t = actions[t % buffer.T] 263 | logp_t, v_t = self.policy.evaluate_actions(o_t, s1, tr_t, mk_t, a_t) 264 | logp_list.append(logp_t); value_list.append(v_t) 265 | logp = torch.cat(logp_list, dim=0) 266 | value = torch.cat(value_list, dim=0).squeeze(-1) 267 | adv = returns - value.detach() 268 | loss_pi = -(logp * adv).mean() 269 | loss_v = 0.5 * (value - returns).pow(2).mean() 270 | loss = loss_pi + self.vf_coef * loss_v 271 | self.opt.zero_grad(set_to_none=True) 272 | loss.backward() 273 | nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) 274 | self.opt.step() 275 | return float(loss.detach().cpu()) 276 | 277 | 278 | class DDPGConfig: 279 | def __init__(self, gamma=0.99, tau=0.005, lr_actor=1e-4, lr_critic=1e-3, 280 | batch_size=64, explore_alpha=0.3, updates_per_step=1): 281 | self.gamma=gamma; self.tau=tau 282 | self.lr_actor=lr_actor; self.lr_critic=lr_critic 283 | self.batch_size=batch_size; self.explore_alpha=explore_alpha 284 | self.updates_per_step=updates_per_step 285 | 286 | 287 | class Replay: 288 | def __init__(self, capacity:int): 289 | self.s, self.trad, self.a, self.r, self.s2, self.trad2, self.done = [],[],[],[],[],[],[] 290 | self.capacity=capacity 291 | def push(self, s, trad, a, r, s2, trad2, done): 292 | if len(self.s) >= self.capacity: 293 | for lst in [self.s,self.trad,self.a,self.r,self.s2,self.trad2,self.done]: lst.pop(0) 294 | self.s.append(s); self.trad.append(trad); self.a.append(a); self.r.append(r) 295 | self.s2.append(s2); self.trad2.append(trad2); self.done.append(done) 296 | def sample(self, n, device): 297 | idx = np.random.choice(len(self.s), size=min(n, len(self.s)), replace=False) 298 | pack = lambda L,catdim: torch.from_numpy(np.stack([L[i] for i in idx], axis=0)).to(device).float() 299 | return (pack(self.s,0), pack(self.trad,0).bool(), 300 | pack(self.a,0), torch.tensor([self.r[i] for i in idx], device=device, dtype=torch.float32), 301 | pack(self.s2,0), pack(self.trad2,0).bool(), 302 | torch.tensor([self.done[i] for i in idx], device=device, dtype=torch.float32)) 303 | 304 | 305 | class DDPGSoftmaxTrainer: 306 | """ 307 | Deterministic actor: policy.get_action_and_value(..., sample=False) → weights on simplex. 308 | Critic: Q(z, w) with z = global state embedding from temporal+cross encoders. 309 | Exploration: mix deterministic weights with a Dirichlet(explore_alpha * (N+1) * weights). 310 | """ 311 | def __init__(self, policy, d_model, action_dim, cfg: DDPGConfig, device="cpu"): 312 | self.policy=policy.to(device); self.device=torch.device(device) 313 | self.actor_opt = torch.optim.Adam(self.policy.parameters(), lr=cfg.lr_actor) 314 | self.critic = QCritic(d_model, action_dim).to(device) 315 | self.critic_t = QCritic(d_model, action_dim).to(device) 316 | self.critic_t.load_state_dict(self.critic.state_dict()) 317 | self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=cfg.lr_critic) 318 | self.cfg=cfg 319 | 320 | def _encode_global(self, x, trad): 321 | # replicate policy._backbone to get global embedding 322 | with torch.no_grad(): 323 | asset_tokens = self.policy.temporal(x) # (B,N,d) 324 | tokens = self.policy.cross(asset_tokens) # (B,1+N,d) 325 | return tokens[:,0,:] # (B,d) 326 | 327 | @torch.no_grad() 328 | def _actor_weights(self, x, trad, deterministic=True, explore_alpha=0.3): 329 | sid = torch.zeros(x.shape[2], dtype=torch.long, device=x.device) # unused 330 | w,_,_,_,_ = self.policy.get_action_and_value(x, sid.unsqueeze(0), trad, None, sample=not deterministic) 331 | if deterministic and explore_alpha>0: 332 | alpha = explore_alpha * (w.shape[-1]) * w 333 | dist = torch.distributions.Dirichlet(alpha) 334 | w_noise = dist.sample() 335 | w = 0.9*w + 0.1*w_noise 336 | return w 337 | 338 | def soft_update(self, target: nn.Module, source: nn.Module, tau: float): 339 | with torch.no_grad(): 340 | for p_t, p in zip(target.parameters(), source.parameters()): 341 | p_t.data.mul_(1.0 - tau).add_(tau * p.data) 342 | 343 | def update(self, replay: Replay): 344 | if len(replay.s) < self.cfg.batch_size: return 0.0 345 | total = 0.0 346 | for _ in range(self.cfg.updates_per_step): 347 | s, trad, a, r, s2, trad2, done = replay.sample(self.cfg.batch_size, self.device) 348 | z = self._encode_global(s, trad) 349 | z2 = self._encode_global(s2, trad2) 350 | with torch.no_grad(): 351 | a2 = self._actor_weights(s2, trad2, deterministic=True, explore_alpha=0.0) 352 | q_target = r + (1.0 - done) * self.cfg.gamma * self.critic_t(z2, a2) 353 | q = self.critic(z, a) 354 | critic_loss = nn.functional.mse_loss(q, q_target) 355 | self.critic_opt.zero_grad(set_to_none=True) 356 | critic_loss.backward(); self.critic_opt.step() 357 | 358 | # actor update: maximize Q(s, actor(s)) → minimize -Q 359 | a_det = self._actor_weights(s, trad, deterministic=True, explore_alpha=0.0) 360 | z_det= self._encode_global(s, trad) 361 | actor_loss = - self.critic(z_det, a_det).mean() 362 | self.actor_opt.zero_grad(set_to_none=True) 363 | actor_loss.backward() 364 | nn.utils.clip_grad_norm_(self.policy.parameters(), 0.5) 365 | self.actor_opt.step() 366 | 367 | self.soft_update(self.critic_t, self.critic, self.cfg.tau) 368 | total += float((critic_loss + actor_loss).detach().cpu()) 369 | return total 370 | 371 | 372 | class QCritic(nn.Module): 373 | """Simple Q(s,w) on top of (global) state embedding + full weight vector.""" 374 | def __init__(self, d_model: int, action_dim: int, hidden=256): 375 | super().__init__() 376 | self.net = nn.Sequential( 377 | nn.Linear(d_model + action_dim, hidden), nn.ReLU(), 378 | nn.Linear(hidden, hidden), nn.ReLU(), 379 | nn.Linear(hidden, 1) 380 | ) 381 | def forward(self, z, a): 382 | x = torch.cat([z, a], dim=-1); return self.net(x).squeeze(-1) 383 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | # data.py 2 | from typing import List, Dict 3 | import numpy as np 4 | import pandas as pd 5 | def load_panel_from_csv(csv_path: str, window: int): 6 | df = pd.read_csv(csv_path, low_memory=False) 7 | assert {"Date", "ticker", "Close"}.issubset(df.columns), "CSV must include Date, ticker, Close" 8 | df["Date"] = pd.to_datetime(df["Date"]) 9 | df = df.sort_values(["Date","ticker"]).reset_index(drop=True) 10 | 11 | # choose features (add/remove freely) 12 | default_feats = [ 13 | "Open","High","Low","Close","Volume", 14 | "macd","macd_signal","macd_hist", 15 | "macdboll","ubboll","lb","rsi","30cci", 16 | "plus_di_14","minus_di_14","dx_14","dx30", 17 | "close30_sma","close60_sma" 18 | ] 19 | feature_cols = [c for c in default_feats if c in df.columns] 20 | tickers = df["ticker"].unique() 21 | dates = df["Date"].drop_duplicates().sort_values() 22 | 23 | close_mat = df.pivot(index="Date", columns="ticker", values="Close").reindex(index=dates, columns=tickers) 24 | if "log_return" in df.columns: 25 | logret_mat = df.pivot(index="Date", columns="ticker", values="log_return").reindex(index=dates, columns=tickers) 26 | else: 27 | logret_mat = np.log(close_mat / close_mat.shift(1)) 28 | prices_rel = np.nan_to_num(np.exp(logret_mat.values) - 1.0, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32) 29 | 30 | if not feature_cols: 31 | raise ValueError("No feature columns found — include at least one numeric feature.") 32 | feats = [] 33 | for c in feature_cols: 34 | mat = df.pivot(index="Date", columns="ticker", values=c).reindex(index=dates, columns=tickers).values 35 | feats.append(mat) 36 | features_raw = np.stack(feats, axis=-1).astype(np.float32) # (T,N,F) 37 | features = cross_sectional_zscore(features_raw) 38 | 39 | tradable_mask = (~np.isnan(close_mat.values)) & np.isfinite(close_mat.values) 40 | 41 | # sectors kept for signature compatibility (flat policy ignores) 42 | if "GICS_Sector" in df.columns: 43 | sector_series = (df.groupby(["ticker","GICS_Sector"]).size().reset_index(name="n") 44 | .sort_values(["ticker","n"], ascending=[True, False]) 45 | .drop_duplicates(subset=["ticker"]).set_index("ticker")["GICS_Sector"] 46 | .reindex(tickers)) 47 | sector_ids = pd.Categorical(sector_series).codes 48 | if (sector_ids < 0).any(): 49 | max_id = sector_ids[sector_ids >= 0].max() if (sector_ids >= 0).any() else -1 50 | sector_ids = np.where(sector_ids < 0, max_id + 1, sector_ids) 51 | num_sectors = int(sector_ids.max() + 1) if len(sector_ids) else 1 52 | else: 53 | sector_ids = np.zeros(len(tickers), dtype=np.int64) 54 | num_sectors = 1 55 | 56 | T, N, Fdim = features.shape 57 | if T <= window: 58 | raise ValueError(f"Not enough rows for window={window}. Got T={T}.") 59 | 60 | return dict( 61 | prices_rel=prices_rel, features=features, tradable_mask=tradable_mask.astype(bool), 62 | sector_ids=sector_ids.astype(np.int64), num_sectors=num_sectors, N=N, Fdim=Fdim, T=T 63 | ) 64 | -------------------------------------------------------------------------------- /env.py: -------------------------------------------------------------------------------- 1 | # env.py 2 | import numpy as np 3 | import gymnasium as gym 4 | from gymnasium import spaces 5 | from typing import Optional, Dict, Any 6 | 7 | 8 | class PortfolioEnv(gym.Env): 9 | """Daily rebalance environment. 10 | Action: logits for (cash + N assets); softmax → weights; applies tradability mask & costs. 11 | """ 12 | metadata = {"render_modes": []} 13 | def __init__(self, 14 | prices_rel: np.ndarray, # (T,N) simple returns 15 | features: np.ndarray, # (T,N,Fdim) 16 | sector_ids: np.ndarray, # (N,) (not used by flat policy; kept for compatibility) 17 | tradable_mask: np.ndarray, # (T,N) bool 18 | market_factors: Optional[np.ndarray] = None, # (T,K) or None 19 | tc_cost: float = 5e-4, 20 | window: int = 30, 21 | sector_caps: Optional[dict] = None, 22 | include_cash: bool = True): 23 | super().__init__() 24 | assert prices_rel.shape[:2] == features.shape[:2] 25 | T, N = prices_rel.shape 26 | self.pr = prices_rel.astype(np.float32) 27 | self.X = features.astype(np.float32) 28 | self.mask = tradable_mask.astype(bool) 29 | self.sec = sector_ids.astype(np.int64) # harmless if unused 30 | self.mkt = None if market_factors is None else market_factors.astype(np.float32) 31 | self.tc = tc_cost 32 | self.window = window 33 | self.N = N 34 | self.include_cash = include_cash 35 | self.M = N + (1 if include_cash else 0) 36 | self.sector_caps = sector_caps or {} # ignored if empty 37 | self.t0 = window 38 | self.t = None 39 | self.prev_w = None 40 | 41 | obs_dim = (window, N, self.X.shape[-1]) 42 | self.observation_space = spaces.Box(low=-10, high=10, shape=obs_dim, dtype=np.float32) 43 | self.action_space = spaces.Box(low=-10, high=10, shape=(self.M,), dtype=np.float32) 44 | 45 | def _get_state(self) -> np.ndarray: 46 | return self.X[self.t-self.window:self.t] 47 | 48 | def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): 49 | super().reset(seed=seed) 50 | self.t = self.t0 51 | if self.include_cash: 52 | self.prev_w = np.zeros(self.M, dtype=np.float32); self.prev_w[0] = 1.0 53 | else: 54 | w = np.ones(self.N, dtype=np.float32)/self.N 55 | self.prev_w = w 56 | return self._get_state(), {} 57 | 58 | def step(self, action_logits: np.ndarray): 59 | # Softmax → weights 60 | exps = np.exp(action_logits - np.max(action_logits)) 61 | w = exps / np.sum(exps) 62 | 63 | if self.include_cash: 64 | cash = float(w[0]); stock_w = w[1:].copy() 65 | else: 66 | cash = 0.0; stock_w = w.copy() 67 | 68 | # Mask untradable names today, renormalize into stock slice 69 | tradable = self.mask[self.t].astype(bool) 70 | stock_w[~tradable] = 0.0 71 | total = stock_w.sum() 72 | if total > 1e-12: 73 | stock_w = stock_w * (1.0 - cash) / total 74 | else: 75 | cash = 1.0; stock_w[:] = 0.0 76 | 77 | # Optional sector caps projection (usually empty for flat policy) 78 | if self.sector_caps: 79 | stock_w = _capped_simplex_projection(stock_w, self.sec, self.sector_caps) 80 | cash = max(0.0, 1.0 - float(stock_w.sum())) 81 | 82 | w_new = np.concatenate([[cash], stock_w]).astype(np.float32) if self.include_cash else stock_w.astype(np.float32) 83 | 84 | # Pre-trade drift weights (mark-to-market) 85 | r = self.pr[self.t] # (N,) 86 | gross = np.concatenate([[1.0], (1.0 + r)]).astype(np.float32) if self.include_cash else (1.0 + r).astype(np.float32) 87 | prev_gross_val = float((self.prev_w * gross).sum()) 88 | w_pre = (self.prev_w * gross) / (prev_gross_val + 1e-12) 89 | 90 | turnover = float(np.abs(w_new - w_pre).sum()) 91 | cost = self.tc * turnover 92 | 93 | port_ret = float((w_new * gross).sum() - 1.0 - cost) 94 | reward = float(np.log(max(1e-8, 1.0 + port_ret))) # log utility 95 | 96 | self.prev_w = w_new 97 | self.t += 1 98 | terminated = (self.t >= self.pr.shape[0]-1) 99 | return self._get_state(), reward, terminated, False, {"turnover": turnover, "raw_ret": port_ret} 100 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | # eval.py 2 | import numpy as np 3 | import torch 4 | from typing import Dict, Tuple 5 | from env import PortfolioEnv 6 | def eval_policy_full_dataset(policy: StochasticHierarchicalDualAttentionPolicy, 7 | prices_rel: np.ndarray, 8 | features: np.ndarray, 9 | tradable_mask: np.ndarray, 10 | sector_ids: np.ndarray, 11 | market_factors: Optional[np.ndarray], 12 | window: int, 13 | device: str): 14 | """ 15 | Runs a deterministic evaluation (Dirichlet means) over the full dataset. 16 | Returns: dict with daily_returns, cum_wealth, metrics, and saved turnovers. 17 | """ 18 | env_eval = PortfolioEnv(prices_rel=prices_rel, 19 | features=features, 20 | sector_ids=sector_ids, 21 | tradable_mask=tradable_mask, 22 | market_factors=market_factors, 23 | tc_cost=0.0, # evaluate raw ability without trading cost impact if desired 24 | window=window, 25 | include_cash=True) 26 | # If you want AFTER-COST evaluation, set tc_cost to the same as training. 27 | 28 | sector_ids_t = torch.from_numpy(sector_ids).long().to(device) 29 | obs, _ = env_eval.reset() 30 | rets = [] 31 | turns = [] 32 | 33 | with torch.no_grad(): 34 | while True: 35 | x_t = torch.from_numpy(env_eval._get_state()).unsqueeze(0).to(device) # (1,W,N,F) 36 | trad_t = torch.from_numpy(env_eval.mask[env_eval.t]).unsqueeze(0).to(device) # (1,N) 37 | mkt_t = None if env_eval.mkt is None else torch.from_numpy(env_eval.mkt[env_eval.t]).unsqueeze(0).to(device) 38 | 39 | # Deterministic: use Dirichlet mean (sample=False) 40 | weights, _, _, _, _ = policy.get_action_and_value( 41 | x=x_t, 42 | sector_ids=sector_ids_t.unsqueeze(0), 43 | tradable_mask=trad_t, 44 | market_factors=mkt_t, 45 | sample=False 46 | ) 47 | logits = torch.log(weights.clamp_min(1e-12))[0].cpu().numpy() 48 | _, _, done, _, info = env_eval.step(logits) 49 | rets.append(info["raw_ret"]) 50 | turns.append(info["turnover"]) 51 | if done: 52 | break 53 | 54 | daily_rets = np.array(rets, dtype=np.float32) 55 | metrics = compute_metrics(daily_rets) 56 | return { 57 | "daily_returns": daily_rets, 58 | "turnover": np.array(turns, dtype=np.float32), 59 | "cum_wealth": metrics["cum_wealth"], 60 | "metrics": {k: v for k, v in metrics.items() if k != "cum_wealth"} 61 | } 62 | 63 | 64 | def load_policy(ckpt_path, device): 65 | # Robust to PyTorch 2.6 defaults 66 | try: 67 | ckpt = torch.load(ckpt_path, map_location=device, weights_only=False) 68 | except Exception: 69 | try: 70 | from torch.serialization import add_safe_globals 71 | from torch.nn.parameter import UninitializedParameter 72 | add_safe_globals([UninitializedParameter]) 73 | except Exception: 74 | pass 75 | ckpt = torch.load(ckpt_path, map_location=device, weights_only=True) 76 | 77 | d_model = int(ckpt["config"]["d_model"]) 78 | window = int(ckpt["config"]["window"]) 79 | tc_cost = float(ckpt["config"]["tc_cost"]) 80 | 81 | policy = StochasticHierarchicalDualAttentionPolicy( 82 | d_model=d_model, 83 | n_heads_time=2, n_layers_time=1, 84 | n_heads_cross=2, n_layers_cross=1, 85 | include_cash=True, temporal_asset_chunk=64, temporal_use_ckpt=True 86 | ).to(device) 87 | policy.load_state_dict(ckpt["state_dict"]) 88 | policy.eval() 89 | return policy, window, tc_cost 90 | 91 | 92 | def buy_and_hold_baseline(panel, start_index): 93 | R = panel["prices_rel"] # (T, N) 94 | mask = panel["mask"].astype(bool) # (T, N) 95 | if start_index >= R.shape[0]: 96 | raise ValueError("start_index beyond panel length.") 97 | tradable = np.where(mask[start_index])[0] 98 | if tradable.size == 0: 99 | raise ValueError("No tradable assets at baseline start.") 100 | w0 = np.zeros(R.shape[1], dtype=np.float32) 101 | w0[tradable] = 1.0 / tradable.size 102 | G = np.cumprod(1.0 + np.nan_to_num(R[start_index:], nan=0.0), axis=0) # (T_test, N) 103 | wealth = (G * w0[None, :]).sum(axis=1).astype(np.float32) 104 | rets = np.zeros_like(wealth, dtype=np.float32) 105 | rets[1:] = wealth[1:] / np.clip(wealth[:-1], 1e-12, None) - 1.0 106 | ext = compute_metrics_from_returns(rets) 107 | return wealth, ext 108 | 109 | 110 | def plot_equity(cw, dates, label, fname): 111 | d = dates[WINDOW+1:WINDOW+1+len(cw)] 112 | plt.figure(figsize=(8,4)) 113 | plt.plot(d, cw, label=label) 114 | plt.xlabel("Date"); plt.ylabel("Cumulative wealth"); plt.legend(); plt.tight_layout() 115 | plt.savefig(OUT_DIR / fname, dpi=150); plt.close() 116 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | # metrics.py 2 | import numpy as np 3 | def _drawdown_path(wealth: np.ndarray): 4 | """Return drawdown series and the longest drawdown duration in days.""" 5 | peak = np.maximum.accumulate(wealth) 6 | dd = wealth / np.maximum(peak, 1e-12) - 1.0 7 | # duration: consecutive days below peak 8 | under = wealth < peak 9 | longest = cur = 0 10 | for u in under: 11 | cur = cur + 1 if u else 0 12 | if cur > longest: longest = cur 13 | return dd, longest 14 | 15 | 16 | def compute_metrics_from_returns(rets: np.ndarray): 17 | """Compute a comprehensive set of metrics from daily returns.""" 18 | rets = np.asarray(rets, dtype=np.float64) 19 | wealth = np.cumprod((1.0 + rets).clip(0)) 20 | years = max(1e-8, len(rets) / TDAYS) 21 | ann_mu = rets.mean() * TDAYS 22 | ann_vol = (rets.std(ddof=0) + 1e-12) * np.sqrt(TDAYS) 23 | sharpe = ann_mu / ann_vol if ann_vol > 0 else 0.0 24 | 25 | # Sortino: use downside deviation relative to MAR=0 26 | downside = np.clip(rets, None, 0.0) 27 | dd_dev = np.sqrt(np.mean(downside**2) + 1e-18) * np.sqrt(TDAYS) 28 | sortino = ann_mu / dd_dev if dd_dev > 0 else 0.0 29 | 30 | # Drawdowns 31 | dd, max_dd_duration = _drawdown_path(wealth) 32 | mdd = float(dd.min()) if wealth.size > 0 else 0.0 33 | 34 | # CAGR 35 | cagr = (wealth[-1] ** (1.0 / years) - 1.0) if wealth.size > 0 else 0.0 36 | 37 | # Calmar = CAGR / |MDD| 38 | calmar = cagr / abs(mdd) if mdd < 0 else np.nan 39 | 40 | # Hit rate, avg win/loss 41 | pos = rets[rets > 0] 42 | neg = rets[rets < 0] 43 | hit_rate = (rets > 0).mean() if rets.size > 0 else np.nan 44 | avg_win = pos.mean() if pos.size > 0 else np.nan 45 | avg_loss = neg.mean() if neg.size > 0 else np.nan 46 | 47 | # Moments 48 | # avoid scipy; use unbiased-like formulas (but OK with sample skew/kurt via pandas if desired) 49 | mean = rets.mean() 50 | std = rets.std(ddof=0) + 1e-18 51 | skew = np.mean(((rets - mean) / std) ** 3) 52 | kurt = np.mean(((rets - mean) / std) ** 4) # raw (Fisher+3) 53 | 54 | # Tail risk: historical VaR/CVaR at 5% 55 | if rets.size > 0: 56 | var5 = np.percentile(rets, 5) # 5% quantile (loss is negative) 57 | cvar5 = rets[rets <= var5].mean() if (rets <= var5).any() else var5 58 | else: 59 | var5, cvar5 = np.nan, np.nan 60 | 61 | # Tail ratio (95th gain / |5th loss|) 62 | p95 = np.percentile(rets, 95) if rets.size > 0 else np.nan 63 | p05 = np.percentile(rets, 5) if rets.size > 0 else np.nan 64 | tail_ratio = (p95 / abs(p05)) if (not np.isnan(p95) and not np.isnan(p05) and p05 < 0) else np.nan 65 | 66 | return { 67 | "sharpe": float(sharpe), 68 | "sortino": float(sortino), 69 | "calmar": float(calmar), 70 | "ann_return": float(ann_mu), 71 | "ann_vol": float(ann_vol), 72 | "cagr": float(cagr), 73 | "mdd": float(mdd), 74 | "max_dd_duration_days": int(max_dd_duration), 75 | "hit_rate": float(hit_rate), 76 | "avg_win": float(avg_win), 77 | "avg_loss": float(avg_loss), 78 | "skew": float(skew), 79 | "kurtosis": float(kurt), 80 | "VaR_5": float(var5), 81 | "CVaR_5": float(cvar5), 82 | "tail_ratio": float(tail_ratio), 83 | "terminal_wealth": float(wealth[-1]) if wealth.size > 0 else np.nan, 84 | } 85 | -------------------------------------------------------------------------------- /policy.py: -------------------------------------------------------------------------------- 1 | # policy.py 2 | import math 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as Fnn 7 | from torch.distributions import Dirichlet 8 | class TemporalEncoder(nn.Module): 9 | """Transformer over time PER ASSET with asset-chunking to bound memory.""" 10 | def __init__(self, d_model=64, n_heads=2, n_layers=1, dropout=0.1, pool="last", 11 | asset_chunk: int = 64, use_ckpt: bool = True): 12 | super().__init__() 13 | self.in_proj = nn.LazyLinear(d_model) 14 | enc_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads, dropout=dropout, batch_first=True) 15 | self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers) 16 | self.pool = pool 17 | self.asset_chunk = max(1, int(asset_chunk)) 18 | self.use_ckpt = use_ckpt 19 | 20 | def _encode(self, xi: torch.Tensor) -> torch.Tensor: 21 | if self.use_ckpt and xi.requires_grad and self.encoder.num_layers > 0: 22 | for layer in self.encoder.layers: 23 | xi = torch.utils.checkpoint.checkpoint(layer, xi) 24 | if self.encoder.norm is not None: 25 | xi = self.encoder.norm(xi) 26 | return xi 27 | else: 28 | return self.encoder(xi) 29 | 30 | def forward(self, x: torch.Tensor) -> torch.Tensor: 31 | # x: (B, W, N, Fdim) 32 | B, W, N, Fdim = x.shape 33 | x = self.in_proj(x) # (B,W,N,d) 34 | pe = sinusoidal_time_encoding(W, x.size(-1), x.device) 35 | x = x + pe.view(1, W, 1, -1) 36 | x = x.permute(0, 2, 1, 3) # (B,N,W,d) 37 | 38 | outs = [] 39 | for start in range(0, N, self.asset_chunk): 40 | end = min(start + self.asset_chunk, N) 41 | xi = x[:, start:end, :, :].reshape(B * (end - start), W, -1) # (B*chunk, W, d) 42 | yi = self._encode(xi) # " 43 | yi = yi[:, -1, :] if self.pool == "last" else yi.mean(dim=1) # (B*chunk, d) 44 | yi = yi.view(B, end - start, -1) # (B, chunk, d) 45 | outs.append(yi) 46 | del xi, yi 47 | x = torch.cat(outs, dim=1) # (B, N, d) 48 | return x 49 | 50 | 51 | class CrossSectionalAttention(nn.Module): 52 | """Transformer across assets with a global token (no group tokens).""" 53 | def __init__(self, d_model=64, n_heads=2, n_layers=1, dropout=0.1): 54 | super().__init__() 55 | self.global_token = nn.Parameter(torch.randn(1, 1, d_model) * 0.02) 56 | enc_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads, dropout=dropout, batch_first=True) 57 | self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers) 58 | 59 | def forward(self, asset_tokens: torch.Tensor) -> torch.Tensor: 60 | B = asset_tokens.size(0) 61 | g = self.global_token.expand(B, -1, -1) 62 | tokens = torch.cat([g, asset_tokens], dim=1) # (B, 1+N, d) 63 | return self.encoder(tokens) 64 | 65 | 66 | class FlatAction: 67 | p_all: torch.Tensor # (B, N+1) pre-mask Dirichlet sample 68 | mask_full: torch.Tensor # (B, N+1) bool; True=tradable (cash True) 69 | 70 | 71 | class HierAction: 72 | p_sec: torch.Tensor # (B, S+1) 73 | q_list: List[List[torch.Tensor]] # list[S][B] variable lengths 74 | idx_list: List[List[torch.Tensor]] # list[B][S] indices 75 | sectors_missing: torch.Tensor # (B,S) bool 76 | 77 | 78 | class StochasticHierarchicalDualAttentionPolicy(nn.Module): 79 | """ 80 | NOTE: name kept for drop-in compatibility, but this is now a FLAT Dirichlet actor-critic. 81 | - Temporal encoder per asset 82 | - Cross-asset attention with a global token 83 | - Single Dirichlet over [cash + all assets] 84 | """ 85 | def __init__(self, 86 | d_model: int = 64, 87 | n_heads_time: int = 2, 88 | n_layers_time: int = 1, 89 | n_heads_cross: int = 2, 90 | n_layers_cross: int = 1, 91 | num_sectors: int = 0, # ignored in flat 92 | num_regions: int = 0, # ignored in flat 93 | include_cash: bool = True, 94 | dropout: float = 0.1, 95 | min_alpha: float = 0.05, 96 | temporal_asset_chunk: int = 64, 97 | temporal_use_ckpt: bool = True): 98 | super().__init__() 99 | self.include_cash = include_cash 100 | self.min_alpha = min_alpha 101 | 102 | self.temporal = TemporalEncoder(d_model=d_model, n_heads=n_heads_time, n_layers=n_layers_time, 103 | dropout=dropout, pool="last", asset_chunk=temporal_asset_chunk, 104 | use_ckpt=temporal_use_ckpt) 105 | self.cross = CrossSectionalAttention(d_model=d_model, n_heads=n_heads_cross, n_layers=n_layers_cross, 106 | dropout=dropout) 107 | # Heads 108 | self.cash_head = nn.Sequential(nn.LayerNorm(d_model), nn.Linear(d_model, 1)) 109 | self.asset_head = nn.Sequential(nn.LayerNorm(d_model), nn.Linear(d_model, 1)) 110 | self.value_head = nn.Sequential(nn.LayerNorm(d_model), nn.Linear(d_model, d_model), nn.GELU(), nn.Linear(d_model, 1)) 111 | 112 | self.mkt_proj = nn.LazyLinear(d_model) 113 | self.use_mkt = False 114 | 115 | def _backbone(self, 116 | x: torch.Tensor, # (B,W,N,Fdim) 117 | tradable_mask: Optional[torch.Tensor], # (B,N) bool 118 | market_factors: Optional[torch.Tensor] # (B,K) or None 119 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 120 | B, W, N, Fdim = x.shape 121 | asset_tokens = self.temporal(x) # (B,N,d) 122 | tokens = self.cross(asset_tokens) # (B,1+N,d) 123 | g_out = tokens[:, 0, :] # (B,d) 124 | a_out = tokens[:, 1:, :] # (B,N,d) 125 | 126 | if market_factors is not None: 127 | if not self.use_mkt: 128 | self.use_mkt = True 129 | g_out = g_out + self.mkt_proj(market_factors) 130 | 131 | cash_logit = self.cash_head(g_out).squeeze(-1) # (B,) 132 | asset_logits = self.asset_head(a_out).squeeze(-1) # (B,N) 133 | value = self.value_head(g_out).squeeze(-1) # (B,) 134 | return cash_logit, asset_logits, value 135 | 136 | @torch.no_grad() 137 | def get_action_and_value(self, 138 | x: torch.Tensor, # (B,W,N,Fdim) 139 | sector_ids: torch.Tensor, # unused; kept for signature compatibility 140 | tradable_mask: Optional[torch.Tensor], # (B,N) 141 | market_factors: Optional[torch.Tensor], # (B,K) 142 | sample: bool = True 143 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, FlatAction, Any]: 144 | device = x.device 145 | B, W, N, Fdim = x.shape 146 | 147 | cash_logit, asset_logits, value = self._backbone(x, tradable_mask, market_factors) 148 | logits_full = torch.cat([cash_logit.unsqueeze(-1), asset_logits], dim=-1) # (B, N+1) 149 | 150 | # Build full mask (cash always tradable) 151 | if tradable_mask is None: 152 | mask_full = torch.ones(B, N+1, dtype=torch.bool, device=device) 153 | else: 154 | mask_full = torch.cat([torch.ones(B,1, dtype=torch.bool, device=device), tradable_mask], dim=-1) 155 | 156 | masked_logits = torch.where(mask_full, logits_full, torch.full_like(logits_full, -30.0)) 157 | alpha = _dirichlet_alpha_from_logits(masked_logits, self.min_alpha) # (B,N+1) 158 | 159 | dist = Dirichlet(alpha) 160 | if sample: 161 | p_all = dist.rsample() 162 | logp = dist.log_prob(p_all) 163 | else: 164 | p_all = alpha / alpha.sum(dim=-1, keepdim=True) 165 | logp = dist.log_prob(p_all) 166 | 167 | # Zero out non-tradables and renormalize into feasible names 168 | p = p_all * mask_full.float() 169 | s = p.sum(dim=-1, keepdim=True).clamp_min(1e-8) 170 | weights = p / s # (B,N+1) 171 | 172 | action = FlatAction(p_all=p_all, mask_full=mask_full) 173 | aux = {"logits_full": logits_full, "alpha": alpha} 174 | return weights, logp, value, action, aux 175 | 176 | def evaluate_actions(self, 177 | x: torch.Tensor, 178 | sector_ids: torch.Tensor, # unused 179 | tradable_mask: Optional[torch.Tensor], 180 | market_factors: Optional[torch.Tensor], 181 | action: FlatAction) -> Tuple[torch.Tensor, torch.Tensor]: 182 | # Recompute alpha and log-prob at current params. 183 | B, W, N, Fdim = x.shape 184 | device = x.device 185 | 186 | cash_logit, asset_logits, value = self._backbone(x, tradable_mask, market_factors) 187 | logits_full = torch.cat([cash_logit.unsqueeze(-1), asset_logits], dim=-1) # (B,N+1) 188 | 189 | mask_full = action.mask_full.to(device) 190 | masked_logits = torch.where(mask_full, logits_full, torch.full_like(logits_full, -30.0)) 191 | alpha = _dirichlet_alpha_from_logits(masked_logits, self.min_alpha) 192 | 193 | dist = Dirichlet(alpha) 194 | logp = dist.log_prob(action.p_all.to(device)) 195 | return logp, value 196 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pandas 3 | torch 4 | gymnasium 5 | matplotlib 6 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # train.py 2 | import argparse 3 | import numpy as np 4 | import torch 5 | from typing import List 6 | from env import PortfolioEnv 7 | from policy import StochasticHierarchicalDualAttentionPolicy 8 | from agents import PPOConfig, PPOTrainer, RolloutBuffer, A2CTrainer, REINFORCETrainer 9 | from utils import set_seeds, cs_zscore 10 | from data import load_panel_from_csv 11 | from metrics import compute_metrics_from_returns 12 | def run_training(csv_path: str, 13 | window: int = 30, 14 | tc_cost: float = 5e-4, 15 | d_model: int = 64, 16 | heads_time: int = 2, 17 | layers_time: int = 1, 18 | heads_cross: int = 2, 19 | layers_cross: int = 1, 20 | buffer_steps: int = 128, 21 | num_updates: int = 5, 22 | lr: float = 3e-4, 23 | minibatch_size: int = 256, 24 | update_epochs: int = 3, 25 | device: str = "auto", 26 | temporal_asset_chunk: int = 64, 27 | temporal_use_ckpt: bool = True, 28 | minibatch_micro: int = 32, 29 | out_dir: str = "./runs", 30 | seed: int = 42): 31 | set_seeds(seed) 32 | out_dir = Path(out_dir) 33 | out_dir.mkdir(parents=True, exist_ok=True) 34 | 35 | if device == "auto": 36 | device = "cuda" if torch.cuda.is_available() else "cpu" 37 | 38 | panel = load_panel_from_csv(csv_path, window) 39 | prices_rel = panel["prices_rel"]; features = panel["features"]; tradable_mask = panel["tradable_mask"] 40 | sector_ids = panel["sector_ids"]; num_sectors = panel["num_sectors"] 41 | N, Fdim = panel["N"], panel["Fdim"]; market_factors = panel["market_factors"] 42 | dates = panel["dates"] 43 | 44 | # Build env 45 | env = PortfolioEnv(prices_rel=prices_rel, 46 | features=features, 47 | sector_ids=sector_ids, 48 | tradable_mask=tradable_mask, 49 | market_factors=market_factors, 50 | tc_cost=tc_cost, 51 | window=window, 52 | sector_caps=None, 53 | include_cash=True) 54 | 55 | # Policy (flat single-Dirichlet in the provided module) 56 | policy = StochasticHierarchicalDualAttentionPolicy( 57 | d_model=d_model, 58 | n_heads_time=heads_time, n_layers_time=layers_time, 59 | n_heads_cross=heads_cross, n_layers_cross=layers_cross, 60 | num_sectors=num_sectors, # ignored by flat policy (kept for signature) 61 | include_cash=True, 62 | dropout=0.1, 63 | temporal_asset_chunk=temporal_asset_chunk, 64 | temporal_use_ckpt=temporal_use_ckpt 65 | ).to(device) 66 | 67 | cfg = PPOConfig(gamma=0.99, gae_lambda=0.95, clip_coef=0.2, vf_coef=0.5, 68 | ent_coef=0.0, max_grad_norm=0.5, learning_rate=lr, 69 | update_epochs=update_epochs, minibatch_size=minibatch_size, 70 | minibatch_micro=minibatch_micro) 71 | trainer = PPOTrainer(policy, cfg, device=device) 72 | 73 | sector_ids_torch = torch.from_numpy(sector_ids).long().to(device) 74 | obs_shape = (window, N, Fdim) 75 | 76 | # Logging containers 77 | loss_hist = [] 78 | mean_reward_hist = [] 79 | eval_cum_wealth_hist = [] 80 | eval_metrics_hist = [] 81 | 82 | for update in range(num_updates): 83 | # keep rollout buffer on CPU → avoids GPU OOM 84 | buffer = RolloutBuffer(T=buffer_steps, B=1, obs_shape=obs_shape, 85 | n_assets=N, sector_ids=sector_ids_torch, 86 | device="cpu", 87 | has_mkt=(env.mkt is not None), 88 | k_mkt=(0 if env.mkt is None else env.mkt.shape[-1])) 89 | buffer._step = 0 90 | obs, _ = env.reset() 91 | 92 | for t in range(buffer_steps): 93 | x_t = torch.from_numpy(env._get_state()).unsqueeze(0).to(device) # (1,W,N,F) 94 | trad_t= torch.from_numpy(env.mask[env.t]).unsqueeze(0).to(device) # (1,N) bool 95 | mkt_t = None if env.mkt is None else torch.from_numpy(env.mkt[env.t]).unsqueeze(0).to(device) 96 | with torch.no_grad(): 97 | weights, logp, value, action, _ = policy.get_action_and_value( 98 | x=x_t, 99 | sector_ids=sector_ids_torch.unsqueeze(0), 100 | tradable_mask=trad_t, 101 | market_factors=mkt_t, 102 | sample=True 103 | ) 104 | # Env expects logits; map weights→logits stably 105 | logits = torch.log(weights.clamp_min(1e-8))[0].cpu().numpy() 106 | next_obs, reward, done, truncated, info = env.step(logits) 107 | 108 | buffer.add( 109 | obs_t=x_t[0].cpu(), 110 | tradable_t=trad_t[0].cpu(), 111 | mkt_t=None if mkt_t is None else mkt_t[0].cpu(), 112 | action=action, 113 | weight_t=weights[0].cpu(), 114 | reward_t=torch.tensor(reward, dtype=torch.float32), 115 | done_t=torch.tensor(done, dtype=torch.bool), 116 | value_t=value[0].detach().cpu(), 117 | logp_t=logp[0].detach().cpu() 118 | ) 119 | obs = next_obs 120 | if done: 121 | obs, _ = env.reset() 122 | 123 | # Bootstrap advantages 124 | x_last = torch.from_numpy(env._get_state()).unsqueeze(0).to(device) 125 | trad_last= torch.from_numpy(env.mask[env.t]).unsqueeze(0).to(device) 126 | mkt_last = None if env.mkt is None else torch.from_numpy(env.mkt[env.t]).unsqueeze(0).to(device) 127 | with torch.no_grad(): 128 | _, _, last_value, _, _ = policy.get_action_and_value(x_last, sector_ids_torch.unsqueeze(0), trad_last, mkt_last, sample=False) 129 | buffer.compute_gae(last_value.squeeze(-1).detach().cpu(), gamma=cfg.gamma, lam=cfg.gae_lambda) 130 | 131 | # PPO update 132 | loss = trainer.update(buffer) 133 | loss_hist.append(loss) 134 | mean_reward_hist.append(float(buffer.rewards.mean().cpu().numpy())) 135 | 136 | # Deterministic evaluation over full dataset 137 | eval_out = eval_policy_full_dataset(policy, 138 | prices_rel, features, tradable_mask, 139 | sector_ids, market_factors, 140 | window, device) 141 | eval_cum_wealth_hist.append(eval_out["cum_wealth"]) 142 | eval_metrics_hist.append(eval_out["metrics"]) 143 | 144 | m = eval_out["metrics"] 145 | print(f"[{update+1}/{num_updates}] loss={loss:.4f} | meanR={mean_reward_hist[-1]:.6f} | " 146 | f"Eval: Sharpe={m['sharpe']:.3f} CAGR={m['cagr']:.3%} MDD={m['mdd']:.2%}") 147 | 148 | # Save CSV log 149 | rows = [] 150 | for i, (loss_i, mr_i, em_i) in enumerate(zip(loss_hist, mean_reward_hist, eval_metrics_hist), start=1): 151 | rows.append({ 152 | "update": i, 153 | "ppo_loss": loss_i, 154 | "mean_buffer_reward": mr_i, 155 | "eval_sharpe": em_i["sharpe"], 156 | "eval_cagr": em_i["cagr"], 157 | "eval_mdd": em_i["mdd"], 158 | }) 159 | df_log = pd.DataFrame(rows) 160 | out_dir.mkdir(parents=True, exist_ok=True) 161 | df_log.to_csv(out_dir / "training_log.csv", index=False) 162 | 163 | # Save plots (training curves + last equity curve) 164 | save_training_plots(out_dir, loss_hist, mean_reward_hist, eval_cum_wealth_hist, dates, window) 165 | 166 | print(f"\nSaved: {out_dir/'training_log.csv'}, {out_dir/'training_curves.png'}, " 167 | f"{out_dir/'eval_equity_curve.png'}") 168 | 169 | 170 | def train_and_save(csv_path: str, 171 | save_path: str = "./checkpoints/policy.pt", 172 | window: int = 30, 173 | tc_cost: float = 5e-4, 174 | d_model: int = 64, 175 | heads_time: int = 2, 176 | layers_time: int = 1, 177 | heads_cross: int = 2, 178 | layers_cross: int = 1, 179 | buffer_steps: int = 128, 180 | num_updates: int = 5, 181 | lr: float = 3e-4, 182 | minibatch_size: int = 256, 183 | update_epochs: int = 3, 184 | device: str = "auto", 185 | temporal_asset_chunk: int = 64, 186 | temporal_use_ckpt: bool = True, 187 | minibatch_micro: int = 32, 188 | seed: int = 42): 189 | 190 | set_seeds(seed) 191 | if device == "auto": 192 | device = "cuda" if torch.cuda.is_available() else "cpu" 193 | 194 | panel = load_panel_from_csv(csv_path, window) 195 | prices_rel = panel["prices_rel"]; features = panel["features"]; tradable_mask = panel["tradable_mask"] 196 | sector_ids = panel["sector_ids"]; num_sectors = panel["num_sectors"] 197 | N, Fdim = panel["N"], panel["Fdim"] 198 | 199 | env = PortfolioEnv(prices_rel=prices_rel, 200 | features=features, 201 | sector_ids=sector_ids, 202 | tradable_mask=tradable_mask, 203 | market_factors=None, 204 | tc_cost=tc_cost, 205 | window=window, 206 | include_cash=True) 207 | 208 | # Flat single-Dirichlet policy (class name kept for compatibility) 209 | policy = StochasticHierarchicalDualAttentionPolicy( 210 | d_model=d_model, 211 | n_heads_time=heads_time, n_layers_time=layers_time, 212 | n_heads_cross=heads_cross, n_layers_cross=layers_cross, 213 | num_sectors=num_sectors, # ignored by flat actor 214 | include_cash=True, 215 | dropout=0.1, 216 | temporal_asset_chunk=temporal_asset_chunk, 217 | temporal_use_ckpt=temporal_use_ckpt 218 | ).to(device) 219 | 220 | cfg = PPOConfig(gamma=0.99, gae_lambda=0.95, clip_coef=0.2, vf_coef=0.5, 221 | ent_coef=0.0, max_grad_norm=0.5, learning_rate=lr, 222 | update_epochs=update_epochs, minibatch_size=minibatch_size, 223 | minibatch_micro=minibatch_micro) 224 | trainer = PPOTrainer(policy, cfg, device=device) 225 | 226 | obs_shape = (window, N, Fdim) 227 | sector_ids_torch = torch.from_numpy(sector_ids).long().to(device) 228 | 229 | for update in range(num_updates): 230 | buffer = RolloutBuffer(T=buffer_steps, B=1, obs_shape=obs_shape, 231 | n_assets=N, sector_ids=sector_ids_torch, 232 | device="cpu", has_mkt=False, k_mkt=0) 233 | buffer._step = 0 234 | env.reset() 235 | 236 | # collect rollout 237 | for t in range(buffer_steps): 238 | x_t = torch.from_numpy(env._get_state()).unsqueeze(0).to(device) # (1,W,N,F) 239 | trad_t= torch.from_numpy(env.mask[env.t]).unsqueeze(0).to(device) # (1,N) bool 240 | with torch.no_grad(): 241 | weights, logp, value, action, _ = policy.get_action_and_value( 242 | x=x_t, 243 | sector_ids=sector_ids_torch.unsqueeze(0), 244 | tradable_mask=trad_t, 245 | market_factors=None, 246 | sample=True 247 | ) 248 | # Env expects logits; map weights→logits stably 249 | logits = torch.log(weights.clamp_min(1e-8))[0].cpu().numpy() 250 | _, reward, done, _, _ = env.step(logits) 251 | 252 | buffer.add( 253 | obs_t=x_t[0].cpu(), 254 | tradable_t=trad_t[0].cpu(), 255 | mkt_t=None, 256 | action=action, 257 | weight_t=weights[0].cpu(), 258 | reward_t=torch.tensor(reward, dtype=torch.float32), 259 | done_t=torch.tensor(done, dtype=torch.bool), 260 | value_t=value[0].detach().cpu(), 261 | logp_t=logp[0].detach().cpu() 262 | ) 263 | if done: 264 | env.reset() 265 | 266 | # bootstrap 267 | x_last = torch.from_numpy(env._get_state()).unsqueeze(0).to(device) 268 | trad_last= torch.from_numpy(env.mask[env.t]).unsqueeze(0).to(device) 269 | with torch.no_grad(): 270 | _, _, last_value, _, _ = policy.get_action_and_value( 271 | x_last, sector_ids_torch.unsqueeze(0), trad_last, None, sample=False 272 | ) 273 | buffer.compute_gae(last_value.squeeze(-1).detach().cpu(), gamma=cfg.gamma, lam=cfg.gae_lambda) 274 | 275 | # update 276 | loss = trainer.update(buffer) 277 | print(f"[{update+1}/{num_updates}] PPO loss = {loss:.4f}") 278 | 279 | # -------- save checkpoint -------- 280 | save_path = Path(save_path) 281 | save_path.parent.mkdir(parents=True, exist_ok=True) 282 | ckpt = { 283 | "state_dict": policy.state_dict(), 284 | "config": { 285 | "window": window, 286 | "tc_cost": tc_cost, 287 | "d_model": d_model, 288 | "heads_time": heads_time, 289 | "layers_time": layers_time, 290 | "heads_cross": heads_cross, 291 | "layers_cross": layers_cross, 292 | "temporal_asset_chunk": temporal_asset_chunk, 293 | "temporal_use_ckpt": temporal_use_ckpt, 294 | "include_cash": True, 295 | } 296 | } 297 | torch.save(ckpt, save_path) 298 | print(f"Saved model to: {str(save_path.resolve())}") 299 | 300 | 301 | def parse_args(): 302 | p = argparse.ArgumentParser(description="Train PPO/A2C/REINFORCE/DDPG on portfolio policy.") 303 | p.add_argument("--csv", type=str, required=True) 304 | p.add_argument("--algo", type=str, default="ppo", choices=["ppo","a2c","reinforce","ddpg"]) 305 | p.add_argument("--window", type=int, default=30) 306 | p.add_argument("--tc", type=float, default=5e-4) 307 | p.add_argument("--d_model", type=int, default=64) 308 | p.add_argument("--heads_time", type=int, default=2) 309 | p.add_argument("--layers_time", type=int, default=1) 310 | p.add_argument("--heads_cross", type=int, default=2) 311 | p.add_argument("--layers_cross", type=int, default=1) 312 | p.add_argument("--buffer_steps", type=int, default=128) 313 | p.add_argument("--updates", type=int, default=5) 314 | p.add_argument("--lr", type=float, default=3e-4) 315 | p.add_argument("--mb", type=int, default=256) 316 | p.add_argument("--epochs", type=int, default=3) 317 | p.add_argument("--device", type=str, default="auto") 318 | p.add_argument("--asset_chunk", type=int, default=64) 319 | p.add_argument("--no_ckpt", action="store_true") 320 | p.add_argument("--mb_micro", type=int, default=32) 321 | p.add_argument("--out_dir", type=str, default="./runs_multi") 322 | p.add_argument("--seed", type=int, default=42) 323 | return p.parse_args() 324 | 325 | 326 | def run(args): 327 | set_seeds(args.seed) 328 | device = "cuda" if args.device=="auto" and torch.cuda.is_available() else args.device 329 | 330 | # load & split 331 | df = pd.read_csv(args.csv, low_memory=False) 332 | assert {"Date","ticker","Close"}.issubset(df.columns) 333 | df_train, df_test, split_date = date_split_20y(df, years=20) 334 | print(f"Train {df_train['Date'].min().date()} → {df_train['Date'].max().date()} | " 335 | f"Test {df_test['Date'].min().date()} → {df_test['Date'].max().date()} | split={split_date.date()}") 336 | 337 | # features 338 | default_feats = [ 339 | "Open","High","Low","Close","Volume", 340 | "macd","macd_signal","macd_hist", 341 | "macdboll","ubboll","lb","rsi","30cci", 342 | "plus_di_14","minus_di_14","dx_14","dx30", 343 | "close30_sma","close60_sma" 344 | ] 345 | feature_cols = [c for c in default_feats if c in df.columns] 346 | if not feature_cols: raise ValueError("No numeric feature columns found.") 347 | 348 | # panels 349 | panel_tr = build_panel(df_train, feature_cols) 350 | panel_te = build_panel(df_test, feature_cols) 351 | 352 | # env for training 353 | env = PortfolioEnv(prices_rel=panel_tr["prices_rel"], features=panel_tr["features"], 354 | sector_ids=panel_tr["sector_ids"], tradable_mask=panel_tr["mask"], 355 | market_factors=None, tc_cost=args.tc, window=args.window, include_cash=True) 356 | 357 | N = panel_tr["features"].shape[1]; Fdim = panel_tr["features"].shape[2] 358 | obs_shape = (args.window, N, Fdim) 359 | 360 | # actor (shared) 361 | policy = StochasticHierarchicalDualAttentionPolicy( 362 | d_model=args.d_model, n_heads_time=args.heads_time, n_layers_time=args.layers_time, 363 | n_heads_cross=args.heads_cross, n_layers_cross=args.layers_cross, 364 | include_cash=True, temporal_asset_chunk=args.asset_chunk, temporal_use_ckpt=(not args.no_ckpt) 365 | ).to(device) 366 | 367 | out_dir = Path(args.out_dir) / args.algo 368 | out_dir.mkdir(parents=True, exist_ok=True) 369 | 370 | # ----- select algo ----- 371 | loss_hist, mean_reward_hist = [], [] 372 | if args.algo == "ppo": 373 | cfg = PPOConfig(gamma=0.99, gae_lambda=0.95, clip_coef=0.2, vf_coef=0.5, 374 | ent_coef=0.0, max_grad_norm=0.5, learning_rate=args.lr, 375 | update_epochs=args.epochs, minibatch_size=args.mb, minibatch_micro=args.mb_micro) 376 | trainer = PPOTrainer(policy, cfg, device=device) 377 | sid_t = torch.from_numpy(panel_tr["sector_ids"]).long().to(device) 378 | for upd in range(args.updates): 379 | buffer = RolloutBuffer(T=args.buffer_steps, B=1, obs_shape=obs_shape, 380 | n_assets=N, sector_ids=sid_t, device="cpu", 381 | has_mkt=False, k_mkt=0) 382 | buffer._step = 0; env.reset() 383 | for t in range(args.buffer_steps): 384 | x = torch.from_numpy(env._get_state()).unsqueeze(0).to(device) 385 | trad = torch.from_numpy(env.mask[env.t]).unsqueeze(0).to(device) 386 | with torch.no_grad(): 387 | w, logp, v, a, _ = policy.get_action_and_value(x, sid_t.unsqueeze(0), trad, None, sample=True) 388 | logits = torch.log(w.clamp_min(1e-8))[0].cpu().numpy() 389 | _, r, done, _, _ = env.step(logits) 390 | buffer.add(x[0].cpu(), trad[0].cpu(), None, a, w[0].cpu(), 391 | torch.tensor(r, dtype=torch.float32), torch.tensor(done, dtype=torch.bool), 392 | v[0].detach().cpu(), logp[0].detach().cpu()) 393 | if done: env.reset() 394 | with torch.no_grad(): 395 | xl = torch.from_numpy(env._get_state()).unsqueeze(0).to(device) 396 | tradl = torch.from_numpy(env.mask[env.t]).unsqueeze(0).to(device) 397 | _, _, lv, _, _ = policy.get_action_and_value(xl, sid_t.unsqueeze(0), tradl, None, sample=False) 398 | buffer.compute_gae(lv.squeeze(-1).detach().cpu(), cfg.gamma, cfg.gae_lambda) 399 | loss = trainer.update(buffer); loss_hist.append(loss) 400 | mean_reward_hist.append(float(buffer.rewards.mean().cpu().numpy())) 401 | print(f"[PPO {upd+1}/{args.updates}] loss={loss:.4f} | meanR={mean_reward_hist[-1]:.6f}") 402 | 403 | elif args.algo == "a2c": 404 | cfg = A2CConfig(lr=args.lr, update_epochs=args.epochs, minibatch_size=args.mb, 405 | minibatch_micro=args.mb_micro, vf_coef=0.5, ent_coef=0.0) 406 | trainer = A2CTrainer(policy, cfg, device=device) 407 | sid_t = torch.from_numpy(panel_tr["sector_ids"]).long().to(device) 408 | for upd in range(args.updates): 409 | buffer = RolloutBuffer(T=args.buffer_steps, B=1, obs_shape=obs_shape, 410 | n_assets=N, sector_ids=sid_t, device="cpu", has_mkt=False, k_mkt=0) 411 | buffer._step = 0; env.reset() 412 | for t in range(args.buffer_steps): 413 | x = torch.from_numpy(env._get_state()).unsqueeze(0).to(device) 414 | trad = torch.from_numpy(env.mask[env.t]).unsqueeze(0).to(device) 415 | with torch.no_grad(): 416 | w, logp, v, a, _ = policy.get_action_and_value(x, sid_t.unsqueeze(0), trad, None, sample=True) 417 | logits = torch.log(w.clamp_min(1e-8))[0].cpu().numpy() 418 | _, r, done, _, _ = env.step(logits) 419 | buffer.add(x[0].cpu(), trad[0].cpu(), None, a, w[0].cpu(), 420 | torch.tensor(r, dtype=torch.float32), torch.tensor(done, dtype=torch.bool), 421 | v[0].detach().cpu(), logp[0].detach().cpu()) 422 | if done: env.reset() 423 | with torch.no_grad(): 424 | xl = torch.from_numpy(env._get_state()).unsqueeze(0).to(device) 425 | tradl = torch.from_numpy(env.mask[env.t]).unsqueeze(0).to(device) 426 | _, _, lv, _, _ = policy.get_action_and_value(xl, sid_t.unsqueeze(0), tradl, None, sample=False) 427 | buffer.compute_gae(lv.squeeze(-1).detach().cpu(), cfg.gamma, cfg.gae_lambda) 428 | loss = trainer.update(buffer); loss_hist.append(loss) 429 | mean_reward_hist.append(float(buffer.rewards.mean().cpu().numpy())) 430 | print(f"[A2C {upd+1}/{args.updates}] loss={loss:.4f} | meanR={mean_reward_hist[-1]:.6f}") 431 | 432 | elif args.algo == "reinforce": 433 | trainer = REINFORCETrainer(policy, lr=args.lr, vf_coef=0.5, device=device, max_grad_norm=0.5) 434 | sid_t = torch.from_numpy(panel_tr["sector_ids"]).long().to(device) 435 | for upd in range(args.updates): 436 | buffer = RolloutBuffer(T=args.buffer_steps, B=1, obs_shape=obs_shape, 437 | n_assets=N, sector_ids=sid_t, device="cpu", has_mkt=False, k_mkt=0) 438 | buffer._step = 0; env.reset() 439 | for t in range(args.buffer_steps): 440 | x = torch.from_numpy(env._get_state()).unsqueeze(0).to(device) 441 | trad = torch.from_numpy(env.mask[env.t]).unsqueeze(0).to(device) 442 | with torch.no_grad(): 443 | w, logp, v, a, _ = policy.get_action_and_value(x, sid_t.unsqueeze(0), trad, None, sample=True) 444 | logits = torch.log(w.clamp_min(1e-8))[0].cpu().numpy() 445 | _, r, done, _, _ = env.step(logits) 446 | buffer.add(x[0].cpu(), trad[0].cpu(), None, a, w[0].cpu(), 447 | torch.tensor(r, dtype=torch.float32), torch.tensor(done, dtype=torch.bool), 448 | v[0].detach().cpu(), logp[0].detach().cpu()) 449 | if done: env.reset() 450 | # Monte Carlo returns (no bootstrap): use buffer.rewards cumulative with gamma 451 | # We can reuse compute_gae with lam=1 and next_value=0: 452 | buffer.compute_gae(torch.zeros(1), gamma=0.99, lam=1.0) 453 | loss = trainer.update(buffer); loss_hist.append(loss) 454 | mean_reward_hist.append(float(buffer.rewards.mean().cpu().numpy())) 455 | print(f"[REINFORCE {upd+1}/{args.updates}] loss={loss:.4f} | meanR={mean_reward_hist[-1]:.6f}") 456 | 457 | elif args.algo == "ddpg": 458 | # short off-policy run on a rolling window of the training panel 459 | cfg = DDPGConfig(gamma=0.99, tau=0.005, lr_actor=1e-4, lr_critic=1e-3, 460 | batch_size=64, explore_alpha=0.3, updates_per_step=1) 461 | trainer = DDPGSoftmaxTrainer(policy, d_model=args.d_model, action_dim=N+1, cfg=cfg, device=device) 462 | replay = Replay(capacity=5000) 463 | env.reset() 464 | sid_t = torch.from_numpy(panel_tr["sector_ids"]).long().to(device) # unused 465 | for step in range(args.buffer_steps * args.updates): 466 | x = torch.from_numpy(env._get_state()).unsqueeze(0).to(device) 467 | trad = torch.from_numpy(env.mask[env.t]).unsqueeze(0).to(device) 468 | # exploration action 469 | with torch.no_grad(): 470 | w = trainer._actor_weights(x, trad, deterministic=True, explore_alpha=cfg.explore_alpha) 471 | logits = torch.log(w.clamp_min(1e-8))[0].cpu().numpy() 472 | prev_state = env._get_state().copy(); prev_trad = env.mask[env.t].copy() 473 | _, r, done, _, _ = env.step(logits) 474 | next_state = env._get_state().copy(); next_trad = env.mask[env.t].copy() 475 | replay.push(prev_state, prev_trad, w[0].detach().cpu().numpy(), r, next_state, next_trad, float(done)) 476 | loss = trainer.update(replay) 477 | if done: env.reset() 478 | # DDPG doesn't have per-update loss history; store last total loss 479 | loss_hist = [loss] if isinstance(loss, float) else [] 480 | 481 | else: 482 | raise ValueError(f"Unknown algo: {args.algo}") 483 | 484 | # ----- save model ----- 485 | out_dir.mkdir(parents=True, exist_ok=True) 486 | ckpt_path = out_dir / f"{args.algo}_policy.pth" 487 | torch.save({"state_dict": policy.state_dict(), 488 | "config": {"d_model": args.d_model, "window": args.window, "tc_cost": args.tc}}, 489 | ckpt_path) 490 | print(f"Saved model → {ckpt_path}") 491 | 492 | # ----- evaluation & plots ----- 493 | train_daily, train_cw, train_metrics = eval_on_panel(policy, panel_tr, args.window, args.tc, device, deterministic=True) 494 | test_daily, test_cw, test_metrics = eval_on_panel(policy, panel_te, args.window, args.tc, device, deterministic=True) 495 | print(f"Train: Sharpe={train_metrics['sharpe']:.3f} CAGR={train_metrics['cagr']:.2%} MDD={train_metrics['mdd']:.2%}") 496 | print(f"Test : Sharpe={test_metrics['sharpe']:.3f} CAGR={test_metrics['cagr']:.2%} MDD={test_metrics['mdd']:.2%}") 497 | 498 | # save CSV log and plots 499 | pd.DataFrame({"update": np.arange(1, len(loss_hist)+1), 500 | "loss": loss_hist, 501 | "mean_reward": mean_reward_hist[:len(loss_hist)] if len(mean_reward_hist)>=len(loss_hist) else None} 502 | ).to_csv(out_dir/"training_log.csv", index=False) 503 | save_plots(out_dir, loss_hist, mean_reward_hist, train_cw, test_cw, panel_tr["dates"], panel_te["dates"], args.window) 504 | print(f"Saved plots & logs → {out_dir}") 505 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # utils.py 2 | import numpy as np 3 | import torch 4 | import random 5 | def set_seeds(seed: int = 42): 6 | np.random.seed(seed) 7 | torch.manual_seed(seed) 8 | random.seed(seed) 9 | def project_to_simplex(v: np.ndarray) -> np.ndarray: 10 | # Duchi et al. (2008) projection onto simplex 11 | n = v.shape[-1] 12 | u = np.sort(v)[::-1] 13 | cssv = np.cumsum(u) - 1 14 | ind = np.arange(1, n+1) 15 | cond = u - cssv / ind > 0 16 | if not np.any(cond): 17 | return np.ones_like(v) / n 18 | rho = ind[cond][-1] 19 | theta = cssv[cond][-1] / rho 20 | w = np.maximum(v - theta, 0) 21 | return w 22 | 23 | 24 | def _capped_simplex_projection(weights: np.ndarray, sector_ids: np.ndarray, sector_caps: dict) -> np.ndarray: 25 | # Simple iterative projection to enforce per-sector caps. 26 | w = weights.copy() 27 | for _ in range(10): 28 | for s, cap in sector_caps.items(): 29 | idx = (sector_ids == s) 30 | total = w[idx].sum() 31 | if total > cap + 1e-12 and total > 0: 32 | w[idx] *= cap / total 33 | w = project_to_simplex(w) 34 | return w 35 | 36 | 37 | def sinusoidal_time_encoding(L: int, d_model: int, device: torch.device) -> torch.Tensor: 38 | pe = torch.zeros(L, d_model, device=device) 39 | position = torch.arange(0, L, device=device, dtype=torch.float32).unsqueeze(1) 40 | div_term = torch.exp(torch.arange(0, d_model, 2, device=device, dtype=torch.float32) * (-math.log(10000.0) / d_model)) 41 | pe[:, 0::2] = torch.sin(position * div_term) 42 | pe[:, 1::2] = torch.cos(position * div_term) 43 | return pe 44 | 45 | 46 | def _dirichlet_alpha_from_logits(logits: torch.Tensor, min_alpha: float = 0.05) -> torch.Tensor: 47 | return Fnn.softplus(logits) + min_alpha 48 | 49 | 50 | def cs_zscore(X: np.ndarray) -> np.ndarray: 51 | mu = np.nanmean(X, axis=1, keepdims=True) 52 | sd = np.nanstd(X, axis=1, keepdims=True) + 1e-8 53 | Z = (X - mu) / sd 54 | return np.nan_to_num(Z, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32) 55 | 56 | 57 | def cross_sectional_zscore(features_raw: np.ndarray) -> np.ndarray: 58 | mu = np.nanmean(features_raw, axis=1, keepdims=True) 59 | sd = np.nanstd(features_raw, axis=1, keepdims=True) + 1e-8 60 | z = (features_raw - mu) / sd 61 | return np.nan_to_num(z, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32) 62 | --------------------------------------------------------------------------------