├── src ├── __init__.py ├── downloader.py └── common.py ├── .gitignore ├── api_keys.json.example ├── requirements.txt ├── 中文說明.md ├── README.md ├── crypto_screener.py ├── stock_screener.py ├── crypto_trend_screener.py └── crypto_historical_trend_finder.py /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea/ 3 | -------------------------------------------------------------------------------- /api_keys.json.example: -------------------------------------------------------------------------------- 1 | { 2 | "stocksymbol": "", 3 | "polygon": "" 4 | } 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytz 2 | pandas 3 | requests 4 | stocksymbol 5 | python-binance 6 | polygon-api-client 7 | dtaidistance 8 | matplotlib 9 | shapedtw -------------------------------------------------------------------------------- /中文說明.md: -------------------------------------------------------------------------------- 1 | 2 | # 股票和加密貨幣走勢篩選工具使用指南 3 | 4 | ## 安裝設置 5 | ```bash 6 | pip3 install -r requirements.txt 7 | ``` 8 | 對於美股功能,需要 Polygon 和 Stocksymbol 的 API 金鑰(加密貨幣功能不需要)。將 `api_keys.json.example` 重命名為 `api_keys.json` 並填入 API 金鑰。 9 | 10 | ## 1. 加密貨幣篩選器 (Crypto Screener) 11 | 12 | ### 用途 13 | 識別當前市場中表現強勢的加密貨幣,通過計算其與多個移動平均線的相對關係來量化強度。 14 | 15 | ### 運作流程 16 | 1. 連接到加密貨幣交易所 API 獲取數據 17 | 2. 計算每個加密貨幣的移動平均線和 ATR 技術指標 18 | 3. 應用相對強度計算公式,評估價格與移動平均線的關係 19 | 4. 對結果進行排序,生成強勢標的列表 20 | 5. 輸出兼容 TradingView 的文件格式 21 | 22 | ### 關鍵參數設定 23 | - **時間框架 (-t)**: (5m, 15m, 1h, 4h) 24 | - **計算持續時間 (-d)**: 控制歷史數據量 25 | - **程式內部參數**: 26 | - 相對強度公式: 考量價格與多個移動平均線之間的關係 27 | - ATR標準化: 確保不同價格範圍的資產可以公平比較 28 | 29 | ### 執行方法 30 | ```bash 31 | python3 crypto_screener.py [選項] 32 | ``` 33 | 34 | 選項: 35 | - `-t, --timeframe <值>`: 時間框架 (Optional; 可選值: 5m, 15m, 30m, 1h, 2h, 4h, 8h, 1d; 默認: "15m") 36 | - `-d, --days <值>`: 計算持續時間,單位為天 (Optional; 默認: 3) 37 | 38 | 例如: 39 | ```bash 40 | python3 crypto_screener.py 41 | python3 crypto_screener.py -t "1h" -d 5 42 | ``` 43 | 44 | ## 2. 股票篩選器 (Stock Screener) 45 | 46 | ### 用途 47 | 識別相對於大盤強勢且符合技術走勢條件的股票,通過篩選提供強勢選標的。 48 | 49 | ### 運作流程 50 | 1. 首先計算 SPY 指數的基準相對強度 51 | 2. 並行獲取股票歷史數據(日線和小時級別) 52 | 3. 執行成交額和走勢條件篩選 53 | 4. 計算每個股票的相對強度得分 54 | 5. 與 SPY 相對強度比較,篩選出優於大盤的標的 55 | 6. 輸出排序結果 56 | 57 | ### 關鍵參數設定 58 | - **每日回溯期 (_1D_OF_DAYS_TRACEBACK)**: 默認252 59 | - **小時回溯期 (_1H_OF_DAYS_TRACEBACK)**: 默認126 60 | - **最低成交額 (MIN_TURNOVER)**: 篩選具流動性的股票,默認10,000,000 61 | - **Minervini 條件 (-g 選項控制)**: 62 | - 啟用: 應用額外的技術條件篩選(如價格與移動平均線的關係) 63 | - 停用: 僅依據相對強度評分,不考慮技術條件 64 | 65 | ### 執行方法 66 | ```bash 67 | python3 stock_screener.py [選項] 68 | ``` 69 | 70 | 選項: 71 | - `-a, --all`: 在輸出中包含所有強勢標的而不僅是前980個 (Optional; 默認: 僅包含前980個,因為TradingView有導入限制) 72 | - `-g`: 忽略 Minervini 走勢模板條件,僅計算相對強度得分 (Optional; 默認: 應用Minervini條件) 73 | 74 | 例如: 75 | ```bash 76 | python3 stock_screener.py 77 | python3 stock_screener.py -a -g 78 | ``` 79 | 80 | ## 3. 加密貨幣走勢篩選器 (Crypto Trend Screener) 81 | 82 | ### 用途 83 | 基於使用者定義的參考走勢,使用動態時間規整(DTW)算法識別當前市場中發現潛在交易機會。 84 | 85 | ### 運作流程 86 | 1. 載入或獲取預定義的參考走勢數據 87 | 2. 為每個待分析標的下載歷史數據 88 | 3. 通過DTW計算標的與參考走勢的相似度 89 | - 先用dtaidistance dtw 快速篩選 (有C可以加速) 90 | - 再用shapedtw,做到更好的對齊 91 | - 以上都是進行價格走勢對比,再分析移動平均線關係對比 92 | 4. 綜合計算相似度得分 93 | 5. 生成視覺化報告及TradingView觀察列表 94 | 95 | ### 關鍵參數設定 96 | - **-f, --file `<文件路徑>`**: 強勢標的文件路徑 (Optional; 默認: 空,使用所有可用標的) 97 | - **--asset `<類型>`**: 資產類型 (Optional; 可選值為 'crypto' 或 'stock'; 默認: 'crypto') 98 | - **-nv, --no_visualize**: 禁用 DTW 對齊的可視化 (Optional; 默認: 啟用可視化) 99 | - **-k, --topk `<數值>`**: 每個參考走勢記錄的最匹配標的數量 (Optional; 默認: 6) 100 | - **-s, --sleep `<秒數>`**: API 請求之間的休眠時間,單位為秒 (Optional; 默認: 0.5) 101 | - **TIMEFRAMES_TO_ANALYZE**: 分析的時間框架陣列,默認 ["15m", "30m", "1h", "2h", "4h"] 102 | - **TIMEZONE**: 時區設置,默認 "America/Los_Angeles",可修改為您所在地區如 "Asia/Taipei" 103 | - **DTW_WINDOW_RATIO**: 控制水平匹配範圍,默認0.2 104 | - **DTW_MAX_POINT_DISTANCE**: 控制垂直匹配限制,默認0.66 105 | - **PRICE_WEIGHT 和 DIFF_WEIGHT**: 價格與移動平均線關係匹配的權重分配,默認0.4和0.6 106 | - **REFERENCE_TRENDS**: 參考走勢配置,格式為 `[開始日期時間, 結束日期時間, 時間框架, 標籤]` 107 | 108 | ### 執行方法 109 | ```bash 110 | python3 crypto_trend_screener.py [選項] 111 | ``` 112 | 113 | 例如: 114 | ```bash 115 | python3 crypto_trend_screener.py 116 | python3 crypto_trend_screener.py --asset=crypto -f strong_targets.txt -k 10 -s 1.0 117 | ``` 118 | 119 | ### 如何修改參考走勢 120 | 在程式碼中修改 `REFERENCE_TRENDS` 字典來定義您想搜尋的參考走勢: 121 | 122 | ```python 123 | REFERENCE_TRENDS = { 124 | "AVAX": [ 125 | [datetime(2023, 11, 9, 12, 0), datetime(2023, 11, 14, 15, 0), "1h", "standard"], 126 | ], 127 | "SOL": [ 128 | [datetime(2023, 9, 23, 0, 0), datetime(2023, 10, 15, 21, 0), "4h", "standard"] 129 | ], 130 | # 添加您自己的參考走勢 131 | } 132 | ``` 133 | 134 | ### 相似度視覺化 135 | 腳本會生成兩種視覺化圖表: 136 | 1. 價格/SMA對齊圖: 顯示參考走勢和目標走勢的形態對比 137 | 2. SMA差異圖: 顯示移動平均線關係的匹配情況 138 | 139 | ## 4. 歷史相似走勢分析器 (Crypto Historical Trend Finder) 140 | 141 | ### 用途 142 | 搜尋歷史數據中與參考走勢相似的走勢,並分析這些走勢的後續價格走向,提供統計性預測參考。 143 | 144 | ### 運作流程 145 | 1. 在多個時間框架(15m, 30m, 1h, 2h, 4h)搜尋歷史相似走勢 146 | 2. 使用DTW算法進行走勢匹配 147 | 3. 分析每個匹配走勢的後續價格走向 148 | 4. 計算統計數據(如上漲/下跌比例) 149 | 5. 生成包含過去、走勢、未來的完整視覺化報告 150 | 151 | ### 關鍵參數設定 152 | - `-k, --topk <數值>`: 保留的最佳匹配數量 (Optional; 默認: 300) 153 | - `-s, --sleep <秒數>`: API 請求間隔時間,單位為秒 (Optional; 默認: 15) 154 | - **TOP_K**: 保留的最佳匹配數量,默認300 155 | - **API_SLEEP_SECONDS**: API請求間隔時間,默認15秒 156 | - **EXTENSION_FACTORS_FOR_STATS**: 延伸係數統計範圍,默認[0.25, 0.5, 0.75, 1.0, 1.5, 2.0, 2.5] 157 | - **VIS_EXTENSION_FUTURE_LENGTH_FACTOR**: 未來視覺化延伸係數,默認2.0 158 | - **REFERENCE_TRENDS**: 參考走勢配置,格式為 `[開始日期時間, 結束日期時間, 時間框架, 標籤]`,建議選擇的參考走勢高低點的K棒數量要相似,走勢結尾應該是理想的進場點位 159 | - **GLOBAL_OVERLAP_FILTERING**: 重疊過濾策略,默認True 160 | - **True (全域過濾)**: 在所有標的和時間框架中確保沒有任何時間重疊的走勢。如果發現兩個相似走勢在時間上有重疊(即使來自不同標的),只保留相似度最高的那個。這確保每個時間段只有一個最佳匹配走勢,適合需要高品質、無重複的分析樣本。 161 | - **False (個別標的過濾)**: 只在同一標的內避免時間重疊,允許不同標的在相同時間段都有匹配走勢。例如BTC和ETH可以在同一時間段都被識別為相似走勢。這會產生更多的分析樣本,但可能包含市場整體走勢的影響。 162 | - **TIMEZONE**: 時區設置,默認 "America/Los_Angeles",影響日期時間顯示和參考走勢時間解析 163 | - **HISTORICAL_START_DATE**: 歷史數據起始日期,默認datetime(2021, 1, 1) 164 | - **TIMEFRAMES_TO_ANALYZE**: 分析的時間框架,默認["15m", "30m", "1h", "2h", "4h"] 165 | 166 | ### 上漲/下跌認定與統計分析說明 167 | 168 | #### 走勢方向判定機制 169 | 系統使用兩種不同的延伸係數來進行分析: 170 | 171 | 1. **視覺化標記用** - **VIS_EXTENSION_FUTURE_LENGTH_FACTOR** (默認2.0) 172 | - 用於判定每個找到的歷史走勢是「上漲」或「下跌」 173 | - 決定視覺化圖表中檔案名稱的標記(如 `_rise.png`、`_fall.png`) 174 | - 計算方式:走勢長度 × 2.0 = 觀察的未來期間 175 | - 例如:30根K棒的走勢 × 2.0 = 觀察未來60根K棒 176 | - 比較走勢結束時收盤價 vs 觀察期間結束時收盤價來判定漲跌 177 | 178 | 2. **統計分析用** - **EXTENSION_FACTORS_FOR_STATS** [0.25, 0.5, 0.75, 1.0, 1.5, 2.0, 2.5] 179 | - 用於生成多維度統計報告 180 | - 提供不同時間長度的預測成功率分析 181 | - 與視覺化標記獨立運作 182 | 183 | #### 統計分析目的與方法 184 | 程式對所有找到的歷史相似走勢進行統計分析,目的是回答「如果現在出現類似的走勢,未來價格走向的機率是多少?」 185 | 186 | **多時間範圍統計**: 187 | - **短期預測** (0.25x): 觀察走勢長度1/4的未來期間 188 | - **中期預測** (1.0x): 觀察與走勢等長的未來期間 189 | - **長期預測** (2.5x): 觀察走勢長度2.5倍的未來期間 190 | 191 | **統計指標計算**: 192 | 對每個延伸係數分別計算: 193 | - **上漲比例**: 未來價格上漲的走勢數量 / 總有效走勢數量 194 | - **下跌比例**: 未來價格下跌的走勢數量 / 總有效走勢數量 195 | - **數據不足比例**: 未來數據不充分的走勢比例 196 | - **無數據比例**: 完全沒有未來數據的走勢比例 197 | 198 | **統計報告範例**: 199 | ``` 200 | Extension Factor Analysis: 201 | 0.25x: Rise 198(43.5%) | Fall 254(55.8%) | Insufficient 3(0.7%) 202 | 0.5x: Rise 191(42.0%) | Fall 259(56.9%) | Insufficient 5(1.1%) 203 | 0.75x: Rise 186(40.9%) | Fall 260(57.1%) | Insufficient 9(2.0%) 204 | 1.0x: Rise 169(37.1%) | Fall 275(60.4%) | Insufficient 11(2.4%) 205 | 1.5x: Rise 166(36.5%) | Fall 268(58.9%) | Insufficient 21(4.6%) 206 | 2.0x: Rise 159(34.9%) | Fall 267(58.7%) | Insufficient 29(6.4%) 207 | 2.5x: Rise 167(36.7%) | Fall 250(54.9%) | Insufficient 38(8.4%) 208 | ``` 209 | 210 | **特殊情況處理**: 211 | - **insufficient_data**: 未來數據不足指定觀察期間,但仍有部分數據可分析 212 | - **no_future_data**: 走勢結束後完全沒有未來數據可供分析 213 | 214 | ### 執行方法 215 | ```bash 216 | python3 crypto_historical_trend_finder.py [選項] 217 | ``` 218 | 219 | 例如: 220 | ```bash 221 | python3 crypto_historical_trend_finder.py 222 | python3 crypto_historical_trend_finder.py -k 500 -s 10 223 | ``` 224 | 225 | ### 輸出報告 226 | - **整體統計摘要**: 所有時間框架的綜合統計 227 | - **時間框架分析**: 各時間框架的詳細結果 228 | - **三段式視覺化**: 過去+走勢+未來的完整分析圖表 229 | - **走勢統計**: 例如「75%的相似走勢在2倍走勢長度內出現上漲」 230 | 231 | ## 結果解讀與應用 232 | 233 | ### 加密貨幣和股票篩選結果 234 | - 高分標的表示相對強度高,多數情況下表現優於市場平均 235 | - 可以定期運行篩選器追蹤強度變化走勢 236 | 237 | ### 相似走勢匹配結果 238 | - Distance越小越好,Score越大越好 239 | - 高匹配分數表示當前走勢與歷史走勢高度相似 240 | - 視覺化圖表中的連接線顯示時間軸上的具體對齊點 241 | - 可以研究參考走勢後續發展來評估當前標的的潛在走勢 242 | 243 | ### 歷史走勢分析結果 244 | - 統計報告顯示相似走勢的歷史表現,如「漲跌比例」 245 | - 可以根據統計結果評估當前相似走勢的潛在走勢 246 | - 不同延伸係數提供短期和長期的預測參考 247 | - 數據充足性指標幫助評估預測的可靠程度 248 | 249 | 所有篩選工具的結果文件都兼容TradingView格式,便於進一步技術分析和決策。 250 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Screener 2 | 3 | 下載美股和加密貨幣的歷史數據並透過自己的策略去找到強勢標的 4 | 5 | The purpose of this project is to download historical data for US stocks and cryptocurrencies, and use different strategies to identify strong performing assets. 6 | 7 | ## Installation 8 | 9 | ```bash 10 | pip3 install -r requirements.txt 11 | ``` 12 | 13 | API keys are needed for [Polygon](https://polygon.io) and [Stocksymbol](https://stock-symbol.herokuapp.com) (Not a requirement for Crypto usage) 14 | 15 | Rename `api_keys.json.example` to `api_keys.json` 16 | 17 | ## Usage 18 | 19 | ### 1. Crypto Screener 20 | 21 | Identify strong performing cryptocurrencies by comparing them with SMA-30, SMA-45 and SMA-60. 22 | 23 | ```bash 24 | python3 crypto_screener.py 25 | python3 crypto_screener.py -t "15m" -d 3 26 | ``` 27 | 28 | Options: 29 | * `-t` Time frame (5m, 15m, 30m, 1h, 2h, 4h, 8h, 1d) (default: 15m) 30 | * `-d` Calculation duration in days (default: 3 days) 31 | 32 | The script will generate a TXT file in `output//` directory that can be imported into [TradingView](https://www.tradingview.com/)'s watchlist. 33 | 34 | ### 2. Stock Screener 35 | 36 | Analyze US stocks using relative strength calculation and Minervini trend template conditions. 37 | 38 | ```bash 39 | python3 stock_screener.py 40 | python3 stock_screener.py -a -g 41 | ``` 42 | 43 | Options: 44 | * `-a, --all` Include all strong targets in output instead of just top 980 (TradingView has import limit) 45 | * `-g` Ignore Minervini trend template conditions and calculate RS score only 46 | 47 | The script will generate a TXT file in `output//` directory that can be imported into [TradingView](https://www.tradingview.com/)'s watchlist. 48 | 49 | ### 3. Crypto Trend Screener 50 | 51 | Find cryptocurrencies with price patterns similar to predefined reference trends using Dynamic Time Warping (DTW) algorithms. 52 | 53 | ```bash 54 | python3 crypto_trend_screener.py 55 | python3 crypto_trend_screener.py --asset=crypto -f strong_targets.txt -k 10 -s 1.0 56 | ``` 57 | 58 | Options: 59 | * `-f, --file` Path to strong target file (default: use all available symbols) 60 | * `--asset` Asset type: 'crypto' or 'stock' (default: crypto) 61 | * `-nv, --no_visualize` Disable DTW alignment visualizations 62 | * `-k, --topk` Number of top symbols to record per reference trend (default: 6) 63 | * `-s, --sleep` Sleep time between API requests in seconds (default: 0.5) 64 | 65 | **Key Parameters:** 66 | - **REFERENCE_TRENDS**: Define reference patterns in format `[start_datetime, end_datetime, timeframe, label]` 67 | - **TIMEFRAMES_TO_ANALYZE**: Analysis timeframes, default ["15m", "30m", "1h", "2h", "4h"] 68 | - **DTW_WINDOW_RATIO**: Controls horizontal matching flexibility (default: 0.2) 69 | - **PRICE_WEIGHT/DIFF_WEIGHT**: Balance price vs moving average matching (default: 0.4/0.6) 70 | 71 | The script generates TXT files and visualization charts in `similarity_output//` directory. 72 | 73 | ### 4. Crypto Historical Trend Finder 74 | 75 | Search historical cryptocurrency data for patterns similar to reference trends and analyze their future price movements with statistical insights. 76 | 77 | ```bash 78 | python3 crypto_historical_trend_finder.py 79 | python3 crypto_historical_trend_finder.py -k 500 -s 10 80 | ``` 81 | 82 | Options: 83 | * `-k, --topk` Number of top matches to keep per reference trend (default: 300) 84 | * `-s, --sleep` Sleep time between API requests in seconds (default: 15) 85 | 86 | **Key Features:** 87 | - Multi-timeframe pattern matching (15m, 30m, 1h, 2h, 4h) 88 | - Future trend prediction with rise/fall statistics 89 | - Three-panel visualizations (past + pattern + future) 90 | - Statistical analysis with multiple extension factors (0.25x to 2.5x) 91 | - Overlap filtering for clean analysis samples 92 | 93 | **Key Parameters:** 94 | - **REFERENCE_TRENDS**: Define reference patterns in format `[start_datetime, end_datetime, timeframe, label]`. Choose patterns with similar high/low bar counts and end at ideal entry points 95 | - **TIMEZONE**: Timezone setting for datetime display and reference trend parsing (default: "America/Los_Angeles") 96 | - **HISTORICAL_START_DATE**: Starting date for historical data collection (default: 2021-01-01) 97 | - **TIMEFRAMES_TO_ANALYZE**: Analysis timeframes for pattern searching (default: ["15m", "30m", "1h", "2h", "4h"]) 98 | - **VIS_EXTENSION_FUTURE_LENGTH_FACTOR**: Determines how far into the future to observe after pattern completion for rise/fall classification. For example, 2.0 means observing 2x the pattern length into the future (default: 2.0) 99 | - **EXTENSION_FACTORS_FOR_STATS**: Multiple factors for comprehensive statistical analysis [0.25, 0.5, 0.75, 1.0, 1.5, 2.0, 2.5] 100 | - **GLOBAL_OVERLAP_FILTERING**: 101 | - `True`: No overlapping patterns across all symbols (stricter, higher quality) 102 | - `False`: Allow overlaps between different symbols (more samples) 103 | 104 | 105 | **Statistical Output Example:** 106 | ``` 107 | Extension Factor Analysis: 108 | 0.25x: Rise 198(43.5%) | Fall 254(55.8%) | Insufficient 3(0.7%) 109 | 0.5x: Rise 191(42.0%) | Fall 259(56.9%) | Insufficient 5(1.1%) 110 | 0.75x: Rise 186(40.9%) | Fall 260(57.1%) | Insufficient 9(2.0%) 111 | 1.0x: Rise 169(37.1%) | Fall 275(60.4%) | Insufficient 11(2.4%) 112 | 1.5x: Rise 166(36.5%) | Fall 268(58.9%) | Insufficient 21(4.6%) 113 | 2.0x: Rise 159(34.9%) | Fall 267(58.7%) | Insufficient 29(6.4%) 114 | 2.5x: Rise 167(36.7%) | Fall 250(54.9%) | Insufficient 38(8.4%) 115 | ``` 116 | 117 | Results are saved in `past_similar_trends_report//` directory with comprehensive visualizations and statistical reports. 118 | 119 | ## Dynamic Time Warping (DTW) 120 | 121 | DTW is an algorithm for measuring similarity between two temporal sequences that may vary in speed. Unlike Euclidean distance, DTW can handle sequences of unequal lengths and is invariant to time shifts, making it ideal for financial pattern matching. 122 | 123 | **Applications in this project:** 124 | - Price pattern similarity detection 125 | - Moving average relationship matching 126 | - Historical trend analysis with future outcome prediction 127 | 128 | ## Relative Strength Formula 129 | 130 | The RS score is calculated using a weighted sum of relative strength indicators: 131 | 132 | $$ bars = \text{total bars (depend on time frame, e.g. } 4 \times 24 \times days \text{ for 15m})$$ 133 | 134 | $$ W_i = e^{2 \times \ln(2) \times i / bars} $$ 135 | 136 | $$ \begin{align*} 137 | N_i & = \frac{(P_i - MA30_i) + (P_i - MA45_i) + (P_i - MA60_i) + (MA30_i - MA45_i) + (MA30_i - MA60_i) + (MA45_i - MA60_i)}{ATR_i}\\ 138 | \end{align*} $$ 139 | 140 | $$ Score = \frac{\sum_{i=1}^{bars} N_i \times W_i}{\sum_{i=1}^{bars} W_i} $$ 141 | 142 | Where: 143 | - Weight is calculated as an exponential function to make the weight at the midpoint (L/2) is exactly half of the weight at the endpoint (L) 144 | - ATR (Average True Range) is used for normalization 145 | 146 | ## Configuration 147 | 148 | ### Reference Trends (for DTW-based tools) 149 | 150 | Define reference patterns in the `REFERENCE_TRENDS` dictionary: 151 | 152 | ```python 153 | REFERENCE_TRENDS = { 154 | "AVAX": [ 155 | [datetime(2023, 11, 9, 12, 0), datetime(2023, 11, 14, 15, 0), "1h", "standard"], 156 | ], 157 | "SOL": [ 158 | [datetime(2023, 9, 23, 0, 0), datetime(2023, 10, 15, 21, 0), "4h", "standard"] 159 | ], 160 | } 161 | ``` 162 | 163 | **Tips:** 164 | - Choose patterns with similar high/low bar counts for better normalization 165 | - End patterns at ideal entry points 166 | - Consider different timeframes for various market conditions 167 | 168 | ## Output Files 169 | 170 | Scripts save results with the following patterns: 171 | - **Crypto Screener**: `output//_crypto__strong_targets.txt` 172 | - **Stock Screener**: `output//_stock_strong_targets_.txt` 173 | - **Crypto Trend Screener**: `similarity_output//_similar_trend_tradingview.txt` + visualization PNG files 174 | - **Historical Trend Finder**: `past_similar_trends_report//overall_summary.txt` + detailed analysis + visualization PNG files 175 | 176 | **Visualization Features:** 177 | - Candlestick charts with volume bars 178 | - Moving average overlays (SMA-30, SMA-45, SMA-60) 179 | - DTW alignment connection lines 180 | - Normalized price scales for fair comparison 181 | - Three-panel analysis (past/pattern/future) for historical finder 182 | 183 | All TXT output files are compatible with [TradingView](https://www.tradingview.com/) watchlist import format. 184 | 185 | ## License 186 | 187 | [MIT](https://choosealicense.com/licenses/mit/) 188 | -------------------------------------------------------------------------------- /crypto_screener.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import os 4 | import numpy as np 5 | from datetime import datetime 6 | import multiprocessing as mp 7 | from concurrent.futures import ProcessPoolExecutor, as_completed 8 | from src.downloader import CryptoDownloader 9 | 10 | 11 | def calc_total_bars(time_interval, days): 12 | bars_dict = { 13 | "5m": 12 * 24 * days, 14 | "15m": 4 * 24 * days, 15 | "30m": 2 * 24 * days, 16 | "1h": 24 * days, 17 | "2h": 12 * days, 18 | "4h": 6 * days, 19 | "8h": 3 * days, 20 | } 21 | return bars_dict.get(time_interval) 22 | 23 | 24 | def calculate_rs_score(crypto_data, required_bars): 25 | """ 26 | Calculate RS score for cryptocurrency 27 | 28 | Args: 29 | crypto_data: DataFrame with cryptocurrency data 30 | required_bars: Number of bars required for calculation 31 | 32 | Returns: 33 | tuple[bool, float, str]: Success flag, RS score, error message 34 | """ 35 | # Check if we have enough data 36 | if len(crypto_data) < required_bars: 37 | return False, 0, f"Insufficient data: {len(crypto_data)} < {required_bars}" 38 | 39 | # Create a copy to avoid modifying the original data 40 | data = crypto_data.copy() 41 | 42 | # Take the most recent required_bars data points 43 | data = data.tail(required_bars).reset_index(drop=True) 44 | 45 | # Calculate RS Score 46 | rs_score = 0.0 47 | total_weight = 0.0 48 | 49 | # Calculate for each data point 50 | for i in range(required_bars): 51 | # Current data point values 52 | current_close = data['close'].iloc[i] 53 | moving_average_30 = data['sma_30'].iloc[i] 54 | moving_average_45 = data['sma_45'].iloc[i] 55 | moving_average_60 = data['sma_60'].iloc[i] 56 | current_atr = data['atr'].iloc[i] 57 | 58 | # Calculate relative strength numerator 59 | numerator = ((current_close - moving_average_30) + 60 | (current_close - moving_average_45) + 61 | (current_close - moving_average_60) + 62 | (moving_average_30 - moving_average_45) + 63 | (moving_average_30 - moving_average_60) + 64 | (moving_average_45 - moving_average_60)) 65 | 66 | # Use ATR as denominator with small epsilon to avoid division by zero 67 | denominator = current_atr + 0.0000000000000000001 68 | # denominator = (moving_average_30 + moving_average_45 + moving_average_60) / 3 69 | 70 | # Calculate relative strength for this point 71 | relative_strength = numerator / denominator 72 | 73 | # Gives higher importance to newer data 74 | # weight = i 75 | k = 2 * np.log(2) / required_bars 76 | weight = np.exp(k * i) # Exponential weight where w(L/2) * 2 = w(L) 77 | 78 | 79 | # Add to weighted sum 80 | rs_score += relative_strength * weight 81 | total_weight += weight 82 | 83 | # Normalize the final score by total weight 84 | if total_weight > 0: 85 | rs_score = rs_score / total_weight 86 | else: 87 | return False, 0, "Weight calculation error" 88 | 89 | return True, rs_score, "" 90 | 91 | 92 | def process_crypto(symbol, timeframe, days): 93 | """Process a single cryptocurrency and calculate its RS score""" 94 | try: 95 | cd = CryptoDownloader() 96 | 97 | # Calculate required bars 98 | required_bars = calc_total_bars(timeframe, days) 99 | 100 | # Calculate start timestamp with some buffer (20% more time to ensure we get enough data) 101 | buffer_factor = 1.2 102 | now = int(time.time()) 103 | 104 | # Estimate interval seconds based on timeframe 105 | if "m" in timeframe: 106 | minutes = int(timeframe.replace("m", "")) 107 | interval_seconds = minutes * 60 108 | elif "h" in timeframe: 109 | hours = int(timeframe.replace("h", "")) 110 | interval_seconds = hours * 3600 111 | elif "d" in timeframe: 112 | days = int(timeframe.replace("d", "")) 113 | interval_seconds = days * 24 * 3600 114 | else: 115 | # Default to 1h if unknown format 116 | interval_seconds = 3600 117 | 118 | start_ts = now - int(required_bars * interval_seconds * buffer_factor) 119 | 120 | # Get crypto data 121 | success, data = cd.get_data(symbol, start_ts=start_ts, end_ts=now, timeframe=timeframe, atr=True) 122 | 123 | if not success or data.empty: 124 | error_msg = "Failed to get data or empty dataset" 125 | print(f"{symbol} -> Error: {error_msg}") 126 | return {"crypto": symbol, "status": "failed", "reason": error_msg} 127 | 128 | # Calculate RS score 129 | success, rs_score, error = calculate_rs_score(data, required_bars) 130 | if not success: 131 | print(f"{symbol} -> Error: {error}") 132 | return {"crypto": symbol, "status": "failed", "reason": error} 133 | 134 | print(f"{symbol} -> Successfully calculated RS Score: {rs_score}") 135 | return { 136 | "crypto": symbol, 137 | "status": "success", 138 | "rs_score": rs_score 139 | } 140 | 141 | except Exception as e: 142 | error_msg = str(e) 143 | print(f"{symbol} -> Error: {error_msg}") 144 | return {"crypto": symbol, "status": "failed", "reason": error_msg} 145 | 146 | 147 | if __name__ == '__main__': 148 | parser = argparse.ArgumentParser() 149 | parser.add_argument('-t', '--timeframe', type=str, help='Time frame (5m, 15m, 30m, 1h, 2h, 4h, 8h, 1d)', default="15m") 150 | parser.add_argument('-d', '--days', type=int, help='Calculation duration in days (default 3 days)', default=3) 151 | args = parser.parse_args() 152 | timeframe = args.timeframe 153 | days = args.days 154 | 155 | # Initialize crypto downloader 156 | crypto_downloader = CryptoDownloader() 157 | 158 | # Get list of all symbols 159 | all_cryptos = crypto_downloader.get_all_symbols() 160 | print(f"Total cryptos to process: {len(all_cryptos)}") 161 | 162 | # Process all cryptos using ProcessPoolExecutor 163 | num_cores = min(4, mp.cpu_count()) # Use maximum 4 cores, binance rest api has rate limit 164 | print(f"Using {num_cores} processes") 165 | with ProcessPoolExecutor(max_workers=num_cores) as executor: 166 | futures = {executor.submit(process_crypto, crypto, timeframe, days): crypto for crypto in all_cryptos} 167 | results = [] 168 | 169 | for future in as_completed(futures): 170 | crypto = futures[future] 171 | try: 172 | result = future.result() 173 | results.append(result) 174 | except Exception as e: 175 | print(f"{crypto} -> Error: {str(e)}") 176 | results.append({"crypto": crypto, "status": "failed", "reason": str(e)}) 177 | 178 | # Process results 179 | failed_targets = [] # Failed to download data or error happened 180 | target_score = {} 181 | 182 | for result in results: 183 | if result["status"] == "success": 184 | target_score[result["crypto"]] = result["rs_score"] 185 | else: 186 | failed_targets.append((result["crypto"], result["reason"])) 187 | 188 | # Sort by RS score 189 | targets = [x for x in target_score.keys()] 190 | targets.sort(key=lambda x: target_score[x], reverse=True) 191 | 192 | # Print results 193 | print(f"\nAnalysis Results:") 194 | print(f"Total cryptos processed: {len(all_cryptos)}") 195 | print(f"Failed cryptos: {len(failed_targets)}") 196 | print(f"Successfully calculated: {len(targets)}") 197 | 198 | print("\n=========================== Target : Score (TOP 20) ===========================") 199 | for idx, crypto in enumerate(targets[:20], 1): 200 | score = target_score[crypto] 201 | print(f"{idx}. {crypto}: {score:.6f}") 202 | print("===============================================================================") 203 | 204 | # Save results 205 | full_date_str = datetime.now().strftime("%Y-%m-%d_%H-%M") 206 | date_str = datetime.now().strftime("%Y-%m-%d") 207 | txt_content = "###BTCETH\nBINANCE:BTCUSDT.P,BINANCE:ETHUSDT\n###Targets (Sort by score)\n" 208 | 209 | # Add all targets 210 | if targets: 211 | txt_content += ",".join([f"BINANCE:{crypto}.P" for crypto in targets]) 212 | 213 | # Create output/ directory structure 214 | base_folder = "output" 215 | date_folder = os.path.join(base_folder, date_str) 216 | os.makedirs(date_folder, exist_ok=True) 217 | 218 | # Save the file with full timestamp in filename 219 | output_file = f"{full_date_str}_crypto_{timeframe}_strong_targets.txt" 220 | file_path = os.path.join(date_folder, output_file) 221 | with open(file_path, "w") as f: 222 | f.write(txt_content) 223 | 224 | # Save failed cryptos for analysis 225 | # failed_file = f"{full_date_str}_failed_cryptos_{timeframe}.txt" 226 | # failed_path = os.path.join(date_folder, failed_file) 227 | # with open(failed_path, "w") as f: 228 | # for crypto, reason in failed_targets: 229 | # f.write(f"{crypto}: {reason}\n") 230 | 231 | print(f"\nResults saved to {file_path}") 232 | # print(f"Failed cryptos saved to {failed_path}") 233 | -------------------------------------------------------------------------------- /stock_screener.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError 3 | import multiprocessing as mp 4 | import pandas as pd 5 | import time 6 | import os 7 | import numpy as np 8 | import argparse 9 | from src.downloader import StockDownloader 10 | 11 | 12 | #================= CONFIGURATIONS =================# 13 | _1D_OF_DAYS_TRACEBACK = 252 14 | _1H_OF_DAYS_TRACEBACK = 126 15 | MIN_TURNOVER = 10000000 # Minimum turnover for stock to be considered (volume * close) 16 | #==================================================# 17 | 18 | 19 | def calculate_rs_score(hourly_data: pd.DataFrame, ticker: str = "unknown") -> tuple[bool, float, str]: 20 | """ 21 | Calculate RS score from hourly data without using Z-score normalization. 22 | 23 | The RS score is calculated as a weighted sum of relative strength indicators, 24 | with newer data given higher weight. ATR is used for normalization to allow 25 | comparison across different stocks. 26 | """ 27 | # Define fixed parameters 28 | required_bars = _1H_OF_DAYS_TRACEBACK * 8 29 | 30 | # Check if we have enough data 31 | if len(hourly_data) < required_bars: 32 | return False, 0, f"Insufficient hourly data: {len(hourly_data)} < {required_bars}" 33 | 34 | # Create a copy to avoid modifying the original data 35 | data = hourly_data.copy() 36 | 37 | # Take the most recent required_bars data points 38 | data = data.tail(required_bars).reset_index(drop=True) 39 | 40 | # Calculate RS Score 41 | rs_score = 0.0 42 | total_weight = 0.0 43 | 44 | # Calculate for each data point 45 | for i in range(required_bars): 46 | # Current data point values 47 | current_close = data['close'].iloc[i] 48 | moving_average_30 = data['sma_30'].iloc[i] 49 | moving_average_45 = data['sma_45'].iloc[i] 50 | moving_average_60 = data['sma_60'].iloc[i] 51 | current_atr = data['atr'].iloc[i] 52 | 53 | # Calculate relative strength numerator 54 | numerator = ((current_close - moving_average_30) + 55 | (current_close - moving_average_45) + 56 | (current_close - moving_average_60) + 57 | (moving_average_30 - moving_average_45) + 58 | (moving_average_30 - moving_average_60) + 59 | (moving_average_45 - moving_average_60)) 60 | 61 | # Use ATR as denominator with small epsilon to avoid division by zero 62 | denominator = current_atr + 0.001 63 | 64 | # Calculate relative strength for this point 65 | relative_strength = numerator / denominator 66 | 67 | # Gives higher importance to newer data 68 | # weight = i 69 | k = 2 * np.log(2) / required_bars 70 | weight = np.exp(k * i) # Exponential weight where w(L/2) * 2 = w(L) 71 | 72 | # Add to weighted sum 73 | rs_score += relative_strength * weight 74 | total_weight += weight 75 | 76 | # Normalize the final score by total weight 77 | if total_weight > 0: 78 | rs_score = rs_score / total_weight 79 | else: 80 | return False, 0, "Weight calculation error" 81 | 82 | return True, rs_score, "" 83 | 84 | 85 | def calculate_spy_rs_score() -> float: 86 | 87 | sd = StockDownloader() 88 | print("Processing SPY RS score calculation...") 89 | 90 | # Request more data than needed to ensure we have enough after filtering 91 | now = int(time.time()) 92 | buffer_days = int(_1H_OF_DAYS_TRACEBACK * 3) # 200% buffer for safety 93 | hourly_start_ts = now - (buffer_days * 24 * 3600) 94 | 95 | success, hourly_data = sd.get_data("SPY", hourly_start_ts, end_ts=now, timeframe="1h", atr=True) 96 | if not success or hourly_data is None: 97 | raise ValueError("Failed to get hourly data for SPY") 98 | 99 | success, rs_score, error = calculate_rs_score(hourly_data, "SPY") 100 | if not success: 101 | raise ValueError(f"Failed to calculate SPY RS score: {error}") 102 | 103 | print(f"Finished SPY -> RS Score {rs_score}") 104 | return rs_score 105 | 106 | 107 | def calc_relative_strength(ticker: str, use_template: bool) -> dict: 108 | """ 109 | Calculate relative strength and check trend template conditions for a given stock ticker. 110 | 111 | Args: 112 | ticker (str): The stock ticker symbol. 113 | use_template (bool): Flag to determine whether to apply Minervini trend template conditions. 114 | 115 | Returns: 116 | dict: A dictionary containing the stock ticker, status of the calculation ('success' or 'failed'), 117 | and additional information such as the reason for failure or the calculated RS score. 118 | 119 | The function performs the following steps: 120 | 1. Downloads daily stock data for the given ticker. 121 | 2. Checks if there is sufficient daily data. 122 | 3. Verifies if the stock meets the minimum turnover requirement. 123 | 4. If `use_template` is True, checks the stock against Minervini trend template conditions. 124 | 5. Downloads hourly stock data for the given ticker. 125 | 6. Calculates the relative strength (RS) score using the hourly data. 126 | 7. Returns the result with the RS score if successful, or the reason for failure. 127 | """ 128 | 129 | print(f"Processing {ticker}...") 130 | sd = StockDownloader() 131 | now = int(time.time()) 132 | 133 | # Request more data than needed for daily timeframe 134 | buffer_days = int(_1D_OF_DAYS_TRACEBACK * 2) 135 | daily_start_ts = now - (buffer_days * 24 * 3600) 136 | success, daily_data = sd.get_data(ticker, daily_start_ts, end_ts=now, timeframe="1d", dropna=False, atr=False) 137 | 138 | if not success or daily_data is None: 139 | msg = "No daily data" 140 | print(f"Finished {ticker} -> Failed: {msg}") 141 | return {"stock": ticker, "status": "failed", "reason": msg} 142 | 143 | # Take the most recent required days 144 | if len(daily_data) < _1D_OF_DAYS_TRACEBACK: 145 | msg = f"Insufficient daily data: {len(daily_data)} < {_1D_OF_DAYS_TRACEBACK}" 146 | print(f"Finished {ticker} -> Failed: {msg}") 147 | return {"stock": ticker, "status": "failed", "reason": msg} 148 | 149 | daily_data = daily_data.tail(_1D_OF_DAYS_TRACEBACK) 150 | 151 | # Check turnover 152 | last_10_days = daily_data.tail(10) 153 | average_turnover = (last_10_days['volume'] * last_10_days['close']).mean() 154 | if average_turnover < MIN_TURNOVER: 155 | msg = "Insufficient turnover" 156 | print(f"Finished {ticker} -> Failed: {msg}") 157 | return {"stock": ticker, "status": "failed", "reason": msg} 158 | 159 | # Get required values for trend template 160 | current_close = daily_data['close'].values[-1] 161 | moving_average_50 = daily_data['sma_50'].values[-1] 162 | moving_average_60 = daily_data['sma_60'].values[-1] 163 | moving_average_150 = daily_data['sma_150'].values[-1] 164 | moving_average_200 = daily_data['sma_200'].values[-1] 165 | 166 | # Calculate high/low using configured lookback period 167 | low_of_period = daily_data["close"].min() 168 | high_of_period = daily_data["close"].max() 169 | 170 | # Check Minervini trend template conditions 171 | if use_template: 172 | conditions = [ 173 | (current_close > moving_average_150 and current_close > moving_average_200), # Condition 1 174 | moving_average_150 > moving_average_200, # Condition 2 175 | True, # Condition 3 (assumed true as per original) 176 | moving_average_50 > moving_average_150 > moving_average_200, # Condition 4 177 | True, # Condition 5 (assumed true as per original) 178 | current_close > low_of_period * 1.3, # Condition 6 179 | current_close > high_of_period * 0.75, # Condition 7 180 | True, # Condition 8 (assumed true as per original) 181 | current_close >= 10 # Condition 9 182 | ] 183 | 184 | if not all(conditions): 185 | failed_conditions = [i + 1 for i, cond in enumerate(conditions) if not cond] 186 | msg = f"Failed conditions: {failed_conditions}" 187 | print(f"Finished {ticker} -> Failed: {msg}") 188 | return {"stock": ticker, "status": "failed", "reason": msg} 189 | 190 | # Get hourly data with buffer for RS score calculation 191 | buffer_days = int(_1H_OF_DAYS_TRACEBACK * 3) 192 | hourly_start_ts = now - (buffer_days * 24 * 3600) 193 | success, hourly_data = sd.get_data(ticker, hourly_start_ts, end_ts=now, timeframe="1h", atr=True) 194 | 195 | if not success or hourly_data is None: 196 | msg = "No hourly data" 197 | print(f"Finished {ticker} -> Failed: {msg}") 198 | return {"stock": ticker, "status": "failed", "reason": msg} 199 | 200 | success, rs_score, error = calculate_rs_score(hourly_data, ticker) 201 | if not success: 202 | print(f"Finished {ticker} -> Failed: {error}") 203 | return {"stock": ticker, "status": "failed", "reason": error} 204 | 205 | print(f"Finished {ticker} -> RS Score {rs_score}") 206 | return { 207 | "stock": ticker, 208 | "status": "success", 209 | "rs_score": rs_score 210 | } 211 | 212 | 213 | if __name__ == '__main__': 214 | # Parse command line arguments 215 | parser = argparse.ArgumentParser(description='Stock Trend Analysis') 216 | parser.add_argument('-a', '--all', action='store_true', help='Include all strong targets in output') 217 | parser.add_argument('-g', action='store_true', help='Ignore Minerivini conditions and calculate RS score only') 218 | args = parser.parse_args() 219 | 220 | # Initialize stock downloader 221 | sd = StockDownloader() 222 | 223 | # Get list of all tickers 224 | all_tickers = sd.get_all_tickers() 225 | print(f"Total tickers to process: {len(all_tickers)}") 226 | 227 | # Calculate SPY's RS score first 228 | try: 229 | spy_rs_score = calculate_spy_rs_score() 230 | print(f"SPY RS Score: {spy_rs_score}") 231 | except Exception as e: 232 | print(f"Failed to calculate SPY RS score: {e}") 233 | exit(1) 234 | 235 | # Process all tickers using ProcessPoolExecutor 236 | num_cores = min(36, mp.cpu_count()) 237 | print(f"Using {num_cores} processes") 238 | 239 | # Process results 240 | strong_targets = [] 241 | target_rs_score = {} 242 | failed_tickers = [] 243 | use_template = not args.g 244 | 245 | with ProcessPoolExecutor(max_workers=num_cores) as executor: 246 | futures = {executor.submit(calc_relative_strength, ticker, use_template): ticker for ticker in all_tickers} 247 | 248 | for future in as_completed(futures): 249 | ticker = futures[future] 250 | try: 251 | result = future.result(timeout=10) 252 | if result["status"] == "success": 253 | if result["rs_score"] >= spy_rs_score: 254 | strong_targets.append(ticker) 255 | target_rs_score[ticker] = result["rs_score"] 256 | else: 257 | failed_tickers.append((ticker, result["reason"])) 258 | except TimeoutError: 259 | print(f"{ticker} took too long to process") 260 | except Exception as e: 261 | failed_tickers.append((ticker, str(e))) 262 | 263 | # Sort by RS score 264 | strong_targets.sort(key=lambda x: target_rs_score[x], reverse=True) 265 | 266 | # Print results 267 | total_analyzed = len(all_tickers) - len(failed_tickers) 268 | success_rate = len(strong_targets) / total_analyzed * 100 if total_analyzed > 0 else 0 269 | 270 | print(f"\nAnalysis Results:") 271 | print(f"Total tickers processed: {len(all_tickers)}") 272 | print(f"Failed tickers: {len(failed_tickers)}") 273 | print(f"Found {len(strong_targets)} stocks that meet requirements and are stronger than SPY") 274 | print(f"Success rate: {success_rate:.2f}%") 275 | 276 | print(f"\nStrong targets: {', '.join(strong_targets[:50])}") # Show top 50 only in console 277 | 278 | print("\n====== Top 50 Targets by RS Score ======") 279 | for ticker in strong_targets[:50]: 280 | score = target_rs_score[ticker] 281 | print(f"{ticker}: {score}") 282 | print("=======================================") 283 | 284 | # Save results 285 | full_date_str = datetime.now().strftime("%Y-%m-%d_%H-%M") 286 | date_str = datetime.now().strftime("%Y-%m-%d") 287 | txt_content = "###INDEX\nSPY,QQQ,DJI\n###TARGETS\n" 288 | 289 | # Use all strong targets or just top 980 based on --all flag 290 | output_targets = strong_targets if args.all else strong_targets[:980] 291 | txt_content += ",".join(output_targets) 292 | 293 | # Create output/ directory structure 294 | base_folder = "output" 295 | date_folder = os.path.join(base_folder, date_str) 296 | os.makedirs(date_folder, exist_ok=True) 297 | 298 | # Create output files with full timestamp in filename 299 | without_conditions = "_no_conditions" if args.g else "" 300 | for_tv = "all" if args.all else "top980" 301 | output_file = f"{full_date_str}_stock_{for_tv}{without_conditions}_strong_targets.txt" 302 | file_path = os.path.join(date_folder, output_file) 303 | with open(file_path, "w") as f: 304 | f.write(txt_content) 305 | 306 | # Save failed tickers for analysis 307 | # failed_file = f"{full_date_str}_failed_tickers.txt" 308 | # failed_path = os.path.join(date_folder, failed_file) 309 | # with open(failed_path, "w") as f: 310 | # for ticker, reason in failed_tickers: 311 | # f.write(f"{ticker}: {reason}\n") 312 | 313 | print(f"\nResults saved to {file_path}") 314 | # print(f"Failed tickers saved to {failed_path}") 315 | print(f"Included {'all' if args.all else 'top 980'} strong targets in output file") -------------------------------------------------------------------------------- /src/downloader.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | import os 4 | import time 5 | import pandas as pd 6 | import numpy as np 7 | from datetime import datetime 8 | from pytz import timezone 9 | from stocksymbol import StockSymbol 10 | from polygon import RESTClient 11 | from urllib3.util.retry import Retry 12 | from binance import Client 13 | from pathlib import Path 14 | 15 | 16 | # Configurable parameters 17 | STOCK_SMA = [20, 30, 45, 50, 60, 150, 200] 18 | CRYPTO_SMA = [30, 45, 60] 19 | ATR_PERIOD = 60 20 | 21 | 22 | def calculate_atr(df, period=ATR_PERIOD): 23 | """ 24 | Calculate Average True Range (ATR) for the given dataframe 25 | 26 | Args: 27 | df: DataFrame containing 'high', 'low', 'close' columns 28 | period: Period for ATR calculation (default: ATR_PERIOD) 29 | 30 | Returns: 31 | Series containing ATR values 32 | """ 33 | high = df['high'] 34 | low = df['low'] 35 | close = df['close'] 36 | 37 | # Calculate True Range 38 | tr1 = high - low 39 | tr2 = abs(high - close.shift()) 40 | tr3 = abs(low - close.shift()) 41 | 42 | # Get the maximum of the three price ranges 43 | tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1) 44 | 45 | # Calculate ATR as the simple moving average of True Range 46 | atr = tr.rolling(window=period).mean() 47 | 48 | return atr 49 | 50 | 51 | def parse_time_string(time_string): 52 | pattern_with_number = r"(\d+)([mhdMHD])$" 53 | pattern_without_number = r"([dD])$" 54 | match_with_number = re.match(pattern_with_number, time_string) 55 | match_without_number = re.match(pattern_without_number, time_string) 56 | 57 | if match_with_number: 58 | number = int(match_with_number.group(1)) 59 | unit = match_with_number.group(2) 60 | elif match_without_number: 61 | number = 1 62 | unit = match_without_number.group(1) 63 | else: 64 | raise ValueError("Invalid time format. Only formats like '15m', '4h', 'd' are allowed.") 65 | 66 | unit = unit.lower() 67 | unit_match = { 68 | "m": "minute", 69 | "h": "hour", 70 | "d": "day" 71 | } 72 | return number, unit_match[unit] 73 | 74 | 75 | class StockDownloader: 76 | def __init__(self, api_file: str = "api_keys.json"): 77 | with open(api_file) as f: 78 | self.api_keys = json.load(f) 79 | 80 | retry_strategy = Retry( 81 | total=3, 82 | backoff_factor=0.5, 83 | status_forcelist=[413, 429, 499, 500, 502, 503, 504], 84 | allowed_methods=["HEAD", "GET", "PUT", "DELETE", "OPTIONS", "TRACE", "POST"], 85 | raise_on_status=False, 86 | respect_retry_after_header=True 87 | ) 88 | 89 | self.client = RESTClient( 90 | api_key=self.api_keys["polygon"], 91 | num_pools=100, 92 | connect_timeout=1.0, 93 | read_timeout=1.0, 94 | retries=10 95 | ) 96 | 97 | def _validate_data_quality(self, df: pd.DataFrame) -> bool: 98 | """ 99 | Validate data quality 100 | - Check if latest data is within a week 101 | - Check for stale prices (same closing price for 10+ consecutive periods) 102 | """ 103 | if df.empty: 104 | return False 105 | 106 | # Check data freshness 107 | latest_ts = df['timestamp'].max() 108 | week_ago = time.time() - (7 * 24 * 3600) 109 | if latest_ts < week_ago: 110 | return False 111 | 112 | # Check for stale prices 113 | consecutive_same_price = df['close'].rolling(window=10).apply( 114 | lambda x: len(set(x)) == 1 115 | ) 116 | if consecutive_same_price.any(): 117 | return False 118 | 119 | return True 120 | 121 | def get_data(self, ticker: str, start_ts: int, end_ts: int = None, timeframe: str = "1d", dropna=True, atr=True, validate=True) -> tuple[bool, pd.DataFrame]: 122 | """ 123 | Get stock data with SMA calculation and data quality validation 124 | Args: 125 | ticker: Stock symbol 126 | start_ts: Start timestamp 127 | end_ts: End timestamp (default: current time) 128 | timeframe: Time interval ("1d" or "1h") 129 | dropna: Whether to drop NA values 130 | atr: Whether to calculate ATR (default: True) 131 | Returns: 132 | (success, DataFrame) 133 | """ 134 | # Calculate extended start for SMA calculation 135 | max_sma = max(STOCK_SMA) 136 | fc = 1.3 if timeframe == "1d" else 0.6 137 | extension = np.int64(max_sma * 24 * 3600 * fc) 138 | extended_start = np.int64(start_ts - extension) 139 | 140 | # Get current time if end_ts not provided 141 | if end_ts is None: 142 | end_ts = np.int64(time.time()) 143 | 144 | # Parse timeframe 145 | multiplier, timespan = parse_time_string(timeframe) 146 | 147 | # Request data from Polygon 148 | aggs = self.client.list_aggs( 149 | ticker, 150 | multiplier, 151 | timespan, 152 | np.int64(extended_start * 1000), 153 | np.int64(end_ts * 1000), 154 | limit=10000 155 | ) 156 | 157 | if not aggs: 158 | return False, pd.DataFrame() 159 | 160 | # Convert to DataFrame with timestamp 161 | df = pd.DataFrame([{ 162 | 'timestamp': np.int64(agg.timestamp // 1000), 163 | 'open': np.float64(agg.open), 164 | 'close': np.float64(agg.close), 165 | 'high': np.float64(agg.high), 166 | 'low': np.float64(agg.low), 167 | 'volume': np.float64(agg.volume) 168 | } for agg in aggs]) 169 | 170 | if df.empty: 171 | return False, df 172 | 173 | # Sort by timestamp 174 | df = df.sort_values('timestamp') 175 | 176 | # Filter market hours (9:00 AM - 4:00 PM NY time) 177 | if timespan == "hour" or timespan == "minute": 178 | # Create temporary datetime column in NY timezone for filtering 179 | ny_tz = timezone('America/New_York') 180 | temp_dt = pd.to_datetime(df['timestamp'], unit='s', utc=True).dt.tz_convert(ny_tz) 181 | 182 | # Create filter based on NY market hours 183 | if timespan == "hour": 184 | market_hours_filter = temp_dt.dt.time.between( 185 | pd.to_datetime('09:00').time(), 186 | pd.to_datetime('16:00').time(), 187 | inclusive='left' 188 | ) 189 | else: # minute timeframe 190 | market_hours_filter = temp_dt.dt.time.between( 191 | pd.to_datetime('09:30').time(), 192 | pd.to_datetime('16:00').time(), 193 | inclusive='left' 194 | ) 195 | 196 | # Apply filter and drop temporary column 197 | df = df[market_hours_filter] 198 | 199 | # Validate data quality 200 | if validate and not self._validate_data_quality(df): 201 | return False, pd.DataFrame() 202 | 203 | # Calculate SMAs 204 | for period in STOCK_SMA: 205 | df[f'sma_{period}'] = df['close'].rolling(window=period).mean().astype(np.float64) 206 | 207 | # Calculate ATR if requested 208 | if atr: 209 | df['atr'] = calculate_atr(df, period=ATR_PERIOD).astype(np.float64) 210 | 211 | # Drop rows with NaN values 212 | if dropna: 213 | df = df.dropna() 214 | 215 | # Filter to requested time range and reset index 216 | df = df[(df['timestamp'] >= start_ts) & (df['timestamp'] <= end_ts)] 217 | df = df.reset_index(drop=True) 218 | 219 | return True, df 220 | 221 | def get_all_tickers(self): 222 | """Get all stock symbols from both StockSymbol and Polygon""" 223 | # Get symbols from StockSymbol 224 | ss = StockSymbol(self.api_keys["stocksymbol"]) 225 | stock_symbol_list = [x for x in ss.get_symbol_list(market="US", symbols_only=True) 226 | if "." not in x] 227 | 228 | # Get symbols from Polygon 229 | polygon_stocks = self.client.list_tickers( 230 | market="stocks", 231 | # type="CS", 232 | active=True, 233 | limit=1000 234 | ) 235 | polygon_common_stocks = [ticker.ticker for ticker in polygon_stocks] 236 | 237 | # Merge and return unique symbols 238 | all_symbols = sorted(set(stock_symbol_list).union(set(polygon_common_stocks))) 239 | print(f"Found {len(all_symbols)} unique stock symbols") 240 | return all_symbols 241 | 242 | 243 | class CryptoDownloader: 244 | def __init__(self): 245 | self.binance_client = Client(requests_params={"timeout": 300}) 246 | 247 | def get_all_symbols(self): 248 | """ 249 | Get all USDT pairs in binance 250 | """ 251 | binance_response = self.binance_client.futures_exchange_info() 252 | binance_symbols = set() 253 | for item in binance_response["symbols"]: 254 | symbol_name = item["pair"] 255 | if symbol_name[-4:] == "USDT": 256 | binance_symbols.add(symbol_name) 257 | return sorted(list(binance_symbols)) 258 | 259 | def _validate_data_quality(self, df: pd.DataFrame) -> bool: 260 | """ 261 | Validate crypto data quality 262 | - Check if latest data is within a week 263 | - Check for stale prices (same closing price for 10+ consecutive periods) 264 | """ 265 | if df.empty: 266 | return False 267 | 268 | # Check data freshness 269 | latest_ts = df['timestamp'].max() 270 | week_ago = time.time() - (7 * 24 * 3600) 271 | if latest_ts < week_ago: 272 | return False 273 | 274 | # Check for stale prices 275 | consecutive_same_price = df['close'].rolling(window=10).apply( 276 | lambda x: len(set(x)) == 1 277 | ) 278 | if consecutive_same_price.any(): 279 | return False 280 | 281 | return True 282 | 283 | def get_data(self, crypto, start_ts=None, end_ts=None, timeframe="4h", dropna=True, atr=True, validate=True) -> tuple[bool, pd.DataFrame]: 284 | """ 285 | Get cryptocurrency data with SMA calculation and data quality validation 286 | Args: 287 | crypto: Cryptocurrency symbol 288 | start_ts: Start timestamp (default: None, fetches latest 1500 datapoints) 289 | end_ts: End timestamp (default: current time) 290 | timeframe: Time interval (e.g., "5m", "15m", "1h", "4h") 291 | dropna: Whether to drop NA values (default: True) 292 | atr: Whether to calculate ATR (default: True) 293 | Returns: 294 | (success, DataFrame) 295 | """ 296 | try: 297 | # Default end_ts to current time if not provided 298 | if end_ts is None: 299 | end_ts = np.int64(time.time()) 300 | 301 | # Convert to milliseconds for Binance API 302 | end_ts_ms = np.int64(end_ts * 1000) 303 | 304 | if start_ts is None: 305 | # Fetch only the latest 1500 datapoints 306 | response = self.binance_client.futures_klines( 307 | symbol=crypto, 308 | interval=timeframe, 309 | limit=1500 310 | ) 311 | else: 312 | # Calculate extended start for SMA calculation 313 | max_sma = max(CRYPTO_SMA) 314 | 315 | # Calculate number of time intervals in max_sma 316 | num_intervals, unit = parse_time_string(timeframe) 317 | if unit == "minute": 318 | interval_seconds = np.int64(num_intervals * 60) 319 | elif unit == "hour": 320 | interval_seconds = np.int64(num_intervals * 3600) 321 | else: # day 322 | interval_seconds = np.int64(num_intervals * 86400) 323 | 324 | # Calculate extension in milliseconds (number of bars needed for max SMA) 325 | extension_ms = np.int64(max_sma * interval_seconds * 1000 * 1.2) # 20% buffer 326 | 327 | # Extended start timestamp with buffer for SMA calculation 328 | extended_start_ts_ms = np.int64(start_ts * 1000 - extension_ms) 329 | 330 | # Fetch historical data from the extended start date 331 | all_data = [] 332 | current_timestamp = extended_start_ts_ms 333 | 334 | while current_timestamp < end_ts_ms: 335 | response = self.binance_client.futures_klines( 336 | symbol=crypto, 337 | interval=timeframe, 338 | startTime=np.int64(current_timestamp), 339 | endTime=np.int64(end_ts_ms), 340 | limit=1500 341 | ) 342 | 343 | if not response: 344 | break 345 | 346 | all_data.extend(response) 347 | 348 | # Update current timestamp to the last received data point + 1 349 | if response: 350 | current_timestamp = np.int64(response[-1][6]) + 1 351 | else: 352 | break 353 | 354 | if not all_data: 355 | print(f"{crypto} -> No data retrieved") 356 | return False, pd.DataFrame() 357 | 358 | response = all_data 359 | 360 | # Convert to DataFrame with timestamp as primary field 361 | df = pd.DataFrame(response, 362 | columns=["Datetime", "Open Price", "High Price", "Low Price", "Close Price", 363 | "Volume", "Close Time", "Quote Volume", "Number of Trades", 364 | "Taker buy base asset volume", "Taker buy quote asset volume", "Ignore"]) 365 | 366 | # Check if DataFrame is empty 367 | if df.empty: 368 | print(f"{crypto} -> Empty DataFrame after initial conversion") 369 | return False, pd.DataFrame() 370 | 371 | # Convert datetime to timestamp (in seconds) using int64 372 | df['timestamp'] = df['Datetime'].values.astype(np.int64) // 1000 373 | 374 | # Rename columns to match stock df format (lowercase) using float64 375 | df['open'] = df['Open Price'].astype(np.float64) 376 | df['high'] = df['High Price'].astype(np.float64) 377 | df['low'] = df['Low Price'].astype(np.float64) 378 | df['close'] = df['Close Price'].astype(np.float64) 379 | df['volume'] = df['Volume'].astype(np.float64) 380 | 381 | # Drop duplicate timestamps 382 | df = df.drop_duplicates(subset=['timestamp'], keep='first') 383 | 384 | # Sort by timestamp 385 | df = df.sort_values('timestamp') 386 | 387 | # Validate data quality 388 | if validate and not self._validate_data_quality(df): 389 | print(f"{crypto} -> Failed data quality validation") 390 | return False, pd.DataFrame() 391 | 392 | # Calculate SMAs 393 | for duration in CRYPTO_SMA: 394 | df[f"sma_{duration}"] = df['close'].rolling(window=duration).mean().astype(np.float64) 395 | 396 | # Calculate ATR if requested 397 | if atr: 398 | df['atr'] = calculate_atr(df, period=ATR_PERIOD).astype(np.float64) 399 | 400 | # Drop NaN values if requested 401 | if dropna: 402 | df = df.dropna() 403 | 404 | # Filter to requested time range (only after calculating SMAs) 405 | if start_ts is not None: 406 | df = df[(df['timestamp'] >= start_ts) & (df['timestamp'] <= end_ts)] 407 | 408 | # Final check if we have any data left 409 | if df.empty: 410 | print(f"{crypto} -> No data left after filtering") 411 | return False, pd.DataFrame() 412 | 413 | # Reset index 414 | df = df.reset_index(drop=True) 415 | 416 | # Keep only necessary columns 417 | columns_to_keep = ['timestamp', 'open', 'high', 'low', 'close', 'volume'] 418 | 419 | # Add SMA columns 420 | columns_to_keep += [f'sma_{period}' for period in CRYPTO_SMA] 421 | 422 | # Add ATR column if calculated 423 | if atr: 424 | columns_to_keep.append('atr') 425 | 426 | df = df[columns_to_keep] 427 | 428 | print(f"{crypto} -> Get data from binance successfully ({len(df)} rows from {datetime.fromtimestamp(df['timestamp'].iloc[0])} to {datetime.fromtimestamp(df['timestamp'].iloc[-1])})") 429 | return True, df 430 | 431 | except Exception as e: 432 | print(f"{crypto} -> Error: {e}") 433 | return False, pd.DataFrame() 434 | -------------------------------------------------------------------------------- /crypto_trend_screener.py: -------------------------------------------------------------------------------- 1 | """ 2 | Trend Similarity Analyzer 3 | ----------------------------------------------- 4 | This script identifies assets currently forming price patterns similar to historical successful 5 | setups using Dynamic Time Warping (DTW) and Shape Dynamic Time Warping (ShapeDTW). By analyzing price 6 | and moving average relationships across multiple timeframes, it discovers trading opportunities 7 | where the user may consider entering a position. 8 | 9 | Key Features: 10 | - Customizable reference trends with visual alignment reports 11 | - TradingView-compatible output for efficient tracking 12 | - Score system for ranking similar candidate trends (higher is better) 13 | - Enhanced shape detection using ShapeDTW for SMA differences 14 | 15 | Scoring Methodology: 16 | - Price score = 1 / (1 + price_distance) 17 | - SMA difference score = 1 / (1 + diff_distance * BALANCE_PD_RATIO) 18 | - Overall score = (price_score * PRICE_WEIGHT + diff_score * DIFF_WEIGHT) 19 | 20 | Usage: 21 | python crypto_trend_screener.py [options] 22 | 23 | Configuration: 24 | - Set your timezone (TIMEZONE), reference patterns (REFERENCE_TRENDS), and analysis timeframes 25 | - Define reference trends with format [start_datetime, end_datetime, timeframe, label] 26 | - End reference patterns at ideal entry points and balance high/low point distribution 27 | 28 | Notes: 29 | - Stock asset type is not yet supported 30 | - Results saved to similarity_output/[timestamp] directory 31 | - TradingView watchlists only display the first occurrence of duplicate symbols 32 | """ 33 | 34 | import os 35 | import time 36 | import argparse 37 | import pandas as pd 38 | import numpy as np 39 | from datetime import datetime 40 | from multiprocessing import Pool, cpu_count 41 | import matplotlib.pyplot as plt 42 | import matplotlib.dates as mdates 43 | from matplotlib.patches import ConnectionPatch 44 | from src.downloader import StockDownloader, CryptoDownloader 45 | from src.common import ( 46 | TrendAnalysisConfig, 47 | DataNormalizer, 48 | DTWCalculator, 49 | FileManager, 50 | ReferenceDataManager, 51 | BaseDataProcessor, 52 | calculate_timeframe_seconds, 53 | format_dt_with_tz, 54 | parse_target_symbols, 55 | create_output_directory, 56 | plot_candlesticks_with_volume 57 | ) 58 | 59 | 60 | # =========== Reference Trend Configuration =========== 61 | # Define reference trends in datetime format (will be converted to timestamps) 62 | # Format: [start_datetime, end_datetime, timeframe, label] 63 | REFERENCE_TRENDS = { 64 | "AVAX": [ 65 | [datetime(2023, 11, 9, 12, 0), datetime(2023, 11, 14, 15, 0), "1h", "standard"], 66 | ], 67 | "MKR" :[ 68 | [datetime(2023, 6, 26, 13, 0), datetime(2023, 7, 17, 12, 0), "4h", "standard"], 69 | ], 70 | "CRV": [ 71 | [datetime(2024, 11, 4, 0, 0), datetime(2024, 11, 21, 0, 0), "4h", "uptrend"], 72 | [datetime(2024, 11, 4, 0, 0), datetime(2024, 11, 28, 0, 0), "4h", "uptrend_2"], 73 | 74 | ], 75 | "GMT": [ 76 | [datetime(2022, 3, 26, 9, 0), datetime(2022, 4, 14, 21, 0), "4h", "uptrend"] 77 | ], 78 | "SOL": [ 79 | [datetime(2023, 9, 23, 0, 0), datetime(2023, 10, 15, 21, 0), "4h", "standard"] 80 | ], 81 | "LQTY": [ 82 | [datetime(2025, 5, 7, 5, 0), datetime(2025, 5, 9, 21, 0), "30m", "standard"] 83 | ], 84 | "MOODENG":[ 85 | [datetime(2025, 5, 8, 0, 0), datetime(2025, 5, 11, 1, 0), "1h", "standard"] 86 | ], 87 | } 88 | 89 | # ================ Configuration ================ 90 | # Timezone for datetime conversion 91 | TIMEZONE = "America/Los_Angeles" 92 | 93 | # Output directory path 94 | 95 | OUTPUT_DIR = "similarity_output" 96 | 97 | # Timeframes to analyze 98 | TIMEFRAMES_TO_ANALYZE = ["15m", "30m", "1h", "2h", "4h"] 99 | 100 | # Top K symbols to record per reference trend 101 | TOP_K = 10 102 | 103 | # DTW window constraint 104 | DTW_WINDOW_RATIO = 0.2 # Range: 0.0 to 1.0 105 | DTW_WINDOW_RATIO_FOR_DIFF = 0.1 # Range: 0.0 to 1.0 106 | 107 | # Maximum distance limit for DTW matching points 108 | DTW_MAX_POINT_DISTANCE = 0.66 # Range: 0.0 to 2.0 109 | DTW_MAX_POINT_DISTANCE_FOR_DIFF = 0.5 # Range: 0.0 to 2.0 110 | 111 | # Weights for price and difference features in Shape DTW calculation 112 | SHAPEDTW_BALANCE_PD_RATIO = 4 # Ratio of price distance to SMA difference distance 113 | PRICE_WEIGHT = 0.4 114 | DIFF_WEIGHT = 0.6 115 | 116 | # Shape descriptor parameters for SMA differences 117 | SLOPE_WINDOW_SIZE = 5 # Window size for slope descriptor 118 | PAA_WINDOW_SIZE = 5 # Window size for PAA descriptor 119 | 120 | # Constants for similarity calculation 121 | SMA_PERIODS = [30, 45, 60] 122 | DTW_WINDOW_FACTORS = [0.9, 0.95, 1.0, 1.05, 1.1] # Window size factors for DTW 123 | MIN_QUERY_LENGTH = 60 # Minimum number of data points required for query trend 124 | 125 | # BINANCE API request interval parameter 126 | API_SLEEP_SECONDS = 0.5 # Sleep time between requests (seconds) 127 | 128 | # Request buffer ratio 129 | REQUEST_TIME_BUFFER_RATIO = 1.2 130 | 131 | 132 | # ================ Data Processing Classes ================ 133 | 134 | class DataProcessor(BaseDataProcessor): 135 | """Data processor for both stocks and cryptocurrencies""" 136 | 137 | def __init__(self, asset_type: str, config: TrendAnalysisConfig = None): 138 | """Initialize appropriate downloader based on asset type""" 139 | super().__init__(asset_type, config.sma_periods if config else None) 140 | 141 | if asset_type == "crypto": 142 | self.downloader = CryptoDownloader() 143 | else: 144 | self.downloader = StockDownloader(save_dir=".", api_file="api_keys.json") 145 | 146 | self.config = config or TrendAnalysisConfig() 147 | 148 | def get_data(self, symbol: str, timeframe: str, start_ts: int, end_ts: int, 149 | is_crypto: bool = True, include_buffer: bool = True, 150 | is_reference: bool = False) -> pd.DataFrame: 151 | """Get data with buffer period for SMA calculation""" 152 | if include_buffer: 153 | # Calculate buffer period for SMA calculation 154 | interval = end_ts - start_ts 155 | buffer_start_ts = start_ts - interval 156 | else: 157 | buffer_start_ts = start_ts 158 | 159 | if is_crypto is None: 160 | is_crypto = (self.asset_type == "crypto") 161 | 162 | # Get data using the appropriate downloader 163 | if is_crypto: 164 | # For crypto, add USDT if not already there 165 | if not symbol.endswith("USDT"): 166 | symbol_full = f"{symbol}USDT" 167 | else: 168 | symbol_full = symbol 169 | 170 | # Set validate=False for reference trends, otherwise use default (True) 171 | success, df = self.downloader.get_data( 172 | symbol_full, 173 | buffer_start_ts, 174 | end_ts, 175 | validate=not is_reference, # Disable validation for reference trends 176 | timeframe=timeframe 177 | ) 178 | else: # stock 179 | success, df = self.downloader.get_data( 180 | symbol, 181 | buffer_start_ts, 182 | end_ts, 183 | validate=not is_reference, # Disable validation for reference trends 184 | timeframe=timeframe 185 | ) 186 | 187 | if not success or df is None or df.empty: 188 | print(f"Failed to get data for {symbol} ({timeframe})") 189 | return pd.DataFrame() 190 | 191 | # Filter to requested time range 192 | start_time = pd.Timestamp.fromtimestamp(start_ts) 193 | end_time = pd.Timestamp.fromtimestamp(end_ts) 194 | 195 | # Use the processor from common to prepare the dataframe 196 | df = self.processor.prepare_dataframe(df) 197 | 198 | # Filter to requested time range after preparation 199 | df = df[(df.index >= start_time) & (df.index <= end_time)] 200 | 201 | return df 202 | 203 | 204 | # ================ DTW Similarity Calculator ================ 205 | 206 | class DTWSimilarityCalculator: 207 | """Calculate similarity using DTW and ShapeDTW algorithms""" 208 | 209 | def __init__(self, config: TrendAnalysisConfig): 210 | """Initialize DTW calculator with configuration""" 211 | self.config = config 212 | self.dtw_calc = DTWCalculator(config) 213 | 214 | def find_best_similarity_window(self, query_df: pd.DataFrame, target_df: pd.DataFrame) -> dict: 215 | """Find best similarity window based on price and difference features""" 216 | query_len = len(query_df) 217 | 218 | # Check if target sequence is long enough 219 | if len(target_df) < query_len * min(self.config.window_scale_factors): 220 | return { 221 | "similarity": 0.0, 222 | "price_distance": float('inf'), 223 | "diff_distance": float('inf'), 224 | "price_path": None, 225 | "diff_path": None, 226 | "window_data": None, 227 | "window_info": None 228 | } 229 | 230 | best_similarity = -1 231 | best_price_distance = float('inf') 232 | best_diff_distance = float('inf') 233 | best_price_path = None 234 | best_diff_path = None 235 | best_window_data = None 236 | best_window_info = None 237 | 238 | # Pre-process reference sequence features 239 | query_price_norm, query_diff_norm = self.dtw_calc.normalize_features(query_df) 240 | 241 | # Define shape descriptors 242 | price_descriptor, diff_descriptor = self.dtw_calc.create_shape_descriptors() 243 | 244 | # Try different window sizes, but fix right boundary at the last time point of target sequence 245 | for factor in self.config.window_scale_factors: 246 | window_size = int(query_len * factor) 247 | 248 | # Skip if window size exceeds target sequence length 249 | if window_size > len(target_df): 250 | continue 251 | 252 | # Calculate window start index, fixing right boundary at the latest data point 253 | start_idx = len(target_df) - window_size 254 | 255 | # Extract window data 256 | window = target_df.iloc[start_idx:len(target_df)] 257 | 258 | # Confirm window length is correct 259 | if len(window) != window_size: 260 | print(f"Warning: Window size mismatch. Expected {window_size}, got {len(window)}") 261 | continue 262 | 263 | # Normalize target window features - based on current window data 264 | window_price_norm, window_diff_norm = self.dtw_calc.normalize_features(window) 265 | 266 | # Calculate DTW for price features (using dtaidistance) - for initial screening 267 | _, dtw_price_distance, _ = self.dtw_calc.calculate_dtw_similarity( 268 | query_price_norm, window_price_norm, self.config.dtw_window_ratio, self.config.dtw_max_point_distance 269 | ) 270 | 271 | # If no valid path found for price features (distance is inf), continue to next factor 272 | if np.isinf(dtw_price_distance): 273 | print(f" Factor {factor}: No valid price path found due to max_step constraint") 274 | continue 275 | 276 | # Calculate DTW for SMA difference features (using dtaidistance) - for initial screening 277 | _, dtw_diff_distance, _ = self.dtw_calc.calculate_dtw_similarity( 278 | query_diff_norm, window_diff_norm, self.config.dtw_window_ratio_diff, self.config.dtw_max_point_distance_diff 279 | ) 280 | 281 | # If no valid path found for SMA difference features (distance is inf), continue to next factor 282 | if np.isinf(dtw_diff_distance): 283 | print(f" Factor {factor}: No valid SMA diff path found due to max_step constraint") 284 | continue 285 | 286 | # Use dynamic subsequence width based on window factor 287 | subsequence_width = max(2, min(5, int(factor * 3))) 288 | 289 | # Calculate ShapeDTW for price features 290 | price_shape_dist, price_shape_path = self.dtw_calc.calculate_shapedtw( 291 | query_price_norm, window_price_norm, price_descriptor, self.config.dtw_window_ratio, subsequence_width 292 | ) 293 | 294 | # If no valid path found for price features, continue to next factor 295 | if np.isinf(price_shape_dist): 296 | print(f" Factor {factor}: No valid shape path found for price features") 297 | continue 298 | 299 | # Calculate ShapeDTW for difference features 300 | diff_shape_dist, diff_shape_path = self.dtw_calc.calculate_shapedtw( 301 | query_diff_norm, window_diff_norm, diff_descriptor, self.config.dtw_window_ratio_diff, subsequence_width 302 | ) 303 | 304 | # If no valid path found for difference features, continue to next factor 305 | if np.isinf(diff_shape_dist): 306 | print(f" Factor {factor}: No valid shape path found for diff features") 307 | continue 308 | 309 | # Calculate overall similarity using arithmetic mean 310 | price_score = 1 / (1 + price_shape_dist) 311 | sma_diff_score = 1 / (1 + diff_shape_dist * self.config.shapedtw_balance_pd_ratio) 312 | similarity = (price_score * self.config.price_weight) + (sma_diff_score * self.config.diff_weight) 313 | 314 | # If similarity is higher, update best results 315 | if similarity > best_similarity: 316 | best_similarity = similarity 317 | best_price_distance = price_shape_dist 318 | best_diff_distance = diff_shape_dist 319 | best_price_path = price_shape_path 320 | best_diff_path = diff_shape_path 321 | best_window_data = window 322 | best_window_info = (factor, start_idx, len(target_df)) 323 | 324 | print(f" Factor {factor}: similarity={similarity:.4f}, price_shape_dist={price_shape_dist:.4f}, " 325 | f"diff_shape_dist={diff_shape_dist:.4f}, window={start_idx}:{len(target_df)}, " 326 | f"period={window.index[0]} to {window.index[-1]}") 327 | 328 | if best_window_data is not None: 329 | print(f" Best window: {best_window_data.index[0]} to {best_window_data.index[-1]}") 330 | 331 | # Return all relevant information as a dictionary 332 | return { 333 | "similarity": best_similarity, 334 | "price_distance": best_price_distance, 335 | "diff_distance": best_diff_distance, 336 | "price_path": best_price_path, 337 | "diff_path": best_diff_path, 338 | "window_data": best_window_data, 339 | "window_info": best_window_info 340 | } 341 | 342 | 343 | # ================ Utility Functions ================ 344 | 345 | def process_symbol_dtw(args: tuple) -> dict: 346 | """Process a single symbol DTW comparison (designed for multiprocessing)""" 347 | target_symbol, target_df, timeframe, ref_symbol, ref_idx, ref_df, ref_timeframe, ref_label, config = args 348 | 349 | print(f"Processing DTW for {target_symbol} [{timeframe}] against {ref_symbol} reference #{ref_idx} ({ref_label}) [{ref_timeframe}]...") 350 | 351 | # Calculate similarity 352 | dtw_calculator = DTWSimilarityCalculator(config) 353 | similarity_result = dtw_calculator.find_best_similarity_window( 354 | ref_df, target_df 355 | ) 356 | 357 | return { 358 | "symbol": target_symbol, 359 | "timeframe": timeframe, 360 | "ref_symbol": ref_symbol, 361 | "ref_idx": ref_idx, 362 | "ref_timeframe": ref_timeframe, 363 | "ref_label": ref_label, 364 | "score": similarity_result["similarity"], 365 | "price_distance": similarity_result["price_distance"], 366 | "diff_distance": similarity_result["diff_distance"], 367 | "price_path": similarity_result["price_path"], 368 | "diff_path": similarity_result["diff_path"], 369 | "window_data": similarity_result["window_data"], 370 | "window_info": similarity_result["window_info"] 371 | } 372 | 373 | 374 | # ================ Visualization Functions ================ 375 | 376 | def visualize_dtw_alignment(query_df, window_df, price_path, ref_symbol, target_symbol, 377 | timeframe, similarity, save_dir, ref_label, price_distance, diff_distance): 378 | """Visualize DTW alignment using candlestick charts with volume and connection lines""" 379 | if not price_path: 380 | print(f"No warping path available for visualization of {target_symbol}") 381 | return 382 | 383 | # Create figure with normalized data for plotting 384 | fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(20, 16), sharex=False, gridspec_kw={'height_ratios': [1, 1]}) 385 | 386 | ref_normalized_df, _ = DataNormalizer.normalize_ohlc_dataframe(query_df, include_volume=True) 387 | target_normalized_df, _ = DataNormalizer.normalize_ohlc_dataframe(window_df, include_volume=True) 388 | 389 | # Plot reference sequence with candlesticks and volume 390 | plot_candlesticks_with_volume(ax1, ref_normalized_df, volume_ratio=0.12) 391 | ax1.plot(ref_normalized_df.index, ref_normalized_df['SMA_30'], 'blue', linewidth=1, alpha=0.7, label='SMA30') 392 | ax1.plot(ref_normalized_df.index, ref_normalized_df['SMA_45'], 'orange', linewidth=1, alpha=0.7, label='SMA45') 393 | ax1.plot(ref_normalized_df.index, ref_normalized_df['SMA_60'], 'purple', linewidth=1, alpha=0.7, label='SMA60') 394 | ax1.set_ylabel('Normalized Price') 395 | ax1.set_title(f'{ref_symbol} Reference Trend - {ref_label} (Length: {len(query_df)} points)') 396 | ax1.grid(True, alpha=0.3) 397 | ax1.legend(loc='upper left') 398 | 399 | # Plot target sequence with candlesticks and volume 400 | plot_candlesticks_with_volume(ax2, target_normalized_df) 401 | ax2.plot(target_normalized_df.index, target_normalized_df['SMA_30'], 'blue', linewidth=1, alpha=0.7, label='SMA30') 402 | ax2.plot(target_normalized_df.index, target_normalized_df['SMA_45'], 'orange', linewidth=1, alpha=0.7, label='SMA45') 403 | ax2.plot(target_normalized_df.index, target_normalized_df['SMA_60'], 'purple', linewidth=1, alpha=0.7, label='SMA60') 404 | ax2.set_ylabel('Normalized Price') 405 | ax2.set_title(f'{target_symbol} {timeframe} Trend (Length: {len(window_df)} points, Latest: {format_dt_with_tz(window_df.index[-1], TIMEZONE)})') 406 | ax2.grid(True, alpha=0.3) 407 | ax2.legend(loc='upper left') 408 | 409 | # Format dates 410 | for ax in [ax1, ax2]: 411 | ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d')) 412 | plt.setp(ax.xaxis.get_majorticklabels(), rotation=45, ha='right') 413 | 414 | price_cols = ['Open', 'High', 'Low', 'Close'] 415 | ref_price_min = ref_normalized_df[price_cols].values.min() 416 | ref_price_max = ref_normalized_df[price_cols].values.max() 417 | target_price_min = target_normalized_df[price_cols].values.min() 418 | target_price_max = target_normalized_df[price_cols].values.max() 419 | 420 | step_size = max(1, len(price_path) // 100) 421 | 422 | for idx, (i, j) in enumerate(price_path): 423 | if idx % step_size == 0: 424 | ref_x = query_df.index[i] 425 | ref_y = ref_normalized_df['Close'].iloc[i] 426 | target_x = window_df.index[j] 427 | target_y = target_normalized_df['Close'].iloc[j] 428 | 429 | if (ref_price_min <= ref_y <= ref_price_max and 430 | target_price_min <= target_y <= target_price_max): 431 | 432 | con = ConnectionPatch( 433 | xyA=(mdates.date2num(ref_x), ref_y), coordsA=ax1.transData, 434 | xyB=(mdates.date2num(target_x), target_y), coordsB=ax2.transData, 435 | color='gray', alpha=0.4, linewidth=0.7, linestyle='-', 436 | zorder=1 437 | ) 438 | 439 | fig.add_artist(con) 440 | 441 | # Add title with parameters and metrics information 442 | plt.suptitle(f'Price/SMA Alignment with Volume\n {ref_symbol}({ref_label}) vs {target_symbol} ({timeframe})\n' 443 | f'Score: {similarity:.4f}, Price Distance: {price_distance:.4f}, SMA Diff Distance: {diff_distance:.4f}', 444 | fontsize=12) 445 | 446 | # Add info text 447 | info_text = (f"Reference Period: {format_dt_with_tz(query_df.index[0], TIMEZONE)} to {format_dt_with_tz(query_df.index[-1], TIMEZONE)}\n" 448 | f"Target Period: {format_dt_with_tz(window_df.index[0], TIMEZONE)} to {format_dt_with_tz(window_df.index[-1], TIMEZONE)}") 449 | 450 | plt.figtext(0.02, 0.02, info_text, fontsize=10, 451 | bbox=dict(facecolor='white', alpha=0.8, boxstyle='round,pad=0.5')) 452 | 453 | # Ensure layout is compact 454 | plt.tight_layout() 455 | plt.subplots_adjust(top=0.9, hspace=0.3) 456 | 457 | # Save image 458 | FileManager.ensure_directories(save_dir) 459 | filename = f"score_{similarity:.4f}_{target_symbol}_{timeframe}.png" 460 | filepath = os.path.join(save_dir, filename) 461 | plt.savefig(filepath, dpi=150, bbox_inches='tight') 462 | plt.close(fig) 463 | 464 | print(f"Saved visualization to {filepath}") 465 | 466 | 467 | def visualize_sma_differences(query_df, window_df, diff_path, ref_symbol, target_symbol, 468 | timeframe, similarity, save_dir, ref_label, price_distance, diff_distance): 469 | """Visualize SMA differences between reference and target sequences (no volume)""" 470 | if not diff_path: 471 | print(f"No SMA diff warping path available for visualization of {target_symbol}") 472 | return 473 | 474 | # Create figure 475 | fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(20, 16), sharex=False, gridspec_kw={'height_ratios': [1, 1]}) 476 | 477 | # Define SMA difference columns 478 | diff_cols = ['SMA30_SMA45', 'SMA30_SMA60', 'SMA45_SMA60'] 479 | 480 | # Extract and normalize reference SMA differences for plotting 481 | ref_diffs = query_df[diff_cols].values 482 | ref_normalized = DataNormalizer.normalize_to_range(ref_diffs) 483 | 484 | # Reshape back to format needed for plotting 485 | ref_normalized_reshaped = ref_normalized.reshape(query_df[diff_cols].shape) 486 | ref_normalized_df = pd.DataFrame( 487 | ref_normalized_reshaped, 488 | index=query_df.index, 489 | columns=diff_cols 490 | ) 491 | 492 | # Extract and normalize target SMA differences for plotting 493 | target_diffs = window_df[diff_cols].values 494 | target_normalized = DataNormalizer.normalize_to_range(target_diffs) 495 | 496 | # Reshape back to format needed for plotting 497 | target_normalized_reshaped = target_normalized.reshape(window_df[diff_cols].shape) 498 | target_normalized_df = pd.DataFrame( 499 | target_normalized_reshaped, 500 | index=window_df.index, 501 | columns=diff_cols 502 | ) 503 | 504 | # Define line styles and colors for better visual distinction 505 | line_styles = [ 506 | {'color': 'green', 'linestyle': '-', 'linewidth': 2, 'label': 'SMA30-SMA45'}, 507 | {'color': 'blue', 'linestyle': '-', 'linewidth': 2, 'label': 'SMA30-SMA60'}, 508 | {'color': 'orange', 'linestyle': '-', 'linewidth': 2, 'label': 'SMA45-SMA60'} 509 | ] 510 | 511 | # Plot reference sequence differences 512 | for i, col in enumerate(diff_cols): 513 | ax1.plot(ref_normalized_df.index, ref_normalized_df[col], **line_styles[i]) 514 | ax1.set_ylabel('Normalized SMA Differences') 515 | ax1.set_title(f'{ref_symbol} Reference SMA Diff - {ref_label} (Length: {len(query_df)} points)') 516 | ax1.grid(True, alpha=0.3) 517 | ax1.legend(loc='upper left') 518 | 519 | # Plot target sequence differences 520 | for i, col in enumerate(diff_cols): 521 | ax2.plot(target_normalized_df.index, target_normalized_df[col], **line_styles[i]) 522 | ax2.set_ylabel('Normalized SMA Differences') 523 | ax2.set_title(f'{target_symbol} {timeframe} SMA Diff (Length: {len(window_df)} points, Latest: {window_df.index[-1].strftime("%Y-%m-%d")})') 524 | ax2.grid(True, alpha=0.3) 525 | ax2.legend(loc='upper left') 526 | 527 | # Format dates 528 | for ax in [ax1, ax2]: 529 | ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d')) 530 | plt.setp(ax.xaxis.get_majorticklabels(), rotation=45, ha='right') 531 | 532 | # Set y-axis range for each chart with padding for aesthetics 533 | ref_min = ref_normalized_df.values.min() 534 | ref_max = ref_normalized_df.values.max() 535 | ref_padding = (ref_max - ref_min) * 0.1 536 | ax1.set_ylim(ref_min - ref_padding, ref_max + ref_padding) 537 | 538 | target_min = target_normalized_df.values.min() 539 | target_max = target_normalized_df.values.max() 540 | target_padding = (target_max - target_min) * 0.1 541 | ax2.set_ylim(target_min - target_padding, target_max + ref_padding) 542 | 543 | # Add warping path visualization 544 | # Determine number of connection points (reduce density of connections) 545 | step_size = max(1, len(diff_path) // 100) 546 | 547 | # Add markers at connection points - focus on SMA30-SMA45 as the primary line 548 | connected_ref_indices = [i for i, _ in diff_path[::step_size]] 549 | connected_target_indices = [j for _, j in diff_path[::step_size]] 550 | 551 | # Place small markers at connection points on SMA30-SMA45 line 552 | ax1.scatter(query_df.index[connected_ref_indices], 553 | ref_normalized_df['SMA30_SMA45'].iloc[connected_ref_indices], 554 | color='darkgreen', s=15, alpha=0.6, zorder=5) 555 | ax2.scatter(window_df.index[connected_target_indices], 556 | target_normalized_df['SMA30_SMA45'].iloc[connected_target_indices], 557 | color='darkgreen', s=15, alpha=0.6, zorder=5) 558 | 559 | # Draw connecting lines using ConnectionPatch 560 | for idx, (i, j) in enumerate(diff_path): 561 | if idx % step_size == 0: 562 | ref_x = query_df.index[i] 563 | ref_y = ref_normalized_df['SMA30_SMA45'].iloc[i] 564 | target_x = window_df.index[j] 565 | target_y = target_normalized_df['SMA30_SMA45'].iloc[j] 566 | 567 | # Create a connection patch between the two points 568 | con = ConnectionPatch( 569 | xyA=(mdates.date2num(ref_x), ref_y), coordsA=ax1.transData, 570 | xyB=(mdates.date2num(target_x), target_y), coordsB=ax2.transData, 571 | color='gray', alpha=0.4, linewidth=0.7, linestyle='-', 572 | zorder=1 573 | ) 574 | 575 | fig.add_artist(con) 576 | 577 | # Add title with parameters and metrics information 578 | plt.suptitle(f'SMA Differences\n {ref_symbol}({ref_label}) vs {target_symbol} ({timeframe})\n' 579 | f'Score: {similarity:.4f}, Price Distance: {price_distance:.4f}, SMA Diff Distance: {diff_distance:.4f}', 580 | fontsize=12) 581 | 582 | # Add info text 583 | info_text = (f"Reference Period: {format_dt_with_tz(query_df.index[0], TIMEZONE)} to {format_dt_with_tz(query_df.index[-1], TIMEZONE)}\n" 584 | f"Target Period: {format_dt_with_tz(window_df.index[0], TIMEZONE)} to {format_dt_with_tz(window_df.index[-1], TIMEZONE)}") 585 | 586 | # Add text box 587 | plt.figtext(0.02, 0.02, info_text, fontsize=10, 588 | bbox=dict(facecolor='white', alpha=0.8, boxstyle='round,pad=0.5')) 589 | 590 | # Ensure layout is compact 591 | plt.tight_layout() 592 | plt.subplots_adjust(top=0.9, hspace=0.3) 593 | 594 | # Save image with modified filename including the diff_distance 595 | FileManager.ensure_directories(save_dir) 596 | filename = f"diff_distance_{diff_distance:.4f}_{target_symbol}_{timeframe}.png" 597 | filepath = os.path.join(save_dir, filename) 598 | plt.savefig(filepath, dpi=150, bbox_inches='tight') 599 | plt.close(fig) 600 | 601 | print(f"Saved SMA differences visualization to {filepath}") 602 | 603 | 604 | # ================ Main Function ================ 605 | 606 | def main(): 607 | """Main function to run the DTW similarity analysis""" 608 | # Record start time 609 | start_time = time.time() 610 | 611 | parser = argparse.ArgumentParser(description='Analyze trend similarity for crypto or stock symbols') 612 | parser.add_argument('-f', '--file', required=False, default="", help='Path to strong target file') 613 | parser.add_argument('--asset', choices=['crypto', 'stock'], default='crypto', help='Asset type of input file (default: crypto)') 614 | parser.add_argument('-nv', '--no_visualize', action='store_true', default=False, help='Enable visualization of DTW alignments') 615 | parser.add_argument('-k', '--topk', type=int, default=TOP_K, help=f'Number of top symbols to record per reference trend (default: {TOP_K})') 616 | parser.add_argument('-s', '--sleep', type=float, default=API_SLEEP_SECONDS, help=f'Sleep time between API requests in seconds (default: {API_SLEEP_SECONDS})') 617 | args = parser.parse_args() 618 | 619 | # Create configuration from script constants 620 | config = TrendAnalysisConfig() 621 | config.sma_periods = SMA_PERIODS 622 | config.dtw_window_ratio = DTW_WINDOW_RATIO 623 | config.dtw_window_ratio_diff = DTW_WINDOW_RATIO_FOR_DIFF 624 | config.dtw_max_point_distance = DTW_MAX_POINT_DISTANCE 625 | config.dtw_max_point_distance_diff = DTW_MAX_POINT_DISTANCE_FOR_DIFF 626 | config.shapedtw_balance_pd_ratio = SHAPEDTW_BALANCE_PD_RATIO 627 | config.price_weight = PRICE_WEIGHT 628 | config.diff_weight = DIFF_WEIGHT 629 | config.slope_window_size = SLOPE_WINDOW_SIZE 630 | config.paa_window_size = PAA_WINDOW_SIZE 631 | config.window_scale_factors = DTW_WINDOW_FACTORS 632 | config.min_query_length = MIN_QUERY_LENGTH 633 | config.api_sleep_seconds = args.sleep 634 | config.request_time_buffer_ratio = REQUEST_TIME_BUFFER_RATIO 635 | 636 | # Use parameters 637 | enable_visualization = not args.no_visualize 638 | top_k = args.topk 639 | api_sleep_seconds = args.sleep 640 | 641 | print(f"\nAnalysis Configuration:") 642 | print(f"Input Asset Type: {args.asset}") 643 | print(f"Reference Trends: {len(REFERENCE_TRENDS)} symbols with {sum(len(trends) for trends in REFERENCE_TRENDS.values())} total trends") 644 | print(f"Timeframes to Analyze: {TIMEFRAMES_TO_ANALYZE}") 645 | print(f"DTW Window Ratio: {config.dtw_window_ratio}") 646 | print(f"DTW Window Ratio for DIFF: {config.dtw_window_ratio_diff}") 647 | print(f"DTW Max Point Distance: {config.dtw_max_point_distance}") 648 | print(f"DTW Max Point Distance for DIFF: {config.dtw_max_point_distance_diff}") 649 | print(f"Slope Window Size: {config.slope_window_size}") 650 | print(f"PAA Window Size: {config.paa_window_size}") 651 | print(f"Price Weight: {config.price_weight}, Diff Weight: {config.diff_weight}") 652 | print(f"Using ShapeDTW for both price and difference features") 653 | print(f"Request Time Buffer Ratio: {config.request_time_buffer_ratio}") 654 | print(f"API Sleep Time: {api_sleep_seconds} seconds") 655 | print(f"Visualization Enabled: {enable_visualization}") 656 | print(f"Top K Symbols: {top_k}\n") 657 | 658 | # Initialize data processor 659 | data_processor = DataProcessor(args.asset, config) 660 | 661 | # Get target symbols - either from file or all available 662 | if args.file: 663 | target_symbols = parse_target_symbols(args.file) 664 | if not target_symbols: 665 | # If file provided but parsing failed, use all symbols 666 | if args.asset == "crypto": 667 | target_symbols = data_processor.downloader.get_all_symbols() 668 | # Remove USDT suffix for clean symbol names 669 | target_symbols = [s.replace('USDT', '') for s in target_symbols] 670 | else: 671 | target_symbols = data_processor.downloader.get_all_tickers() 672 | else: 673 | # No file provided, use all symbols 674 | if args.asset == "crypto": 675 | target_symbols = data_processor.downloader.get_all_symbols() 676 | # Remove USDT suffix for clean symbol names 677 | target_symbols = [s.replace('USDT', '') for s in target_symbols] 678 | else: 679 | target_symbols = data_processor.downloader.get_all_tickers() 680 | 681 | print(f"Found {len(target_symbols)} targets to analyze") 682 | 683 | # Create output directory 684 | output_dir = create_output_directory(OUTPUT_DIR) 685 | 686 | # Store all results 687 | all_results = {} 688 | 689 | # Load or retrieve all reference trend data using unified manager 690 | reference_data = {} 691 | for ref_symbol, ref_trends in REFERENCE_TRENDS.items(): 692 | for ref_idx, ref_trend_info in enumerate(ref_trends): 693 | start_datetime, end_datetime, ref_timeframe, ref_label = ref_trend_info 694 | 695 | ref_df = ReferenceDataManager.load_or_fetch_reference_data( 696 | ref_symbol, start_datetime, end_datetime, ref_timeframe, ref_label, 697 | OUTPUT_DIR, TIMEZONE, data_processor, config 698 | ) 699 | 700 | if ref_df is not None: 701 | reference_data[(ref_symbol, ref_idx)] = { 702 | 'df': ref_df, 703 | 'timeframe': ref_timeframe, 704 | 'label': ref_label 705 | } 706 | 707 | # Find longest sequence length (data point count) in reference data 708 | max_ref_length = 0 709 | for ref_info in reference_data.values(): 710 | ref_df = ref_info['df'] 711 | max_ref_length = max(max_ref_length, len(ref_df)) 712 | 713 | print(f"Maximum reference trend length: {max_ref_length} data points") 714 | 715 | # Process each timeframe for target symbols 716 | for timeframe in TIMEFRAMES_TO_ANALYZE: 717 | print(f"\nProcessing timeframe: {timeframe}") 718 | 719 | # Initialize results for this timeframe 720 | all_results[timeframe] = {} 721 | 722 | # Calculate seconds corresponding to current timeframe 723 | timeframe_seconds = calculate_timeframe_seconds(timeframe) 724 | 725 | # Calculate history duration to request (seconds) 726 | # Use longest reference sequence length * timeframe seconds * max_window_factor * buffer ratio 727 | history_seconds = int(max_ref_length * timeframe_seconds * max(config.window_scale_factors) * config.request_time_buffer_ratio) 728 | 729 | # Get current time as end_ts 730 | end_ts = int(datetime.now().timestamp()) 731 | start_ts = end_ts - history_seconds 732 | 733 | print(f"Current timestamp: {end_ts}, date: {datetime.fromtimestamp(end_ts)}") 734 | print(f"Calculated history duration: {history_seconds} seconds ({history_seconds/86400:.1f} days)") 735 | print(f"Start timestamp: {start_ts}, date: {datetime.fromtimestamp(start_ts)}") 736 | 737 | # For crypto, pre-fetch all target data (has API request delay) 738 | if args.asset == "crypto": 739 | print(f"Getting data for all crypto symbols in timeframe {timeframe}...") 740 | target_data = {} 741 | 742 | for symbol in target_symbols: 743 | print(f"Getting data for {symbol} [{timeframe}]...") 744 | 745 | # Get data for this symbol 746 | df = data_processor.get_data( 747 | symbol, 748 | timeframe, 749 | start_ts, 750 | end_ts, 751 | is_crypto=True 752 | ) 753 | 754 | if not df.empty and len(df) > 0: 755 | print(f" Got data from {df.index[0]} to {df.index[-1]}, {len(df)} points") 756 | target_data[symbol] = df 757 | 758 | # Sleep to avoid triggering API rate limits 759 | time.sleep(api_sleep_seconds) 760 | 761 | print(f"Successfully retrieved data for {len(target_data)} out of {len(target_symbols)} symbols") 762 | else: 763 | # For stocks, we'll fetch data on-demand in parallel processing 764 | target_data = None 765 | 766 | # Process each reference trend, comparing with current time range 767 | for (ref_symbol, ref_idx), ref_info in reference_data.items(): 768 | ref_df = ref_info['df'] 769 | ref_timeframe = ref_info['timeframe'] 770 | label = ref_info['label'] 771 | 772 | print(f"\nAnalyzing {ref_symbol} reference #{ref_idx} ({label}, {ref_timeframe}) against target timeframe {timeframe}:") 773 | 774 | # For crypto, use pre-fetched data and process in parallel 775 | if args.asset == "crypto": 776 | # Filter symbols with sufficient data points 777 | valid_symbols = [] 778 | valid_dfs = [] 779 | 780 | for symbol, df in target_data.items(): 781 | if len(df) >= len(ref_df) * min(config.window_scale_factors): 782 | valid_symbols.append(symbol) 783 | valid_dfs.append(df) 784 | print(f" {symbol}: data period {df.index[0]} to {df.index[-1]}, {len(df)} points") 785 | 786 | # Prepare multiprocessing arguments 787 | process_args = [ 788 | (symbol, df, timeframe, ref_symbol, ref_idx, ref_df, ref_timeframe, label, config) 789 | for symbol, df in zip(valid_symbols, valid_dfs) 790 | ] 791 | 792 | # Process DTW calculations in parallel 793 | with Pool(processes=min(cpu_count()-1, len(valid_symbols))) if len(valid_symbols) > 1 else Pool(processes=1) as pool: 794 | results = pool.map(process_symbol_dtw, process_args) 795 | 796 | # Process results 797 | target_scores = {} 798 | for result in results: 799 | target_scores[result["symbol"]] = result 800 | if result["window_data"] is not None: 801 | window_start = result["window_data"].index[0] 802 | window_end = result["window_data"].index[-1] 803 | print(f" {result['symbol']}: score={result['score']:.4f}, " 804 | f"price_distance={result['price_distance']:.4f}, " 805 | f"diff_distance={result['diff_distance']:.4f}, " 806 | f"window={window_start} to {window_end}") 807 | 808 | else: # For stocks, fetch data and process in one step (not implemented here) 809 | # This part needs additional code to handle stocks 810 | continue 811 | 812 | # Sort targets by score 813 | sorted_targets = sorted( 814 | target_scores.keys(), 815 | key=lambda x: target_scores[x]["score"], 816 | reverse=True 817 | ) 818 | 819 | # Store results 820 | key = f"{ref_symbol}_{ref_idx}_{ref_timeframe}" 821 | all_results[timeframe][key] = { 822 | 'ref_symbol': ref_symbol, 823 | 'ref_idx': ref_idx, 824 | 'ref_timeframe': ref_timeframe, 825 | 'label': label, 826 | 'targets': sorted_targets[:top_k], 827 | 'results': target_scores 828 | } 829 | 830 | # If enabled, generate visualizations for top K symbols 831 | if enable_visualization: 832 | print(f"Generating visualizations for top {top_k} symbols...") 833 | # Use updated folder naming format with reference timeframe 834 | vis_dir = f"{output_dir}/vis_{timeframe}_{ref_symbol}_{ref_timeframe}_{label}" 835 | for symbol in sorted_targets[:top_k]: 836 | result = target_scores[symbol] 837 | if result["price_path"] and result["window_data"] is not None: 838 | visualize_dtw_alignment( 839 | ref_df, 840 | result["window_data"], 841 | result["price_path"], 842 | ref_symbol, 843 | symbol, 844 | timeframe, 845 | result["score"], 846 | vis_dir, 847 | label, 848 | result["price_distance"], 849 | result["diff_distance"] 850 | ) 851 | 852 | visualize_sma_differences( 853 | ref_df, 854 | result["window_data"], 855 | result["diff_path"], 856 | ref_symbol, 857 | symbol, 858 | timeframe, 859 | result["score"], 860 | vis_dir, 861 | label, 862 | result["price_distance"], 863 | result["diff_distance"] 864 | ) 865 | 866 | # Create a string list for the detailed results (optimization) 867 | summary = [] 868 | summary.append("\n============= DETAILED RESULTS =============") 869 | summary.append(f"DTW Similarity Analysis") 870 | summary.append(f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M')}") 871 | summary.append(f"Asset Type: {args.asset}") 872 | summary.append(f"DTW Window Ratio: {config.dtw_window_ratio}") 873 | summary.append(f"DTW Window Ratio for DIFF: {config.dtw_window_ratio_diff}") 874 | summary.append(f"DTW Max Point Distance: {config.dtw_max_point_distance}") 875 | summary.append(f"DTW Max Point Distance for DIFF: {config.dtw_max_point_distance_diff}") 876 | summary.append(f"Slope Window Size: {config.slope_window_size}") 877 | summary.append(f"PAA Window Size: {config.paa_window_size}") 878 | summary.append(f"Price Weight: {config.price_weight}, Diff Weight: {config.diff_weight}") 879 | summary.append(f"Scoring Method: Arithmetic mean of price and SMA difference similarities") 880 | summary.append(f"ShapeDTW enabled for both price and SMA difference features") 881 | summary.append(f"Request Time Buffer Ratio: {config.request_time_buffer_ratio}\n") 882 | 883 | for timeframe in TIMEFRAMES_TO_ANALYZE: 884 | if timeframe not in all_results: 885 | continue 886 | 887 | summary.append(f"\n{'='*40}") 888 | summary.append(f"TIMEFRAME: {timeframe}") 889 | summary.append(f"{'='*40}") 890 | 891 | for key, results in all_results[timeframe].items(): 892 | ref_symbol = results['ref_symbol'] 893 | ref_idx = results['ref_idx'] 894 | ref_timeframe = results['ref_timeframe'] 895 | label = results['label'] 896 | 897 | summary.append(f"\n--- {ref_symbol} Reference #{ref_idx} ({label}, {ref_timeframe}) ---") 898 | 899 | # Filter out results with infinite distance or score <= 0 900 | valid_targets = [symbol for symbol in results['targets'] 901 | if not np.isinf(results['results'][symbol]["price_distance"]) 902 | and not np.isinf(results['results'][symbol]["diff_distance"]) 903 | and results['results'][symbol]["score"] > 0] 904 | 905 | if len(valid_targets) > 0: 906 | summary.append("Top Similarity Scores:") 907 | 908 | for symbol in valid_targets: 909 | score = results['results'][symbol]["score"] 910 | price_distance = results['results'][symbol]["price_distance"] 911 | diff_distance = results['results'][symbol]["diff_distance"] 912 | window_data = results['results'][symbol]["window_data"] 913 | if window_data is not None: 914 | window_period = f"{format_dt_with_tz(window_data.index[0], TIMEZONE)} to {format_dt_with_tz(window_data.index[-1], TIMEZONE)}" 915 | else: 916 | window_period = "N/A" 917 | summary.append(f"{symbol}: Score={score:.4f}, Price Dist={price_distance:.4f}, SMA Diff Dist={diff_distance:.4f}, Window={window_period}") 918 | 919 | # Join all summary lines into a single string 920 | summary_text = '\n'.join(summary) 921 | 922 | # Print to console 923 | print(summary_text) 924 | 925 | # Save detailed results to file 926 | detail_file = f"{output_dir}/similarity_search_report.txt" 927 | with open(detail_file, "w") as f: 928 | f.write(summary_text) 929 | 930 | print(f"\nDetailed results saved to: {detail_file}") 931 | 932 | # Save TradingView format (excluding results with score -1) 933 | tv_file = f"{output_dir}/{datetime.now().strftime('%Y-%m-%d_%H-%M')}_similar_trend_tradingview.txt" 934 | with open(tv_file, "w") as f: 935 | for timeframe in TIMEFRAMES_TO_ANALYZE: 936 | if timeframe not in all_results: 937 | continue 938 | 939 | for key, results in all_results[timeframe].items(): 940 | ref_symbol = results['ref_symbol'] 941 | ref_idx = results['ref_idx'] 942 | ref_timeframe = results['ref_timeframe'] 943 | label = results['label'] 944 | 945 | f.write(f"\n###{timeframe}_{ref_symbol}_{ref_idx}_{label}\n") 946 | 947 | # Filter symbols with score greater than 0 (exclude -1 score results) 948 | valid_targets = [s for s in results['targets'] if results['results'][s]["score"] > 0] 949 | 950 | if args.asset == "crypto": 951 | symbols_str = ','.join([f"BINANCE:{s}USDT.P" for s in valid_targets]) 952 | else: 953 | symbols_str = ','.join(valid_targets) 954 | 955 | f.write(symbols_str) 956 | 957 | print(f"TradingView format saved to: {tv_file}") 958 | 959 | # Calculate and output total runtime 960 | end_time = time.time() 961 | total_time = end_time - start_time 962 | print(f"\nTotal execution time: {total_time:.2f} seconds ({total_time/60:.2f} minutes)") 963 | 964 | 965 | if __name__ == "__main__": 966 | main() -------------------------------------------------------------------------------- /src/common.py: -------------------------------------------------------------------------------- 1 | """ 2 | Common utilities for trend similarity analysis 3 | Shared functions and classes for both crypto and stock analysis 4 | Supports multiple asset types and provides foundation for volume analysis 5 | """ 6 | 7 | import os 8 | import time 9 | import pickle 10 | import numpy as np 11 | import pandas as pd 12 | from datetime import datetime 13 | from pytz import timezone 14 | from abc import ABC, abstractmethod 15 | from typing import Dict, List, Tuple 16 | from dtaidistance import dtw, dtw_ndim 17 | from shapedtw.shapedtw import shape_dtw 18 | from shapedtw.shapeDescriptors import SlopeDescriptor, PAADescriptor, CompoundDescriptor, RawSubsequenceDescriptor 19 | import matplotlib.pyplot as plt 20 | import matplotlib.dates as mdates 21 | from matplotlib.patches import Rectangle 22 | 23 | 24 | # ================ Configuration Management ================ 25 | 26 | class TrendAnalysisConfig: 27 | """Configuration manager for trend analysis parameters""" 28 | 29 | def __init__(self): 30 | # Default SMA periods 31 | self.sma_periods = [30, 45, 60] 32 | 33 | # Default DTW parameters 34 | self.dtw_window_ratio = 0.2 35 | self.dtw_window_ratio_diff = 0.1 36 | self.dtw_max_point_distance = 0.66 37 | self.dtw_max_point_distance_diff = 0.5 38 | 39 | # Default ShapeDTW parameters 40 | self.shapedtw_balance_pd_ratio = 4 41 | self.price_weight = 0.4 42 | self.diff_weight = 0.6 43 | self.slope_window_size = 5 44 | self.paa_window_size = 5 45 | 46 | # Default window scaling factors 47 | self.window_scale_factors = [0.9, 0.95, 1.0, 1.05, 1.1] 48 | 49 | # Minimum query length 50 | self.min_query_length = 60 51 | 52 | # API settings 53 | self.api_sleep_seconds = 0.5 54 | self.request_time_buffer_ratio = 1.2 55 | 56 | 57 | # ================ Time and Date Utilities ================ 58 | 59 | def calculate_timeframe_seconds(timeframe: str) -> int: 60 | """Calculate seconds for timeframe string (e.g., "15m", "1h", "4h")""" 61 | if 'm' in timeframe: 62 | minutes = int(timeframe.replace('m', '')) 63 | return minutes * 60 64 | elif 'h' in timeframe: 65 | hours = int(timeframe.replace('h', '')) 66 | return hours * 3600 67 | elif 'd' in timeframe: 68 | days = int(timeframe.replace('d', '')) 69 | return days * 86400 70 | else: 71 | print(f"Unknown timeframe format: {timeframe}, defaulting to 1 hour") 72 | return 3600 73 | 74 | 75 | def convert_datetime_to_timestamp(dt_obj: datetime, tz_name: str) -> int: 76 | """Convert datetime object to timestamp with timezone consideration""" 77 | tz = timezone(tz_name) 78 | dt_with_tz = tz.localize(dt_obj) 79 | return int(dt_with_tz.timestamp()) 80 | 81 | 82 | def format_dt_with_tz(dt: pd.Timestamp, tz_name: str = "UTC") -> str: 83 | """Format datetime with timezone consideration""" 84 | if dt.tz is None: 85 | # If no timezone info, assume UTC and convert 86 | dt = dt.tz_localize('UTC') 87 | 88 | # Convert to target timezone 89 | tz = timezone(tz_name) 90 | dt_local = dt.tz_convert(tz) 91 | 92 | return dt_local.strftime('%Y-%m-%d %H:%M') 93 | 94 | 95 | # ================ Data Normalization ================ 96 | 97 | class DataNormalizer: 98 | """Handles data normalization with support for different asset types""" 99 | 100 | @staticmethod 101 | def normalize_to_range(data_array: np.ndarray, target_range: Tuple[float, float] = (-1, 1)) -> np.ndarray: 102 | """Normalize data using Z-score followed by min-max scaling to target range""" 103 | # Convert to numpy array if needed 104 | if isinstance(data_array, (pd.Series, pd.DataFrame)): 105 | data_array = data_array.values 106 | 107 | # Flatten array to compute global statistics 108 | flat_data = data_array.flatten() 109 | 110 | # Step 1: Z-score normalization 111 | global_mean = np.mean(flat_data) 112 | global_std = np.std(flat_data) 113 | 114 | # Handle zero standard deviation 115 | if global_std > 0: 116 | z_scored = (data_array - global_mean) / global_std 117 | else: 118 | # Fallback to min-max if std is zero 119 | global_min = np.min(flat_data) 120 | global_max = np.max(flat_data) 121 | if global_max > global_min: 122 | z_scored = (data_array - global_min) / (global_max - global_min) 123 | else: 124 | z_scored = np.zeros_like(data_array) 125 | 126 | # Step 2: Min-max scaling to target range 127 | z_min = np.min(z_scored) 128 | z_max = np.max(z_scored) 129 | 130 | # Handle case where min equals max 131 | if z_max > z_min: 132 | target_min, target_max = target_range 133 | normalized = target_min + (z_scored - z_min) * (target_max - target_min) / (z_max - z_min) 134 | else: 135 | # If all values are the same, set to the middle of the target range 136 | target_min, target_max = target_range 137 | normalized = np.full_like(z_scored, (target_min + target_max) / 2) 138 | 139 | return normalized 140 | 141 | @staticmethod 142 | def calculate_normalization_params(data_array: np.ndarray, target_range: Tuple[float, float] = (-1, 1)) -> Dict: 143 | """Calculate normalization parameters from data without applying normalization""" 144 | # Convert to numpy array if needed 145 | if isinstance(data_array, (pd.Series, pd.DataFrame)): 146 | data_array = data_array.values 147 | 148 | # Flatten array to compute global statistics 149 | flat_data = data_array.flatten() 150 | 151 | # Step 1: Z-score normalization parameters 152 | global_mean = np.mean(flat_data) 153 | global_std = np.std(flat_data) 154 | 155 | # Handle zero standard deviation case 156 | if global_std > 0: 157 | z_scored = (data_array - global_mean) / global_std 158 | else: 159 | # Fallback to min-max if std is zero 160 | global_min = np.min(flat_data) 161 | global_max = np.max(flat_data) 162 | if global_max > global_min: 163 | z_scored = (data_array - global_min) / (global_max - global_min) 164 | else: 165 | z_scored = np.zeros_like(data_array) 166 | 167 | # Step 2: Min-max scaling parameters 168 | z_min = np.min(z_scored) 169 | z_max = np.max(z_scored) 170 | 171 | target_min, target_max = target_range 172 | 173 | # Return normalization parameters 174 | norm_params = { 175 | 'global_mean': global_mean, 176 | 'global_std': global_std, 177 | 'z_min': z_min, 178 | 'z_max': z_max, 179 | 'target_min': target_min, 180 | 'target_max': target_max 181 | } 182 | 183 | return norm_params 184 | 185 | @staticmethod 186 | def apply_normalization_params(data_array: np.ndarray, norm_params: Dict) -> np.ndarray: 187 | """Apply normalization parameters to data""" 188 | # Convert to numpy array if needed 189 | if isinstance(data_array, (pd.Series, pd.DataFrame)): 190 | data_array = data_array.values 191 | 192 | # Extract parameters 193 | global_mean = norm_params['global_mean'] 194 | global_std = norm_params['global_std'] 195 | z_min = norm_params['z_min'] 196 | z_max = norm_params['z_max'] 197 | target_min = norm_params['target_min'] 198 | target_max = norm_params['target_max'] 199 | 200 | # Apply Z-score normalization using stored parameters 201 | if global_std > 0: 202 | z_scored = (data_array - global_mean) / global_std 203 | else: 204 | # This is a fallback case that should rarely happen 205 | z_scored = np.zeros_like(data_array) 206 | 207 | # Apply min-max scaling using stored parameters 208 | if z_max > z_min: 209 | normalized = target_min + (z_scored - z_min) * (target_max - target_min) / (z_max - z_min) 210 | else: 211 | # This is a fallback case that should rarely happen 212 | normalized = np.full_like(z_scored, (target_min + target_max) / 2) 213 | 214 | return normalized 215 | 216 | @staticmethod 217 | def normalize_ohlc_dataframe(df: pd.DataFrame, include_volume: bool = False) -> Tuple[pd.DataFrame, Dict]: 218 | """Normalize OHLC data in a DataFrame to range [-1, 1]""" 219 | # Extract OHLC columns 220 | ohlc_columns = ['Open', 'High', 'Low', 'Close'] 221 | 222 | # Get values for normalization 223 | ohlc_values = df[ohlc_columns].values 224 | 225 | # Calculate normalization parameters 226 | norm_params = DataNormalizer.calculate_normalization_params(ohlc_values, (-1, 1)) 227 | 228 | # Apply normalization to OHLC 229 | normalized_values = DataNormalizer.apply_normalization_params(ohlc_values, norm_params) 230 | 231 | # Create a copy of the original dataframe 232 | normalized_df = df.copy() 233 | 234 | # Replace OHLC values with normalized ones 235 | for i, column in enumerate(ohlc_columns): 236 | normalized_df[column] = normalized_values[:, i] 237 | 238 | if include_volume and 'Volume' in df.columns: 239 | volume_values = df['Volume'].values.reshape(-1, 1) 240 | volume_norm_params = DataNormalizer.calculate_normalization_params(volume_values, (0, 1)) 241 | normalized_volume = DataNormalizer.apply_normalization_params(volume_values, volume_norm_params) 242 | normalized_df['Volume'] = normalized_volume.flatten() 243 | norm_params['volume_norm_params'] = volume_norm_params 244 | 245 | # Also normalize SMA columns using OHLC parameters 246 | sma_columns = ['SMA_30', 'SMA_45', 'SMA_60'] 247 | for column in sma_columns: 248 | if column in normalized_df.columns: 249 | normalized_df[column] = DataNormalizer.apply_normalization_params( 250 | normalized_df[column].values.reshape(-1, 1), norm_params 251 | ).flatten() 252 | 253 | return normalized_df, norm_params 254 | 255 | 256 | # ================ Time Series Processing ================ 257 | 258 | class TimeSeriesProcessor: 259 | """Handles time series data preprocessing and feature calculation""" 260 | 261 | def __init__(self, sma_periods: List[int] = None): 262 | """Initialize processor with SMA periods""" 263 | self.sma_periods = sma_periods or [30, 45, 60] 264 | 265 | def prepare_dataframe(self, df: pd.DataFrame, include_volume: bool = True) -> pd.DataFrame: 266 | """Prepare and standardize dataframe for analysis""" 267 | # Convert timestamp to datetime if needed 268 | if 'datetime' not in df.columns and 'timestamp' in df.columns: 269 | df['datetime'] = pd.to_datetime(df['timestamp'], unit='s') 270 | 271 | # Set datetime as index if not already done 272 | if not isinstance(df.index, pd.DatetimeIndex): 273 | df = df.set_index('datetime') 274 | 275 | # Calculate SMAs if not already present 276 | for period in self.sma_periods: 277 | sma_column = f'sma_{period}' 278 | if sma_column not in df.columns: 279 | df[sma_column] = df['close'].rolling(window=period).mean() 280 | 281 | # Standardize column names 282 | column_mapping = { 283 | 'close': 'Close', 284 | 'open': 'Open', 285 | 'high': 'High', 286 | 'low': 'Low', 287 | 'sma_30': 'SMA_30', 288 | 'sma_45': 'SMA_45', 289 | 'sma_60': 'SMA_60' 290 | } 291 | 292 | # Add volume mapping if needed and available 293 | if include_volume and 'volume' in df.columns: 294 | column_mapping['volume'] = 'Volume' 295 | 296 | df = df.rename(columns=column_mapping) 297 | 298 | # Calculate SMA difference features 299 | df['SMA30_SMA45'] = df['SMA_30'] - df['SMA_45'] 300 | df['SMA30_SMA60'] = df['SMA_30'] - df['SMA_60'] 301 | df['SMA45_SMA60'] = df['SMA_45'] - df['SMA_60'] 302 | 303 | # Calculate price-SMA differences 304 | df['Close_SMA30'] = df['Close'] - df['SMA_30'] 305 | df['Close_SMA45'] = df['Close'] - df['SMA_45'] 306 | df['Close_SMA60'] = df['Close'] - df['SMA_60'] 307 | 308 | return df 309 | 310 | def calculate_sma_features(self, df: pd.DataFrame) -> pd.DataFrame: 311 | """Calculate SMA and related features for the dataframe""" 312 | for period in self.sma_periods: 313 | sma_col = f'SMA_{period}' 314 | if sma_col not in df.columns: 315 | df[sma_col] = df['Close'].rolling(window=period).mean() 316 | 317 | return df 318 | 319 | 320 | # ================ DTW Calculation Engine ================ 321 | 322 | class DTWCalculator: 323 | """Core DTW calculation engine supporting multiple asset types""" 324 | 325 | def __init__(self, config: TrendAnalysisConfig = None): 326 | """Initialize DTW calculator with configuration""" 327 | self.config = config or TrendAnalysisConfig() 328 | self.c_available = self._check_c_availability() 329 | 330 | def _check_c_availability(self) -> bool: 331 | """Check if C implementation of dtaidistance is available""" 332 | try: 333 | return dtw.try_import_c(verbose=False) 334 | except Exception as e: 335 | print(f"Error checking C availability: {e}") 336 | return False 337 | 338 | def normalize_features(self, df: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]: 339 | """Normalize price and difference features""" 340 | # Define feature columns 341 | price_columns = ['Close', 'SMA_30', 'SMA_45', 'SMA_60'] 342 | diff_columns = ['SMA30_SMA45', 'SMA30_SMA60', 'SMA45_SMA60'] 343 | 344 | # Normalize price features 345 | price_features = DataNormalizer.normalize_to_range(df[price_columns].values) 346 | 347 | # Normalize difference features 348 | diff_features = DataNormalizer.normalize_to_range(df[diff_columns].values) 349 | 350 | return price_features, diff_features 351 | 352 | def calculate_dtw_similarity(self, query_series: np.ndarray, target_series: np.ndarray, 353 | window_ratio: float, max_point_distance: float) -> Tuple[float, float, List]: 354 | """Calculate DTW similarity between two series""" 355 | # Calculate window size 356 | window_size = int(max(len(query_series), len(target_series)) * window_ratio) 357 | 358 | # Calculate DTW 359 | try: 360 | distance, paths = dtw_ndim.warping_paths( 361 | query_series, 362 | target_series, 363 | window=window_size, 364 | use_c=self.c_available, 365 | max_step=max_point_distance 366 | ) 367 | 368 | # If distance is inf (no valid path), return low similarity 369 | if np.isinf(distance): 370 | return 0.0, float('inf'), [] 371 | 372 | # Find best path 373 | path = dtw.best_path(paths) 374 | 375 | # Convert distance to similarity score 376 | similarity = 1 / (1 + distance) 377 | 378 | return similarity, distance, path 379 | 380 | except Exception as e: 381 | print(f"Error in DTW calculation: {e}") 382 | return 0.0, float('inf'), [] 383 | 384 | def calculate_shapedtw(self, query_series: np.ndarray, target_series: np.ndarray, 385 | shape_descriptor, window_ratio: float, subsequence_width: int = 5) -> Tuple[float, List]: 386 | """Calculate Shape DTW between two series""" 387 | try: 388 | # Calculate window size 389 | window_size = int(max(len(query_series), len(target_series)) * window_ratio) 390 | 391 | # Calculate ShapeDTW 392 | shape_dtw_results = shape_dtw( 393 | x=query_series, 394 | y=target_series, 395 | step_pattern="asymmetric", 396 | open_begin=False, 397 | subsequence_width=subsequence_width, 398 | shape_descriptor=shape_descriptor, 399 | multivariate_version="dependent", 400 | window_type="sakoechiba", 401 | window_args={"window_size": window_size}, 402 | ) 403 | 404 | # Extract results 405 | distance = shape_dtw_results.shape_normalized_distance 406 | path = list(zip(shape_dtw_results.index1, shape_dtw_results.index2)) 407 | 408 | return distance, path 409 | 410 | except Exception as e: 411 | print(f"Error in ShapeDTW calculation: {e}") 412 | return float('inf'), [] 413 | 414 | def create_shape_descriptors(self) -> Tuple: 415 | """Create shape descriptors for price and difference features""" 416 | # For price features - emphasis on raw shape 417 | price_descriptor = CompoundDescriptor( 418 | [RawSubsequenceDescriptor(), SlopeDescriptor(slope_window=self.config.slope_window_size)], 419 | descriptors_weights=[4.0, 1.0] 420 | ) 421 | 422 | # For difference features - emphasis on slope and patterns 423 | diff_descriptor = CompoundDescriptor( 424 | [SlopeDescriptor(slope_window=self.config.slope_window_size), 425 | PAADescriptor(piecewise_aggregation_window=self.config.paa_window_size)], 426 | descriptors_weights=[3.0, 1.0] 427 | ) 428 | 429 | return price_descriptor, diff_descriptor 430 | 431 | 432 | # ================ Visualization Functions ================ 433 | 434 | def plot_candlesticks_with_volume(ax: plt.Axes, df: pd.DataFrame, width_factor: float = 0.6, volume_ratio: float = 0.15): 435 | """Plot candlestick chart with volume bars in same subplot""" 436 | if len(df) <= 1: 437 | print("Not enough data points to plot candlesticks") 438 | return 439 | 440 | # Ensure required columns exist 441 | required_cols = ['Open', 'High', 'Low', 'Close'] 442 | if not all(col in df.columns for col in required_cols): 443 | print("Missing required OHLC columns") 444 | return 445 | 446 | # Calculate appropriate width for candlesticks in days 447 | time_diff = (df.index[1] - df.index[0]).total_seconds() / 86400 # Convert to days 448 | width = time_diff * width_factor 449 | 450 | # Define colors (green up, red down) 451 | up_color = 'green' 452 | down_color = 'red' 453 | 454 | # Get price range for proper scaling 455 | price_min = df[['Low']].min().iloc[0] 456 | price_max = df[['High']].max().iloc[0] 457 | price_range = price_max - price_min 458 | 459 | # Calculate volume range and normalization if volume exists 460 | has_volume = 'Volume' in df.columns 461 | if has_volume: 462 | volume_min = 0 # 因為Volume已經標準化到[0,1],所以最小值是0 463 | volume_max = df['Volume'].max() 464 | if volume_max > volume_min: 465 | volume_height = price_range * volume_ratio 466 | volume_base = price_min - price_range * 0.1 # Gap from price data 467 | scaled_volume = df['Volume'] * volume_height 468 | else: 469 | has_volume = False 470 | 471 | # Plot candlesticks 472 | for timestamp, row in df.iterrows(): 473 | open_price = row['Open'] 474 | high_price = row['High'] 475 | low_price = row['Low'] 476 | close_price = row['Close'] 477 | 478 | # Determine if it's an up or down candle 479 | is_upward_candle = close_price >= open_price 480 | color = up_color if is_upward_candle else down_color 481 | 482 | # Draw the high-low line (wick) 483 | ax.plot([timestamp, timestamp], [low_price, high_price], 484 | color=color, linewidth=1, alpha=0.8) 485 | 486 | # Calculate rectangle coordinates in timestamp space 487 | half_width_timedelta = pd.Timedelta(days=width/2) 488 | 489 | # Draw the open-close rectangle (body) 490 | if is_upward_candle: 491 | rect_bottom = open_price 492 | rect_height = close_price - open_price 493 | else: 494 | rect_bottom = close_price 495 | rect_height = open_price - close_price 496 | 497 | # Create rectangle for the body 498 | rect = Rectangle((timestamp - half_width_timedelta, rect_bottom), 499 | pd.Timedelta(days=width), rect_height, 500 | facecolor=color, edgecolor=color, alpha=0.8) 501 | ax.add_patch(rect) 502 | 503 | # Plot volume bars if available 504 | if has_volume: 505 | volume_value = scaled_volume.loc[timestamp] 506 | volume_rect = Rectangle((timestamp - half_width_timedelta, volume_base), 507 | pd.Timedelta(days=width), volume_value, 508 | facecolor=color, edgecolor=color, alpha=0.5) 509 | ax.add_patch(volume_rect) 510 | 511 | # Set y-axis limits to accommodate both price and volume 512 | if has_volume: 513 | y_bottom = volume_base - volume_height * 0.05 514 | else: 515 | y_bottom = price_min - price_range * 0.05 516 | 517 | y_top = price_max + price_range * 0.05 518 | ax.set_ylim(y_bottom, y_top) 519 | 520 | # ================ Data Management Classes ================ 521 | 522 | class DataCacheManager: 523 | """Manages data caching for timeframe data""" 524 | 525 | @staticmethod 526 | def get_timeframe_cache_path(output_dir: str, timeframe: str) -> str: 527 | """Generate cache filename for timeframe data""" 528 | return os.path.join(output_dir, f"all_symbols_{timeframe}.pkl") 529 | 530 | @staticmethod 531 | def download_timeframe_data(timeframe: str, output_dir: str, config: TrendAnalysisConfig, 532 | historical_start_date: datetime, data_processor) -> dict: 533 | """Download and cache data for all symbols of a specific timeframe""" 534 | # Check for cached data 535 | cache_file = DataCacheManager.get_timeframe_cache_path(output_dir, timeframe) 536 | 537 | if os.path.exists(cache_file): 538 | print(f"Loading cached data for timeframe {timeframe}...") 539 | data_dict = FileManager.load_from_cache(cache_file) 540 | if data_dict is not None: 541 | print(f"Loaded cached data for {len(data_dict)} symbols in timeframe {timeframe}") 542 | return data_dict 543 | 544 | print(f"No cached data found for timeframe {timeframe}, downloading...") 545 | 546 | # Get all available symbols 547 | all_symbols = data_processor.downloader.get_all_symbols() 548 | symbols = [s.replace('USDT', '') for s in all_symbols] 549 | print(f"Found {len(symbols)} available symbols") 550 | 551 | # Convert start date to timestamp 552 | start_timestamp = int(historical_start_date.timestamp()) 553 | 554 | # Get current time as end timestamp 555 | end_timestamp = int(time.time()) 556 | 557 | # Download data for all symbols 558 | data_dict = {} 559 | 560 | print(f"Downloading data for timeframe {timeframe} from {historical_start_date} to now...") 561 | 562 | for symbol in symbols: 563 | print(f"Downloading data for {symbol} ({timeframe})...") 564 | 565 | # Get data 566 | df = data_processor.get_data( 567 | symbol, 568 | timeframe, 569 | start_timestamp, 570 | end_timestamp 571 | ) 572 | 573 | if not df.empty: 574 | data_dict[symbol] = df 575 | print(f"Downloaded {len(df)} data points for {symbol} ({timeframe})") 576 | else: 577 | print(f"Failed to download data for {symbol} ({timeframe})") 578 | data_dict[symbol] = None 579 | 580 | # Sleep to avoid API rate limits 581 | time.sleep(config.api_sleep_seconds) 582 | 583 | # Save to cache 584 | FileManager.save_to_cache(data_dict, cache_file) 585 | 586 | print(f"Saved data for timeframe {timeframe} to cache") 587 | 588 | return data_dict 589 | 590 | 591 | class ReferenceDataManager: 592 | """Manages reference trend data loading and visualization""" 593 | 594 | @staticmethod 595 | def get_reference_cache_path(output_dir: str, symbol: str, timeframe: str, label: str, start_ts: int, end_ts: int) -> str: 596 | """Generate cache path for reference data""" 597 | reference_dir = os.path.join(output_dir, "reference") 598 | return os.path.join(reference_dir, f"ref_{symbol}_{timeframe}_{label}_{start_ts}_{end_ts}.pkl") 599 | 600 | @staticmethod 601 | def load_or_fetch_reference_data(symbol: str, start_datetime: datetime, end_datetime: datetime, 602 | timeframe: str, label: str, output_dir: str, timezone_name: str, 603 | data_processor, config: TrendAnalysisConfig) -> pd.DataFrame: 604 | """Load or fetch reference trend data with unified caching""" 605 | print(f"Loading reference trend for {symbol} ({timeframe}) from {start_datetime} to {end_datetime}...") 606 | 607 | # Convert datetime to timestamp 608 | start_ts = convert_datetime_to_timestamp(start_datetime, timezone_name) 609 | end_ts = convert_datetime_to_timestamp(end_datetime, timezone_name) 610 | 611 | # Create reference directory 612 | reference_dir = os.path.join(output_dir, "reference") 613 | FileManager.ensure_directories(reference_dir) 614 | 615 | # Cache file path 616 | cache_file = ReferenceDataManager.get_reference_cache_path(output_dir, symbol, timeframe, label, start_ts, end_ts) 617 | 618 | # Check if cache exists 619 | reference_data = FileManager.load_from_cache(cache_file) 620 | if reference_data is not None: 621 | print(f"Loading cached reference data for {symbol}...") 622 | 623 | # Create reference visualization if not already done 624 | viz_file = os.path.join(reference_dir, f"ref_{symbol}_{timeframe}_{label}_{start_ts}_{end_ts}.png") 625 | if not os.path.exists(viz_file): 626 | ReferenceDataManager.create_reference_visualization( 627 | reference_data['df'], 628 | reference_data.get('past_df'), 629 | reference_data.get('future_df'), 630 | symbol, timeframe, label, viz_file, 631 | timezone_name 632 | ) 633 | 634 | return reference_data['df'] # Return only reference data for comparison 635 | 636 | # Calculate extended period for past and future 637 | time_difference = end_ts - start_ts 638 | extended_start_ts = start_ts - time_difference * 1.0 # 1x past 639 | extended_end_ts = end_ts + time_difference * 2.0 # 2x future 640 | 641 | # Get extended data (past + reference + future) 642 | extended_df = data_processor.get_data( 643 | symbol, 644 | timeframe, 645 | extended_start_ts, 646 | extended_end_ts, 647 | include_buffer=False, 648 | is_reference=True 649 | ) 650 | 651 | if extended_df.empty: 652 | print(f"Failed to get extended data for {symbol} at {timeframe}") 653 | return None 654 | 655 | # Split data into past, reference, and future 656 | reference_start_time = pd.Timestamp.fromtimestamp(start_ts, tz='UTC') 657 | reference_end_time = pd.Timestamp.fromtimestamp(end_ts, tz='UTC') 658 | 659 | if extended_df.index.tz is None: 660 | extended_df.index = pd.to_datetime(extended_df.index, utc=True) 661 | elif extended_df.index.tz != pd.Timestamp.now(tz='UTC').tz: 662 | extended_df.index = extended_df.index.tz_convert('UTC') 663 | 664 | past_df = extended_df[extended_df.index < reference_start_time] 665 | reference_df = extended_df[(extended_df.index >= reference_start_time) & 666 | (extended_df.index <= reference_end_time)] 667 | future_df = extended_df[extended_df.index > reference_end_time] 668 | 669 | # Create reference data object to cache 670 | reference_data = { 671 | 'df': reference_df, # Only reference data for comparison 672 | 'past_df': past_df if not past_df.empty else None, 673 | 'future_df': future_df if not future_df.empty else None 674 | } 675 | 676 | # Save to cache 677 | FileManager.save_to_cache(reference_data, cache_file) 678 | 679 | # Create reference visualization with past + reference + future 680 | viz_file = os.path.join(reference_dir, f"ref_{symbol}_{timeframe}_{label}_{start_ts}_{end_ts}.png") 681 | ReferenceDataManager.create_reference_visualization( 682 | reference_df, past_df, future_df, symbol, timeframe, label, viz_file 683 | ) 684 | 685 | print(f"Saved reference data for {symbol} with {len(reference_df)} data points") 686 | return reference_df # Return only reference data for comparison 687 | 688 | 689 | @staticmethod 690 | def create_reference_visualization(reference_df: pd.DataFrame, past_df: pd.DataFrame, 691 | future_df: pd.DataFrame, symbol: str, timeframe: str, 692 | label: str, output_path: str, timezone_name: str = "UTC"): 693 | """Create reference visualization with two subplots: reference only and past + reference + future""" 694 | try: 695 | # Create figure with two subplots 696 | fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(20, 16), gridspec_kw={'height_ratios': [1, 1]}) 697 | 698 | # Normalize reference data (include volume) 699 | reference_normalized_df, _ = DataNormalizer.normalize_ohlc_dataframe(reference_df, include_volume=True) 700 | 701 | # Plot 1: Reference trend only with volume 702 | plot_candlesticks_with_volume(ax1, reference_normalized_df, volume_ratio=0.12) 703 | ax1.plot(reference_normalized_df.index, reference_normalized_df['SMA_30'], 'blue', linewidth=2, alpha=0.8, label='SMA30') 704 | ax1.plot(reference_normalized_df.index, reference_normalized_df['SMA_45'], 'orange', linewidth=2, alpha=0.8, label='SMA45') 705 | ax1.plot(reference_normalized_df.index, reference_normalized_df['SMA_60'], 'purple', linewidth=2, alpha=0.8, label='SMA60') 706 | ax1.set_title(f'Reference Trend: {symbol} ({timeframe}, {label})', fontsize=14) 707 | ax1.set_ylabel('Normalized Price [-1, 1]', fontsize=12) 708 | ax1.set_ylim(-1.2, 1.2) 709 | ax1.legend(loc='upper left') 710 | ax1.grid(True, alpha=0.3) 711 | 712 | # Plot 2: Past + Reference + Future 713 | # Combine all available data 714 | combined_parts = [] 715 | if past_df is not None and not past_df.empty: 716 | # Take last 1x reference length of past data 717 | past_length = len(reference_df) 718 | if len(past_df) >= past_length: 719 | past_data = past_df.iloc[-past_length:] 720 | else: 721 | past_data = past_df 722 | combined_parts.append(past_data) 723 | 724 | combined_parts.append(reference_df) 725 | 726 | if future_df is not None and not future_df.empty: 727 | # Take first 2x reference length of future data 728 | future_length = len(reference_df) * 2 729 | if len(future_df) >= future_length: 730 | future_data = future_df.iloc[:future_length] 731 | else: 732 | future_data = future_df 733 | combined_parts.append(future_data) 734 | 735 | if combined_parts: 736 | combined_df = pd.concat(combined_parts) 737 | 738 | # Use reference normalization parameters for entire combined data 739 | ref_ohlc = reference_df[['Open', 'High', 'Low', 'Close']].values 740 | ref_norm_params = DataNormalizer.calculate_normalization_params(ref_ohlc, (-1, 1)) 741 | 742 | # Apply normalization to combined OHLC data 743 | combined_ohlc = combined_df[['Open', 'High', 'Low', 'Close']].values 744 | combined_normalized = DataNormalizer.apply_normalization_params(combined_ohlc, ref_norm_params) 745 | 746 | combined_normalized_df = combined_df.copy() 747 | for i, column in enumerate(['Open', 'High', 'Low', 'Close']): 748 | combined_normalized_df[column] = combined_normalized[:, i] 749 | 750 | # Normalize SMA columns using reference parameters 751 | sma_columns = ['SMA_30', 'SMA_45', 'SMA_60'] 752 | for column in sma_columns: 753 | if column in combined_normalized_df.columns: 754 | combined_normalized_df[column] = DataNormalizer.apply_normalization_params( 755 | combined_normalized_df[column].values.reshape(-1, 1), ref_norm_params 756 | ).flatten() 757 | 758 | # Separately normalize Volume 759 | if 'Volume' in combined_df.columns: 760 | volume_values = combined_df['Volume'].values.reshape(-1, 1) 761 | volume_norm_params = DataNormalizer.calculate_normalization_params(volume_values, (0, 1)) 762 | normalized_volume = DataNormalizer.apply_normalization_params(volume_values, volume_norm_params) 763 | combined_normalized_df['Volume'] = normalized_volume.flatten() 764 | 765 | plot_candlesticks_with_volume(ax2, combined_normalized_df, volume_ratio=0.12) 766 | ax2.plot(combined_normalized_df.index, combined_normalized_df['SMA_30'], 'blue', linewidth=2, alpha=0.8, label='SMA30') 767 | ax2.plot(combined_normalized_df.index, combined_normalized_df['SMA_45'], 'orange', linewidth=2, alpha=0.8, label='SMA45') 768 | ax2.plot(combined_normalized_df.index, combined_normalized_df['SMA_60'], 'purple', linewidth=2, alpha=0.8, label='SMA60') 769 | 770 | # Add vertical lines to mark reference boundaries 771 | reference_start = reference_df.index[0] 772 | reference_end = reference_df.index[-1] 773 | ax2.axvline(x=reference_start, color='blue', linestyle='--', linewidth=2, alpha=0.8, label='Reference Start') 774 | ax2.axvline(x=reference_end, color='red', linestyle='--', linewidth=2, alpha=0.8, label='Reference End') 775 | 776 | ax2.set_title(f'Extended View: {symbol} - Past + Reference + Future', fontsize=14) 777 | ax2.set_ylabel('Normalized Price (ref range: [-1, 1])', fontsize=12) 778 | ax2.legend(loc='upper left') 779 | ax2.grid(True, alpha=0.3) 780 | 781 | # Set y-axis range dynamically for combined data 782 | combined_values = combined_normalized_df[['Open', 'High', 'Low', 'Close']].values.flatten() 783 | y_min, y_max = np.min(combined_values), np.max(combined_values) 784 | y_padding = (y_max - y_min) * 0.1 if y_max > y_min else 0.1 785 | ax2.set_ylim(y_min - y_padding, y_max + y_padding) 786 | else: 787 | ax2.text(0.5, 0.5, 'No extended data available', 788 | ha='center', va='center', transform=ax2.transAxes, fontsize=14) 789 | ax2.set_title('Extended View: No Data', fontsize=14) 790 | 791 | # Format date ticks 792 | for ax in [ax1, ax2]: 793 | ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d')) 794 | plt.setp(ax.xaxis.get_majorticklabels(), rotation=45, ha='right') 795 | ax.tick_params(axis='both', which='major', labelsize=10) 796 | 797 | # Add info textbox 798 | past_length = len(past_data) if past_df is not None and not past_df.empty else 0 799 | future_length = len(future_data) if future_df is not None and not future_df.empty else 0 800 | 801 | info_text = ( 802 | f"Symbol: {symbol}\n" 803 | f"Timeframe: {timeframe}\n" 804 | f"Label: {label}\n" 805 | f"Reference Period: {format_dt_with_tz(reference_df.index[0], timezone_name)} to {format_dt_with_tz(reference_df.index[-1], timezone_name)}\n" 806 | f"Data Points: {len(reference_df)}\n" 807 | f"Extended View:\n" 808 | f" Past: {past_length} bars\n" 809 | f" Reference: {len(reference_df)} bars\n" 810 | f" Future: {future_length} bars" 811 | ) 812 | 813 | plt.figtext(0.02, 0.02, info_text, fontsize=10, 814 | bbox=dict(facecolor='white', alpha=0.8, boxstyle='round,pad=0.5')) 815 | 816 | # Adjust layout 817 | plt.tight_layout(rect=[0, 0.08, 1, 0.95]) 818 | plt.subplots_adjust(hspace=0.2) 819 | 820 | # Save and close 821 | plt.savefig(output_path, dpi=200, bbox_inches='tight') 822 | plt.close(fig) 823 | 824 | print(f"Saved reference visualization to {output_path}") 825 | except Exception as e: 826 | print(f"Error in reference visualization: {e}") 827 | 828 | 829 | # ================ File and Directory Management ================ 830 | 831 | class FileManager: 832 | """Handles file operations and directory management""" 833 | 834 | @staticmethod 835 | def ensure_directories(*dirs: str) -> None: 836 | """Create directories if they don't exist""" 837 | for directory in dirs: 838 | os.makedirs(directory, exist_ok=True) 839 | 840 | @staticmethod 841 | def get_cache_filename(base_dir: str, prefix: str, **kwargs) -> str: 842 | """Generate cache filename with parameters""" 843 | parts = [prefix] 844 | for key, value in kwargs.items(): 845 | parts.append(f"{key}_{value}") 846 | 847 | filename = "_".join(parts) + ".pkl" 848 | return os.path.join(base_dir, filename) 849 | 850 | @staticmethod 851 | def save_to_cache(data: any, filepath: str) -> None: 852 | """Save data to cache file""" 853 | os.makedirs(os.path.dirname(filepath), exist_ok=True) 854 | with open(filepath, 'wb') as f: 855 | pickle.dump(data, f) 856 | 857 | @staticmethod 858 | def load_from_cache(filepath: str) -> any: 859 | """Load data from cache file""" 860 | if os.path.exists(filepath): 861 | with open(filepath, 'rb') as f: 862 | return pickle.load(f) 863 | return None 864 | 865 | 866 | # ================ Abstract Base Classes ================ 867 | 868 | class BaseDataDownloader(ABC): 869 | """Abstract base class for data downloaders""" 870 | 871 | @abstractmethod 872 | def get_data(self, symbol: str, start_timestamp: int, end_timestamp: int, 873 | timeframe: str = "1h", **kwargs) -> Tuple[bool, pd.DataFrame]: 874 | """Get data for a symbol""" 875 | pass 876 | 877 | @abstractmethod 878 | def get_all_symbols(self) -> List[str]: 879 | """Get all available symbols""" 880 | pass 881 | 882 | 883 | class BaseDataProcessor(ABC): 884 | """Abstract base class for data processors""" 885 | 886 | def __init__(self, asset_type: str, sma_periods: List[int] = None): 887 | """Initialize data processor""" 888 | self.asset_type = asset_type 889 | self.processor = TimeSeriesProcessor(sma_periods) 890 | 891 | @abstractmethod 892 | def get_data(self, symbol: str, timeframe: str, start_ts: int, end_ts: int, **kwargs) -> pd.DataFrame: 893 | """Get and process data for a symbol""" 894 | pass 895 | 896 | 897 | # ================ Utility Functions ================ 898 | 899 | def parse_target_symbols(filepath: str, target_section: str = "###TARGETS") -> List[str]: 900 | """Parse target symbols from file""" 901 | if not os.path.exists(filepath): 902 | print(f"Target file not found: {filepath}") 903 | return [] 904 | 905 | targets = [] 906 | with open(filepath, 'r') as f: 907 | lines = f.readlines() 908 | 909 | target_section_found = False 910 | for line in lines: 911 | if target_section in line: 912 | target_section_found = True 913 | continue 914 | if target_section_found and line.strip(): 915 | symbols = line.strip().split(',') 916 | targets.extend([s.split(':')[-1].replace('USDT.P', '').strip() 917 | for s in symbols if s.strip()]) 918 | 919 | if not targets: 920 | print(f"No valid targets found in {filepath}") 921 | return [] 922 | 923 | return targets 924 | 925 | 926 | def create_output_directory(base_dir: str, prefix: str = "") -> str: 927 | """Create timestamped output directory""" 928 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 929 | dir_name = f"{prefix}_{timestamp}" if prefix else timestamp 930 | output_path = os.path.join(base_dir, dir_name) 931 | FileManager.ensure_directories(output_path) 932 | return output_path 933 | 934 | 935 | def get_period_overlap(period1: Tuple[datetime, datetime], period2: Tuple[datetime, datetime]) -> bool: 936 | """Check if two time periods overlap""" 937 | start1, end1 = period1 938 | start2, end2 = period2 939 | 940 | return (start1 <= end2) and (start2 <= end1) 941 | 942 | 943 | def filter_non_overlapping_results(results: List[Dict], global_filtering: bool = True) -> List[Dict]: 944 | """Filter results to keep only non-overlapping periods with best scores""" 945 | if not results: 946 | return [] 947 | 948 | if global_filtering: 949 | # Global filtering: no overlaps across all symbols 950 | sorted_results = sorted(results, key=lambda x: x['similarity'], reverse=True) 951 | selected_results = [] 952 | 953 | for result in sorted_results: 954 | window_data = result.get('window_data') 955 | if window_data is None: 956 | continue 957 | 958 | current_period = (window_data.index[0], window_data.index[-1]) 959 | 960 | # Check overlap with already selected periods 961 | has_overlap = False 962 | for selected_result in selected_results: 963 | selected_data = selected_result.get('window_data') 964 | if selected_data is None: 965 | continue 966 | selected_period = (selected_data.index[0], selected_data.index[-1]) 967 | 968 | if get_period_overlap(current_period, selected_period): 969 | has_overlap = True 970 | break 971 | 972 | # If no overlap, add to selected 973 | if not has_overlap: 974 | selected_results.append(result) 975 | 976 | return selected_results 977 | 978 | else: 979 | # Per-symbol filtering: filter within each symbol independently 980 | symbol_groups = {} 981 | for result in results: 982 | symbol = result.get('symbol', 'unknown') 983 | if symbol not in symbol_groups: 984 | symbol_groups[symbol] = [] 985 | symbol_groups[symbol].append(result) 986 | 987 | # Filter each symbol group independently 988 | all_selected_results = [] 989 | for symbol, symbol_results in symbol_groups.items(): 990 | sorted_symbol_results = sorted(symbol_results, key=lambda x: x['similarity'], reverse=True) 991 | symbol_selected_results = [] 992 | 993 | for result in sorted_symbol_results: 994 | window_data = result.get('window_data') 995 | if window_data is None: 996 | continue 997 | 998 | current_period = (window_data.index[0], window_data.index[-1]) 999 | 1000 | # Check overlap with already selected periods for this symbol 1001 | has_overlap = False 1002 | for selected_result in symbol_selected_results: 1003 | selected_data = selected_result.get('window_data') 1004 | if selected_data is None: 1005 | continue 1006 | selected_period = (selected_data.index[0], selected_data.index[-1]) 1007 | 1008 | if get_period_overlap(current_period, selected_period): 1009 | has_overlap = True 1010 | break 1011 | 1012 | # If no overlap, add to selected for this symbol 1013 | if not has_overlap: 1014 | symbol_selected_results.append(result) 1015 | 1016 | # Add all selected results from this symbol to the overall list 1017 | all_selected_results.extend(symbol_selected_results) 1018 | 1019 | # Sort the final list by similarity score 1020 | all_selected_results.sort(key=lambda x: x['similarity'], reverse=True) 1021 | 1022 | return all_selected_results 1023 | 1024 | 1025 | # ================ Export Functions ================ 1026 | 1027 | __all__ = [ 1028 | 'TrendAnalysisConfig', 1029 | 'DataNormalizer', 1030 | 'TimeSeriesProcessor', 1031 | 'DTWCalculator', 1032 | 'FileManager', 1033 | 'DataCacheManager', 1034 | 'ReferenceDataManager', 1035 | 'BaseDataDownloader', 1036 | 'BaseDataProcessor', 1037 | 'plot_candlesticks_with_volume', 1038 | 'calculate_timeframe_seconds', 1039 | 'convert_datetime_to_timestamp', 1040 | 'format_dt_with_tz', 1041 | 'parse_target_symbols', 1042 | 'create_output_directory', 1043 | 'get_period_overlap', 1044 | 'filter_non_overlapping_results', 1045 | ] -------------------------------------------------------------------------------- /crypto_historical_trend_finder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Cryptocurrency Similar Pattern Finder using Dynamic Time Warping (DTW) 3 | ====================================================================== 4 | This script identifies cryptocurrency price patterns that are similar to predefined reference trends 5 | using Dynamic Time Warping (DTW) and Shape-based DTW algorithms. It searches historical data 6 | across multiple timeframes to find patterns matching reference trends, then analyzes their future 7 | price movements to provide statistical insights for trading strategies. 8 | 9 | KEY FEATURES: 10 | - Pattern Matching: Uses DTW and ShapeDTW algorithms to find similar price patterns 11 | - Multi-timeframe Analysis: Searches across 15m, 30m, 1h, 2h, 4h timeframes 12 | - Future Trend Prediction: Analyzes price movements after pattern completion (rise/fall statistics) 13 | - Comprehensive Visualization: Generates detailed candlestick charts with volume and moving averages 14 | - Statistical Analysis: Provides detailed statistics for different extension factors (0.25x to 2.5x) 15 | - Parallel Processing: Uses multiprocessing for efficient computation 16 | - Data Caching: Caches downloaded data by timeframe to avoid redundant API calls 17 | - Non-overlapping Filtering: Removes overlapping patterns for cleaner analysis 18 | - Extensible Configuration: Easy to add new reference trends and adjust parameters 19 | 20 | USAGE: 21 | # Basic usage with default parameters 22 | python crypto_historical_trend_finder.py 23 | 24 | # Custom parameters 25 | python crypto_historical_trend_finder.py -k 500 -s 10 26 | 27 | Example workflow: 28 | 1. Define reference trends (e.g., AVAX uptrend from Nov 9-14, 2023) 29 | 2. Script downloads historical data for 200+ cryptocurrencies 30 | 3. Finds similar patterns using DTW similarity matching 31 | 4. Analyzes future price movements after each pattern 32 | 5. Generates visualizations and statistical reports 33 | 6. Output: "75% of similar patterns resulted in price rises within 2x pattern length" 34 | """ 35 | 36 | import os 37 | import time 38 | import numpy as np 39 | import pandas as pd 40 | from datetime import datetime 41 | from multiprocessing import Pool, cpu_count 42 | import matplotlib.pyplot as plt 43 | import matplotlib.dates as mdates 44 | import argparse 45 | from src.downloader import CryptoDownloader 46 | from src.common import ( 47 | TrendAnalysisConfig, 48 | DataNormalizer, 49 | DTWCalculator, 50 | FileManager, 51 | ReferenceDataManager, 52 | DataCacheManager, 53 | BaseDataProcessor, 54 | create_output_directory, 55 | filter_non_overlapping_results, 56 | plot_candlesticks_with_volume, 57 | format_dt_with_tz 58 | ) 59 | 60 | # ================ Configuration ================ 61 | # Reference trends definition 62 | REFERENCE_TRENDS = { 63 | "AVAX": [ 64 | [datetime(2023, 11, 9, 12, 0), datetime(2023, 11, 14, 18, 0), "1h", "standard"], 65 | ], 66 | "MKR": [ 67 | [datetime(2023, 6, 24, 9, 0), datetime(2023, 7, 18, 5, 0), "4h", "standard"], 68 | ], 69 | "CRV": [ 70 | [datetime(2024, 10, 23, 1, 0), datetime(2024, 11, 24, 0, 0), "4h", "uptrend"], 71 | [datetime(2024, 11, 4, 0, 0), datetime(2024, 11, 21, 0, 0), "4h", "uptrend_1"], 72 | [datetime(2024, 11, 4, 0, 0), datetime(2024, 11, 29, 0, 0), "4h", "uptrend_2"], 73 | ], 74 | "GMT": [ 75 | [datetime(2022, 3, 26, 9, 0), datetime(2022, 4, 14, 21, 0), "4h", "uptrend"] 76 | ], 77 | "SOL": [ 78 | [datetime(2023, 9, 21, 1, 0), datetime(2023, 10, 15, 21, 0), "4h", "standard"] 79 | ], 80 | "LQTY": [ 81 | [datetime(2025, 5, 7, 5, 0), datetime(2025, 5, 9, 21, 0), "30m", "standard"] 82 | ], 83 | "MOODENG": [ 84 | [datetime(2025, 5, 8, 0, 0), datetime(2025, 5, 11, 1, 0), "1h", "standard"] 85 | ], 86 | } 87 | 88 | # Historical starting point (used only if no cached data exists) 89 | HISTORICAL_START_DATE = datetime(2021, 1, 1) 90 | 91 | # Timezone for datetime conversion 92 | TIMEZONE = "America/Los_Angeles" 93 | 94 | # Timeframes to analyze 95 | TIMEFRAMES_TO_ANALYZE = ["15m", "30m", "1h", "2h", "4h"] 96 | 97 | # Main output directory 98 | OUTPUT_DIR = "historical_trend_finder_reports" 99 | 100 | # Top K results to keep per reference trend 101 | TOP_K = 300 102 | 103 | # API request parameters 104 | API_SLEEP_SECONDS = 15 105 | 106 | # Overlap filtering strategy 107 | # True: Global filtering - no overlaps across all symbols 108 | # False: Per-symbol filtering - allow overlaps between different symbols 109 | GLOBAL_OVERLAP_FILTERING = True 110 | 111 | # DTW parameters 112 | DTW_WINDOW_RATIO = 0.12 113 | DTW_MAX_POINT_DISTANCE = 0.6 114 | DTW_WINDOW_RATIO_FOR_DIFF = 0.1 115 | DTW_MAX_POINT_DISTANCE_FOR_DIFF = 0.5 116 | 117 | # ShapeDTW parameters 118 | SHAPEDTW_BALANCE_PD_RATIO = 4 119 | PRICE_WEIGHT = 0.6 120 | DIFF_WEIGHT = 0.4 121 | SLOPE_WINDOW_SIZE = 5 122 | PAA_WINDOW_SIZE = 5 123 | 124 | # Window scaling factors to test 125 | WINDOW_SCALE_FACTORS = [0.9, 0.95, 1.0, 1.05, 1.1] 126 | 127 | # SMA periods for comparison 128 | SMA_PERIODS = [30, 45, 60] 129 | 130 | # Sliding window step size (as fraction of reference trend length) 131 | SLIDING_WINDOW_STEP_RATIO = 0.11 132 | 133 | # Minimum similarity score to consider (score threshold) 134 | MIN_SIMILARITY_SCORE = 0.25 135 | 136 | # Extension visualization parameters 137 | VIS_EXTENSION_PAST_LENGTH_FACTOR = 1.0 138 | VIS_EXTENSION_FUTURE_LENGTH_FACTOR = 2.0 139 | 140 | # Extension factors for statistics 141 | EXTENSION_FACTORS_FOR_STATS = [0.25, 0.5, 0.75, 1.0, 1.5, 2.0, 2.5] 142 | 143 | # ================ Data Processing Classes ================ 144 | 145 | class DataProcessor(BaseDataProcessor): 146 | """Data processor for cryptocurrency analysis with caching support""" 147 | 148 | def __init__(self, config: TrendAnalysisConfig = None): 149 | """Initialize data processor for crypto analysis""" 150 | super().__init__("crypto", config.sma_periods if config else None) 151 | self.downloader = CryptoDownloader() 152 | self.config = config or TrendAnalysisConfig() 153 | 154 | def get_data(self, symbol: str, timeframe: str, start_ts: int, end_ts: int, 155 | include_buffer: bool = True, is_reference: bool = False) -> pd.DataFrame: 156 | """Get data with buffer period for SMA calculation""" 157 | if include_buffer: 158 | # Calculate buffer period for SMA calculation 159 | interval = end_ts - start_ts 160 | buffer_start_ts = start_ts - interval 161 | else: 162 | buffer_start_ts = start_ts 163 | 164 | # For crypto, add USDT if not already there 165 | if not symbol.endswith("USDT"): 166 | symbol_full = f"{symbol}USDT" 167 | else: 168 | symbol_full = symbol 169 | 170 | # Set validate=False for reference trends, otherwise use default (True) 171 | success, df = self.downloader.get_data( 172 | symbol_full, 173 | buffer_start_ts, 174 | end_ts, 175 | validate=not is_reference, # Disable validation for reference trends 176 | timeframe=timeframe 177 | ) 178 | 179 | if not success or df is None or df.empty: 180 | print(f"Failed to get data for {symbol} ({timeframe})") 181 | return pd.DataFrame() 182 | 183 | # Filter to requested time range 184 | start_time = pd.Timestamp.fromtimestamp(start_ts) 185 | end_time = pd.Timestamp.fromtimestamp(end_ts) 186 | 187 | # Use the processor from common to prepare the dataframe 188 | df = self.processor.prepare_dataframe(df) 189 | 190 | # Filter to requested time range after preparation 191 | df = df[(df.index >= start_time) & (df.index <= end_time)] 192 | 193 | return df 194 | 195 | 196 | # ================ DTW Similarity Calculator ================ 197 | 198 | class DTWSimilarityFinder: 199 | """Class to find similarity using DTW and ShapeDTW""" 200 | 201 | def __init__(self, config: TrendAnalysisConfig): 202 | """Initialize similarity finder with configuration""" 203 | self.config = config 204 | self.dtw_calc = DTWCalculator(config) 205 | 206 | def find_similarity_in_window(self, reference_df: pd.DataFrame, target_df: pd.DataFrame, 207 | window_start_index: int, window_size: int) -> dict: 208 | """Find similarity between reference trend and target window""" 209 | # Check window boundaries 210 | if window_start_index + window_size > len(target_df): 211 | return { 212 | "similarity": 0.0, 213 | "price_distance": float('inf'), 214 | "diff_distance": float('inf'), 215 | "price_path": None, 216 | "diff_path": None, 217 | "window_data": None, 218 | "window_info": None 219 | } 220 | 221 | # Extract window 222 | window = target_df.iloc[window_start_index:window_start_index + window_size] 223 | 224 | # Normalize reference and window features 225 | reference_price_normalized, reference_diff_normalized = self.dtw_calc.normalize_features(reference_df) 226 | window_price_normalized, window_diff_normalized = self.dtw_calc.normalize_features(window) 227 | 228 | # Initial DTW screening for price 229 | _, price_dtw_distance, _ = self.dtw_calc.calculate_dtw_similarity( 230 | reference_price_normalized, window_price_normalized, 231 | self.config.dtw_window_ratio, self.config.dtw_max_point_distance 232 | ) 233 | 234 | # If price distance is too high, return early 235 | if np.isinf(price_dtw_distance): 236 | return { 237 | "similarity": 0.0, 238 | "price_distance": float('inf'), 239 | "diff_distance": float('inf'), 240 | "price_path": None, 241 | "diff_path": None, 242 | "window_data": None, 243 | "window_info": None 244 | } 245 | 246 | # Initial DTW screening for difference 247 | _, diff_dtw_distance, _ = self.dtw_calc.calculate_dtw_similarity( 248 | reference_diff_normalized, window_diff_normalized, 249 | self.config.dtw_window_ratio_diff, self.config.dtw_max_point_distance_diff 250 | ) 251 | 252 | # If difference distance is too high, return early 253 | if np.isinf(diff_dtw_distance): 254 | return { 255 | "similarity": 0.0, 256 | "price_distance": float('inf'), 257 | "diff_distance": float('inf'), 258 | "price_path": None, 259 | "diff_path": None, 260 | "window_data": None, 261 | "window_info": None 262 | } 263 | 264 | # Define shape descriptors 265 | price_descriptor, diff_descriptor = self.dtw_calc.create_shape_descriptors() 266 | 267 | # Calculate ShapeDTW for price 268 | price_shape_distance, price_shape_path = self.dtw_calc.calculate_shapedtw( 269 | reference_price_normalized, window_price_normalized, price_descriptor, self.config.dtw_window_ratio 270 | ) 271 | 272 | # If no valid path found, return early 273 | if np.isinf(price_shape_distance): 274 | return { 275 | "similarity": 0.0, 276 | "price_distance": float('inf'), 277 | "diff_distance": float('inf'), 278 | "price_path": None, 279 | "diff_path": None, 280 | "window_data": None, 281 | "window_info": None 282 | } 283 | 284 | # Calculate ShapeDTW for difference 285 | diff_shape_distance, diff_shape_path = self.dtw_calc.calculate_shapedtw( 286 | reference_diff_normalized, window_diff_normalized, diff_descriptor, self.config.dtw_window_ratio_diff 287 | ) 288 | 289 | # If no valid path found, return early 290 | if np.isinf(diff_shape_distance): 291 | return { 292 | "similarity": 0.0, 293 | "price_distance": float('inf'), 294 | "diff_distance": float('inf'), 295 | "price_path": None, 296 | "diff_path": None, 297 | "window_data": None, 298 | "window_info": None 299 | } 300 | 301 | # Calculate final scores 302 | price_score = 1 / (1 + price_shape_distance) 303 | diff_score = 1 / (1 + diff_shape_distance * self.config.shapedtw_balance_pd_ratio) 304 | similarity = (price_score * self.config.price_weight) + (diff_score * self.config.diff_weight) 305 | 306 | return { 307 | "similarity": similarity, 308 | "price_distance": price_shape_distance, 309 | "diff_distance": diff_shape_distance, 310 | "price_path": price_shape_path, 311 | "diff_path": diff_shape_path, 312 | "window_data": window, 313 | "window_info": (window_start_index, window_size) 314 | } 315 | 316 | def process_target(self, args: tuple) -> dict: 317 | """Process a single target symbol (for multiprocessing)""" 318 | reference_df, target_df, symbol, timeframe, reference_symbol, reference_timeframe, reference_label = args 319 | 320 | if target_df is None or len(target_df) < len(reference_df): 321 | return { 322 | "symbol": symbol, 323 | "timeframe": timeframe, 324 | "ref_symbol": reference_symbol, 325 | "ref_timeframe": reference_timeframe, 326 | "ref_label": reference_label, 327 | "result": None 328 | } 329 | 330 | print(f"Processing {symbol} ({timeframe}) against {reference_symbol} ({reference_timeframe}, {reference_label})...") 331 | 332 | # Calculate sliding window parameters 333 | reference_length = len(reference_df) 334 | step_size = max(1, int(reference_length * SLIDING_WINDOW_STEP_RATIO)) 335 | 336 | best_result = None 337 | best_similarity = -1 338 | 339 | # Sliding window approach 340 | max_start_index = len(target_df) - reference_length 341 | 342 | for start_index in range(max_start_index, 0, -step_size): 343 | # Try different window scaling factors 344 | for factor in self.config.window_scale_factors: 345 | window_size = int(reference_length * factor) 346 | 347 | # Skip if window size exceeds available data 348 | if start_index + window_size > len(target_df): 349 | continue 350 | 351 | # Calculate similarity 352 | result = self.find_similarity_in_window(reference_df, target_df, start_index, window_size) 353 | 354 | # Keep only the best result that meets minimum similarity threshold 355 | if result["similarity"] >= MIN_SIMILARITY_SCORE and result["similarity"] > best_similarity: 356 | best_similarity = result["similarity"] 357 | best_result = result 358 | 359 | # Log progress for this result 360 | window_period = ( 361 | f"{format_dt_with_tz(result['window_data'].index[0], TIMEZONE)} to {format_dt_with_tz(result['window_data'].index[-1], TIMEZONE)}" 362 | if result['window_data'] is not None else "N/A" 363 | ) 364 | 365 | print(f" New best match for {symbol}: score={result['similarity']:.4f}, " 366 | f"price_distance={result['price_distance']:.4f}, " 367 | f"diff_distance={result['diff_distance']:.4f}, " 368 | f"window={window_period}, factor={factor}") 369 | 370 | return { 371 | "symbol": symbol, 372 | "timeframe": timeframe, 373 | "ref_symbol": reference_symbol, 374 | "ref_timeframe": reference_timeframe, 375 | "ref_label": reference_label, 376 | "result": best_result 377 | } 378 | 379 | 380 | # ================ Analysis Functions ================ 381 | 382 | def analyze_future_trend(pattern_df: pd.DataFrame, target_df: pd.DataFrame, 383 | extension_factors: list = None) -> dict: 384 | """Analyze future trend for different extension factors""" 385 | if extension_factors is None: 386 | extension_factors = EXTENSION_FACTORS_FOR_STATS 387 | 388 | # Find pattern end in target data 389 | pattern_end_date = pattern_df.index[-1] 390 | future_data = target_df[target_df.index > pattern_end_date] 391 | 392 | if len(future_data) == 0: 393 | return {factor: {'trend': 'no_future_data', 'data_points': 0, 'insufficient_data': False} for factor in extension_factors} 394 | 395 | pattern_length = len(pattern_df) 396 | pattern_last_close = pattern_df['Close'].iloc[-1] 397 | 398 | results = {} 399 | 400 | for factor in extension_factors: 401 | future_length = int(pattern_length * factor) 402 | 403 | if future_length < 1: 404 | results[factor] = {'trend': 'invalid_factor', 'data_points': 0, 'insufficient_data': False} 405 | continue 406 | 407 | if future_length > len(future_data): 408 | # If requested length exceeds available data, but some data is available 409 | if len(future_data) > 0: 410 | # Use all available future data 411 | future_sample = future_data 412 | future_last_close = future_sample['Close'].iloc[-1] 413 | trend = 'rise' if future_last_close > pattern_last_close else 'fall' 414 | results[factor] = { 415 | 'trend': trend, 416 | 'data_points': len(future_sample), 417 | 'price_change': future_last_close - pattern_last_close, 418 | 'price_change_pct': ((future_last_close - pattern_last_close) / pattern_last_close) * 100, 419 | 'insufficient_data': True, # Mark insufficient data 420 | 'requested_length': future_length, 421 | 'available_length': len(future_data) 422 | } 423 | else: 424 | results[factor] = {'trend': 'no_future_data', 'data_points': 0, 'insufficient_data': False} 425 | continue 426 | 427 | future_sample = future_data.iloc[:future_length] 428 | future_last_close = future_sample['Close'].iloc[-1] 429 | 430 | trend = 'rise' if future_last_close > pattern_last_close else 'fall' 431 | results[factor] = { 432 | 'trend': trend, 433 | 'data_points': len(future_sample), 434 | 'price_change': future_last_close - pattern_last_close, 435 | 'price_change_pct': ((future_last_close - pattern_last_close) / pattern_last_close) * 100, 436 | 'insufficient_data': False # Sufficient data 437 | } 438 | 439 | return results 440 | 441 | 442 | def calculate_trend_statistics(results: list, data_dict: dict, extension_factors: list = None) -> dict: 443 | """Calculate trend statistics for a list of results""" 444 | if extension_factors is None: 445 | extension_factors = EXTENSION_FACTORS_FOR_STATS 446 | 447 | if not results: 448 | return { 449 | 'total_results': 0, 450 | 'default_factor_stats': {'rise': 0, 'fall': 0, 'insufficient_data': 0, 'no_future_data': 0}, 451 | 'extension_factor_stats': {factor: {'rise': 0, 'fall': 0, 'insufficient_data': 0, 'no_future_data': 0} for factor in extension_factors} 452 | } 453 | 454 | # Statistics for default extension factor 455 | default_stats = {'rise': 0, 'fall': 0, 'insufficient_data': 0, 'no_future_data': 0} 456 | 457 | # Statistics for different extension factors 458 | extension_stats = {factor: {'rise': 0, 'fall': 0, 'insufficient_data': 0, 'no_future_data': 0} for factor in extension_factors} 459 | 460 | for result in results: 461 | if result['window_data'] is None: 462 | continue 463 | 464 | symbol = result['symbol'] 465 | target_df = data_dict.get(symbol) 466 | 467 | if target_df is None: 468 | continue 469 | 470 | # Analyze trend for default extension factor (use VIS_EXTENSION_FUTURE_LENGTH_FACTOR) 471 | pattern_df = result['window_data'] 472 | default_trend = analyze_future_trend(pattern_df, target_df, [VIS_EXTENSION_FUTURE_LENGTH_FACTOR]) 473 | 474 | if VIS_EXTENSION_FUTURE_LENGTH_FACTOR in default_trend: 475 | trend_info = default_trend[VIS_EXTENSION_FUTURE_LENGTH_FACTOR] 476 | trend_result = trend_info['trend'] 477 | 478 | if trend_result in ['rise', 'fall']: 479 | if trend_info.get('insufficient_data', False): 480 | default_stats['insufficient_data'] += 1 481 | else: 482 | default_stats[trend_result] += 1 483 | elif trend_result == 'no_future_data': 484 | default_stats['no_future_data'] += 1 485 | else: 486 | default_stats['no_future_data'] += 1 # Other unknown cases categorized as no_future_data 487 | 488 | # Analyze trend for all extension factors 489 | all_trends = analyze_future_trend(pattern_df, target_df, extension_factors) 490 | 491 | for factor in extension_factors: 492 | if factor in all_trends: 493 | trend_info = all_trends[factor] 494 | trend_result = trend_info['trend'] 495 | 496 | if trend_result in ['rise', 'fall']: 497 | if trend_info.get('insufficient_data', False): 498 | extension_stats[factor]['insufficient_data'] += 1 499 | else: 500 | extension_stats[factor][trend_result] += 1 501 | elif trend_result == 'no_future_data': 502 | extension_stats[factor]['no_future_data'] += 1 503 | else: 504 | extension_stats[factor]['no_future_data'] += 1 # Other unknown cases 505 | 506 | return { 507 | 'total_results': len(results), 508 | 'default_factor_stats': default_stats, 509 | 'extension_factor_stats': extension_stats 510 | } 511 | 512 | 513 | def format_trend_statistics(stats: dict, factor_name: str = "Default") -> list: 514 | """Format trend statistics into readable text""" 515 | lines = [] 516 | total = stats['total_results'] 517 | 518 | if total == 0: 519 | lines.append(f"{factor_name}: No results available") 520 | return lines 521 | 522 | # Default factor statistics 523 | default_stats = stats['default_factor_stats'] 524 | rise_count = default_stats['rise'] 525 | fall_count = default_stats['fall'] 526 | insufficient_data_count = default_stats.get('insufficient_data', 0) 527 | no_future_data_count = default_stats.get('no_future_data', 0) 528 | 529 | rise_percentage = (rise_count / total) * 100 if total > 0 else 0 530 | fall_percentage = (fall_count / total) * 100 if total > 0 else 0 531 | insufficient_percentage = (insufficient_data_count / total) * 100 if total > 0 else 0 532 | no_future_percentage = (no_future_data_count / total) * 100 if total > 0 else 0 533 | 534 | lines.append(f"{factor_name} Extension Factor ({VIS_EXTENSION_FUTURE_LENGTH_FACTOR}x):") 535 | lines.append(f" Rise: {rise_count}/{total} ({rise_percentage:.1f}%)") 536 | lines.append(f" Fall: {fall_count}/{total} ({fall_percentage:.1f}%)") 537 | if insufficient_data_count > 0: 538 | lines.append(f" Insufficient Future Data: {insufficient_data_count}/{total} ({insufficient_percentage:.1f}%)") 539 | if no_future_data_count > 0: 540 | lines.append(f" No Future Data: {no_future_data_count}/{total} ({no_future_percentage:.1f}%)") 541 | 542 | # Extension factor statistics 543 | extension_stats = stats['extension_factor_stats'] 544 | lines.append(f"\nExtension Factor Analysis:") 545 | 546 | for factor in sorted(extension_stats.keys()): 547 | factor_stats = extension_stats[factor] 548 | rise_count = factor_stats['rise'] 549 | fall_count = factor_stats['fall'] 550 | insufficient_count = factor_stats.get('insufficient_data', 0) 551 | no_future_count = factor_stats.get('no_future_data', 0) 552 | 553 | rise_percentage = (rise_count / total) * 100 if total > 0 else 0 554 | fall_percentage = (fall_count / total) * 100 if total > 0 else 0 555 | insufficient_percentage = (insufficient_count / total) * 100 if total > 0 else 0 556 | no_future_percentage = (no_future_count / total) * 100 if total > 0 else 0 557 | 558 | line = f" {factor}x: Rise {rise_count}({rise_percentage:.1f}%) | Fall {fall_count}({fall_percentage:.1f}%)" 559 | if insufficient_count > 0: 560 | line += f" | Insufficient {insufficient_count}({insufficient_percentage:.1f}%)" 561 | if no_future_count > 0: 562 | line += f" | No Future {no_future_count}({no_future_percentage:.1f}%)" 563 | lines.append(line) 564 | 565 | return lines 566 | 567 | 568 | def get_trend_direction(pattern_df: pd.DataFrame, target_df: pd.DataFrame, 569 | extension_factor: float = None) -> str: 570 | """Get the trend direction for a specific extension factor""" 571 | if extension_factor is None: 572 | extension_factor = VIS_EXTENSION_FUTURE_LENGTH_FACTOR 573 | 574 | trend_analysis = analyze_future_trend(pattern_df, target_df, [extension_factor]) 575 | 576 | if extension_factor in trend_analysis: 577 | trend_info = trend_analysis[extension_factor] 578 | trend_result = trend_info['trend'] 579 | 580 | if trend_result in ['rise', 'fall']: 581 | if trend_info.get('insufficient_data', False): 582 | return f"{trend_result}_insufficient" 583 | else: 584 | return trend_result 585 | else: 586 | return trend_result 587 | 588 | return 'unknown' 589 | 590 | 591 | # ================ Visualization Functions ================ 592 | 593 | def create_full_analysis_chart(reference_df: pd.DataFrame, window_df: pd.DataFrame, target_df: pd.DataFrame, 594 | symbol: str, reference_symbol: str, timeframe: str, reference_timeframe: str, 595 | reference_label: str, similarity: float, price_distance: float, diff_distance: float, 596 | visualization_dir: str) -> str: 597 | """Create comprehensive visualization with three subplots, all with volume""" 598 | try: 599 | # Calculate extension periods 600 | pattern_length = len(window_df) 601 | past_length = int(pattern_length * VIS_EXTENSION_PAST_LENGTH_FACTOR) 602 | future_length = int(pattern_length * VIS_EXTENSION_FUTURE_LENGTH_FACTOR) 603 | 604 | # Get pattern period information 605 | pattern_start_date = window_df.index[0] 606 | pattern_end_date = window_df.index[-1] 607 | 608 | # Get past data (before pattern) 609 | past_df = target_df[target_df.index < pattern_start_date] 610 | if len(past_df) >= past_length: 611 | past_data = past_df.iloc[-past_length:] 612 | else: 613 | past_data = past_df 614 | 615 | # Get future data (after pattern) 616 | future_df = target_df[target_df.index > pattern_end_date] 617 | 618 | # Analyze future trend to determine file name suffix 619 | trend_analysis = analyze_future_trend(window_df, target_df, [VIS_EXTENSION_FUTURE_LENGTH_FACTOR]) 620 | 621 | if VIS_EXTENSION_FUTURE_LENGTH_FACTOR in trend_analysis: 622 | trend_info = trend_analysis[VIS_EXTENSION_FUTURE_LENGTH_FACTOR] 623 | trend_result = trend_info['trend'] 624 | 625 | if trend_result in ['rise', 'fall']: 626 | if trend_info.get('insufficient_data', False): 627 | trend_suffix = f"_{trend_result}_insufficient" 628 | else: 629 | trend_suffix = f"_{trend_result}" 630 | elif trend_result == 'no_future_data': 631 | trend_suffix = "_no_future" 632 | else: 633 | trend_suffix = "_unknown" 634 | else: 635 | trend_suffix = "_unknown" 636 | 637 | if len(future_df) >= future_length: 638 | future_data = future_df.iloc[:future_length] 639 | else: 640 | future_data = future_df 641 | 642 | # Combine all data for extended view 643 | extended_parts = [] 644 | if not past_data.empty: 645 | extended_parts.append(past_data) 646 | extended_parts.append(window_df) 647 | if not future_data.empty: 648 | extended_parts.append(future_data) 649 | 650 | extended_df = pd.concat(extended_parts) if extended_parts else window_df 651 | 652 | # Normalize data independently for each subplot 653 | reference_normalized_df, _ = DataNormalizer.normalize_ohlc_dataframe(reference_df, include_volume=True) 654 | window_normalized_df, _ = DataNormalizer.normalize_ohlc_dataframe(window_df, include_volume=True) 655 | 656 | # For extended view, use pattern normalization parameters for OHLC 657 | pattern_ohlc = window_df[['Open', 'High', 'Low', 'Close']].values 658 | extended_norm_params = DataNormalizer.calculate_normalization_params(pattern_ohlc, (-1, 1)) 659 | 660 | # Apply pattern normalization to extended OHLC data 661 | extended_ohlc = extended_df[['Open', 'High', 'Low', 'Close']].values 662 | extended_normalized = DataNormalizer.apply_normalization_params(extended_ohlc, extended_norm_params) 663 | 664 | extended_normalized_df = extended_df.copy() 665 | for i, column in enumerate(['Open', 'High', 'Low', 'Close']): 666 | extended_normalized_df[column] = extended_normalized[:, i] 667 | 668 | if 'Volume' in extended_df.columns: 669 | volume_values = extended_df['Volume'].values.reshape(-1, 1) 670 | volume_norm_params = DataNormalizer.calculate_normalization_params(volume_values, (0, 1)) 671 | normalized_volume = DataNormalizer.apply_normalization_params(volume_values, volume_norm_params) 672 | extended_normalized_df['Volume'] = normalized_volume.flatten() 673 | 674 | # Also normalize SMA columns for extended view 675 | sma_columns = ['SMA_30', 'SMA_45', 'SMA_60'] 676 | for column in sma_columns: 677 | if column in extended_normalized_df.columns: 678 | extended_normalized_df[column] = DataNormalizer.apply_normalization_params( 679 | extended_normalized_df[column].values.reshape(-1, 1), 680 | extended_norm_params 681 | ).flatten() 682 | 683 | fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(20, 40)) 684 | 685 | # Plot 1: Reference trend with volume 686 | plot_candlesticks_with_volume(ax1, reference_normalized_df, volume_ratio=0.12) 687 | ax1.plot(reference_normalized_df.index, reference_normalized_df['SMA_30'], 'blue', linewidth=1.1, alpha=0.9, label='SMA30') 688 | ax1.plot(reference_normalized_df.index, reference_normalized_df['SMA_45'], 'orange', linewidth=1.1, alpha=0.9, label='SMA45') 689 | ax1.plot(reference_normalized_df.index, reference_normalized_df['SMA_60'], 'purple', linewidth=1.1, alpha=0.9, label='SMA60') 690 | ax1.set_title(f'Reference Trend: {reference_symbol} ({reference_timeframe}, {reference_label})', fontsize=14) 691 | ax1.set_ylabel('Normalized Price [-1, 1]') 692 | ax1.set_ylim(-1.2, 1.2) 693 | ax1.legend(loc='upper left', fontsize=14) 694 | ax1.grid(True, alpha=0.1) 695 | 696 | # Plot 2: Target pattern with volume 697 | plot_candlesticks_with_volume(ax2, window_normalized_df, volume_ratio=0.12) 698 | ax2.plot(window_normalized_df.index, window_normalized_df['SMA_30'], 'blue', linewidth=1.1, alpha=0.9, label='SMA30') 699 | ax2.plot(window_normalized_df.index, window_normalized_df['SMA_45'], 'orange', linewidth=1.1, alpha=0.9, label='SMA45') 700 | ax2.plot(window_normalized_df.index, window_normalized_df['SMA_60'], 'purple', linewidth=1.1, alpha=0.9, label='SMA60') 701 | ax2.set_title(f'Target Pattern: {symbol} ({timeframe})', fontsize=14) 702 | ax2.set_ylabel('Normalized Price [-1, 1]') 703 | ax2.set_ylim(-1.2, 1.2) 704 | ax2.legend(loc='upper left', fontsize=14) 705 | ax2.grid(True, alpha=0.1) 706 | 707 | # Plot 3: Extended view (past + pattern + future) with volume 708 | plot_candlesticks_with_volume(ax3, extended_normalized_df, volume_ratio=0.12) 709 | ax3.plot(extended_normalized_df.index, extended_normalized_df['SMA_30'], 'blue', linewidth=1.1, alpha=0.9, label='SMA30') 710 | ax3.plot(extended_normalized_df.index, extended_normalized_df['SMA_45'], 'orange', linewidth=1.1, alpha=0.9, label='SMA45') 711 | ax3.plot(extended_normalized_df.index, extended_normalized_df['SMA_60'], 'purple', linewidth=1.1, alpha=0.9, label='SMA60') 712 | 713 | # Add vertical lines to mark pattern boundaries in extended view 714 | ax3.axvline(x=pattern_start_date, color='blue', linestyle='--', linewidth=1, alpha=0.7, label='Pattern Start') 715 | ax3.axvline(x=pattern_end_date, color='red', linestyle='--', linewidth=1, alpha=0.7, label='Pattern End') 716 | 717 | ax3.set_title(f'Extended Analysis: {symbol} ({timeframe}) - Past + Pattern + Future', fontsize=14) 718 | ax3.set_xlabel('Date') 719 | ax3.set_ylabel('Normalized Price (pattern range: [-1, 1])') 720 | ax3.legend(loc='upper left', fontsize=14) 721 | ax3.grid(True, alpha=0.1) 722 | 723 | # Format date ticks for all axes 724 | for ax in [ax1, ax2, ax3]: 725 | ax.tick_params(axis='x', rotation=45) 726 | ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d')) 727 | 728 | # Set appropriate number of ticks based on data length 729 | reference_days = (reference_df.index[-1] - reference_df.index[0]).days 730 | window_days = (window_df.index[-1] - window_df.index[0]).days 731 | extended_days = (extended_df.index[-1] - extended_df.index[0]).days 732 | 733 | if reference_days > 0: 734 | ax1.xaxis.set_major_locator(mdates.DayLocator(interval=max(1, reference_days // 5))) 735 | if window_days > 0: 736 | ax2.xaxis.set_major_locator(mdates.DayLocator(interval=max(1, window_days // 5))) 737 | if extended_days > 0: 738 | ax3.xaxis.set_major_locator(mdates.DayLocator(interval=max(1, extended_days // 8))) 739 | 740 | # Add info text including future data status 741 | actual_future_length = len(future_data) 742 | expected_future_length = future_length 743 | 744 | future_status = "" 745 | if actual_future_length == 0: 746 | future_status = " (No future data available)" 747 | elif actual_future_length < expected_future_length: 748 | future_status = f" (Only {actual_future_length}/{expected_future_length} available)" 749 | 750 | info_text = ( 751 | f"Similarity Score: {similarity:.4f}\n" 752 | f"Price Distance: {price_distance:.4f}\n" 753 | f"SMA Diff Distance: {diff_distance:.4f}\n" 754 | f"Pattern Period: {format_dt_with_tz(pattern_start_date, TIMEZONE)} to {format_dt_with_tz(pattern_end_date, TIMEZONE)}\n" 755 | f"Extended View:\n" 756 | f" Past Factor: {VIS_EXTENSION_PAST_LENGTH_FACTOR}x ({len(past_data)} bars)\n" 757 | f" Pattern: 1.0x ({len(window_df)} bars)\n" 758 | f" Future Factor: {VIS_EXTENSION_FUTURE_LENGTH_FACTOR}x ({actual_future_length} bars{future_status})" 759 | ) 760 | 761 | plt.figtext(0.02, 0.04, info_text, fontsize=14, 762 | bbox=dict(facecolor='white', alpha=0.8, boxstyle='round,pad=0.5')) 763 | 764 | # Add main title 765 | fig.suptitle(f"Trend Analysis: {reference_symbol} vs {symbol}", fontsize=20, y=0.95) 766 | 767 | # Adjust layout 768 | plt.tight_layout(rect=[0, 0.08, 1, 0.95]) 769 | plt.subplots_adjust(hspace=0.12) 770 | 771 | # Generate output filename with trend direction 772 | score = similarity 773 | timestamp = window_df.index[0].strftime("%Y%m%d") 774 | output_filename = f"score_{score:.4f}_{symbol}_{timestamp}{trend_suffix}.png" 775 | output_path = os.path.join(visualization_dir, output_filename) 776 | 777 | # Save figure 778 | FileManager.ensure_directories(visualization_dir) 779 | plt.savefig(output_path, dpi=150, bbox_inches='tight') 780 | plt.close(fig) 781 | 782 | print(f"Saved full analysis visualization to {output_path}") 783 | return output_path 784 | 785 | except Exception as e: 786 | print(f"Error in full analysis visualization: {e}") 787 | # Return a default path in case of error 788 | score = similarity 789 | timestamp = window_df.index[0].strftime("%Y%m%d") if not window_df.empty else "unknown" 790 | return os.path.join(visualization_dir, f"score_{score:.4f}_{symbol}_{timestamp}_error.png") 791 | 792 | 793 | def create_visualizations_parallel(args: tuple): 794 | """Worker function for parallel visualization""" 795 | target_df, result, reference_df, symbol, timeframe, reference_symbol, reference_timeframe, reference_label, visualization_dir = args 796 | 797 | if result is None or result["window_data"] is None: 798 | return None 799 | 800 | # Create visualization directory if it doesn't exist 801 | FileManager.ensure_directories(visualization_dir) 802 | 803 | # Create full analysis visualization 804 | output_path = create_full_analysis_chart( 805 | reference_df, 806 | result["window_data"], 807 | target_df, 808 | symbol, 809 | reference_symbol, 810 | timeframe, 811 | reference_timeframe, 812 | reference_label, 813 | result["similarity"], 814 | result["price_distance"], 815 | result["diff_distance"], 816 | visualization_dir 817 | ) 818 | 819 | return {'analysis_path': output_path} 820 | 821 | 822 | # ================ Main Function ================ 823 | 824 | def main(): 825 | """Main function to run the trend similarity analysis""" 826 | # Parse command line arguments 827 | parser = argparse.ArgumentParser(description='Find similar trends in cryptocurrency data') 828 | parser.add_argument('-k', '--topk', type=int, default=TOP_K, help=f'Number of top matches to keep (default: {TOP_K})') 829 | parser.add_argument('-s', '--sleep', type=float, default=API_SLEEP_SECONDS, help=f'Sleep time between API requests (default: {API_SLEEP_SECONDS})') 830 | args = parser.parse_args() 831 | 832 | # Record start time 833 | start_time = time.time() 834 | 835 | # Create configuration from script constants 836 | config = TrendAnalysisConfig() 837 | config.sma_periods = SMA_PERIODS 838 | config.dtw_window_ratio = DTW_WINDOW_RATIO 839 | config.dtw_window_ratio_diff = DTW_WINDOW_RATIO_FOR_DIFF 840 | config.dtw_max_point_distance = DTW_MAX_POINT_DISTANCE 841 | config.dtw_max_point_distance_diff = DTW_MAX_POINT_DISTANCE_FOR_DIFF 842 | config.shapedtw_balance_pd_ratio = SHAPEDTW_BALANCE_PD_RATIO 843 | config.price_weight = PRICE_WEIGHT 844 | config.diff_weight = DIFF_WEIGHT 845 | config.slope_window_size = SLOPE_WINDOW_SIZE 846 | config.paa_window_size = PAA_WINDOW_SIZE 847 | config.window_scale_factors = WINDOW_SCALE_FACTORS 848 | config.min_query_length = 60 # Not used in this script but kept for consistency 849 | config.api_sleep_seconds = args.sleep 850 | config.request_time_buffer_ratio = 1.2 # Not used in this script but kept for consistency 851 | 852 | # Create run directory with timestamp 853 | run_directory = create_output_directory(OUTPUT_DIR) 854 | 855 | print(f"Configuration:") 856 | print(f" Past Extension Factor: {VIS_EXTENSION_PAST_LENGTH_FACTOR}x") 857 | print(f" Future Extension Factor: {VIS_EXTENSION_FUTURE_LENGTH_FACTOR}x") 858 | print(f" Extension Factors for Stats: {EXTENSION_FACTORS_FOR_STATS}") 859 | print(f" Overlap Filtering: {'Global' if GLOBAL_OVERLAP_FILTERING else 'Per-Symbol'}") 860 | print(f" Top K Results: {args.topk}") 861 | print(f" API Sleep: {args.sleep}s") 862 | 863 | # Initialize data processor 864 | data_processor = DataProcessor(config) 865 | 866 | # Process reference trends using unified manager 867 | reference_trends = [] 868 | reference_data = {} 869 | 870 | for reference_symbol, trends in REFERENCE_TRENDS.items(): 871 | for i, trend_info in enumerate(trends): 872 | start_datetime, end_datetime, reference_timeframe, reference_label = trend_info 873 | 874 | # Load reference trend using unified manager 875 | reference_df = ReferenceDataManager.load_or_fetch_reference_data( 876 | reference_symbol, start_datetime, end_datetime, reference_timeframe, reference_label, 877 | OUTPUT_DIR, TIMEZONE, data_processor, config 878 | ) 879 | 880 | if reference_df is not None and not reference_df.empty: 881 | reference_key = (reference_symbol, reference_timeframe, reference_label) 882 | reference_data[reference_key] = reference_df 883 | reference_trends.append((reference_symbol, reference_timeframe, reference_label)) 884 | print(f"Loaded reference trend for {reference_symbol} ({reference_timeframe}, {reference_label}) with {len(reference_df)} data points") 885 | 886 | if not reference_trends: 887 | print("No valid reference trends found. Exiting.") 888 | return 889 | 890 | # Process each timeframe separately 891 | all_results = {} 892 | all_reference_statistics = {} # Store statistics for overall summary 893 | 894 | for timeframe in TIMEFRAMES_TO_ANALYZE: 895 | print(f"\n{'='*80}") 896 | print(f"Processing timeframe: {timeframe}") 897 | print(f"{'='*80}\n") 898 | 899 | # Create timeframe results directory 900 | timeframe_results_directory = os.path.join(run_directory, f"{timeframe}_results") 901 | FileManager.ensure_directories(timeframe_results_directory) 902 | 903 | # Download data for this timeframe using unified cache manager 904 | data_dict = DataCacheManager.download_timeframe_data(timeframe, OUTPUT_DIR, config, HISTORICAL_START_DATE, data_processor) 905 | 906 | # Initialize similarity finder 907 | similarity_finder = DTWSimilarityFinder(config) 908 | 909 | # Dictionary to store results for this timeframe 910 | timeframe_results = {} 911 | 912 | # Process each reference trend - regardless of its timeframe 913 | for reference_symbol, reference_timeframe, reference_label in reference_trends: 914 | reference_key = (reference_symbol, reference_timeframe, reference_label) 915 | reference_df = reference_data[reference_key] 916 | 917 | print(f"\nProcessing reference trend: {reference_symbol} ({reference_timeframe}, {reference_label})") 918 | 919 | # Create result directory for this reference 920 | reference_result_directory = os.path.join(timeframe_results_directory, f"{reference_symbol}_{reference_timeframe}_{reference_label}") 921 | FileManager.ensure_directories(reference_result_directory) 922 | 923 | # Prepare arguments for parallel processing 924 | process_arguments = [] 925 | valid_symbols = [] 926 | 927 | for symbol, target_df in data_dict.items(): 928 | # Skip reference symbol itself 929 | if symbol == reference_symbol: 930 | continue 931 | 932 | # Check if we have enough data 933 | if target_df is not None and len(target_df) >= len(reference_df): 934 | process_arguments.append((reference_df, target_df, symbol, timeframe, reference_symbol, reference_timeframe, reference_label)) 935 | valid_symbols.append(symbol) 936 | 937 | print(f"Processing {len(valid_symbols)} valid symbols for {timeframe}...") 938 | 939 | # Process in parallel 940 | with Pool(processes=min(cpu_count()-1, len(valid_symbols))) if len(valid_symbols) > 1 else Pool(processes=1) as pool: 941 | symbol_results = pool.map(similarity_finder.process_target, process_arguments) 942 | 943 | # Collect all valid results from all symbols 944 | all_symbol_results = [] 945 | for result in symbol_results: 946 | symbol = result["symbol"] 947 | if result["result"] is not None and result["result"]["similarity"] > 0: 948 | # Add symbol information to the result 949 | final_result = result["result"].copy() 950 | final_result["symbol"] = symbol 951 | all_symbol_results.append(final_result) 952 | 953 | print(f"Found {len(all_symbol_results)} valid results before filtering...") 954 | 955 | # Apply overlap filtering based on configuration 956 | filtered_results = filter_non_overlapping_results(all_symbol_results, GLOBAL_OVERLAP_FILTERING) 957 | print(f"After {'global' if GLOBAL_OVERLAP_FILTERING else 'per-symbol'} filtering: {len(filtered_results)} results") 958 | 959 | # Sort by similarity (descending) 960 | filtered_results.sort(key=lambda x: x["similarity"], reverse=True) 961 | 962 | # Get top K results 963 | top_results = filtered_results[:args.topk] 964 | 965 | # Calculate statistics for this reference trend in this timeframe 966 | timeframe_statistics = calculate_trend_statistics(top_results, data_dict, EXTENSION_FACTORS_FOR_STATS) 967 | 968 | # Store statistics for overall summary 969 | if reference_key not in all_reference_statistics: 970 | all_reference_statistics[reference_key] = [] 971 | all_reference_statistics[reference_key].extend(top_results) 972 | 973 | # Create summary for this reference 974 | reference_summary = [] 975 | reference_summary.append(f"Reference: {reference_symbol} ({reference_timeframe}, {reference_label})") 976 | reference_summary.append(f"Reference Period: {format_dt_with_tz(reference_df.index[0], TIMEZONE)} to {format_dt_with_tz(reference_df.index[-1], TIMEZONE)}") 977 | reference_summary.append(f"Number of data points: {len(reference_df)}") 978 | reference_summary.append(f"Filtering Strategy: {'Global' if GLOBAL_OVERLAP_FILTERING else 'Per-Symbol'}") 979 | reference_summary.append("-" * 50) 980 | 981 | # Add trend statistics 982 | trend_statistics_lines = format_trend_statistics(timeframe_statistics, f"Timeframe {timeframe}") 983 | reference_summary.extend(trend_statistics_lines) 984 | reference_summary.append("-" * 50) 985 | 986 | if top_results: 987 | # Generate visualizations for top results 988 | print(f"\nGenerating visualizations for top {len(top_results)} matches...") 989 | visualization_directory = os.path.join(reference_result_directory, "visualizations") 990 | FileManager.ensure_directories(visualization_directory) 991 | 992 | # Prepare visualization arguments 993 | visualization_arguments = [] 994 | for result in top_results: 995 | symbol = result["symbol"] 996 | visualization_arguments.append(( 997 | data_dict[symbol], # Full target dataframe 998 | result, # Result with window data 999 | reference_df, # Reference dataframe 1000 | symbol, # Symbol 1001 | timeframe, # Timeframe 1002 | reference_symbol, # Reference symbol 1003 | reference_timeframe, # Reference timeframe 1004 | reference_label, # Reference label 1005 | visualization_directory # Visualization directory 1006 | )) 1007 | 1008 | # Process visualizations in parallel 1009 | with Pool(processes=min(cpu_count()-1, len(visualization_arguments))) if len(visualization_arguments) > 1 else Pool(processes=1) as pool: 1010 | pool.map(create_visualizations_parallel, visualization_arguments) 1011 | 1012 | # Add results to summary 1013 | reference_summary.append("Top Results:") 1014 | for i, result in enumerate(top_results): 1015 | symbol = result["symbol"] 1016 | score = result["similarity"] 1017 | price_distance = result["price_distance"] 1018 | diff_distance = result["diff_distance"] 1019 | window_data = result["window_data"] 1020 | 1021 | # Get trend direction for this result 1022 | target_df = data_dict.get(symbol) 1023 | trend_direction = get_trend_direction(window_data, target_df) if target_df is not None else 'unknown' 1024 | 1025 | window_period = ( 1026 | f"{format_dt_with_tz(window_data.index[0], TIMEZONE)} to {format_dt_with_tz(window_data.index[-1], TIMEZONE)}" 1027 | if window_data is not None else "N/A" 1028 | ) 1029 | 1030 | reference_summary.append(f"{i+1}. {symbol} ({trend_direction.upper()})") 1031 | reference_summary.append(f" Score: {score:.4f}") 1032 | reference_summary.append(f" Price Distance: {price_distance:.4f}") 1033 | reference_summary.append(f" SMA Diff Distance: {diff_distance:.4f}") 1034 | reference_summary.append(f" Period: {window_period}") 1035 | reference_summary.append("") 1036 | else: 1037 | reference_summary.append("No matching trends found") 1038 | 1039 | # Save reference summary 1040 | reference_summary_text = '\n'.join(reference_summary) 1041 | reference_summary_file = os.path.join(reference_result_directory, "results_summary.txt") 1042 | with open(reference_summary_file, 'w') as f: 1043 | f.write(reference_summary_text) 1044 | 1045 | # Store results 1046 | timeframe_results[reference_key] = { 1047 | "top_results": top_results, 1048 | "all_results": filtered_results, 1049 | "statistics": timeframe_statistics 1050 | } 1051 | 1052 | # Print summary 1053 | print(f"\n{reference_summary_text}") 1054 | 1055 | # Store results for this timeframe 1056 | all_results[timeframe] = timeframe_results 1057 | 1058 | # Calculate overall statistics for each reference trend 1059 | overall_reference_statistics = {} 1060 | for reference_key, all_reference_results in all_reference_statistics.items(): 1061 | # Combine data from all timeframes for this reference 1062 | combined_data_dictionary = {} 1063 | for timeframe in TIMEFRAMES_TO_ANALYZE: 1064 | timeframe_data = DataCacheManager.download_timeframe_data(timeframe, OUTPUT_DIR, config, HISTORICAL_START_DATE, data_processor) 1065 | combined_data_dictionary.update(timeframe_data) 1066 | 1067 | overall_reference_statistics[reference_key] = calculate_trend_statistics(all_reference_results, combined_data_dictionary, EXTENSION_FACTORS_FOR_STATS) 1068 | 1069 | # Create overall summary 1070 | overall_summary = [] 1071 | overall_summary.append(f"Trend Similarity Analysis Report") 1072 | overall_summary.append(f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M')}") 1073 | overall_summary.append(f"Past Extension Factor: {VIS_EXTENSION_PAST_LENGTH_FACTOR}x") 1074 | overall_summary.append(f"Future Extension Factor: {VIS_EXTENSION_FUTURE_LENGTH_FACTOR}x") 1075 | overall_summary.append(f"Extension Factors for Stats: {EXTENSION_FACTORS_FOR_STATS}") 1076 | overall_summary.append(f"Overlap Filtering: {'Global' if GLOBAL_OVERLAP_FILTERING else 'Per-Symbol'}") 1077 | overall_summary.append(f"{'='*50}\n") 1078 | 1079 | # Add overall statistics for each reference trend 1080 | overall_summary.append("OVERALL STATISTICS (All Timeframes Combined)") 1081 | overall_summary.append("="*50) 1082 | for reference_key in reference_trends: 1083 | if reference_key in overall_reference_statistics: 1084 | reference_symbol, reference_timeframe, reference_label = reference_key 1085 | overall_summary.append(f"\nReference: {reference_symbol} ({reference_timeframe}, {reference_label})") 1086 | overall_summary.append("-" * 40) 1087 | 1088 | overall_statistics_lines = format_trend_statistics(overall_reference_statistics[reference_key], "Overall") 1089 | overall_summary.extend(overall_statistics_lines) 1090 | overall_summary.append("\n" + "="*50) 1091 | 1092 | # Include summary for each timeframe (WITHOUT detailed top matches) 1093 | for timeframe, timeframe_results in all_results.items(): 1094 | if not timeframe_results: 1095 | continue 1096 | 1097 | overall_summary.append(f"\n{'='*50}") 1098 | overall_summary.append(f"TIMEFRAME: {timeframe}") 1099 | overall_summary.append(f"{'='*50}\n") 1100 | 1101 | for reference_key, results in timeframe_results.items(): 1102 | reference_symbol, reference_timeframe, reference_label = reference_key 1103 | top_results = results["top_results"] 1104 | timeframe_statistics = results["statistics"] 1105 | 1106 | overall_summary.append(f"Reference: {reference_symbol} ({reference_timeframe}, {reference_label})") 1107 | overall_summary.append(f"{'-'*40}") 1108 | 1109 | # Add timeframe-specific statistics 1110 | timeframe_statistics_lines = format_trend_statistics(timeframe_statistics, f"Timeframe {timeframe}") 1111 | overall_summary.extend(timeframe_statistics_lines) 1112 | 1113 | # Add summary count only (no detailed matches) 1114 | if top_results: 1115 | overall_summary.append(f"\nFound {len(top_results)} matching patterns") 1116 | 1117 | # Count trends by direction 1118 | data_dict = DataCacheManager.download_timeframe_data(timeframe, OUTPUT_DIR, config, HISTORICAL_START_DATE, data_processor) 1119 | rise_count = 0 1120 | fall_count = 0 1121 | insufficient_count = 0 1122 | no_future_count = 0 1123 | unknown_count = 0 1124 | 1125 | for result in top_results: 1126 | symbol = result["symbol"] 1127 | window_data = result["window_data"] 1128 | target_df = data_dict.get(symbol) 1129 | trend_direction = get_trend_direction(window_data, target_df) if target_df is not None else 'unknown' 1130 | 1131 | if trend_direction in ['rise', 'fall']: 1132 | if trend_direction == 'rise': 1133 | rise_count += 1 1134 | else: 1135 | fall_count += 1 1136 | elif 'insufficient' in trend_direction: 1137 | insufficient_count += 1 1138 | elif trend_direction == 'no_future_data': 1139 | no_future_count += 1 1140 | else: 1141 | unknown_count += 1 1142 | 1143 | trend_dist_line = f"Trend Distribution: {rise_count} RISE, {fall_count} FALL" 1144 | if insufficient_count > 0: 1145 | trend_dist_line += f", {insufficient_count} INSUFFICIENT" 1146 | if no_future_count > 0: 1147 | trend_dist_line += f", {no_future_count} NO_FUTURE" 1148 | if unknown_count > 0: 1149 | trend_dist_line += f", {unknown_count} UNKNOWN" 1150 | 1151 | overall_summary.append(trend_dist_line) 1152 | else: 1153 | overall_summary.append(f"\nNo matching trends found") 1154 | 1155 | overall_summary.append("") 1156 | 1157 | # Save overall summary 1158 | overall_summary_text = '\n'.join(overall_summary) 1159 | overall_summary_file = os.path.join(run_directory, "overall_summary.txt") 1160 | with open(overall_summary_file, 'w') as f: 1161 | f.write(overall_summary_text) 1162 | 1163 | # Print overall summary 1164 | print("\n" + overall_summary_text) 1165 | print(f"\nOverall summary saved to: {overall_summary_file}") 1166 | print(f"Results saved to: {run_directory}") 1167 | 1168 | # Calculate and output runtime 1169 | end_time = time.time() 1170 | runtime = end_time - start_time 1171 | print(f"\nTotal runtime: {runtime:.2f} seconds ({runtime/60:.2f} minutes)") 1172 | 1173 | 1174 | if __name__ == "__main__": 1175 | main() 1176 | --------------------------------------------------------------------------------