├── TFT.pdf ├── data └── electricity-future.csv ├── architecture.md └── README.md /TFT.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sid3503/TFT-Forecasting/main/TFT.pdf -------------------------------------------------------------------------------- /data/electricity-future.csv: -------------------------------------------------------------------------------- 1 | unique_id,ds,Exogenous1,Exogenous2 2 | BE,12/31/2016 0:00,64108,70318 3 | BE,12/31/2016 1:00,62492,67898 4 | BE,12/31/2016 2:00,61571,68379 5 | BE,12/31/2016 3:00,60381,64972 6 | BE,12/31/2016 4:00,60298,62900 7 | BE,12/31/2016 5:00,60339,62364 8 | BE,12/31/2016 6:00,62576,64242 9 | BE,12/31/2016 7:00,63732,65884 10 | BE,12/31/2016 8:00,66235,68217 11 | BE,12/31/2016 9:00,66801,69921 12 | BE,12/31/2016 10:00,66964,72069 13 | BE,12/31/2016 11:00,66667,72328 14 | BE,12/31/2016 12:00,66145,72493 15 | BE,12/31/2016 13:00,65052,74228 16 | BE,12/31/2016 14:00,63905,70369 17 | BE,12/31/2016 15:00,63626,67987 18 | BE,12/31/2016 16:00,63817,66692 19 | BE,12/31/2016 17:00,65908,68634 20 | BE,12/31/2016 18:00,69289,72362 21 | BE,12/31/2016 19:00,70329,73957 22 | BE,12/31/2016 20:00,69121,72544 23 | BE,12/31/2016 21:00,66647,69451 24 | BE,12/31/2016 22:00,65886,67823 25 | BE,12/31/2016 23:00,66846,72876 26 | DE,12/31/2017 0:00,1392.25,14744.5 27 | DE,12/31/2017 1:00,1289,13906 28 | DE,12/31/2017 2:00,1206.25,13579 29 | DE,12/31/2017 3:00,1120.5,13302.75 30 | DE,12/31/2017 4:00,1053.25,13232.5 31 | DE,12/31/2017 5:00,1039.75,12631.75 32 | DE,12/31/2017 6:00,1062,10718.25 33 | DE,12/31/2017 7:00,1101.25,10753.75 34 | DE,12/31/2017 8:00,1163.5,12495.75 35 | DE,12/31/2017 9:00,1262.25,14335.5 36 | DE,12/31/2017 10:00,1364,15218 37 | DE,12/31/2017 11:00,1455.75,16021.75 38 | DE,12/31/2017 12:00,1538.75,16549.75 39 | DE,12/31/2017 13:00,1578.25,15791.75 40 | DE,12/31/2017 14:00,1578.25,15209 41 | DE,12/31/2017 15:00,1537,14991.25 42 | DE,12/31/2017 16:00,1482.25,15601.75 43 | DE,12/31/2017 17:00,1422.25,16954.5 44 | DE,12/31/2017 18:00,1352,17005.75 45 | DE,12/31/2017 19:00,1300,16601 46 | DE,12/31/2017 20:00,1261,15977.75 47 | DE,12/31/2017 21:00,1243.25,15715 48 | DE,12/31/2017 22:00,1250.5,15876 49 | DE,12/31/2017 23:00,1270.75,15130 50 | FR,12/31/2016 0:00,64108,70318 51 | FR,12/31/2016 1:00,62492,67898 52 | FR,12/31/2016 2:00,61571,68379 53 | FR,12/31/2016 3:00,60381,64972 54 | FR,12/31/2016 4:00,60298,62900 55 | FR,12/31/2016 5:00,60339,62364 56 | FR,12/31/2016 6:00,62576,64242 57 | FR,12/31/2016 7:00,63732,65884 58 | FR,12/31/2016 8:00,66235,68217 59 | FR,12/31/2016 9:00,66801,69921 60 | FR,12/31/2016 10:00,66964,72069 61 | FR,12/31/2016 11:00,66667,72328 62 | FR,12/31/2016 12:00,66145,72493 63 | FR,12/31/2016 13:00,65052,74228 64 | FR,12/31/2016 14:00,63905,70369 65 | FR,12/31/2016 15:00,63626,67987 66 | FR,12/31/2016 16:00,63817,66692 67 | FR,12/31/2016 17:00,65908,68634 68 | FR,12/31/2016 18:00,69289,72362 69 | FR,12/31/2016 19:00,70329,73957 70 | FR,12/31/2016 20:00,69121,72544 71 | FR,12/31/2016 21:00,66647,69451 72 | FR,12/31/2016 22:00,65886,67823 73 | FR,12/31/2016 23:00,66846,72876 74 | NP,12/24/2018 0:00,49119,461 75 | NP,12/24/2018 1:00,48115,484 76 | NP,12/24/2018 2:00,47727,497 77 | NP,12/24/2018 3:00,47673,509 78 | NP,12/24/2018 4:00,47848,510 79 | NP,12/24/2018 5:00,48673,511 80 | NP,12/24/2018 6:00,50071,526 81 | NP,12/24/2018 7:00,52259,528 82 | NP,12/24/2018 8:00,54395,659 83 | NP,12/24/2018 9:00,56071,897 84 | NP,12/24/2018 10:00,56989,1140 85 | NP,12/24/2018 11:00,56985,1342 86 | NP,12/24/2018 12:00,56391,1547 87 | NP,12/24/2018 13:00,55922,1618 88 | NP,12/24/2018 14:00,55826,1664 89 | NP,12/24/2018 15:00,56216,1768 90 | NP,12/24/2018 16:00,56661,1964 91 | NP,12/24/2018 17:00,55479,2092 92 | NP,12/24/2018 18:00,53741,2243 93 | NP,12/24/2018 19:00,52591,2544 94 | NP,12/24/2018 20:00,51787,2785 95 | NP,12/24/2018 21:00,51488,2919 96 | NP,12/24/2018 22:00,50928,3119 97 | NP,12/24/2018 23:00,49889,3306 98 | PJM,12/24/2018 0:00,84069,10276 99 | PJM,12/24/2018 1:00,81849,9809 100 | PJM,12/24/2018 2:00,80671,9501 101 | PJM,12/24/2018 3:00,80229,9337 102 | PJM,12/24/2018 4:00,80734,9283 103 | PJM,12/24/2018 5:00,82735,9369 104 | PJM,12/24/2018 6:00,85937,9618 105 | PJM,12/24/2018 7:00,89189,9987 106 | PJM,12/24/2018 8:00,91449,10195 107 | PJM,12/24/2018 9:00,93233,10378 108 | PJM,12/24/2018 10:00,94025,10578 109 | PJM,12/24/2018 11:00,93853,10677 110 | PJM,12/24/2018 12:00,93122,10686 111 | PJM,12/24/2018 13:00,92263,10631 112 | PJM,12/24/2018 14:00,91663,10586 113 | PJM,12/24/2018 15:00,91472,10547 114 | PJM,12/24/2018 16:00,93134,10635 115 | PJM,12/24/2018 17:00,96723,11154 116 | PJM,12/24/2018 18:00,96387,11464 117 | PJM,12/24/2018 19:00,94939,11246 118 | PJM,12/24/2018 20:00,94035,11070 119 | PJM,12/24/2018 21:00,92923,10963 120 | PJM,12/24/2018 22:00,90970,10802 121 | PJM,12/24/2018 23:00,88037,10419 122 | -------------------------------------------------------------------------------- /architecture.md: -------------------------------------------------------------------------------- 1 | # 🧠 Temporal Fusion Transformer (TFT) Architecture – Deep Dive 2 | 3 | This section walks you through the full **TFT architecture** (based on the image above) using: 4 | 5 | - 📊 Visual explanation (linked to components in the diagram) 6 | - 📚 Plain language with **real-world examples** 7 | - 🧮 Simplified **math formulas and intuitions** 8 | - 🔍 Dummy data to **see what's happening** 9 | 10 | --- 11 | 12 | ## 📌 Overview of the Flow 13 | 14 | ![Image](https://github.com/user-attachments/assets/b757eda9-37da-410e-93cd-98861f9ede59) 15 | 16 | The model is split into **three parts**: 17 | 1. **Input Encoding** 18 | 2. **Temporal Fusion Decoder** 19 | 3. **Forecasting Head (Quantile Predictions)** 20 | 21 | 4. Here’s a **clear and interactive Mermaid flowchart** that visually represents the **Temporal Fusion Transformer (TFT) architecture**, inspired by the official diagram and aligned with our expanded explanation. 22 | 23 | You can embed this in a markdown file (`architecture.md`) or your `README.md` if your platform supports Mermaid (e.g., GitHub, Obsidian, MkDocs, etc.). 24 | 25 | --- 26 | 27 | ### 🌐 Mermaid Flowchart: Temporal Fusion Transformer (TFT) 28 | 29 |
30 | 📈 Click to expand TFT architecture flowchart (Mermaid) 31 | 32 | ```mermaid 33 | flowchart TD 34 | Inputs[Inputs] --> Encoder[LSTM Encoder] 35 | Inputs --> Decoder[LSTM Decoder] 36 | Inputs --> Static[Static Metadata Encoder] 37 | 38 | Encoder --> VarSelPast[Variable Selection - Past] 39 | Decoder --> VarSelFuture[Variable Selection - Future] 40 | Static --> StaticEnrich[Static Enrichment] 41 | 42 | VarSelPast --> GRNPast[GRN + Add & Norm - Past] 43 | VarSelFuture --> GRNFuture[GRN + Add & Norm - Future] 44 | StaticEnrich --> StaticGRN[Inject Static Context] 45 | 46 | GRNPast --> Attention[Masked Multi-head Attention] 47 | GRNFuture --> Attention 48 | StaticGRN --> Attention 49 | 50 | Attention --> PostGRN[GRN + Add & Norm - Attention] 51 | PostGRN --> FF[Feed-Forward Layer] 52 | FF --> Output[Quantile Forecast Head] 53 | 54 | Output --> P10[p10 Prediction] 55 | Output --> P50[p50 Prediction] 56 | Output --> P90[p90 Prediction] 57 | 58 | ``` 59 | 60 |
61 | 62 | --- 63 | It learns: 64 | - What features matter (Feature Selection) 65 | - When it should pay attention (Temporal Attention) 66 | - How to interpret uncertainty (Quantile Regression) 67 | 68 | --- 69 | 70 | ## 1️⃣ Variable Selection Network 71 | 72 | ![Image](https://github.com/user-attachments/assets/b830ff1f-e073-4def-a168-88ced5461133) 73 | 74 | ### 🔍 What It Does 75 | Learns which input variables to focus on **at each time step** — automatically. 76 | 77 | ### 🧮 Intuition (Simplified) 78 | 79 | Let’s say you input: 80 | 81 | ```python 82 | x_t = [temperature, humidity, holiday, load_zone1, load_zone2] 83 | ``` 84 | 85 | But for the hour `t = 12:00 PM`, only `temperature` and `load_zone1` really matter. 86 | 87 | The Variable Selection Network learns **weights** like: 88 | 89 | ``` 90 | [0.8, 0.05, 0.05, 0.9, 0.05] → Softmax → attention-like scores 91 | ``` 92 | 93 | It **filters** irrelevant inputs dynamically. 94 | 95 | ### 🧪 Dummy Example 96 | 97 | Imagine: 98 | 99 | ``` 100 | At 8 AM → electricity load depends more on: [day, holiday, load_zone1] 101 | At 2 PM → it's more about: [temperature, humidity] 102 | ``` 103 | 104 | TFT learns this mapping automatically using: 105 | 106 | ```math 107 | αₜ = Softmax(W₁ GRN₁(xₜ), ..., Wₙ GRNₙ(xₜ)) 108 | ``` 109 | 110 | Where: 111 | - Each variable is passed through a **Gated Residual Network** 112 | - Their outputs are weighted and summed 113 | - The softmax layer highlights the **most important variables** 114 | 115 | --- 116 | 117 | ## 2️⃣ Gated Residual Network (GRN) 118 | 119 | ![Image](https://github.com/user-attachments/assets/e5877b0a-fbef-4e39-b18b-d7209e318cf1) 120 | 121 | ### 🔍 What It Does 122 | A smart block that decides how much of the transformed signal to **pass through or suppress**. 123 | 124 | ### 🧠 Analogy 125 | 126 | Imagine a **valve in a pipe**: water (information) flows in. The GRN learns **how much to open or close the valve** depending on whether that information is useful. 127 | 128 | ### 🧮 Formula (Simplified) 129 | 130 | ```math 131 | GRN(x) = Gate(x) ⊙ (LayerNorm(Dense₂(ELU(Dense₁(x))))) 132 | ``` 133 | 134 | - `Dense₁`, `Dense₂` = Linear transformations 135 | - `ELU` = Activation function (non-linear twist) 136 | - `Gate(x)` = Learnable sigmoid-based switch (outputs 0–1) 137 | 138 | ### 🧪 Dummy Example 139 | 140 | Input = `temperature = 35°C` 141 | 142 | If temperature isn’t relevant now, the GRN might learn: 143 | 144 | ``` 145 | Gate(temperature) = 0.1 → suppress 146 | ``` 147 | 148 | If temperature is very predictive: 149 | 150 | ``` 151 | Gate(temperature) = 0.9 → amplify 152 | ``` 153 | 154 | --- 155 | 156 | ## 3️⃣ Static Covariate Encoders 157 | 158 | ![Image](https://github.com/user-attachments/assets/0e1f2468-9c63-4d62-8c5d-50991654e0a4) 159 | 160 | ### 🔍 What It Does 161 | Takes features that don’t change over time — like store ID, region, or type — and injects them into the entire model. 162 | 163 | - Helps personalize the model across entities. 164 | - For example: *“Region A always peaks at 6 PM, Region B at 8 PM”* 165 | 166 | --- 167 | 168 | ## 4️⃣ LSTM Encoder-Decoder 169 | 170 | ![Image](https://github.com/user-attachments/assets/3916662c-61e0-44bc-bd63-14223c6f6dcd) 171 | 172 | ### 🔁 Purpose 173 | These are **sequence models** that: 174 | - **Encode past time steps** 175 | - **Decode future known inputs** to predict the target 176 | 177 | ### 🧠 Think of it like: 178 | > LSTM Encoder: “Here's what happened in the last 4 days.” 179 | > LSTM Decoder: “Given that and what's planned (e.g., calendar), here’s what might happen tomorrow.” 180 | 181 | --- 182 | 183 | ## 5️⃣ Static Enrichment 184 | 185 | ![Image](https://github.com/user-attachments/assets/b470c018-3610-49c2-9946-8c84da7d4e95) 186 | 187 | - Combines static features with each temporal step. 188 | - Ensures that **personalization** affects both past and future modeling. 189 | 190 | --- 191 | 192 | ## 6️⃣ Temporal Self-Attention (Masked Multi-head) 193 | 194 | ![Image](https://github.com/user-attachments/assets/4d15acf7-115b-4991-a7fe-0149b6361807) 195 | 196 | ### 🔍 What It Does 197 | Let’s the model **attend to important past steps** across the time sequence. 198 | 199 | ### 🧠 Example 200 | 201 | Imagine predicting electricity usage for `t+1`. 202 | 203 | - Attention might find: 204 | - `t-1` → useful (recent trend) 205 | - `t-24` → useful (same hour yesterday) 206 | - `t-168` → very useful (same hour last week) 207 | 208 | ### 🧮 Attention Score 209 | 210 | For each time step `t`: 211 | ```math 212 | AttentionScore(i) = Qₜ · Kᵢᵀ / sqrt(d_k) 213 | ``` 214 | 215 | - `Q`, `K` = query/key vectors from temporal inputs 216 | - Masking ensures it doesn’t peek into the future 217 | 218 | --- 219 | 220 | ## 7️⃣ Position-wise Feed-Forward Layers 221 | 222 | ![image](https://github.com/user-attachments/assets/6f152b61-dedc-49a8-9279-c07f2f541c9a) 223 | 224 | - Applies dense transformations to each time step 225 | - Helps model **non-linear interactions** over time 226 | 227 | --- 228 | 229 | ## 8️⃣ Output Layer: Quantile Forecasts 230 | 231 | ![Image](https://github.com/user-attachments/assets/9afd876a-e313-4867-9b57-e9354395c0f7) 232 | 233 | The model predicts a **range** instead of just a point. 234 | 235 | ### 🔍 Example 236 | 237 | ``` 238 | For t+1 → 10th percentile: 52.1 239 | 50th percentile (median): 58.4 240 | 90th percentile: 64.3 241 | ``` 242 | 243 | This gives you a **confidence band** rather than a single guess. 244 | 245 | ### 🧮 Quantile Loss Function 246 | 247 | For quantile `q`: 248 | 249 | ```math 250 | QL(y, ŷ, q) = q * max(y - ŷ, 0) + (1 - q) * max(ŷ - y, 0) 251 | ``` 252 | 253 | - Penalizes underestimates more if `q` is high (e.g., p90) 254 | - Helps model learn **risk-aware** forecasts 255 | 256 | --- 257 | 258 | ## 🧾 Full Dummy Forecast Example 259 | 260 | Input: 261 | - Past `y`: [50, 52, 55, 60, 58] 262 | - Future Exogenous1 (temperature): [35, 36, 37] 263 | - Future Exogenous2 (hour): [14, 15, 16] 264 | 265 | Output: 266 | 267 | | Time | p10 | p50 | p90 | 268 | |--------|------|------|------| 269 | | t+1 | 55.2 | 58.4 | 62.7 | 270 | | t+2 | 56.8 | 60.1 | 64.9 | 271 | | t+3 | 57.5 | 61.2 | 65.5 | 272 | 273 | --- 274 | 275 | ## ✅ Summary: Why TFT Excels 276 | 277 | | Feature | Benefit | 278 | |----------------------------|-----------------------------------------------| 279 | | Variable Selection | Learns *which inputs* matter when | 280 | | Gated Residual Networks | Learns *how much* of each signal to use | 281 | | Temporal Attention | Focuses on *important time steps* | 282 | | LSTM Encoder-Decoder | Understands *short-term patterns* | 283 | | Static Enrichment | Adapts to *individual series/entities* | 284 | | Quantile Forecasts | Captures *uncertainty* in predictions | 285 | 286 | --- 287 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Temporal Fusion Transformer (TFT) – An Intuitive Guide 2 | 3 | ## Overview 4 | 5 | **Temporal Fusion Transformer (TFT)** is a deep learning model designed for **interpretable multi-horizon time series forecasting**. It outperforms traditional models and gives insights into **what**, **when**, and **why** a prediction was made. 6 | 7 | This guide walks through: 8 | - A clear breakdown of the TFT architecture 9 | - A dry-run of how data flows through the model 10 | - Example datasets and input formatting 11 | - Role of exogenous variables 12 | - Simplified explanation of key formulas from the research paper 13 | 14 | --- 15 | 16 | ## Architecture Explained Simply 17 | 18 | TFT combines powerful sequence modeling (like LSTM and attention) with interpretability tools. It's structured like a forecasting control center: 19 | 20 | ### Key Components 21 | 22 | | Component | What it Does | 23 | |-----------------------------|------------------------------------------------------------------------------| 24 | | **Gating (GRN)** | Turns off unnecessary parts of the network (like circuit breakers). | 25 | | **Variable Selection** | Learns which variables are useful at each step. | 26 | | **Static Covariate Encoder**| Adds information about each entity (store ID, patient, etc.) everywhere. | 27 | | **Sequence Encoder (LSTM)** | Captures recent patterns (e.g., last 4 days of demand). | 28 | | **Multi-head Attention** | Focuses on important time steps. | 29 | | **Quantile Regression** | Predicts a range of possible future outcomes (p10, p50, p90). | 30 | 31 | --- 32 | 33 | ## Data Flow (Dry Run Walkthrough) 34 | 35 | ### Let's Forecast Electricity Demand 36 | 37 | You have: 38 | - `y`: past electricity usage (target) 39 | - `Exogenous1`, `Exogenous2`: extra features like temp or external load 40 | - `hour`, `day`, `month`: known future features 41 | 42 | ```python 43 | # Convert to Darts TimeSeries 44 | series = TimeSeries.from_dataframe(df, value_cols='y') 45 | future_covariates = TimeSeries.from_dataframe(df[['Exogenous1', 'Exogenous2', 'hour', 'day']]) 46 | ``` 47 | 48 | 1. **TFT Looks Back** 4 days (input_chunk_length=96): sees past `y`, `Exogenous1/2`, time features 49 | 2. **TFT Looks Forward** 1 day (output_chunk_length=24): uses known future exogenous values 50 | 3. **Learns What Matters**: dynamically selects which features to use 51 | 4. **Predicts Quantiles**: p10 (lower), p50 (median), p90 (upper) 52 | 53 | --- 54 | 55 | ## Dataset Format 56 | 57 | | ds | y | Exogenous1 | Exogenous2 | hour | day | 58 | |---------------------|--------|------------|------------|------|-----| 59 | | 2023-01-01 00:00:00 | 70 | 49593 | 57253 | 0 | 1 | 60 | | 2023-01-01 01:00:00 | 65 | 48000 | 55000 | 1 | 1 | 61 | 62 | ### Exogenous Variables: 63 | - Provide external context 64 | - Help improve accuracy 65 | - Must be known into the future (e.g., weather forecasts, holidays) 66 | 67 | --- 68 | 69 | ## 🔧 Data Processing – Why It Matters 70 | 71 | TFT requires **structured, complete, and aligned** time series inputs. Simply throwing raw CSV data at it won’t work. Proper preprocessing ensures: 72 | - ✅ No duplicate timestamps (which would break temporal order) 73 | - ✅ All timestamps are aligned (missing hours are filled) 74 | - ✅ Past and future features are formatted correctly 75 | - ✅ Exogenous features (covariates) are available where needed 76 | 77 | This section highlights how each transformation feeds into the model and **why it's essential**. 78 | 79 | --- 80 | 81 | ## 📥 Processing Flow and Intermediate Objects 82 | 83 | ### Step-by-step Breakdown with Output 84 | 85 | #### 📌 Step 1: Load and clean the main data 86 | 87 | ```python 88 | df = pd.read_csv("/kaggle/input/dummydata/electricity.csv", index_col='ds', parse_dates=True) 89 | df = df[~df.index.duplicated(keep='first')] 90 | ``` 91 | 92 | - **Purpose**: Ensures time index is clean, with unique, ordered timestamps. 93 | - Without this, Darts' resampling (`fill_missing_dates=True`) would fail. 94 | 95 | --- 96 | 97 | #### 📌 Step 2: Create the main target `series` 98 | 99 | ```python 100 | series = TimeSeries.from_dataframe(df, value_cols='y', fill_missing_dates=True, freq='h') 101 | ``` 102 | 103 | - **Object:** `series` 104 | - **Contains:** The time series of electricity usage (`y`) as the target variable. 105 | - **Why important?** This is what the model learns to predict. 106 | - **Internally:** Looks like: 107 | 108 | ``` 109 | TimeSeries 110 | start: 2016-10-22 00:00:00 111 | end: 2016-12-30 23:00:00 112 | data: 113 | y 114 | time 115 | 2016-10-22 00:00:00 70.00 116 | 2016-10-22 01:00:00 37.10 117 | ... ... 118 | ``` 119 | 120 | --- 121 | 122 | #### 📌 Step 3: Extract past exogenous features (covariates) 123 | 124 | ```python 125 | X_past = df[['Exogenous1', 'Exogenous2']] 126 | covariates = TimeSeries.from_dataframe(X_past, fill_missing_dates=True, freq='h') 127 | ``` 128 | 129 | - **Object:** `covariates` 130 | - **Contains:** Features known in the past that may influence `y` (e.g., load in other regions). 131 | - **Why important?** Helps the model learn correlations like: 132 | > “When `Exogenous1` goes up, `y` tends to increase.” 133 | 134 | - **Internally:** 135 | 136 | ``` 137 | TimeSeries 138 | columns: ['Exogenous1', 'Exogenous2'] 139 | data: 140 | Exogenous1 Exogenous2 141 | time 142 | 2016-10-22 00:00:00 49593 57253 143 | 2016-10-22 01:00:00 46073 51887 144 | ... 145 | ``` 146 | 147 | --- 148 | 149 | #### 📌 Step 4: Prepare known future inputs 150 | 151 | ```python 152 | future_df = pd.read_csv('/kaggle/input/dummydata/electricity-future.csv', index_col='ds', parse_dates=True) 153 | future_df = future_df[~future_df.index.duplicated(keep='first')] 154 | X_future = future_df[['Exogenous1', 'Exogenous2']] 155 | ``` 156 | 157 | - **Why important?** 158 | - You don’t know `y` for the future (that’s what you're predicting). 159 | - But you *do* know things like calendar info or scheduled events. 160 | - These are essential for multi-step forecasting. 161 | 162 | --- 163 | 164 | #### 📌 Step 5: Combine past and future into `future_covariates` 165 | 166 | ```python 167 | X = pd.concat([X_past, X_future]) 168 | future_covariates = TimeSeries.from_dataframe(X, fill_missing_dates=True, freq='H') 169 | ``` 170 | 171 | - **Object:** `future_covariates` 172 | - **Contains:** Covariates that are available from past *and* known in future. 173 | - **Why important?** 174 | - TFT uses these for the decoder (to condition predictions). 175 | 176 | - **Internally:** 177 | 178 | ``` 179 | TimeSeries 180 | columns: ['Exogenous1', 'Exogenous2'] 181 | data: 182 | Exogenous1 Exogenous2 183 | time 184 | ... ... ... 185 | 2016-12-31 00:00:00 64108 70318 186 | 2016-12-31 01:00:00 62492 67898 187 | ... 188 | ``` 189 | 190 | --- 191 | 192 | ## 🎯 Summary: Why These Objects Matter 193 | 194 | | Object | Purpose | Used for | 195 | |----------------------|-------------------------------------------|------------------| 196 | | `series` | Main target values to predict (`y`) | Training & eval | 197 | | `covariates` | Past context features (Exogenous1/2) | Encoder | 198 | | `future_covariates` | Known future values (Exogenous1/2) | Decoder | 199 | 200 | - Without these structures, TFT cannot: 201 | - Learn relationships between variables 202 | - Predict into the future 203 | - Handle multiple horizons properly 204 | 205 | --- 206 | 207 | ## Learning Objective (Formula Simplified) 208 | 209 | From the paper: 210 | 211 | ``` 212 | ŷ(q, t, τ) = f_q(τ, y_{t-k:t}, z_{t-k:t}, x_{t-k:t+τ}, s) 213 | ``` 214 | 215 | Where: 216 | - `y`: past target values 217 | - `z`: past observed features (Exogenous1, etc.) 218 | - `x`: known future inputs (e.g., hour, day) 219 | - `s`: static info (store ID, etc.) 220 | - `q`: quantile (like 10%, 50%, 90%) 221 | - `τ`: forecast horizon (steps ahead) 222 | 223 | > The model learns to estimate a range of likely future values based on all this information. 224 | 225 | --- 226 | 227 | ## Final Dry Run Example 228 | 229 | ```python 230 | # Train best model after tuning 231 | model = TFTModel( 232 | input_chunk_length=96, 233 | output_chunk_length=24, 234 | hidden_size=64, 235 | lstm_layers=2, 236 | dropout=0.1, 237 | num_attention_heads=4, 238 | batch_size=32, 239 | n_epochs=50, 240 | ) 241 | 242 | model.fit(series, future_covariates=future_covariates) 243 | forecast = model.predict(n=24, future_covariates=future_covariates) 244 | ``` 245 | 246 | --- 247 | 248 | ## Quantile Loss Explained 249 | 250 | ```math 251 | QL(y, ŷ, q) = q * max(y - ŷ, 0) + (1 - q) * max(ŷ - y, 0) 252 | ``` 253 | 254 | - Penalizes over- and under-prediction differently 255 | - Predicts ranges, not just point forecasts 256 | 257 | --- 258 | 259 | ## When to Use TFT 260 | 261 | | Scenario | Is TFT a Good Fit? | 262 | |-------------------------------|---------------------| 263 | | Multi-step forecasting | Yes | 264 | | External/known future inputs | Yes | 265 | | Need for interpretability | Yes | 266 | | Irregular or missing data | Not ideal | 267 | 268 | --- 269 | 270 | ## References 271 | 272 | > Bryan Lim et al., *Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting* - https://arxiv.org/abs/1912.09363 273 | 274 | ## 👤 Author 275 | 276 | For any questions or issues, please open an issue on GitHub: [@Siddharth Mishra](https://github.com/Sid3503) 277 | 278 | --- 279 | 280 |

281 | Made with ❤️ and lots of ☕ 282 |

283 | --------------------------------------------------------------------------------