├── TFT.pdf
├── data
└── electricity-future.csv
├── architecture.md
└── README.md
/TFT.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sid3503/TFT-Forecasting/main/TFT.pdf
--------------------------------------------------------------------------------
/data/electricity-future.csv:
--------------------------------------------------------------------------------
1 | unique_id,ds,Exogenous1,Exogenous2
2 | BE,12/31/2016 0:00,64108,70318
3 | BE,12/31/2016 1:00,62492,67898
4 | BE,12/31/2016 2:00,61571,68379
5 | BE,12/31/2016 3:00,60381,64972
6 | BE,12/31/2016 4:00,60298,62900
7 | BE,12/31/2016 5:00,60339,62364
8 | BE,12/31/2016 6:00,62576,64242
9 | BE,12/31/2016 7:00,63732,65884
10 | BE,12/31/2016 8:00,66235,68217
11 | BE,12/31/2016 9:00,66801,69921
12 | BE,12/31/2016 10:00,66964,72069
13 | BE,12/31/2016 11:00,66667,72328
14 | BE,12/31/2016 12:00,66145,72493
15 | BE,12/31/2016 13:00,65052,74228
16 | BE,12/31/2016 14:00,63905,70369
17 | BE,12/31/2016 15:00,63626,67987
18 | BE,12/31/2016 16:00,63817,66692
19 | BE,12/31/2016 17:00,65908,68634
20 | BE,12/31/2016 18:00,69289,72362
21 | BE,12/31/2016 19:00,70329,73957
22 | BE,12/31/2016 20:00,69121,72544
23 | BE,12/31/2016 21:00,66647,69451
24 | BE,12/31/2016 22:00,65886,67823
25 | BE,12/31/2016 23:00,66846,72876
26 | DE,12/31/2017 0:00,1392.25,14744.5
27 | DE,12/31/2017 1:00,1289,13906
28 | DE,12/31/2017 2:00,1206.25,13579
29 | DE,12/31/2017 3:00,1120.5,13302.75
30 | DE,12/31/2017 4:00,1053.25,13232.5
31 | DE,12/31/2017 5:00,1039.75,12631.75
32 | DE,12/31/2017 6:00,1062,10718.25
33 | DE,12/31/2017 7:00,1101.25,10753.75
34 | DE,12/31/2017 8:00,1163.5,12495.75
35 | DE,12/31/2017 9:00,1262.25,14335.5
36 | DE,12/31/2017 10:00,1364,15218
37 | DE,12/31/2017 11:00,1455.75,16021.75
38 | DE,12/31/2017 12:00,1538.75,16549.75
39 | DE,12/31/2017 13:00,1578.25,15791.75
40 | DE,12/31/2017 14:00,1578.25,15209
41 | DE,12/31/2017 15:00,1537,14991.25
42 | DE,12/31/2017 16:00,1482.25,15601.75
43 | DE,12/31/2017 17:00,1422.25,16954.5
44 | DE,12/31/2017 18:00,1352,17005.75
45 | DE,12/31/2017 19:00,1300,16601
46 | DE,12/31/2017 20:00,1261,15977.75
47 | DE,12/31/2017 21:00,1243.25,15715
48 | DE,12/31/2017 22:00,1250.5,15876
49 | DE,12/31/2017 23:00,1270.75,15130
50 | FR,12/31/2016 0:00,64108,70318
51 | FR,12/31/2016 1:00,62492,67898
52 | FR,12/31/2016 2:00,61571,68379
53 | FR,12/31/2016 3:00,60381,64972
54 | FR,12/31/2016 4:00,60298,62900
55 | FR,12/31/2016 5:00,60339,62364
56 | FR,12/31/2016 6:00,62576,64242
57 | FR,12/31/2016 7:00,63732,65884
58 | FR,12/31/2016 8:00,66235,68217
59 | FR,12/31/2016 9:00,66801,69921
60 | FR,12/31/2016 10:00,66964,72069
61 | FR,12/31/2016 11:00,66667,72328
62 | FR,12/31/2016 12:00,66145,72493
63 | FR,12/31/2016 13:00,65052,74228
64 | FR,12/31/2016 14:00,63905,70369
65 | FR,12/31/2016 15:00,63626,67987
66 | FR,12/31/2016 16:00,63817,66692
67 | FR,12/31/2016 17:00,65908,68634
68 | FR,12/31/2016 18:00,69289,72362
69 | FR,12/31/2016 19:00,70329,73957
70 | FR,12/31/2016 20:00,69121,72544
71 | FR,12/31/2016 21:00,66647,69451
72 | FR,12/31/2016 22:00,65886,67823
73 | FR,12/31/2016 23:00,66846,72876
74 | NP,12/24/2018 0:00,49119,461
75 | NP,12/24/2018 1:00,48115,484
76 | NP,12/24/2018 2:00,47727,497
77 | NP,12/24/2018 3:00,47673,509
78 | NP,12/24/2018 4:00,47848,510
79 | NP,12/24/2018 5:00,48673,511
80 | NP,12/24/2018 6:00,50071,526
81 | NP,12/24/2018 7:00,52259,528
82 | NP,12/24/2018 8:00,54395,659
83 | NP,12/24/2018 9:00,56071,897
84 | NP,12/24/2018 10:00,56989,1140
85 | NP,12/24/2018 11:00,56985,1342
86 | NP,12/24/2018 12:00,56391,1547
87 | NP,12/24/2018 13:00,55922,1618
88 | NP,12/24/2018 14:00,55826,1664
89 | NP,12/24/2018 15:00,56216,1768
90 | NP,12/24/2018 16:00,56661,1964
91 | NP,12/24/2018 17:00,55479,2092
92 | NP,12/24/2018 18:00,53741,2243
93 | NP,12/24/2018 19:00,52591,2544
94 | NP,12/24/2018 20:00,51787,2785
95 | NP,12/24/2018 21:00,51488,2919
96 | NP,12/24/2018 22:00,50928,3119
97 | NP,12/24/2018 23:00,49889,3306
98 | PJM,12/24/2018 0:00,84069,10276
99 | PJM,12/24/2018 1:00,81849,9809
100 | PJM,12/24/2018 2:00,80671,9501
101 | PJM,12/24/2018 3:00,80229,9337
102 | PJM,12/24/2018 4:00,80734,9283
103 | PJM,12/24/2018 5:00,82735,9369
104 | PJM,12/24/2018 6:00,85937,9618
105 | PJM,12/24/2018 7:00,89189,9987
106 | PJM,12/24/2018 8:00,91449,10195
107 | PJM,12/24/2018 9:00,93233,10378
108 | PJM,12/24/2018 10:00,94025,10578
109 | PJM,12/24/2018 11:00,93853,10677
110 | PJM,12/24/2018 12:00,93122,10686
111 | PJM,12/24/2018 13:00,92263,10631
112 | PJM,12/24/2018 14:00,91663,10586
113 | PJM,12/24/2018 15:00,91472,10547
114 | PJM,12/24/2018 16:00,93134,10635
115 | PJM,12/24/2018 17:00,96723,11154
116 | PJM,12/24/2018 18:00,96387,11464
117 | PJM,12/24/2018 19:00,94939,11246
118 | PJM,12/24/2018 20:00,94035,11070
119 | PJM,12/24/2018 21:00,92923,10963
120 | PJM,12/24/2018 22:00,90970,10802
121 | PJM,12/24/2018 23:00,88037,10419
122 |
--------------------------------------------------------------------------------
/architecture.md:
--------------------------------------------------------------------------------
1 | # 🧠 Temporal Fusion Transformer (TFT) Architecture – Deep Dive
2 |
3 | This section walks you through the full **TFT architecture** (based on the image above) using:
4 |
5 | - 📊 Visual explanation (linked to components in the diagram)
6 | - 📚 Plain language with **real-world examples**
7 | - 🧮 Simplified **math formulas and intuitions**
8 | - 🔍 Dummy data to **see what's happening**
9 |
10 | ---
11 |
12 | ## 📌 Overview of the Flow
13 |
14 | 
15 |
16 | The model is split into **three parts**:
17 | 1. **Input Encoding**
18 | 2. **Temporal Fusion Decoder**
19 | 3. **Forecasting Head (Quantile Predictions)**
20 |
21 | 4. Here’s a **clear and interactive Mermaid flowchart** that visually represents the **Temporal Fusion Transformer (TFT) architecture**, inspired by the official diagram and aligned with our expanded explanation.
22 |
23 | You can embed this in a markdown file (`architecture.md`) or your `README.md` if your platform supports Mermaid (e.g., GitHub, Obsidian, MkDocs, etc.).
24 |
25 | ---
26 |
27 | ### 🌐 Mermaid Flowchart: Temporal Fusion Transformer (TFT)
28 |
29 |
30 | 📈 Click to expand TFT architecture flowchart (Mermaid)
31 |
32 | ```mermaid
33 | flowchart TD
34 | Inputs[Inputs] --> Encoder[LSTM Encoder]
35 | Inputs --> Decoder[LSTM Decoder]
36 | Inputs --> Static[Static Metadata Encoder]
37 |
38 | Encoder --> VarSelPast[Variable Selection - Past]
39 | Decoder --> VarSelFuture[Variable Selection - Future]
40 | Static --> StaticEnrich[Static Enrichment]
41 |
42 | VarSelPast --> GRNPast[GRN + Add & Norm - Past]
43 | VarSelFuture --> GRNFuture[GRN + Add & Norm - Future]
44 | StaticEnrich --> StaticGRN[Inject Static Context]
45 |
46 | GRNPast --> Attention[Masked Multi-head Attention]
47 | GRNFuture --> Attention
48 | StaticGRN --> Attention
49 |
50 | Attention --> PostGRN[GRN + Add & Norm - Attention]
51 | PostGRN --> FF[Feed-Forward Layer]
52 | FF --> Output[Quantile Forecast Head]
53 |
54 | Output --> P10[p10 Prediction]
55 | Output --> P50[p50 Prediction]
56 | Output --> P90[p90 Prediction]
57 |
58 | ```
59 |
60 |
61 |
62 | ---
63 | It learns:
64 | - What features matter (Feature Selection)
65 | - When it should pay attention (Temporal Attention)
66 | - How to interpret uncertainty (Quantile Regression)
67 |
68 | ---
69 |
70 | ## 1️⃣ Variable Selection Network
71 |
72 | 
73 |
74 | ### 🔍 What It Does
75 | Learns which input variables to focus on **at each time step** — automatically.
76 |
77 | ### 🧮 Intuition (Simplified)
78 |
79 | Let’s say you input:
80 |
81 | ```python
82 | x_t = [temperature, humidity, holiday, load_zone1, load_zone2]
83 | ```
84 |
85 | But for the hour `t = 12:00 PM`, only `temperature` and `load_zone1` really matter.
86 |
87 | The Variable Selection Network learns **weights** like:
88 |
89 | ```
90 | [0.8, 0.05, 0.05, 0.9, 0.05] → Softmax → attention-like scores
91 | ```
92 |
93 | It **filters** irrelevant inputs dynamically.
94 |
95 | ### 🧪 Dummy Example
96 |
97 | Imagine:
98 |
99 | ```
100 | At 8 AM → electricity load depends more on: [day, holiday, load_zone1]
101 | At 2 PM → it's more about: [temperature, humidity]
102 | ```
103 |
104 | TFT learns this mapping automatically using:
105 |
106 | ```math
107 | αₜ = Softmax(W₁ GRN₁(xₜ), ..., Wₙ GRNₙ(xₜ))
108 | ```
109 |
110 | Where:
111 | - Each variable is passed through a **Gated Residual Network**
112 | - Their outputs are weighted and summed
113 | - The softmax layer highlights the **most important variables**
114 |
115 | ---
116 |
117 | ## 2️⃣ Gated Residual Network (GRN)
118 |
119 | 
120 |
121 | ### 🔍 What It Does
122 | A smart block that decides how much of the transformed signal to **pass through or suppress**.
123 |
124 | ### 🧠 Analogy
125 |
126 | Imagine a **valve in a pipe**: water (information) flows in. The GRN learns **how much to open or close the valve** depending on whether that information is useful.
127 |
128 | ### 🧮 Formula (Simplified)
129 |
130 | ```math
131 | GRN(x) = Gate(x) ⊙ (LayerNorm(Dense₂(ELU(Dense₁(x)))))
132 | ```
133 |
134 | - `Dense₁`, `Dense₂` = Linear transformations
135 | - `ELU` = Activation function (non-linear twist)
136 | - `Gate(x)` = Learnable sigmoid-based switch (outputs 0–1)
137 |
138 | ### 🧪 Dummy Example
139 |
140 | Input = `temperature = 35°C`
141 |
142 | If temperature isn’t relevant now, the GRN might learn:
143 |
144 | ```
145 | Gate(temperature) = 0.1 → suppress
146 | ```
147 |
148 | If temperature is very predictive:
149 |
150 | ```
151 | Gate(temperature) = 0.9 → amplify
152 | ```
153 |
154 | ---
155 |
156 | ## 3️⃣ Static Covariate Encoders
157 |
158 | 
159 |
160 | ### 🔍 What It Does
161 | Takes features that don’t change over time — like store ID, region, or type — and injects them into the entire model.
162 |
163 | - Helps personalize the model across entities.
164 | - For example: *“Region A always peaks at 6 PM, Region B at 8 PM”*
165 |
166 | ---
167 |
168 | ## 4️⃣ LSTM Encoder-Decoder
169 |
170 | 
171 |
172 | ### 🔁 Purpose
173 | These are **sequence models** that:
174 | - **Encode past time steps**
175 | - **Decode future known inputs** to predict the target
176 |
177 | ### 🧠 Think of it like:
178 | > LSTM Encoder: “Here's what happened in the last 4 days.”
179 | > LSTM Decoder: “Given that and what's planned (e.g., calendar), here’s what might happen tomorrow.”
180 |
181 | ---
182 |
183 | ## 5️⃣ Static Enrichment
184 |
185 | 
186 |
187 | - Combines static features with each temporal step.
188 | - Ensures that **personalization** affects both past and future modeling.
189 |
190 | ---
191 |
192 | ## 6️⃣ Temporal Self-Attention (Masked Multi-head)
193 |
194 | 
195 |
196 | ### 🔍 What It Does
197 | Let’s the model **attend to important past steps** across the time sequence.
198 |
199 | ### 🧠 Example
200 |
201 | Imagine predicting electricity usage for `t+1`.
202 |
203 | - Attention might find:
204 | - `t-1` → useful (recent trend)
205 | - `t-24` → useful (same hour yesterday)
206 | - `t-168` → very useful (same hour last week)
207 |
208 | ### 🧮 Attention Score
209 |
210 | For each time step `t`:
211 | ```math
212 | AttentionScore(i) = Qₜ · Kᵢᵀ / sqrt(d_k)
213 | ```
214 |
215 | - `Q`, `K` = query/key vectors from temporal inputs
216 | - Masking ensures it doesn’t peek into the future
217 |
218 | ---
219 |
220 | ## 7️⃣ Position-wise Feed-Forward Layers
221 |
222 | 
223 |
224 | - Applies dense transformations to each time step
225 | - Helps model **non-linear interactions** over time
226 |
227 | ---
228 |
229 | ## 8️⃣ Output Layer: Quantile Forecasts
230 |
231 | 
232 |
233 | The model predicts a **range** instead of just a point.
234 |
235 | ### 🔍 Example
236 |
237 | ```
238 | For t+1 → 10th percentile: 52.1
239 | 50th percentile (median): 58.4
240 | 90th percentile: 64.3
241 | ```
242 |
243 | This gives you a **confidence band** rather than a single guess.
244 |
245 | ### 🧮 Quantile Loss Function
246 |
247 | For quantile `q`:
248 |
249 | ```math
250 | QL(y, ŷ, q) = q * max(y - ŷ, 0) + (1 - q) * max(ŷ - y, 0)
251 | ```
252 |
253 | - Penalizes underestimates more if `q` is high (e.g., p90)
254 | - Helps model learn **risk-aware** forecasts
255 |
256 | ---
257 |
258 | ## 🧾 Full Dummy Forecast Example
259 |
260 | Input:
261 | - Past `y`: [50, 52, 55, 60, 58]
262 | - Future Exogenous1 (temperature): [35, 36, 37]
263 | - Future Exogenous2 (hour): [14, 15, 16]
264 |
265 | Output:
266 |
267 | | Time | p10 | p50 | p90 |
268 | |--------|------|------|------|
269 | | t+1 | 55.2 | 58.4 | 62.7 |
270 | | t+2 | 56.8 | 60.1 | 64.9 |
271 | | t+3 | 57.5 | 61.2 | 65.5 |
272 |
273 | ---
274 |
275 | ## ✅ Summary: Why TFT Excels
276 |
277 | | Feature | Benefit |
278 | |----------------------------|-----------------------------------------------|
279 | | Variable Selection | Learns *which inputs* matter when |
280 | | Gated Residual Networks | Learns *how much* of each signal to use |
281 | | Temporal Attention | Focuses on *important time steps* |
282 | | LSTM Encoder-Decoder | Understands *short-term patterns* |
283 | | Static Enrichment | Adapts to *individual series/entities* |
284 | | Quantile Forecasts | Captures *uncertainty* in predictions |
285 |
286 | ---
287 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Temporal Fusion Transformer (TFT) – An Intuitive Guide
2 |
3 | ## Overview
4 |
5 | **Temporal Fusion Transformer (TFT)** is a deep learning model designed for **interpretable multi-horizon time series forecasting**. It outperforms traditional models and gives insights into **what**, **when**, and **why** a prediction was made.
6 |
7 | This guide walks through:
8 | - A clear breakdown of the TFT architecture
9 | - A dry-run of how data flows through the model
10 | - Example datasets and input formatting
11 | - Role of exogenous variables
12 | - Simplified explanation of key formulas from the research paper
13 |
14 | ---
15 |
16 | ## Architecture Explained Simply
17 |
18 | TFT combines powerful sequence modeling (like LSTM and attention) with interpretability tools. It's structured like a forecasting control center:
19 |
20 | ### Key Components
21 |
22 | | Component | What it Does |
23 | |-----------------------------|------------------------------------------------------------------------------|
24 | | **Gating (GRN)** | Turns off unnecessary parts of the network (like circuit breakers). |
25 | | **Variable Selection** | Learns which variables are useful at each step. |
26 | | **Static Covariate Encoder**| Adds information about each entity (store ID, patient, etc.) everywhere. |
27 | | **Sequence Encoder (LSTM)** | Captures recent patterns (e.g., last 4 days of demand). |
28 | | **Multi-head Attention** | Focuses on important time steps. |
29 | | **Quantile Regression** | Predicts a range of possible future outcomes (p10, p50, p90). |
30 |
31 | ---
32 |
33 | ## Data Flow (Dry Run Walkthrough)
34 |
35 | ### Let's Forecast Electricity Demand
36 |
37 | You have:
38 | - `y`: past electricity usage (target)
39 | - `Exogenous1`, `Exogenous2`: extra features like temp or external load
40 | - `hour`, `day`, `month`: known future features
41 |
42 | ```python
43 | # Convert to Darts TimeSeries
44 | series = TimeSeries.from_dataframe(df, value_cols='y')
45 | future_covariates = TimeSeries.from_dataframe(df[['Exogenous1', 'Exogenous2', 'hour', 'day']])
46 | ```
47 |
48 | 1. **TFT Looks Back** 4 days (input_chunk_length=96): sees past `y`, `Exogenous1/2`, time features
49 | 2. **TFT Looks Forward** 1 day (output_chunk_length=24): uses known future exogenous values
50 | 3. **Learns What Matters**: dynamically selects which features to use
51 | 4. **Predicts Quantiles**: p10 (lower), p50 (median), p90 (upper)
52 |
53 | ---
54 |
55 | ## Dataset Format
56 |
57 | | ds | y | Exogenous1 | Exogenous2 | hour | day |
58 | |---------------------|--------|------------|------------|------|-----|
59 | | 2023-01-01 00:00:00 | 70 | 49593 | 57253 | 0 | 1 |
60 | | 2023-01-01 01:00:00 | 65 | 48000 | 55000 | 1 | 1 |
61 |
62 | ### Exogenous Variables:
63 | - Provide external context
64 | - Help improve accuracy
65 | - Must be known into the future (e.g., weather forecasts, holidays)
66 |
67 | ---
68 |
69 | ## 🔧 Data Processing – Why It Matters
70 |
71 | TFT requires **structured, complete, and aligned** time series inputs. Simply throwing raw CSV data at it won’t work. Proper preprocessing ensures:
72 | - ✅ No duplicate timestamps (which would break temporal order)
73 | - ✅ All timestamps are aligned (missing hours are filled)
74 | - ✅ Past and future features are formatted correctly
75 | - ✅ Exogenous features (covariates) are available where needed
76 |
77 | This section highlights how each transformation feeds into the model and **why it's essential**.
78 |
79 | ---
80 |
81 | ## 📥 Processing Flow and Intermediate Objects
82 |
83 | ### Step-by-step Breakdown with Output
84 |
85 | #### 📌 Step 1: Load and clean the main data
86 |
87 | ```python
88 | df = pd.read_csv("/kaggle/input/dummydata/electricity.csv", index_col='ds', parse_dates=True)
89 | df = df[~df.index.duplicated(keep='first')]
90 | ```
91 |
92 | - **Purpose**: Ensures time index is clean, with unique, ordered timestamps.
93 | - Without this, Darts' resampling (`fill_missing_dates=True`) would fail.
94 |
95 | ---
96 |
97 | #### 📌 Step 2: Create the main target `series`
98 |
99 | ```python
100 | series = TimeSeries.from_dataframe(df, value_cols='y', fill_missing_dates=True, freq='h')
101 | ```
102 |
103 | - **Object:** `series`
104 | - **Contains:** The time series of electricity usage (`y`) as the target variable.
105 | - **Why important?** This is what the model learns to predict.
106 | - **Internally:** Looks like:
107 |
108 | ```
109 | TimeSeries
110 | start: 2016-10-22 00:00:00
111 | end: 2016-12-30 23:00:00
112 | data:
113 | y
114 | time
115 | 2016-10-22 00:00:00 70.00
116 | 2016-10-22 01:00:00 37.10
117 | ... ...
118 | ```
119 |
120 | ---
121 |
122 | #### 📌 Step 3: Extract past exogenous features (covariates)
123 |
124 | ```python
125 | X_past = df[['Exogenous1', 'Exogenous2']]
126 | covariates = TimeSeries.from_dataframe(X_past, fill_missing_dates=True, freq='h')
127 | ```
128 |
129 | - **Object:** `covariates`
130 | - **Contains:** Features known in the past that may influence `y` (e.g., load in other regions).
131 | - **Why important?** Helps the model learn correlations like:
132 | > “When `Exogenous1` goes up, `y` tends to increase.”
133 |
134 | - **Internally:**
135 |
136 | ```
137 | TimeSeries
138 | columns: ['Exogenous1', 'Exogenous2']
139 | data:
140 | Exogenous1 Exogenous2
141 | time
142 | 2016-10-22 00:00:00 49593 57253
143 | 2016-10-22 01:00:00 46073 51887
144 | ...
145 | ```
146 |
147 | ---
148 |
149 | #### 📌 Step 4: Prepare known future inputs
150 |
151 | ```python
152 | future_df = pd.read_csv('/kaggle/input/dummydata/electricity-future.csv', index_col='ds', parse_dates=True)
153 | future_df = future_df[~future_df.index.duplicated(keep='first')]
154 | X_future = future_df[['Exogenous1', 'Exogenous2']]
155 | ```
156 |
157 | - **Why important?**
158 | - You don’t know `y` for the future (that’s what you're predicting).
159 | - But you *do* know things like calendar info or scheduled events.
160 | - These are essential for multi-step forecasting.
161 |
162 | ---
163 |
164 | #### 📌 Step 5: Combine past and future into `future_covariates`
165 |
166 | ```python
167 | X = pd.concat([X_past, X_future])
168 | future_covariates = TimeSeries.from_dataframe(X, fill_missing_dates=True, freq='H')
169 | ```
170 |
171 | - **Object:** `future_covariates`
172 | - **Contains:** Covariates that are available from past *and* known in future.
173 | - **Why important?**
174 | - TFT uses these for the decoder (to condition predictions).
175 |
176 | - **Internally:**
177 |
178 | ```
179 | TimeSeries
180 | columns: ['Exogenous1', 'Exogenous2']
181 | data:
182 | Exogenous1 Exogenous2
183 | time
184 | ... ... ...
185 | 2016-12-31 00:00:00 64108 70318
186 | 2016-12-31 01:00:00 62492 67898
187 | ...
188 | ```
189 |
190 | ---
191 |
192 | ## 🎯 Summary: Why These Objects Matter
193 |
194 | | Object | Purpose | Used for |
195 | |----------------------|-------------------------------------------|------------------|
196 | | `series` | Main target values to predict (`y`) | Training & eval |
197 | | `covariates` | Past context features (Exogenous1/2) | Encoder |
198 | | `future_covariates` | Known future values (Exogenous1/2) | Decoder |
199 |
200 | - Without these structures, TFT cannot:
201 | - Learn relationships between variables
202 | - Predict into the future
203 | - Handle multiple horizons properly
204 |
205 | ---
206 |
207 | ## Learning Objective (Formula Simplified)
208 |
209 | From the paper:
210 |
211 | ```
212 | ŷ(q, t, τ) = f_q(τ, y_{t-k:t}, z_{t-k:t}, x_{t-k:t+τ}, s)
213 | ```
214 |
215 | Where:
216 | - `y`: past target values
217 | - `z`: past observed features (Exogenous1, etc.)
218 | - `x`: known future inputs (e.g., hour, day)
219 | - `s`: static info (store ID, etc.)
220 | - `q`: quantile (like 10%, 50%, 90%)
221 | - `τ`: forecast horizon (steps ahead)
222 |
223 | > The model learns to estimate a range of likely future values based on all this information.
224 |
225 | ---
226 |
227 | ## Final Dry Run Example
228 |
229 | ```python
230 | # Train best model after tuning
231 | model = TFTModel(
232 | input_chunk_length=96,
233 | output_chunk_length=24,
234 | hidden_size=64,
235 | lstm_layers=2,
236 | dropout=0.1,
237 | num_attention_heads=4,
238 | batch_size=32,
239 | n_epochs=50,
240 | )
241 |
242 | model.fit(series, future_covariates=future_covariates)
243 | forecast = model.predict(n=24, future_covariates=future_covariates)
244 | ```
245 |
246 | ---
247 |
248 | ## Quantile Loss Explained
249 |
250 | ```math
251 | QL(y, ŷ, q) = q * max(y - ŷ, 0) + (1 - q) * max(ŷ - y, 0)
252 | ```
253 |
254 | - Penalizes over- and under-prediction differently
255 | - Predicts ranges, not just point forecasts
256 |
257 | ---
258 |
259 | ## When to Use TFT
260 |
261 | | Scenario | Is TFT a Good Fit? |
262 | |-------------------------------|---------------------|
263 | | Multi-step forecasting | Yes |
264 | | External/known future inputs | Yes |
265 | | Need for interpretability | Yes |
266 | | Irregular or missing data | Not ideal |
267 |
268 | ---
269 |
270 | ## References
271 |
272 | > Bryan Lim et al., *Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting* - https://arxiv.org/abs/1912.09363
273 |
274 | ## 👤 Author
275 |
276 | For any questions or issues, please open an issue on GitHub: [@Siddharth Mishra](https://github.com/Sid3503)
277 |
278 | ---
279 |
280 |
281 | Made with ❤️ and lots of ☕
282 |
283 |
--------------------------------------------------------------------------------