├── data ├── raw │ └── .gitkeep ├── default_signals_config.json └── README.md ├── tests └── .gitkeep ├── models ├── saved │ └── .gitkeep ├── __init__.py ├── static_transformer.py ├── utils.py └── residual_tft.py ├── saved_models ├── stage2_boost │ └── .gitkeep ├── .gitkeep ├── result_demo.webp ├── .gitignore └── README.md ├── src ├── __init__.py ├── data_loader.py ├── inference.py └── trainer.py ├── requirements.txt ├── configs ├── example_sst_config.json └── example_hst_config.json ├── .gitignore ├── LICENSE ├── setup.py ├── notebooks └── Train and run model with demo data and your own data with gradio interface.ipynb ├── examples └── quick_start.py ├── README_CN.md └── README.md /data/raw/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/saved/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /saved_models/stage2_boost/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /saved_models/.gitkeep: -------------------------------------------------------------------------------- 1 | # This file ensures the saved_models folder is tracked by git even when empty 2 | -------------------------------------------------------------------------------- /saved_models/result_demo.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FTF1990/Industrial-digital-twin-by-transformer/HEAD/saved_models/result_demo.webp -------------------------------------------------------------------------------- /saved_models/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore all files in this directory 2 | * 3 | 4 | # Except this .gitignore file and .gitkeep 5 | !.gitignore 6 | !.gitkeep 7 | !README.md 8 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source modules for training, inference, and data loading 3 | """ 4 | 5 | from .data_loader import SensorDataLoader 6 | from .trainer import ModelTrainer 7 | from .inference import ModelInference 8 | 9 | __all__ = ['SensorDataLoader', 'ModelTrainer', 'ModelInference'] 10 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Industrial Digital Twin Models by Transformer 3 | 4 | This package contains Transformer-based models for industrial sensor prediction. 5 | 6 | Enhanced version with Stage2 Residual Boost training system. 7 | """ 8 | 9 | from .static_transformer import StaticSensorTransformer, SST 10 | 11 | __all__ = [ 12 | 'StaticSensorTransformer', 13 | 'SST', 14 | ] 15 | 16 | __version__ = '1.0.0' # Enhanced with Stage2 Boost 17 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Core Dependencies 2 | torch>=2.0.0 3 | numpy>=1.24.0 4 | pandas>=2.0.0 5 | scikit-learn>=1.3.0 6 | scipy>=1.7.0 7 | 8 | # Visualization 9 | matplotlib>=3.7.0 10 | seaborn>=0.12.0 11 | 12 | # Web Interface 13 | gradio>=4.0.0 14 | 15 | # Progress Bars 16 | tqdm>=4.62.0 17 | 18 | # Jupyter Support (Optional) 19 | jupyter>=1.0.0 20 | notebook>=6.4.0 21 | ipywidgets>=7.6.0 22 | 23 | # Optional: For GPU acceleration 24 | # If using CUDA, install appropriate PyTorch version: 25 | # Visit https://pytorch.org/get-started/locally/ 26 | # Example for CUDA 11.8: 27 | # pip install torch==2.1.0+cu118 -f https://download.pytorch.org/whl/torch_stable.html 28 | -------------------------------------------------------------------------------- /configs/example_sst_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "SST", 3 | "description": "Example configuration for SST (StaticSensorTransformer) model", 4 | "signals": { 5 | "boundary": [ 6 | "sensor_temperature_1", 7 | "sensor_pressure_1", 8 | "sensor_flow_1", 9 | "sensor_speed_1" 10 | ], 11 | "target": [ 12 | "sensor_quality_1", 13 | "sensor_vibration_1" 14 | ] 15 | }, 16 | "data_split": { 17 | "test_size": 0.2, 18 | "val_size": 0.2, 19 | "random_state": 42 20 | }, 21 | "model_architecture": { 22 | "d_model": 128, 23 | "nhead": 8, 24 | "num_layers": 3, 25 | "dropout": 0.1 26 | }, 27 | "training": { 28 | "epochs": 100, 29 | "batch_size": 64, 30 | "lr": 0.001, 31 | "weight_decay": 1e-5, 32 | "grad_clip": 1.0, 33 | "early_stop_patience": 25 34 | }, 35 | "scheduler": { 36 | "patience": 10, 37 | "factor": 0.5 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | .Python 7 | build/ 8 | develop-eggs/ 9 | dist/ 10 | downloads/ 11 | eggs/ 12 | .eggs/ 13 | lib/ 14 | lib64/ 15 | parts/ 16 | sdist/ 17 | var/ 18 | wheels/ 19 | *.egg-info/ 20 | .installed.cfg 21 | *.egg 22 | MANIFEST 23 | 24 | # Virtual Environment 25 | venv/ 26 | env/ 27 | ENV/ 28 | env.bak/ 29 | venv.bak/ 30 | 31 | # Jupyter Notebook 32 | .ipynb_checkpoints 33 | *.ipynb_checkpoints/ 34 | 35 | # PyCharm 36 | .idea/ 37 | 38 | # VSCode 39 | .vscode/ 40 | 41 | # Data files (don't commit large datasets) 42 | data/raw/*.csv 43 | data/processed/*.csv 44 | *.h5 45 | *.hdf5 46 | 47 | # Model checkpoints (optional - you may want to commit some) 48 | models/saved/*.pth 49 | models/saved/*.pt 50 | *.pth 51 | *.pt 52 | 53 | # Logs 54 | logs/ 55 | *.log 56 | 57 | # OS 58 | .DS_Store 59 | Thumbs.db 60 | 61 | # Temporary files 62 | *.tmp 63 | *.bak 64 | *.swp 65 | *~ 66 | 67 | # Gradio cache 68 | gradio_cached_examples/ 69 | flagged/ 70 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 FTF1990 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /configs/example_hst_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "HST", 3 | "description": "Example configuration for HST (HybridSensorTransformer) model", 4 | "signals": { 5 | "boundary": [ 6 | "sensor_temperature_1", 7 | "sensor_pressure_1", 8 | "sensor_flow_1", 9 | "sensor_speed_1" 10 | ], 11 | "target": [ 12 | "sensor_quality_1", 13 | "sensor_vibration_1", 14 | "sensor_temperature_internal" 15 | ], 16 | "temporal": [ 17 | "sensor_vibration_1", 18 | "sensor_temperature_internal" 19 | ] 20 | }, 21 | "data_split": { 22 | "test_size": 0.2, 23 | "val_size": 0.2, 24 | "random_state": 42 25 | }, 26 | "model_architecture": { 27 | "d_model": 64, 28 | "nhead": 4, 29 | "num_layers": 2, 30 | "dropout": 0.1, 31 | "gain": 0.1 32 | }, 33 | "hst_specific": { 34 | "context_window": 5, 35 | "apply_smoothing": true, 36 | "use_temporal": true 37 | }, 38 | "training": { 39 | "epochs": 100, 40 | "batch_size": 64, 41 | "lr": 0.001, 42 | "weight_decay": 1e-5, 43 | "grad_clip": 1.0, 44 | "early_stop_patience": 25 45 | }, 46 | "scheduler": { 47 | "patience": 10, 48 | "factor": 0.5 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open("README.md", "r", encoding="utf-8") as fh: 4 | long_description = fh.read() 5 | 6 | setup( 7 | name="industrial-digital-twin-transformer", 8 | version="1.0.0", 9 | author="FTF1990", 10 | author_email="ftf1990@users.noreply.github.com", 11 | description="Transformer-based models for industrial digital twin sensor prediction", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com/FTF1990/Industrial-digital-twin-by-transformer", 15 | packages=find_packages(), 16 | classifiers=[ 17 | "Development Status :: 4 - Beta", 18 | "Intended Audience :: Science/Research", 19 | "Intended Audience :: Developers", 20 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 21 | "License :: OSI Approved :: MIT License", 22 | "Programming Language :: Python :: 3", 23 | "Programming Language :: Python :: 3.8", 24 | "Programming Language :: Python :: 3.9", 25 | "Programming Language :: Python :: 3.10", 26 | "Programming Language :: Python :: 3.11", 27 | ], 28 | python_requires=">=3.8", 29 | install_requires=[ 30 | "torch>=2.0.0", 31 | "numpy>=1.21.0", 32 | "pandas>=1.3.0", 33 | "scikit-learn>=1.0.0", 34 | "scipy>=1.7.0", 35 | "matplotlib>=3.4.0", 36 | "seaborn>=0.11.0", 37 | "gradio>=4.0.0", 38 | "tqdm>=4.62.0", 39 | ], 40 | extras_require={ 41 | "dev": [ 42 | "pytest>=6.0", 43 | "black>=21.0", 44 | "flake8>=3.9", 45 | "mypy>=0.910", 46 | ], 47 | "notebook": [ 48 | "jupyter>=1.0.0", 49 | "notebook>=6.4.0", 50 | "ipywidgets>=7.6.0", 51 | ], 52 | }, 53 | ) 54 | -------------------------------------------------------------------------------- /saved_models/README.md: -------------------------------------------------------------------------------- 1 | # Saved Models Directory 2 | 3 | This directory is used to store trained models, scalers, and inference configurations. 4 | 5 | ## Supported File Types 6 | 7 | - **`*.pth`** - PyTorch model checkpoint files 8 | - **`*_scalers.pkl`** - Scaler files (StandardScaler for input/output normalization) 9 | - **`*_inference.json`** - Inference configuration files 10 | 11 | ## Usage 12 | 13 | ### 1. Place Your Files Here 14 | 15 | Simply copy your trained model files into this directory: 16 | 17 | ```bash 18 | Stage1 models 19 | saved_models/ 20 | ├── my_sst_model.pth 21 | ├── my_sst_model_scalers.pkl 22 | └── my_sst_model_inference.json 23 | Stage2 models 24 | saved_models/stage2_boost 25 | ├── my_stage2_model.pth 26 | ├── my_stage2_scalers.pkl 27 | └── my_stage2_inference.json 28 | ``` 29 | 30 | ### 2. Load in Gradio Interface 31 | 32 | #### Tab 2: SST Model Training 33 | - Trained models are automatically saved here 34 | 35 | #### Tab 3: Residual Extraction 36 | Three loading options: 37 | - **Load from Inference Config**: Select `*_inference.json` file 38 | - **Load from Model File**: Select `*.pth` file 39 | - **Load Scalers**: Select `*_scalers.pkl` file 40 | 41 | Steps: 42 | 1. Open the Gradio interface 43 | 2. Navigate to Tab 3 (🔬 residual extraction) 44 | 3. Click refresh buttons (🔄) to scan this folder 45 | 4. Select files from dropdown menus 46 | 5. Click load buttons (📥) to load 47 | 48 | ### 3. File Naming Convention 49 | 50 | Recommended naming pattern: 51 | ``` 52 | .pth # Model checkpoint 53 | _scalers.pkl # Scalers 54 | _inference.json # Inference config 55 | ``` 56 | 57 | Example: 58 | ``` 59 | SST_20250102_143025.pth 60 | SST_20250102_143025_scalers.pkl 61 | SST_20250102_143025_inference.json 62 | ``` 63 | 64 | ## Notes 65 | 66 | - Files in this directory are **not tracked by git** (except this README) 67 | - You can organize files in subdirectories - the system will scan recursively 68 | - Model files can be large - the `.gitignore` ensures they won't be committed 69 | -------------------------------------------------------------------------------- /data/default_signals_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "boundary_signals": [ 3 | "Other_1", 4 | "Flow_6", 5 | "Flow_7", 6 | "Flow_8", 7 | "Flow_9", 8 | "Flow_10", 9 | "Pres_10", 10 | "Load_3", 11 | "Gas_1", 12 | "Pres_11", 13 | "Pres_24", 14 | "Pres_25", 15 | "Pres_26", 16 | "Humidity_1", 17 | "Speed_1", 18 | "Temp_42", 19 | "Temp_50", 20 | "Gas_2", 21 | "Temp_1" 22 | ], 23 | "target_signals": [ 24 | "Valve_1", 25 | "Valve_2", 26 | "Valve_3", 27 | "Valve_4", 28 | "Valve_5", 29 | "Load_1", 30 | "Temp_3", 31 | "Temp_4", 32 | "Temp_5", 33 | "Temp_6", 34 | "Temp_7", 35 | "DP_1", 36 | "DP_2", 37 | "DP_3", 38 | "DP_4", 39 | "DP_5", 40 | "DP_6", 41 | "DP_7", 42 | "Pres_1", 43 | "Pres_2", 44 | "Pres_3", 45 | "Pres_4", 46 | "Pres_5", 47 | "Pres_6", 48 | "Pres_7", 49 | "Pres_8", 50 | "Pres_9", 51 | "Flow_1", 52 | "Flow_2", 53 | "Flow_3", 54 | "Flow_4", 55 | "Posn_1", 56 | "Posn_2", 57 | "Air_1", 58 | "Pres_12", 59 | "Pres_13", 60 | "Pres_14", 61 | "Pres_15", 62 | "Pres_16", 63 | "Pres_17", 64 | "Pres_18", 65 | "Pres_19", 66 | "Pres_20", 67 | "Pres_21", 68 | "Vib_1", 69 | "Vib_2", 70 | "Posn_5", 71 | "Pres_27", 72 | "Pres_28", 73 | "Pres_29", 74 | "Temp_10", 75 | "Temp_12", 76 | "Temp_11", 77 | "Temp_13", 78 | "Temp_14", 79 | "Temp_15", 80 | "Temp_16", 81 | "Temp_17", 82 | "Temp_18", 83 | "Temp_19", 84 | "Temp_20", 85 | "Temp_21", 86 | "Temp_22", 87 | "Temp_23", 88 | "Temp_24", 89 | "Temp_25", 90 | "Temp_43", 91 | "Temp_47", 92 | "Temp_48", 93 | "Temp_49", 94 | "Temp_51", 95 | "Temp_58", 96 | "Temp_77", 97 | "Temp_78", 98 | "Temp_79", 99 | "Temp_80", 100 | "Temp_81", 101 | "Temp_82", 102 | "Temp_83", 103 | "Temp_84", 104 | "Temp_85", 105 | "Temp_86", 106 | "Temp_87", 107 | "Temp_88", 108 | "Temp_89", 109 | "Temp_90", 110 | "Temp_91", 111 | "Temp_92", 112 | "Temp_59" 113 | ] 114 | } 115 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Data Folder 2 | 3 | This folder contains the datasets for training and evaluating the digital twin models. 4 | 5 | ## Structure 6 | 7 | ``` 8 | data/ 9 | ├── raw/ # Place your raw CSV sensor data here 10 | └── README.md # This file 11 | ``` 12 | 13 | ## Data Format 14 | 15 | Your CSV file should follow this format: 16 | 17 | ### Example CSV Structure 18 | 19 | ```csv 20 | timestamp,sensor_1,sensor_2,sensor_3,sensor_4,sensor_5,...,sensor_n 21 | 2025-01-01 00:00:00,23.5,101.3,45.2,78.9,12.3,...,67.8 22 | 2025-01-01 00:00:01,23.6,101.4,45.1,79.0,12.4,...,67.9 23 | 2025-01-01 00:00:02,23.7,101.5,45.3,79.1,12.5,...,68.0 24 | ... 25 | ``` 26 | 27 | ### Requirements 28 | 29 | 1. **Format**: CSV (Comma-Separated Values) 30 | 2. **Encoding**: UTF-8 31 | 3. **Headers**: First row must contain sensor names 32 | 4. **Timestamp** (Optional): First column can be a timestamp 33 | - If present, it will be automatically excluded from training 34 | 5. **Sensor Columns**: All other columns should contain numeric sensor measurements 35 | 6. **Missing Values**: Handle missing values before uploading (use interpolation or fill methods) 36 | 37 | ### Recommended Data Characteristics 38 | 39 | - **Minimum Samples**: At least 1000 timesteps for meaningful training 40 | - **Sensor Count**: 41 | - Boundary sensors: 3-50 sensors 42 | - Target sensors: 1-30 sensors 43 | - **Sampling Rate**: Consistent sampling intervals 44 | - **Data Quality**: 45 | - Remove outliers if necessary 46 | - Normalize extreme values 47 | - Handle sensor failures appropriately 48 | 49 | ### Example Dataset Structure 50 | 51 | For a manufacturing process: 52 | 53 | | Sensor Type | Examples | 54 | |-------------|----------| 55 | | **Boundary Conditions** (Inputs) | Temperature setpoints, Flow rates, Pressure inputs, Motor speeds | 56 | | **Target Sensors** (Outputs) | Internal temperatures, Product quality metrics, Vibration levels, Energy consumption | 57 | 58 | ## Placing Your Data 59 | 60 | 1. Save your CSV file in the `data/raw/` folder 61 | 2. Use a descriptive filename (e.g., `manufacturing_sensors_2025.csv`) 62 | 3. Update the path in your training code: 63 | 64 | ```python 65 | data_path = 'data/raw/manufacturing_sensors_2025.csv' 66 | ``` 67 | 68 | ## Dataset Examples (To Be Added) 69 | 70 | We will provide example datasets in future releases. Until then, you can: 71 | 72 | 1. Use your own industrial sensor data 73 | 2. Generate synthetic data for testing 74 | 3. Contact us for sample datasets 75 | 76 | ## Data Privacy 77 | 78 | **Important**: Do not commit sensitive or proprietary data to version control! 79 | 80 | - The `.gitignore` file excludes `*.csv` files in `data/raw/` 81 | - Always anonymize data before sharing 82 | - Remove confidential information from sensor names 83 | 84 | ## Getting Help 85 | 86 | If you have questions about data format or preparation: 87 | 1. Check the example notebooks in `notebooks/` 88 | 2. Read the documentation in the main README.md 89 | 3. Open an issue on GitHub 90 | 91 | --- 92 | 93 | **Note**: This folder is set up to store datasets locally. Large datasets are not tracked by Git. 94 | -------------------------------------------------------------------------------- /src/data_loader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data loader for sensor datasets 3 | """ 4 | 5 | import pandas as pd 6 | import numpy as np 7 | from sklearn.preprocessing import StandardScaler 8 | from sklearn.model_selection import train_test_split 9 | import torch 10 | from torch.utils.data import DataLoader, TensorDataset 11 | 12 | 13 | class SensorDataLoader: 14 | """ 15 | Data loader for industrial sensor datasets 16 | 17 | This class handles loading, preprocessing, and splitting of sensor data 18 | for training digital twin models. 19 | """ 20 | 21 | def __init__(self, data_path=None, df=None): 22 | """ 23 | Initialize the data loader 24 | 25 | Args: 26 | data_path (str, optional): Path to CSV file containing sensor data 27 | df (pd.DataFrame, optional): Pre-loaded DataFrame 28 | """ 29 | if df is not None: 30 | self.df = df.copy() 31 | elif data_path is not None: 32 | self.df = pd.read_csv(data_path) 33 | else: 34 | self.df = None 35 | 36 | self.scaler_X = None 37 | self.scaler_y = None 38 | 39 | def load_data(self, data_path): 40 | """ 41 | Load data from CSV file 42 | 43 | Args: 44 | data_path (str): Path to CSV file 45 | 46 | Returns: 47 | pd.DataFrame: Loaded data 48 | """ 49 | self.df = pd.read_csv(data_path) 50 | return self.df 51 | 52 | def get_available_signals(self): 53 | """ 54 | Get list of available sensor signals 55 | 56 | Returns: 57 | list: List of signal names 58 | """ 59 | if self.df is None: 60 | return [] 61 | 62 | cols = self.df.columns.tolist() 63 | 64 | # Remove timestamp columns 65 | if cols and (cols[0].startswith('2025') or 66 | 'timestamp' in cols[0].lower() or 67 | 'time' in cols[0].lower()): 68 | cols = cols[1:] 69 | 70 | return cols 71 | 72 | def prepare_data(self, boundary_signals, target_signals, 73 | test_size=0.2, val_size=0.2, random_state=42): 74 | """ 75 | Prepare and split data for training 76 | 77 | Args: 78 | boundary_signals (list): List of boundary condition signal names 79 | target_signals (list): List of target signal names 80 | test_size (float): Proportion of data for testing. Default: 0.2 81 | val_size (float): Proportion of training data for validation. Default: 0.2 82 | random_state (int): Random seed for reproducibility. Default: 42 83 | 84 | Returns: 85 | dict: Dictionary containing train/val/test splits and scalers 86 | """ 87 | if self.df is None: 88 | raise ValueError("No data loaded. Please load data first.") 89 | 90 | # Extract features and targets 91 | X = self.df[boundary_signals].values 92 | y = self.df[target_signals].values 93 | 94 | # Fit scalers 95 | self.scaler_X = StandardScaler() 96 | self.scaler_y = StandardScaler() 97 | 98 | X_scaled = self.scaler_X.fit_transform(X) 99 | y_scaled = self.scaler_y.fit_transform(y) 100 | 101 | # Split data 102 | X_train, X_test, y_train, y_test = train_test_split( 103 | X_scaled, y_scaled, test_size=test_size, random_state=random_state 104 | ) 105 | 106 | X_train, X_val, y_train, y_val = train_test_split( 107 | X_train, y_train, test_size=val_size, random_state=random_state 108 | ) 109 | 110 | return { 111 | 'X_train': X_train, 112 | 'X_val': X_val, 113 | 'X_test': X_test, 114 | 'y_train': y_train, 115 | 'y_val': y_val, 116 | 'y_test': y_test, 117 | 'scaler_X': self.scaler_X, 118 | 'scaler_y': self.scaler_y 119 | } 120 | 121 | def create_dataloaders(self, X_train, y_train, X_val, y_val, 122 | batch_size=64, shuffle=True): 123 | """ 124 | Create PyTorch DataLoaders 125 | 126 | Args: 127 | X_train (np.ndarray): Training features 128 | y_train (np.ndarray): Training targets 129 | X_val (np.ndarray): Validation features 130 | y_val (np.ndarray): Validation targets 131 | batch_size (int): Batch size. Default: 64 132 | shuffle (bool): Whether to shuffle training data. Default: True 133 | 134 | Returns: 135 | tuple: (train_loader, val_loader) 136 | """ 137 | train_dataset = TensorDataset( 138 | torch.FloatTensor(X_train), 139 | torch.FloatTensor(y_train) 140 | ) 141 | val_dataset = TensorDataset( 142 | torch.FloatTensor(X_val), 143 | torch.FloatTensor(y_val) 144 | ) 145 | 146 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle) 147 | val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) 148 | 149 | return train_loader, val_loader 150 | 151 | def inverse_transform_predictions(self, y_pred_scaled): 152 | """ 153 | Inverse transform scaled predictions back to original scale 154 | 155 | Args: 156 | y_pred_scaled (np.ndarray): Scaled predictions 157 | 158 | Returns: 159 | np.ndarray: Predictions in original scale 160 | """ 161 | if self.scaler_y is None: 162 | raise ValueError("Scaler not fitted. Please prepare data first.") 163 | 164 | return self.scaler_y.inverse_transform(y_pred_scaled) 165 | 166 | def get_data_info(self): 167 | """ 168 | Get information about loaded data 169 | 170 | Returns: 171 | dict: Data information including shape and column names 172 | """ 173 | if self.df is None: 174 | return {"status": "No data loaded"} 175 | 176 | return { 177 | "shape": self.df.shape, 178 | "columns": self.df.columns.tolist(), 179 | "num_samples": len(self.df), 180 | "num_features": len(self.df.columns) 181 | } 182 | -------------------------------------------------------------------------------- /models/static_transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | StaticSensorTransformer (SST): Sensor Sequence Transformer 3 | 4 | This module implements a novel Transformer architecture that treats fixed sensor arrays 5 | as sequences, replacing traditional token/word sequences in NLP. This innovation enables 6 | spatial relationship learning between sensors in industrial digital twin applications. 7 | 8 | Key Innovation: 9 | - Sensors as Sequence Elements: Unlike NLP where tokens represent words, here each 10 | position represents a physical sensor with learned positional embeddings. 11 | - Spatial Attention: Multi-head attention captures complex sensor inter-dependencies. 12 | - Industrial-Specific Design: Optimized for boundary-to-target sensor mapping. 13 | """ 14 | 15 | import torch 16 | import torch.nn as nn 17 | 18 | 19 | class StaticSensorTransformer(nn.Module): 20 | """ 21 | StaticSensorTransformer (SST): Sensor Sequence Transformer 22 | 23 | Innovative architecture - replaces traditional Transformer token sequences with fixed sensor sequences 24 | 25 | Core Innovation: 26 | --------------- 27 | Traditional NLP Transformers: 28 | Input: [Token_1, Token_2, ..., Token_N] (words/subwords) 29 | Position: Learned position embeddings for word order 30 | Attention: Captures semantic relationships between words 31 | 32 | SST (This Model): 33 | Input: [Sensor_1, Sensor_2, ..., Sensor_N] (physical sensors) 34 | Position: Learned position embeddings for sensor locations 35 | Attention: Captures spatial relationships between sensors 36 | 37 | Key Differences from NLP: 38 | ------------------------- 39 | 1. Fixed Sequence Length: N sensors is predetermined by physical system 40 | 2. Spatial Semantics: Position embeddings encode sensor locations, not temporal order 41 | 3. Cross-Sensor Dependencies: Attention learns physical causality (e.g., temperature 42 | sensor affects pressure sensor in industrial processes) 43 | 4. Domain-Specific: Designed for industrial sensor arrays, not language 44 | 45 | Architecture Details: 46 | --------------------- 47 | - Sensor Embedding: Projects each scalar sensor reading to d_model dimensions 48 | - Positional Encoding: Learnable parameters encoding sensor spatial positions 49 | - Multi-Head Attention: Captures complex inter-sensor relationships 50 | - Global Pooling: Aggregates sensor sequence information 51 | - Output Projection: Maps to target sensor predictions 52 | 53 | This design enables Transformers to excel at industrial digital twin tasks by 54 | treating sensor arrays as "sentences" where each sensor is a "word" with spatial 55 | rather than temporal semantics. 56 | 57 | Args: 58 | num_boundary_sensors (int): Number of boundary condition sensors (input sequence length) 59 | num_target_sensors (int): Number of target sensors to predict (output features) 60 | d_model (int): Transformer model dimension. Default: 128 61 | nhead (int): Number of attention heads. Default: 8 62 | num_layers (int): Number of transformer encoder layers. Default: 3 63 | dropout (float): Dropout rate for regularization. Default: 0.1 64 | 65 | Example: 66 | >>> model = StaticSensorTransformer( 67 | ... num_boundary_sensors=10, # 10 sensors in input sequence 68 | ... num_target_sensors=5, # Predict 5 target sensors 69 | ... d_model=128, 70 | ... nhead=8 71 | ... ) 72 | >>> x = torch.randn(32, 10) # Batch of 32 samples, 10 sensor readings 73 | >>> predictions = model(x) # Output: (32, 5) target predictions 74 | """ 75 | 76 | def __init__(self, num_boundary_sensors, num_target_sensors, 77 | d_model=128, nhead=8, num_layers=3, dropout=0.1): 78 | super(StaticSensorTransformer, self).__init__() 79 | 80 | self.num_boundary_sensors = num_boundary_sensors 81 | self.num_target_sensors = num_target_sensors 82 | self.d_model = d_model 83 | 84 | # Boundary condition embedding 85 | self.boundary_embedding = nn.Linear(1, d_model) 86 | self.boundary_position_encoding = nn.Parameter(torch.randn(num_boundary_sensors, d_model)) 87 | 88 | # Transformer encoder 89 | encoder_layer = nn.TransformerEncoderLayer( 90 | d_model=d_model, 91 | nhead=nhead, 92 | dim_feedforward=d_model * 2, 93 | dropout=dropout, 94 | batch_first=True 95 | ) 96 | self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) 97 | 98 | # Output layer 99 | self.output_projection = nn.Linear(d_model, num_target_sensors) 100 | self.global_pool = nn.AdaptiveAvgPool1d(1) 101 | 102 | self._init_weights() 103 | 104 | def _init_weights(self): 105 | """Initialize model weights using Xavier uniform initialization""" 106 | for p in self.parameters(): 107 | if p.dim() > 1: 108 | nn.init.xavier_uniform_(p) 109 | 110 | def forward(self, boundary_conditions): 111 | """ 112 | Forward pass of the model 113 | 114 | Args: 115 | boundary_conditions (torch.Tensor): Input tensor of shape (batch_size, num_boundary_sensors) 116 | 117 | Returns: 118 | torch.Tensor: Predicted target sensor values of shape (batch_size, num_target_sensors) 119 | """ 120 | batch_size = boundary_conditions.shape[0] 121 | 122 | # Embed boundary conditions 123 | x = boundary_conditions.unsqueeze(-1) # (batch, sensors, 1) 124 | x = self.boundary_embedding(x) + self.boundary_position_encoding.unsqueeze(0) 125 | 126 | # Transform 127 | x = self.transformer(x) # (batch, sensors, d_model) 128 | 129 | # Global pooling and projection 130 | x = x.permute(0, 2, 1) # (batch, d_model, sensors) 131 | x = self.global_pool(x).squeeze(-1) # (batch, d_model) 132 | predictions = self.output_projection(x) # (batch, num_target_sensors) 133 | 134 | return predictions 135 | 136 | 137 | # Aliases for backward compatibility and convenience 138 | SST = StaticSensorTransformer 139 | CompactSensorTransformer = StaticSensorTransformer # Alias used in notebook implementations 140 | 141 | -------------------------------------------------------------------------------- /notebooks/Train and run model with demo data and your own data with gradio interface.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "id": "sj9m-dofwm1x", 8 | "colab": { 9 | "base_uri": "https://localhost:8080/" 10 | }, 11 | "outputId": "bd60750e-6ee5-4761-fddd-372950acfe56" 12 | }, 13 | "outputs": [ 14 | { 15 | "output_type": "stream", 16 | "name": "stdout", 17 | "text": [ 18 | "Cloning into 'Industrial-digital-twin-by-transformer'...\n", 19 | "fatal: could not read Password for 'https://%7Btoken%7D@github.com': No such device or address\n", 20 | "[Errno 2] No such file or directory: 'Industrial-digital-twin-by-transformer'\n", 21 | "/content\n" 22 | ] 23 | } 24 | ], 25 | "source": [ 26 | "from google.colab import userdata\n", 27 | "import os\n", 28 | "\n", 29 | "\n", 30 | "!git clone https://github.com/FTF1990/Industrial-digital-twin-by-transformer.git\n", 31 | "\n", 32 | "%cd Industrial-digital-twin-by-transformer\n", 33 | "!pip install kaggle" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "metadata": { 40 | "colab": { 41 | "base_uri": "https://localhost:8080/" 42 | }, 43 | "id": "mJ92oPgfyn0u", 44 | "outputId": "56dc91d5-c9f3-47ed-a112-8f46ad62a9f1" 45 | }, 46 | "outputs": [ 47 | { 48 | "output_type": "stream", 49 | "name": "stdout", 50 | "text": [ 51 | "PyTorch\u7248\u672c: 2.8.0+cu126\n", 52 | "CUDA\u53ef\u7528: True\n", 53 | "GPU: Tesla T4\n" 54 | ] 55 | } 56 | ], 57 | "source": [ 58 | "#Make sure that you connected to T4 or A100 GPU(CPU env was not tested)\n", 59 | "import torch\n", 60 | "print(f\"PyTorch\u7248\u672c: {torch.__version__}\")\n", 61 | "print(f\"CUDA\u53ef\u7528: {torch.cuda.is_available()}\")\n", 62 | "if torch.cuda.is_available():\n", 63 | " print(f\"GPU: {torch.cuda.get_device_name(0)}\")" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "metadata": { 70 | "colab": { 71 | "base_uri": "https://localhost:8080/" 72 | }, 73 | "id": "EcGjEVrP7AS7", 74 | "outputId": "d55f8913-c386-4861-ac25-6a175ba8328f" 75 | }, 76 | "outputs": [ 77 | { 78 | "output_type": "stream", 79 | "name": "stdout", 80 | "text": [ 81 | "Downloading from https://www.kaggle.com/api/v1/datasets/download/tianffan/power-gen-machine?dataset_version_number=1...\n" 82 | ] 83 | }, 84 | { 85 | "output_type": "stream", 86 | "name": "stderr", 87 | "text": [ 88 | "100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 159M/159M [00:00<00:00, 243MB/s]" 89 | ] 90 | }, 91 | { 92 | "output_type": "stream", 93 | "name": "stdout", 94 | "text": [ 95 | "Extracting files...\n" 96 | ] 97 | }, 98 | { 99 | "output_type": "stream", 100 | "name": "stderr", 101 | "text": [ 102 | "\n" 103 | ] 104 | }, 105 | { 106 | "output_type": "stream", 107 | "name": "stdout", 108 | "text": [ 109 | "Path to dataset files: /root/.cache/kagglehub/datasets/tianffan/power-gen-machine/versions/1\n" 110 | ] 111 | } 112 | ], 113 | "source": [ 114 | "# Download demo data from kaggle\n", 115 | "import kagglehub\n", 116 | "\n", 117 | "path = kagglehub.dataset_download(\"tianffan/power-gen-machine\")\n", 118 | "\n", 119 | "print(\"Path to dataset files:\", path)" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "metadata": { 126 | "id": "EiG-wMvC8rpd" 127 | }, 128 | "outputs": [], 129 | "source": [ 130 | "# copy demo data to project if Path to dataset files: /root/.cache/kagglehub/datasets/tianffan/power-gen-machine/versions/1\n", 131 | "import os\n", 132 | "os.makedirs('data', exist_ok=True)\n", 133 | "!cp /root/.cache/kagglehub/datasets/tianffan/power-gen-machine/versions/1/data.csv data/" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "metadata": { 140 | "id": "SkUqsIu2SFLW", 141 | "colab": { 142 | "base_uri": "https://localhost:8080/" 143 | }, 144 | "outputId": "c004a029-f826-426c-afa6-32fd61f55719" 145 | }, 146 | "outputs": [ 147 | { 148 | "output_type": "stream", 149 | "name": "stdout", 150 | "text": [ 151 | "cp: cannot stat '/kaggle/input/power-gen-machine/data.csv': No such file or directory\n" 152 | ] 153 | } 154 | ], 155 | "source": [ 156 | "# copy demo data to project if Path to dataset files: /kaggle/input/\n", 157 | "import os\n", 158 | "os.makedirs('data', exist_ok=True)\n", 159 | "!cp /kaggle/input/power-gen-machine/data.csv data/" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": null, 165 | "metadata": { 166 | "colab": { 167 | "base_uri": "https://localhost:8080/" 168 | }, 169 | "id": "Ffig-cLTy0hb", 170 | "outputId": "603ec7c3-f629-4587-ff48-4e69fed17101" 171 | }, 172 | "outputs": [ 173 | { 174 | "output_type": "stream", 175 | "name": "stdout", 176 | "text": [ 177 | "python3: can't open file '/content/gradio_residual_tft_app.py': [Errno 2] No such file or directory\n" 178 | ] 179 | } 180 | ], 181 | "source": [ 182 | "#run and test with demo data in gradio app\n", 183 | "!python gradio_sensor_transformer_app.py" 184 | ] 185 | }, 186 | { 187 | "cell_type": "markdown", 188 | "source": [ 189 | "If you need to try your own data, just copy your data to path: /content/Industrial-digital-twin-by-transformer/data.\n", 190 | "Note: pls keep your data same format as the demo data which is a very simple csv format." 191 | ], 192 | "metadata": { 193 | "id": "iFf0jAllsmWT" 194 | } 195 | } 196 | ], 197 | "metadata": { 198 | "accelerator": "GPU", 199 | "colab": { 200 | "gpuType": "T4", 201 | "provenance": [] 202 | }, 203 | "kernelspec": { 204 | "display_name": "Python 3", 205 | "name": "python3" 206 | }, 207 | "language_info": { 208 | "name": "python" 209 | } 210 | }, 211 | "nbformat": 4, 212 | "nbformat_minor": 0 213 | } -------------------------------------------------------------------------------- /examples/quick_start.py: -------------------------------------------------------------------------------- 1 | """ 2 | Quick Start Example - Industrial Digital Twin by Transformer 3 | 4 | This script demonstrates a minimal example of training and using 5 | the StaticSensorTransformer (SST) model for sensor prediction. 6 | """ 7 | 8 | import torch 9 | import pandas as pd 10 | import numpy as np 11 | 12 | # Import our modules 13 | from models.static_transformer import StaticSensorTransformer 14 | from src.data_loader import SensorDataLoader 15 | from src.trainer import ModelTrainer 16 | from src.inference import ModelInference 17 | 18 | 19 | def main(): 20 | print("=" * 80) 21 | print("Industrial Digital Twin by Transformer - Quick Start") 22 | print("=" * 80) 23 | 24 | # Set device 25 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 26 | print(f"\nUsing device: {device}") 27 | 28 | # ======================================== 29 | # 1. Load Data 30 | # ======================================== 31 | print("\n" + "=" * 80) 32 | print("Step 1: Loading Data") 33 | print("=" * 80) 34 | 35 | # Replace with your actual data path 36 | data_path = 'data/raw/your_sensor_data.csv' 37 | 38 | # Or use a DataFrame if you already have one loaded 39 | # df = pd.read_csv(data_path) 40 | # data_loader = SensorDataLoader(df=df) 41 | 42 | try: 43 | data_loader = SensorDataLoader(data_path=data_path) 44 | print("✅ Data loaded successfully!") 45 | print(data_loader.get_data_info()) 46 | except FileNotFoundError: 47 | print(f"❌ File not found: {data_path}") 48 | print("\nPlease:") 49 | print("1. Place your CSV file in 'data/raw/' folder") 50 | print("2. Update the data_path variable in this script") 51 | print("\nFor now, creating synthetic data for demonstration...") 52 | 53 | # Create synthetic data for demonstration 54 | n_samples = 5000 55 | n_boundary = 5 56 | n_target = 3 57 | 58 | synthetic_data = { 59 | f'boundary_{i}': np.random.randn(n_samples) for i in range(n_boundary) 60 | } 61 | synthetic_data.update({ 62 | f'target_{i}': np.random.randn(n_samples) for i in range(n_target) 63 | }) 64 | 65 | df = pd.DataFrame(synthetic_data) 66 | data_loader = SensorDataLoader(df=df) 67 | print("✅ Synthetic data created for demonstration") 68 | 69 | # ======================================== 70 | # 2. Configure Signals 71 | # ======================================== 72 | print("\n" + "=" * 80) 73 | print("Step 2: Configuring Sensors") 74 | print("=" * 80) 75 | 76 | available_signals = data_loader.get_available_signals() 77 | 78 | # Select first 5 as boundary, next 3 as targets 79 | # Adjust based on your data 80 | boundary_signals = available_signals[:5] 81 | target_signals = available_signals[5:8] if len(available_signals) > 5 else available_signals[:3] 82 | 83 | print(f"\nBoundary Sensors ({len(boundary_signals)}):") 84 | for sig in boundary_signals: 85 | print(f" • {sig}") 86 | 87 | print(f"\nTarget Sensors ({len(target_signals)}):") 88 | for sig in target_signals: 89 | print(f" • {sig}") 90 | 91 | # ======================================== 92 | # 3. Prepare Data 93 | # ======================================== 94 | print("\n" + "=" * 80) 95 | print("Step 3: Preparing Data") 96 | print("=" * 80) 97 | 98 | data_splits = data_loader.prepare_data( 99 | boundary_signals=boundary_signals, 100 | target_signals=target_signals, 101 | test_size=0.2, 102 | val_size=0.2, 103 | random_state=42 104 | ) 105 | 106 | print(f"\nData Split:") 107 | print(f" Training: {len(data_splits['X_train'])} samples") 108 | print(f" Validation: {len(data_splits['X_val'])} samples") 109 | print(f" Test: {len(data_splits['X_test'])} samples") 110 | 111 | # ======================================== 112 | # 4. Create Model 113 | # ======================================== 114 | print("\n" + "=" * 80) 115 | print("Step 4: Creating StaticSensorTransformer (SST) Model") 116 | print("=" * 80) 117 | 118 | model = StaticSensorTransformer( 119 | num_boundary_sensors=len(boundary_signals), 120 | num_target_sensors=len(target_signals), 121 | d_model=128, 122 | nhead=8, 123 | num_layers=3, 124 | dropout=0.1 125 | ) 126 | 127 | print(f"\n✅ Model created") 128 | print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}") 129 | 130 | # ======================================== 131 | # 5. Train Model 132 | # ======================================== 133 | print("\n" + "=" * 80) 134 | print("Step 5: Training Model") 135 | print("=" * 80) 136 | 137 | # Create data loaders 138 | train_loader, val_loader = data_loader.create_dataloaders( 139 | data_splits['X_train'], 140 | data_splits['y_train'], 141 | data_splits['X_val'], 142 | data_splits['y_val'], 143 | batch_size=64 144 | ) 145 | 146 | # Configure training 147 | config = { 148 | 'lr': 0.001, 149 | 'weight_decay': 1e-5, 150 | 'epochs': 10, # Use 100+ for real training 151 | 'batch_size': 64, 152 | 'grad_clip': 1.0, 153 | 'early_stop_patience': 25, 154 | 'scheduler_patience': 10, 155 | 'scheduler_factor': 0.5 156 | } 157 | 158 | # Train 159 | trainer = ModelTrainer(model, device=str(device), config=config) 160 | history = trainer.train(train_loader, val_loader, verbose=True) 161 | 162 | print(f"\n✅ Training completed!") 163 | print(f" Best validation loss: {history['best_val_loss']:.6f}") 164 | 165 | # ======================================== 166 | # 6. Evaluate Model 167 | # ======================================== 168 | print("\n" + "=" * 80) 169 | print("Step 6: Evaluating Model") 170 | print("=" * 80) 171 | 172 | # Prepare test data 173 | X_test_original = data_splits['scaler_X'].inverse_transform(data_splits['X_test']) 174 | y_test_original = data_splits['scaler_y'].inverse_transform(data_splits['y_test']) 175 | 176 | # Create inference engine 177 | inference = ModelInference( 178 | model=model, 179 | scaler_X=data_splits['scaler_X'], 180 | scaler_y=data_splits['scaler_y'], 181 | device=str(device) 182 | ) 183 | 184 | # Evaluate 185 | metrics = inference.evaluate(X_test_original, y_test_original, target_signals) 186 | inference.print_metrics(metrics) 187 | 188 | # ======================================== 189 | # 7. Save Model 190 | # ======================================== 191 | print("\n" + "=" * 80) 192 | print("Step 7: Saving Model") 193 | print("=" * 80) 194 | 195 | import os 196 | os.makedirs('models/saved', exist_ok=True) 197 | 198 | save_path = 'models/saved/quickstart_sst_model.pth' 199 | trainer.save_model(save_path) 200 | print(f"\n✅ Model saved to: {save_path}") 201 | 202 | print("\n" + "=" * 80) 203 | print("Quick Start Completed Successfully!") 204 | print("=" * 80) 205 | print("\nNext Steps:") 206 | print("1. Try the full tutorial in notebooks/train_and_inference.ipynb") 207 | print("2. Experiment with V4 Hybrid Transformer for temporal data") 208 | print("3. Use the Gradio interface: python gradio_app.py") 209 | print("4. Customize model architecture and hyperparameters") 210 | print("=" * 80) 211 | 212 | 213 | if __name__ == "__main__": 214 | main() 215 | -------------------------------------------------------------------------------- /src/inference.py: -------------------------------------------------------------------------------- 1 | """ 2 | Inference module for trained models 3 | """ 4 | 5 | import torch 6 | import numpy as np 7 | from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error 8 | import matplotlib.pyplot as plt 9 | 10 | 11 | class ModelInference: 12 | """ 13 | Inference engine for trained digital twin models 14 | 15 | This class handles model predictions, evaluation, and visualization 16 | for both V1 and V4 models. 17 | """ 18 | 19 | def __init__(self, model, scaler_X, scaler_y, device='cuda'): 20 | """ 21 | Initialize the inference engine 22 | 23 | Args: 24 | model (nn.Module): Trained model (V1 or V4) 25 | scaler_X (StandardScaler): Scaler for input features 26 | scaler_y (StandardScaler): Scaler for target values 27 | device (str): Device to use for inference. Default: 'cuda' 28 | """ 29 | self.model = model 30 | self.scaler_X = scaler_X 31 | self.scaler_y = scaler_y 32 | self.device = torch.device(device if torch.cuda.is_available() else 'cpu') 33 | self.model.to(self.device) 34 | self.model.eval() 35 | 36 | def predict(self, X): 37 | """ 38 | Make predictions on input data 39 | 40 | Args: 41 | X (np.ndarray): Input features (unscaled) 42 | 43 | Returns: 44 | np.ndarray: Predictions in original scale 45 | """ 46 | # Scale input 47 | X_scaled = self.scaler_X.transform(X) 48 | 49 | # Predict 50 | with torch.no_grad(): 51 | X_tensor = torch.FloatTensor(X_scaled).to(self.device) 52 | y_pred_scaled = self.model(X_tensor).cpu().numpy() 53 | 54 | # Inverse transform 55 | y_pred = self.scaler_y.inverse_transform(y_pred_scaled) 56 | 57 | return y_pred 58 | 59 | def evaluate(self, X, y_true, target_signal_names=None): 60 | """ 61 | Evaluate model performance 62 | 63 | Args: 64 | X (np.ndarray): Input features (unscaled) 65 | y_true (np.ndarray): True target values (unscaled) 66 | target_signal_names (list, optional): Names of target signals 67 | 68 | Returns: 69 | dict: Evaluation metrics for each target signal 70 | """ 71 | y_pred = self.predict(X) 72 | 73 | metrics = {} 74 | n_targets = y_true.shape[1] 75 | 76 | for i in range(n_targets): 77 | signal_name = target_signal_names[i] if target_signal_names else f"Signal_{i}" 78 | 79 | r2 = r2_score(y_true[:, i], y_pred[:, i]) 80 | rmse = np.sqrt(mean_squared_error(y_true[:, i], y_pred[:, i])) 81 | mae = mean_absolute_error(y_true[:, i], y_pred[:, i]) 82 | 83 | metrics[signal_name] = { 84 | 'R2': r2, 85 | 'RMSE': rmse, 86 | 'MAE': mae 87 | } 88 | 89 | # Overall metrics 90 | metrics['Overall'] = { 91 | 'R2': np.mean([m['R2'] for m in metrics.values() if m != metrics.get('Overall')]), 92 | 'RMSE': np.mean([m['RMSE'] for m in metrics.values() if m != metrics.get('Overall')]), 93 | 'MAE': np.mean([m['MAE'] for m in metrics.values() if m != metrics.get('Overall')]) 94 | } 95 | 96 | return metrics 97 | 98 | def plot_predictions(self, X, y_true, signal_indices=None, 99 | target_signal_names=None, start_idx=0, end_idx=None): 100 | """ 101 | Plot predictions vs actual values 102 | 103 | Args: 104 | X (np.ndarray): Input features (unscaled) 105 | y_true (np.ndarray): True target values (unscaled) 106 | signal_indices (list, optional): Indices of signals to plot 107 | target_signal_names (list, optional): Names of target signals 108 | start_idx (int): Start index for plotting. Default: 0 109 | end_idx (int, optional): End index for plotting 110 | 111 | Returns: 112 | matplotlib.figure.Figure: The generated figure 113 | """ 114 | y_pred = self.predict(X) 115 | 116 | if end_idx is None: 117 | end_idx = len(y_true) 118 | 119 | y_true_slice = y_true[start_idx:end_idx] 120 | y_pred_slice = y_pred[start_idx:end_idx] 121 | 122 | # Determine which signals to plot 123 | if signal_indices is None: 124 | signal_indices = range(min(3, y_true.shape[1])) # Plot first 3 by default 125 | 126 | n_signals = len(signal_indices) 127 | fig, axes = plt.subplots(n_signals, 3, figsize=(18, 5*n_signals)) 128 | 129 | if n_signals == 1: 130 | axes = axes.reshape(1, -1) 131 | 132 | for i, sig_idx in enumerate(signal_indices): 133 | signal_name = target_signal_names[sig_idx] if target_signal_names else f"Signal_{sig_idx}" 134 | 135 | y_true_sig = y_true_slice[:, sig_idx] 136 | y_pred_sig = y_pred_slice[:, sig_idx] 137 | residuals = y_true_sig - y_pred_sig 138 | 139 | # Time series plot 140 | ax1 = axes[i, 0] 141 | ax1.plot(range(len(y_true_sig)), y_true_sig, label='Actual', linewidth=2, alpha=0.8) 142 | ax1.plot(range(len(y_pred_sig)), y_pred_sig, label='Predicted', linewidth=2, alpha=0.8) 143 | ax1.set_title(f'{signal_name}\nPrediction vs Actual') 144 | ax1.set_xlabel('Time Step') 145 | ax1.set_ylabel('Value') 146 | ax1.legend() 147 | ax1.grid(True, alpha=0.3) 148 | 149 | # Residuals plot 150 | ax2 = axes[i, 1] 151 | ax2.plot(range(len(residuals)), residuals, color='red', linewidth=1.5, alpha=0.7) 152 | ax2.axhline(y=0, color='black', linestyle='--', linewidth=2) 153 | ax2.fill_between(range(len(residuals)), residuals, alpha=0.3, color='red') 154 | ax2.set_title(f'{signal_name}\nResiduals') 155 | ax2.set_xlabel('Time Step') 156 | ax2.set_ylabel('Residual (Actual - Predicted)') 157 | ax2.grid(True, alpha=0.3) 158 | 159 | # Scatter plot 160 | ax3 = axes[i, 2] 161 | ax3.scatter(y_true_sig, y_pred_sig, alpha=0.6, s=20) 162 | min_val = min(y_true_sig.min(), y_pred_sig.min()) 163 | max_val = max(y_true_sig.max(), y_pred_sig.max()) 164 | ax3.plot([min_val, max_val], [min_val, max_val], 'r--', linewidth=2, label='Perfect Prediction') 165 | r2 = r2_score(y_true_sig, y_pred_sig) 166 | ax3.set_title(f'{signal_name}\nAccuracy (R²={r2:.3f})') 167 | ax3.set_xlabel('Actual Value') 168 | ax3.set_ylabel('Predicted Value') 169 | ax3.legend() 170 | ax3.grid(True, alpha=0.3) 171 | 172 | plt.tight_layout() 173 | return fig 174 | 175 | def print_metrics(self, metrics): 176 | """ 177 | Print evaluation metrics in a formatted way 178 | 179 | Args: 180 | metrics (dict): Metrics dictionary from evaluate() 181 | """ 182 | print("=" * 60) 183 | print("Model Evaluation Metrics") 184 | print("=" * 60) 185 | 186 | for signal_name, metric in metrics.items(): 187 | if signal_name == 'Overall': 188 | print("\n" + "=" * 60) 189 | print("OVERALL PERFORMANCE") 190 | print("=" * 60) 191 | 192 | print(f"\n{signal_name}:") 193 | print(f" R² Score: {metric['R2']:.4f}") 194 | print(f" RMSE: {metric['RMSE']:.4f}") 195 | print(f" MAE: {metric['MAE']:.4f}") 196 | 197 | print("\n" + "=" * 60) 198 | -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Trainer module for V1 and V4 models 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | from torch.utils.data import DataLoader, TensorDataset 9 | import numpy as np 10 | from tqdm import tqdm 11 | 12 | 13 | class ModelTrainer: 14 | """ 15 | Unified trainer for V1 and V4 Transformer models 16 | 17 | This class handles the training loop, validation, early stopping, 18 | and learning rate scheduling for digital twin models. 19 | """ 20 | 21 | def __init__(self, model, device='cuda', config=None): 22 | """ 23 | Initialize the trainer 24 | 25 | Args: 26 | model (nn.Module): The model to train (V1 or V4) 27 | device (str): Device to use for training. Default: 'cuda' 28 | config (dict, optional): Training configuration 29 | """ 30 | self.model = model 31 | self.device = torch.device(device if torch.cuda.is_available() else 'cpu') 32 | self.model.to(self.device) 33 | 34 | # Default configuration 35 | self.config = { 36 | 'lr': 0.001, 37 | 'weight_decay': 1e-5, 38 | 'epochs': 100, 39 | 'batch_size': 64, 40 | 'grad_clip': 1.0, 41 | 'early_stop_patience': 25, 42 | 'scheduler_patience': 10, 43 | 'scheduler_factor': 0.5 44 | } 45 | 46 | if config: 47 | self.config.update(config) 48 | 49 | self.optimizer = None 50 | self.scheduler = None 51 | self.criterion = nn.MSELoss() 52 | self.train_losses = [] 53 | self.val_losses = [] 54 | self.best_val_loss = float('inf') 55 | self.best_model_state = None 56 | 57 | def setup_optimizer(self): 58 | """Setup optimizer and learning rate scheduler""" 59 | self.optimizer = optim.AdamW( 60 | self.model.parameters(), 61 | lr=self.config['lr'], 62 | weight_decay=self.config['weight_decay'] 63 | ) 64 | 65 | self.scheduler = optim.lr_scheduler.ReduceLROnPlateau( 66 | self.optimizer, 67 | patience=self.config['scheduler_patience'], 68 | factor=self.config['scheduler_factor'] 69 | ) 70 | 71 | def train_epoch(self, train_loader): 72 | """ 73 | Train for one epoch 74 | 75 | Args: 76 | train_loader (DataLoader): Training data loader 77 | 78 | Returns: 79 | float: Average training loss for the epoch 80 | """ 81 | self.model.train() 82 | train_loss = 0.0 83 | 84 | for batch_X, batch_y in train_loader: 85 | batch_X, batch_y = batch_X.to(self.device), batch_y.to(self.device) 86 | 87 | self.optimizer.zero_grad() 88 | predictions = self.model(batch_X) 89 | loss = self.criterion(predictions, batch_y) 90 | 91 | loss.backward() 92 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 93 | max_norm=self.config['grad_clip']) 94 | self.optimizer.step() 95 | 96 | train_loss += loss.item() 97 | 98 | return train_loss / len(train_loader) 99 | 100 | def validate(self, val_loader): 101 | """ 102 | Validate the model 103 | 104 | Args: 105 | val_loader (DataLoader): Validation data loader 106 | 107 | Returns: 108 | float: Average validation loss 109 | """ 110 | self.model.eval() 111 | val_loss = 0.0 112 | 113 | with torch.no_grad(): 114 | for batch_X, batch_y in val_loader: 115 | batch_X, batch_y = batch_X.to(self.device), batch_y.to(self.device) 116 | predictions = self.model(batch_X) 117 | val_loss += self.criterion(predictions, batch_y).item() 118 | 119 | return val_loss / len(val_loader) 120 | 121 | def train(self, train_loader, val_loader, verbose=True): 122 | """ 123 | Complete training loop 124 | 125 | Args: 126 | train_loader (DataLoader): Training data loader 127 | val_loader (DataLoader): Validation data loader 128 | verbose (bool): Whether to print training progress. Default: True 129 | 130 | Returns: 131 | dict: Training history and best model state 132 | """ 133 | if self.optimizer is None: 134 | self.setup_optimizer() 135 | 136 | patience_counter = 0 137 | 138 | if verbose: 139 | print(f"Starting training on {self.device}") 140 | print(f"Model parameters: {sum(p.numel() for p in self.model.parameters()):,}") 141 | print("=" * 80) 142 | 143 | for epoch in range(self.config['epochs']): 144 | # Training 145 | train_loss = self.train_epoch(train_loader) 146 | self.train_losses.append(train_loss) 147 | 148 | # Validation 149 | val_loss = self.validate(val_loader) 150 | self.val_losses.append(val_loss) 151 | 152 | # Learning rate scheduling 153 | self.scheduler.step(val_loss) 154 | current_lr = self.optimizer.param_groups[0]['lr'] 155 | 156 | # Early stopping check 157 | if val_loss < self.best_val_loss: 158 | self.best_val_loss = val_loss 159 | self.best_model_state = self.model.state_dict().copy() 160 | patience_counter = 0 161 | status_marker = "⭐" 162 | else: 163 | patience_counter += 1 164 | status_marker = " " 165 | 166 | # Print progress 167 | if verbose: 168 | print(f"{status_marker} Epoch [{epoch+1:3d}/{self.config['epochs']:3d}] | " 169 | f"Train: {train_loss:.6f} | Val: {val_loss:.6f} | " 170 | f"Best: {self.best_val_loss:.6f} | LR: {current_lr:.2e} | " 171 | f"Patience: {patience_counter}/{self.config['early_stop_patience']}") 172 | 173 | # Early stopping 174 | if patience_counter >= self.config['early_stop_patience']: 175 | if verbose: 176 | print(f"\n🛑 Early stopping at epoch {epoch+1}") 177 | break 178 | 179 | # Load best model 180 | if self.best_model_state is not None: 181 | self.model.load_state_dict(self.best_model_state) 182 | 183 | if verbose: 184 | print("=" * 80) 185 | print(f"✅ Training completed! Best validation loss: {self.best_val_loss:.6f}") 186 | 187 | return { 188 | 'train_losses': self.train_losses, 189 | 'val_losses': self.val_losses, 190 | 'best_val_loss': self.best_val_loss, 191 | 'best_model_state': self.best_model_state 192 | } 193 | 194 | def save_model(self, path): 195 | """ 196 | Save model checkpoint 197 | 198 | Args: 199 | path (str): Path to save the model 200 | """ 201 | torch.save({ 202 | 'model_state_dict': self.model.state_dict(), 203 | 'optimizer_state_dict': self.optimizer.state_dict(), 204 | 'train_losses': self.train_losses, 205 | 'val_losses': self.val_losses, 206 | 'best_val_loss': self.best_val_loss, 207 | 'config': self.config 208 | }, path) 209 | 210 | def load_model(self, path): 211 | """ 212 | Load model checkpoint 213 | 214 | Args: 215 | path (str): Path to the saved model 216 | """ 217 | checkpoint = torch.load(path, map_location=self.device) 218 | self.model.load_state_dict(checkpoint['model_state_dict']) 219 | 220 | if 'optimizer_state_dict' in checkpoint and self.optimizer is not None: 221 | self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 222 | 223 | if 'train_losses' in checkpoint: 224 | self.train_losses = checkpoint['train_losses'] 225 | self.val_losses = checkpoint['val_losses'] 226 | self.best_val_loss = checkpoint['best_val_loss'] 227 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions for data processing and model evaluation 3 | """ 4 | 5 | import numpy as np 6 | from scipy.signal import savgol_filter 7 | from scipy.ndimage import maximum_filter1d 8 | 9 | 10 | def create_temporal_context_data(X, y, context_window=5): 11 | """ 12 | Create temporal context data 13 | 14 | Creates temporal context windows for time-series sensor data. Each sample 15 | is expanded to include surrounding timesteps for temporal analysis. 16 | 17 | Args: 18 | X (np.ndarray): Input features of shape (n_samples, n_features) 19 | y (np.ndarray): Target values of shape (n_samples, n_targets) 20 | context_window (int): Number of timesteps to include before and after each sample 21 | 22 | Returns: 23 | tuple: (X_context, y_context, valid_indices) 24 | - X_context: Array of shape (valid_samples, context_size, n_features) 25 | - y_context: Array of shape (valid_samples, n_targets) 26 | - valid_indices: List of original indices that have valid context windows 27 | """ 28 | n_samples, n_features = X.shape 29 | context_size = 2 * context_window + 1 30 | 31 | valid_start = context_window 32 | valid_end = n_samples - context_window 33 | valid_samples = valid_end - valid_start 34 | 35 | if valid_samples <= 0: 36 | raise ValueError(f"Insufficient data, need at least {2*context_window+1} samples") 37 | 38 | X_context = np.zeros((valid_samples, context_size, n_features)) 39 | y_context = np.zeros((valid_samples, y.shape[1])) 40 | valid_indices = [] 41 | 42 | for i in range(valid_samples): 43 | original_idx = valid_start + i 44 | start_idx = original_idx - context_window 45 | end_idx = original_idx + context_window + 1 46 | 47 | X_context[i] = X[start_idx:end_idx] 48 | y_context[i] = y[original_idx] 49 | valid_indices.append(original_idx) 50 | 51 | return X_context, y_context, valid_indices 52 | 53 | 54 | def apply_ifd_smoothing(y_data, target_sensors, ifd_sensor_names, 55 | window_length=15, polyorder=3): 56 | """ 57 | Apply smoothing filter to specified IFD sensors 58 | 59 | Applies Savitzky-Golay smoothing filter to specified sensors to reduce noise 60 | while preserving peak features. Particularly useful for IFD (Industrial Fault Detection) 61 | sensors with noisy signals. 62 | 63 | Args: 64 | y_data (np.ndarray): Target sensor data of shape (n_samples, n_sensors) 65 | target_sensors (list): List of all target sensor names 66 | ifd_sensor_names (list): List of sensor names to apply smoothing to 67 | window_length (int): Length of the filter window. Default: 15 68 | polyorder (int): Order of the polynomial used for filtering. Default: 3 69 | 70 | Returns: 71 | np.ndarray: Smoothed sensor data with same shape as input 72 | """ 73 | y_smoothed = y_data.copy() 74 | 75 | for sensor in ifd_sensor_names: 76 | if sensor in target_sensors: 77 | idx = target_sensors.index(sensor) 78 | original_signal = y_data[:, idx] 79 | 80 | # Adjust window length for short signals 81 | window_len = min(window_length, len(original_signal) // 4) 82 | if window_len % 2 == 0: 83 | window_len += 1 84 | 85 | if window_len >= 3: 86 | # Apply Savitzky-Golay filter 87 | smoothed_signal = savgol_filter(original_signal, window_len, polyorder) 88 | 89 | # Peak enhancement to preserve important features 90 | peaks = maximum_filter1d(original_signal, size=window_len//3) 91 | is_peak = (original_signal == peaks) & (original_signal > np.percentile(original_signal, 75)) 92 | 93 | enhanced_signal = smoothed_signal.copy() 94 | enhanced_signal[is_peak] = smoothed_signal[is_peak] * 0.8 + original_signal[is_peak] * 1.2 95 | 96 | y_smoothed[:, idx] = enhanced_signal 97 | 98 | return y_smoothed 99 | 100 | 101 | def handle_duplicate_columns(df): 102 | """ 103 | Handle duplicate column names in DataFrame by adding numbered suffixes 104 | 105 | Handles duplicate column names in a DataFrame by appending numeric suffixes 106 | to duplicated columns while preserving the original column order. 107 | 108 | Args: 109 | df (pd.DataFrame): Input DataFrame that may contain duplicate column names 110 | 111 | Returns: 112 | tuple: (df, duplicates) 113 | - df: DataFrame with unique column names 114 | - duplicates: Dictionary mapping original column names to duplicate counts 115 | """ 116 | cols = df.columns.tolist() 117 | new_cols = [] 118 | col_counts = {} 119 | 120 | for col in cols: 121 | if col not in col_counts: 122 | col_counts[col] = 0 123 | new_cols.append(col) 124 | else: 125 | col_counts[col] += 1 126 | new_cols.append(f"{col}_#{col_counts[col] + 1}") 127 | 128 | df.columns = new_cols 129 | 130 | # Return statistics of duplicates 131 | duplicates = {k: v for k, v in col_counts.items() if v > 0} 132 | return df, duplicates 133 | 134 | 135 | def get_available_signals(df): 136 | """ 137 | Get all available signals 138 | 139 | Extracts available sensor signal names from a DataFrame, excluding timestamp columns. 140 | 141 | Args: 142 | df (pd.DataFrame): Input DataFrame containing sensor data 143 | 144 | Returns: 145 | list: List of available signal names 146 | """ 147 | if df is None: 148 | return [] 149 | 150 | cols = df.columns.tolist() 151 | 152 | # Remove timestamp columns (assuming first column might be timestamp) 153 | if cols and (cols[0].startswith('2025') or 154 | 'timestamp' in cols[0].lower() or 155 | 'time' in cols[0].lower()): 156 | cols = cols[1:] 157 | 158 | return cols 159 | 160 | 161 | def validate_signal_exclusivity_v1(boundary_signals, target_signals): 162 | """ 163 | Validate signal exclusivity for V1 model 164 | 165 | Validates that boundary and target signals don't overlap for V1 model. 166 | 167 | Args: 168 | boundary_signals (list): List of boundary condition signal names 169 | target_signals (list): List of target signal names 170 | 171 | Returns: 172 | tuple: (is_valid, error_msg) 173 | - is_valid: Boolean indicating if validation passed 174 | - error_msg: Error message if validation failed, empty string otherwise 175 | """ 176 | if not boundary_signals or not target_signals: 177 | return True, "" 178 | 179 | boundary_set = set(boundary_signals) 180 | target_set = set(target_signals) 181 | overlap = boundary_set & target_set 182 | 183 | if overlap: 184 | overlap_list = list(overlap) 185 | error_msg = f"❌ Signal exclusivity error!\n\nThe following signals appear in both boundary and target:\n" 186 | 187 | for i, sig in enumerate(overlap_list[:10], 1): 188 | error_msg += f" {i}. {sig}\n" 189 | 190 | if len(overlap_list) > 10: 191 | error_msg += f" ... and {len(overlap_list)-10} more duplicate signals\n" 192 | 193 | error_msg += f"\nPlease remove these signals from one of the positions!" 194 | return False, error_msg 195 | 196 | return True, "" 197 | 198 | 199 | def validate_signal_exclusivity_v4(boundary_signals, target_signals, temporal_signals): 200 | """ 201 | Validate signal exclusivity for V4 model 202 | 203 | Validates signal selections for V4 model: 204 | 1. Boundary and target signals must not overlap 205 | 2. Temporal signals must be a subset of target signals 206 | 207 | Args: 208 | boundary_signals (list): List of boundary condition signal names 209 | target_signals (list): List of target signal names 210 | temporal_signals (list): List of temporal signal names 211 | 212 | Returns: 213 | tuple: (is_valid, error_msg) 214 | - is_valid: Boolean indicating if validation passed 215 | - error_msg: Error message if validation failed, empty string otherwise 216 | """ 217 | if not boundary_signals or not target_signals: 218 | return True, "" 219 | 220 | errors = [] 221 | 222 | # Check boundary-target overlap 223 | boundary_set = set(boundary_signals) 224 | target_set = set(target_signals) 225 | overlap_bt = boundary_set & target_set 226 | 227 | if overlap_bt: 228 | overlap_list = list(overlap_bt) 229 | error_msg = f"Boundary and target signals overlap ({len(overlap_list)} signals):\n" 230 | for i, sig in enumerate(overlap_list[:5], 1): 231 | error_msg += f" {i}. {sig}\n" 232 | if len(overlap_list) > 5: 233 | error_msg += f" ... and {len(overlap_list)-5} more\n" 234 | errors.append(error_msg) 235 | 236 | # Check temporal signals are subset of target signals 237 | if temporal_signals: 238 | temporal_set = set(temporal_signals) 239 | invalid_temporal = temporal_set - target_set 240 | 241 | if invalid_temporal: 242 | invalid_list = list(invalid_temporal) 243 | error_msg = f"Temporal signals must be in target signals ({len(invalid_list)} invalid):\n" 244 | for i, sig in enumerate(invalid_list[:5], 1): 245 | error_msg += f" {i}. {sig}\n" 246 | if len(invalid_list) > 5: 247 | error_msg += f" ... and {len(invalid_list)-5} more\n" 248 | errors.append(error_msg) 249 | 250 | if errors: 251 | full_error = "❌ Signal selection error!\n\n" + "\n".join(errors) + "\n Please fix before training!" 252 | return False, full_error 253 | 254 | return True, "" 255 | -------------------------------------------------------------------------------- /README_CN.md: -------------------------------------------------------------------------------- 1 | # Industrial Digital Twin by Transformer (基于 Transformer 的工业数字孪生) 2 | 3 | **[English](README.md)** | **[中文](README_CN.md)** 4 | 5 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 6 | [![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/) 7 | [![PyTorch](https://img.shields.io/badge/PyTorch-2.0+-ee4c2c.svg)](https://pytorch.org/) 8 | 9 | > **一个创新的基于 Transformer 的框架,专为复杂系统中的工业数字孪生建模设计,使用序列传感器输出和先进的残差提升训练方法。** 10 | 11 | 本项目引入了 Transformer 架构和残差提升训练方法,专门设计用于预测工业数字孪生应用中的传感器输出。与传统方法不同,我们的模型利用复杂工业环境中**多传感器系统的序列特性**,通过多阶段优化实现更好的预测精度。 12 | 13 | --- 14 | 15 | **如果您觉得这个项目有帮助,请考虑给它一个 ⭐ star!您的支持帮助更多人发现这项工作,并激励项目持续发展。** 16 | 17 | --- 18 | 19 | ## 🌟 核心创新 20 | 21 | **使用 Transformer 进行序列传感器预测**:这个框架将 Transformer 架构应用于工业数字孪生中序列传感器输出预测问题的框架。该模型将多个传感器视为一个序列,捕获传感器之间的空间关系及其测量值的时间依赖性。 22 | 23 | ### 为什么这很重要 24 | 25 | 在复杂的工业系统(制造工厂、化工过程、发电等)中,传感器不是孤立运行的。它们的输出具有以下特征: 26 | - **空间相关性**:物理邻近性和工艺流程创建了依赖关系 27 | - **时间依赖性**:历史测量值影响当前和未来的读数 28 | - **层次结构**:一些传感器测量边界条件,而另一些测量内部状态 29 | 30 | 传统的机器学习方法独立对待传感器或使用简单的时间序列模型。我们基于 Transformer 的方法**捕获传感器相互关系的全部复杂性**。 31 | 32 | ## 🚀 功能特性 33 | 34 | ### 模型架构 35 | 36 | #### **StaticSensorTransformer (SST)** 37 | - **用途**:将边界条件传感器映射到目标传感器预测 38 | - **架构**:具有学习位置编码的传感器序列 Transformer 39 | - **创新点**:将固定传感器阵列视为序列(替代 NLP 中的词元序列) 40 | - **应用场景**:具有复杂传感器相互依赖关系的工业系统 41 | - **优势**: 42 | - 通过注意力机制捕获空间传感器关系 43 | - 快速训练和推理 44 | - 学习传感器之间的物理因果关系 45 | - 非常适合工业数字孪生应用 46 | 47 | ### 🆕 增强型残差提升训练系统 (v1.0) 48 | 49 | #### **Stage2 提升训练** 🚀 50 | - 在 SST 预测残差上训练第二阶段模型 51 | - 进一步优化预测以提高准确性 52 | - 可配置的架构和训练参数 53 | - 自动模型保存和版本控制 54 | 55 | #### **智能 Delta R² 阈值选择** 🎯 56 | - 计算每个信号的 Delta R² (R²_ensemble - R²_stage1) 57 | - 基于 Delta R² 阈值选择性地应用 Stage2 修正 58 | - 生成结合 SST + Stage2 的集成模型 59 | - 优化的性能/效率平衡 60 | - 仅对有显著改进的信号使用 Stage2 61 | 62 | #### **全面的推理对比** 📊 63 | - 比较集成模型与纯 SST 模型 64 | - 可视化所有输出信号的性能改进 65 | - 详细的逐信号指标分析(MAE、RMSE、R²) 66 | - CSV 导出包含预测值和 R² 分数 67 | - 交互式索引范围选择 68 | 69 | #### **全信号可视化** 📈 70 | - 每个输出信号的独立预测 vs 实际值对比 71 | - 动态布局适应信号数量 72 | - 每个信号显示 R² 分数 73 | - 轻松识别模型改进 74 | 75 | ### ⚡ 轻量化与边缘就绪架构 76 | 77 | #### **超轻量化 Transformer 设计** 78 | 尽管基于 Transformer 架构,我们的模型被设计为**超轻量化变体**,在最小化计算需求的同时保持良好性能: 79 | 80 | - **边缘设备优化**:在资源受限的硬件上训练和部署 81 | - **快速推理**:实时预测,延迟极低 82 | - **低内存占用**:适用于嵌入式系统的高效模型架构 83 | - **快速训练**:即使在有限算力下也能快速收敛 84 | 85 | #### **Digital Twin Anything:通用边缘部署** 🌐 86 | 87 | 我们的设计理念实现了**个性化的单体资产数字孪生**: 88 | 89 | - **单车数字孪生**:为每辆汽车建立专属模型 90 | - **单机监控**:为每台发动机建立个性化预测模型 91 | - **设备级定制**:任何在测试台架下有足够传感器数据的设备系统都可以拥有专属的轻量级数字孪生 92 | - **自动化边缘流程**:完整的训练和推理流程可部署在边缘设备上 93 | 94 | **愿景**:为**任何事物**创建自动化的轻量级数字孪生 - 从单个机器到整条生产线,全部运行在边缘硬件上并具备持续学习能力。 95 | 96 | #### **未来潜力:仿真模型代理** 🔬 97 | 98 | **面向计算效率的前瞻性应用展望**: 99 | 100 | 我们轻量化 Transformer 架构的特性开启了一个令人兴奋的未来可能性: 101 | - 将仿真中的每个网格区域视为虚拟"传感器" 102 | - 有潜力使用轻量级 Transformer 学习复杂的仿真行为 103 | - **可能以极低算力逆向构建昂贵的仿真模型**,计算成本有望降低数个数量级 104 | - 有望在保持高精度的同时实现实时仿真代理模型 105 | - 对 CFD、FEA 等计算密集型仿真具有应用前景 106 | 107 | 这一方法可能带来新的应用场景: 108 | - 设计迭代过程中的实时仿真 109 | - 普及高保真仿真的使用 110 | - 在边缘设备中嵌入复杂物理模型 111 | - 加速数字孪生开发周期 112 | 113 | *注:这代表了一个理论框架和未来研究方向,尚未在生产环境中得到充分验证。* 114 | 115 | ### 附加功能 116 | 117 | - ✅ **模块化设计**:易于扩展和定制 118 | - ✅ **全面的训练流程**:内置数据预处理、训练和评估 119 | - ✅ **交互式 Gradio 界面**:适用于所有训练阶段的用户友好型 Web 界面 120 | - ✅ **Jupyter Notebooks**:完整的教程和示例 121 | - ✅ **生产就绪**:可导出模型用于部署 122 | - ✅ **详尽的文档**:清晰的 API 文档和使用示例 123 | - ✅ **自动化模型管理**:智能模型保存和加载(含配置) 124 | 125 | ## 📊 使用场景 126 | 127 | 本框架非常适合: 128 | 129 | - **制造业数字孪生**:从传感器阵列预测设备状态 130 | - **化工过程监控**:建模反应器中的复杂传感器交互 131 | - **发电厂优化**:预测涡轮机和发电机状况 132 | - **HVAC 系统**:预测温度和压力分布 133 | - **预测性维护**:从传感器模式中早期检测异常 134 | - **质量控制**:从工艺传感器预测产品质量 135 | 136 | ## 🏗️ 架构概述 137 | ``` 138 | 联系作者以获得详细信息 139 | ``` 140 | 141 | ## 🔧 安装 142 | 143 | ### 使用 Google Colab 快速开始 144 | 145 | ```bash 146 | # 克隆仓库 147 | !git clone https://github.com/FTF1990/Industrial-digital-twin-by-transformer.git 148 | %cd Industrial-digital-twin-by-transformer 149 | 150 | # 安装依赖 151 | !pip install -r requirements.txt 152 | ``` 153 | 154 | ### 本地安装 155 | 156 | ```bash 157 | # 克隆仓库 158 | git clone https://github.com/FTF1990/Industrial-digital-twin-by-transformer.git 159 | cd Industrial-digital-twin-by-transformer 160 | 161 | # 创建虚拟环境(推荐) 162 | python -m venv venv 163 | source venv/bin/activate # Windows 系统: venv\Scripts\activate 164 | 165 | # 安装依赖 166 | pip install -r requirements.txt 167 | ``` 168 | 169 | ## 📚 快速入门 170 | 171 | ### 1. 准备数据 172 | 173 | 将您的 CSV 传感器数据文件放在 `data/raw/` 文件夹中。您的 CSV 应该具有: 174 | - 每行代表一个时间步 175 | - 每列代表一个传感器测量值 176 | - (可选)第一列可以是时间戳 177 | 178 | CSV 结构示例: 179 | ```csv 180 | timestamp,sensor_1,sensor_2,sensor_3,...,sensor_n 181 | 2025-01-01 00:00:00,23.5,101.3,45.2,...,78.9 182 | 2025-01-01 00:00:01,23.6,101.4,45.1,...,79.0 183 | ... 184 | ``` 185 | 186 | ### 2. 使用 Jupyter Notebook 训练 Stage1 模型(基础训练) 187 | 188 | 本节演示**基础 Stage1 (SST) 模型训练**,用于学习传感器预测建模的基础知识。 189 | 190 | **注意**:Notebook 提供了理解 SST 架构和基础训练过程的基础。如需完整的 Stage2 提升训练和集成模型生成功能,请使用增强型 Gradio 界面(第3节)。 191 | 192 | **可用的 Notebooks**: 193 | - `notebooks/Train and run model with demo data and your own data with gradio interface.ipynb` - 初学者快速入门教程 194 | - `notebooks/transformer_boost_Leap_final.ipynb` - 高级示例:在 LEAP 数据集上的完整 Stage1 + Stage2 训练(作者测试文件,注释为中文) 195 | 196 | **基础训练示例**(用于您自己的数据): 197 | 198 | ```python 199 | from models.static_transformer import StaticSensorTransformer 200 | from src.data_loader import SensorDataLoader 201 | from src.trainer import ModelTrainer 202 | 203 | # 加载数据 204 | data_loader = SensorDataLoader(data_path='data/raw/your_data.csv') 205 | 206 | # 配置信号 207 | boundary_signals = ['sensor_1', 'sensor_2', 'sensor_3'] # 输入 208 | target_signals = ['sensor_4', 'sensor_5'] # 要预测的输出 209 | 210 | # 准备数据 211 | data_splits = data_loader.prepare_data(boundary_signals, target_signals) 212 | 213 | # 创建和训练 Stage1 SST 模型 214 | model = StaticSensorTransformer( 215 | num_boundary_sensors=len(boundary_signals), 216 | num_target_sensors=len(target_signals) 217 | ) 218 | 219 | trainer = ModelTrainer(model, device='cuda') 220 | history = trainer.train(train_loader, val_loader) 221 | 222 | # 保存训练好的模型 223 | torch.save(model.state_dict(), 'saved_models/my_sst_model.pth') 224 | ``` 225 | 226 | **在 Stage1 中您将学到**: 227 | - 加载和预处理传感器数据 228 | - 配置边界传感器和目标传感器 229 | - 训练静态传感器 Transformer (SST) 230 | - 基础模型评估和预测 231 | 232 | **如需完整功能**(Stage2 提升 + 集成模型),请继续第3节。 233 | 234 | ### 3. 使用增强型 Gradio 界面(完整 Stage1 + Stage2 训练) 235 | 236 | **Gradio UI 演示视频**:即将推出 237 | 238 | #### **Jupyter Notebook 入门教程** 239 | 240 | 有关分步指南,请参阅: 241 | - `notebooks/Train and run model with demo data and your own data with gradio interface.ipynb` 242 | 243 | 该 notebook 演示了: 244 | - 从 Kaggle 下载演示数据(power-gen-machine 数据集) 245 | - 设置 Gradio 界面 246 | - 使用演示数据或您自己的自定义数据进行训练 247 | 248 | 只需按照 notebook 步骤操作即可开始使用完整工作流程。 249 | 250 | #### **完整工作流程** 251 | 252 | 增强型界面提供**完整的端到端工作流程**: 253 | - 📊 **Tab 1: 数据加载** - 刷新并选择演示数据(`data.csv`)或上传您自己的 CSV 254 | - 🎯 **Tab 2: 信号配置与 Stage1 训练** - 刷新,加载信号配置,选择参数,训练基础 SST 模型 255 | - 🔬 **Tab 3: 残差提取** - 从 Stage1 模型中提取和分析预测误差 256 | - 🚀 **Tab 4: Stage2 提升训练** - 在残差上训练第二阶段模型进行误差修正 257 | - 🎯 **Tab 5: 集成模型生成** - 基于智能 Delta R² 阈值的模型组合 258 | - 📊 **Tab 6: 推理对比** - 比较 Stage1 SST vs. 集成模型性能并可视化 259 | - 💾 **Tab 7: 导出** - 自动模型保存(含完整配置) 260 | 261 | **这是体验框架完整功能的推荐方式**,包括: 262 | - 使用演示数据的自动化多阶段训练流程 263 | - 智能的逐信号 Stage2 选择 264 | - 全面的性能指标和可视化 265 | - 生产就绪的集成模型生成 266 | 267 | **使用您自己的数据**: 268 | 只需将您的 CSV 文件放在 `data/` 文件夹中,在 Tab 1 中刷新并选择您的文件。确保您的 CSV 遵循与演示数据相同的格式(时间步作为行,传感器作为列)。然后在 Tab 2 中配置您自己的输入/输出信号。 269 | 270 | **快速入门指南**:参见 `docs/QUICKSTART.md` 获取 5 分钟教程 271 | 272 | ## 📖 文档 273 | ``` 274 | 联系作者以获得详细信息 275 | ``` 276 | ## 🎯 性能 277 | 278 | ### 基准测试结果 279 | 280 | #### 🏭 工业旋转机械案例研究 281 | 282 | **数据集**:[发电机械传感器数据](https://www.kaggle.com/datasets/tianffan/power-gen-machine) 283 | 284 | **应用领域**:真实世界的尖端发电旋转机械 285 | - 复杂工业设备的多传感器系统监测 286 | - 生产环境的高频操作数据 287 | - 工业数字孪生应用的代表性案例 288 | 289 | **数据集特征**: 290 | - **来源**:真实工业设备传感器阵列 291 | - **复杂度**:高性能旋转系统中的多传感器相互依赖关系 292 | - **规模**:覆盖关键参数的完整传感器套件 293 | - **质量**:生产级传感器测量数据 294 | 295 | **性能结果**(测试集): 296 | 297 | | 指标 | Stage1 (SST) | Stage1+Stage2 集成 | 改进幅度 | 298 | |------|--------------|---------------------|----------| 299 | | **R²** | 0.8101 | **0.9014** | +11.3% | 300 | | **MAE** | 1.56 | **1.24** | -20.2% | 301 | | **RMSE** | 3.89 | **3.57** | -8.3% | 302 | 303 | **配置**: 304 | - **数据集**:89 个目标信号,21.7 万样本 305 | - **Stage1**:50 epochs,默认超参数 306 | - **Stage2**:选择性增强 36/89 个信号(Delta R² 阈值:0.03) 307 | - **硬件**:单卡 NVIDIA A100 GPU 308 | - **训练**:无数据增强,无特殊调参 309 | 310 | **训练推荐**(基于实践经验): 311 | 312 | 以上结果使用默认超参数获得。然而,通过以下参数调优策略**通常可以获得更好的性能**: 313 | - 📉 **更低的学习率**:较小的学习率(例如 0.00003 vs. 默认 0.0001)通常能带来更好的收敛 314 | - ⏱️ **更高的调度器耐心值**:增加学习率调度器耐心值(例如 8 vs. 默认 3)允许更稳定的训练 315 | - 📊 **更高的衰减因子**:更高的学习率衰减因子可减少激进的学习率下降 316 | - 🔄 **更多的训练轮数**:使用上述设置训练更多轮次通常能提高最终性能 317 | 318 | 这些调整有助于实现更平滑的收敛和更好的泛化能力,特别是对于复杂的工业传感器系统。 319 | 320 | **Stage2 智能选择**: 321 | - **36 个信号** 选择 Stage2 校正(观察到显著改进) 322 | - **53 个信号** 保持 Stage1 预测(已表现良好) 323 | - 自适应策略平衡性能提升与计算效率 324 | 325 | **信号改进示例**(Stage1 → 集成): 326 | - 振动传感器:R² -0.13 → 0.26,-0.55 → 0.47(挑战性信号) 327 | - 温度传感器:R² 0.35 → 0.59,0.68 → 0.93(中等改进) 328 | - 压力传感器:R² 0.08 → 0.47,0.42 → 0.63(显著提升) 329 | 330 |
331 | 📊 点击查看完整效果演示图(所有信号预测效果可视化) 332 | 333 |
334 | 335 | 下图展示了经过 Stage1 + Stage2 Boost 后,所有 89 个目标信号在测试集上的预测效果: 336 | 337 | ![所有信号预测效果演示](saved_models/result_demo.webp) 338 | 339 | **图片说明**: 340 | - 蓝色线条:真实值(Ground Truth) 341 | - 橙色线条:模型预测值(Prediction) 342 | - 每个子图代表一个传感器信号的预测效果 343 | - 可以看到大部分信号的预测曲线与真实值高度吻合 344 | 345 |
346 | 347 | **实用见解**: 348 | - ✅ **强劲的开箱即用基线**:Stage1 使用默认设置达到 R² = 0.81 349 | - ✅ **按需精炼**:Stage2 增强为挑战性信号提供针对性改进 350 | - ✅ **真实传感器数据**:在生产设备测量数据上展示有效性 351 | - ✅ **高效训练**:两个阶段都能在标准硬件上快速训练 352 | 353 | **训练模型**:[Kaggle Models 提供](https://www.kaggle.com/models/tianffan/industrial-digital-twin-by-transformer) 354 | 355 | **模型文件位置**: 356 | - **Stage1 模型**:三个文件(`.pth`、`_config.json`、`_scaler.pkl`)位于 `saved_models/` 目录下 357 | - **Stage2 模型**:位于 `saved_models/stage2_boost/` 目录下 358 | 359 | **关于基准测试的说明**: 360 | 这些结果作为特定数据集上的参考示例提供。本项目优先考虑**实用性和易部署性**,而非竞争性基准分数。性能将根据您的具体工业应用、传感器特性和数据质量而变化。我们鼓励用户在自己的应用场景中评估本框架。 361 | 362 | --- 363 | 364 | #### 🌍 大气物理仿真基准测试 365 | 366 | **数据集**:LEAP 大气物理仿真数据集 367 | 368 | **性能结果**: 369 | - **硬件**:单卡 NVIDIA A100 GPU(Google Colab) 370 | - **信号**:164 个输出信号(不包括 ptend_q 系列) 371 | - **Stage1 (SST)**:R² ≈ 0.56 372 | - **Stage2 Boost**:R² ≈ 0.58 373 | - **训练**:未应用数据增强 374 | 375 | **测试 Notebook**:参见 `notebooks/transformer_boost_Leap_final.ipynb`(作者测试文件,注释为中文) 376 | 377 | --- 378 | 379 | ### 📌 性能说明 380 | 381 | **变异因素**: 382 | 结果可能因以下因素而变化: 383 | - 数据集特征(传感器相关模式、噪声水平、信号复杂度) 384 | - 物理系统属性(传感器空间关系、时间动态) 385 | - 模型配置(架构大小、训练参数) 386 | - 应用领域(制造业、能源、化工过程等) 387 | 388 | **观察到的最佳结果**: 389 | - **高度相关的传感器系统**:R² > 0.80(如旋转机械) 390 | - **复杂多物理系统**:R² 0.55-0.65(如大气仿真) 391 | 392 | 当传感器输出具有**明确的物理相互依赖关系和空间关系**时,该框架表现出特别强的性能,这与其核心设计理念一致。 393 | 394 | --- 395 | 396 | ### 🤝 欢迎社区贡献 397 | 398 | 我们热烈鼓励用户分享基准测试结果!如果您已将此框架应用于您的领域,请贡献: 399 | - 您工业应用中的**脱敏数据集** 400 | - **性能指标**(R²、MAE、RMSE 等)和可视化 401 | - **应用案例描述**和领域见解 402 | 403 | 您的贡献有助于建立对框架在不同工业场景下能力的理解。请开启 [issue](https://github.com/FTF1990/Industrial-digital-twin-by-transformer/issues) 或提交 pull request! 404 | 405 | ## 🤝 贡献 406 | 407 | 感谢您对本项目的关注!我们非常重视社区的参与和反馈。 408 | 409 | **支持本项目的方式**: 410 | - ⭐ **给我们一个 star!** 这有助于更多人发现这项工作,并激励项目持续发展 411 | - 🐛 **Bug 报告或建议?** 欢迎开启 [issue](https://github.com/FTF1990/Industrial-digital-twin-by-transformer/issues) 412 | - 💬 **想法或问题?** 欢迎在 issue 或评论中讨论 413 | - 📊 **性能结果?** 分享您的脱敏数据和结果 - 这些特别有价值! 414 | 415 | **当前状态**:由于时间限制,作者可能无法立即审查和合并外部的 Pull Request。衷心感谢您的理解。 416 | 417 | **对于重大更改**:恳请您先开启 issue 讨论您的提议,然后再投入大量精力。 418 | 419 | ⏱️ **回复时间**:作者会在时间允许的情况下回复。非常感谢您的耐心。 420 | 421 | 非常感谢您的理解、耐心和贡献!🙏 422 | 423 | ### 开发设置 424 | 425 | ```bash 426 | # 克隆仓库 427 | git clone https://github.com/FTF1990/Industrial-digital-twin-by-transformer.git 428 | cd Industrial-digital-twin-by-transformer 429 | 430 | # 以开发模式安装 431 | pip install -e . 432 | 433 | # 运行测试(如果可用) 434 | python -m pytest tests/ 435 | ``` 436 | 437 | ## 📄 许可证 438 | 439 | 本项目根据 MIT 许可证授权 - 详情请参阅 [LICENSE](LICENSE) 文件。 440 | 441 | ## 🙏 致谢 442 | 443 | - Transformer 架构基于 "Attention Is All You Need"(Vaswani et al., 2017) 444 | - 灵感来自工业自动化中的数字孪生应用 445 | - 使用 PyTorch、Gradio 和出色的开源社区构建 446 | 447 | ## 📞 联系方式 448 | 449 | 如有问题、议题或合作: 450 | - **GitHub Issues**:[创建 issue](https://github.com/FTF1990/Industrial-digital-twin-by-transformer/issues) 451 | - **电子邮件**:shvichenko11@gmail.com 452 | 453 | ## 🔗 引用 454 | 455 | 如果您在研究中使用此工作,请引用: 456 | 457 | ```bibtex 458 | @software{industrial_digital_twin_transformer, 459 | author = {FTF1990}, 460 | title = {Industrial Digital Twin by Transformer}, 461 | year = {2025}, 462 | url = {https://github.com/FTF1990/Industrial-digital-twin-by-transformer} 463 | } 464 | ``` 465 | 466 | ## 🗺️ 路线图 467 | 468 | ### v1.0(当前)✅ 469 | - [x] Stage2 提升训练系统 470 | - [x] 智能 R² 阈值选择 471 | - [x] 集成模型生成 472 | - [x] 推理对比工具 473 | - [x] 增强型 Gradio 界面 474 | 475 | ### v2.0(即将推出)🚀 476 | 477 | #### **Stage3 时序震荡增强系统** 🕐 478 | 下一代演进目标:时序震荡信号重构 479 | 480 | - **Stage3 时序震荡特征提取**: 481 | - 针对具有时序震荡特性的信号(高频脉动、振动等) 482 | - 当前的空间序列 Transformer 对时序高频震荡信号只能提取均值特征,无法还原时序震荡特征 483 | - 采用时序 ML 模型或时序 Transformer 进行纯时序特征提取 484 | - 增强并还原信号本身固有的时序震荡特征 485 | 486 | - **最终残差未来预测**: 487 | - 经过 Stage1 + Stage2 + Stage3 后,最终残差基本已不包含空间特征 488 | - 可对最终残差进行纯时序预测,实现未来时间步预测 489 | - 适用于需要前向预测能力的应用场景 490 | 491 | - **信号关联掩码编辑功能**(计划推出): 492 | - 最大限度利用 Transformer 的灵活性,编辑输入输出信号关联掩码 493 | - 运用真实工程经验对不直接关联的要素之间施加掩码屏蔽 494 | - 更好地还原真实系统行为,融入领域专家知识 495 | - 通过专家引导的特征关系提高模型准确性 496 | 497 | - **完整的空间-时间分解架构**: 498 | - **Stage1 (SST)**:空间传感器关系和跨传感器依赖性 499 | - **Stage2 (Boost)**:空间残差修正和次级空间模式 500 | - **Stage3 (Temporal)**:纯时序震荡特征和时间序列动态 501 | - **最终目标**:将空间和时间特征完全剥离并分层预测,除不可预测的噪音特征外,捕捉所有可预测模式,实现场景泛用化的数字孪生 502 | 503 | - **分层特征提取哲学**: 504 | - 第一层:主要空间传感器相关性(SST) 505 | - 第二层:残差空间模式(Stage2 提升) 506 | - 第三层:时序震荡特征(Stage3 时序) 507 | - 最终残差:不可约随机噪声 + 可选的未来预测 508 | 509 | 此设计旨在通过系统性地分解和捕获不同领域的所有可预测特征,实现**通用数字孪生建模**。 510 | --- 511 | 512 | **为工业 AI 社区精心打造 ❤️** 513 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Industrial Digital Twin by Transformer 2 | 3 | **[English](README.md)** | **[中文](README_CN.md)** 4 | 5 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 6 | [![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/) 7 | [![PyTorch](https://img.shields.io/badge/PyTorch-2.0+-ee4c2c.svg)](https://pytorch.org/) 8 | 9 | > **An innovative Transformer-based framework for industrial digital twin modeling using sequential sensor outputs from complex systems with advanced residual boost training.** 10 | 11 | This project introduces Transformer architectures and residual boost training methodology specifically designed for predicting sensor outputs in industrial digital twin applications. Unlike traditional approaches, our models leverage the **sequential nature of multi-sensor systems** in complex industrial environments to achieve improved prediction accuracy through multi-stage refinement. 12 | 13 | --- 14 | 15 | **If you find this project helpful, please consider giving it a ⭐ star! Your support helps others discover this work and motivates continued development.** 16 | 17 | --- 18 | 19 | ## 🌟 Key Innovation 20 | 21 | **Sequential Sensor Prediction using Transformers**: This framework applies Transformer architecture to the problem of predicting sequential sensor outputs in industrial digital twins. The model treats multiple sensors as a sequence, capturing both spatial relationships between sensors and temporal dependencies in their measurements. 22 | 23 | ### Why This Matters 24 | 25 | In complex industrial systems (manufacturing plants, chemical processes, power generation, etc.), sensors don't operate in isolation. Their outputs are: 26 | - **Spatially correlated**: Physical proximity and process flow create dependencies 27 | - **Temporally dependent**: Historical measurements influence current and future readings 28 | - **Hierarchically structured**: Some sensors measure boundary conditions while others measure internal states 29 | 30 | Traditional machine learning approaches treat sensors independently or use simple time-series models. Our Transformer-based approach **captures the full complexity of sensor interrelationships**. 31 | 32 | ## 🚀 Features 33 | 34 | ### Model Architecture 35 | 36 | #### **StaticSensorTransformer (SST)** 37 | - **Purpose**: Maps boundary condition sensors to target sensor predictions 38 | - **Architecture**: Sensor sequence Transformer with learned positional encodings 39 | - **Innovation**: Treats fixed sensor arrays as sequences (replacing NLP token sequences) 40 | - **Use Case**: Industrial systems with complex sensor inter-dependencies 41 | - **Advantages**: 42 | - Captures spatial sensor relationships through attention mechanism 43 | - Fast training and inference 44 | - Learns physical causality between sensors 45 | - Excellent for industrial digital twin applications 46 | 47 | ### 🆕 Enhanced Residual Boost Training System (v1.0) 48 | 49 | #### **Stage2 Boost Training** 🚀 50 | - Train secondary models on residuals from SST predictions 51 | - Further refine predictions for improved accuracy 52 | - Configurable architecture and training parameters 53 | - Automatic model saving and versioning 54 | 55 | #### **Intelligent Delta R² Threshold Selection** 🎯 56 | - Calculate Delta R² (R²_ensemble - R²_stage1) for each signal 57 | - Selectively apply Stage2 corrections based on Delta R² threshold 58 | - Generate ensemble models combining SST + Stage2 59 | - Optimized performance/efficiency balance 60 | - Only use Stage2 for signals where it provides significant improvement 61 | 62 | #### **Comprehensive Inference Comparison** 📊 63 | - Compare ensemble model vs. pure SST model 64 | - Visualize performance improvements for all output signals 65 | - Detailed per-signal metrics analysis (MAE, RMSE, R²) 66 | - CSV export with predictions and R² scores 67 | - Interactive index range selection 68 | 69 | #### **All-Signal Visualization** 📈 70 | - Individual prediction vs actual comparison for every output signal 71 | - Dynamic layout adapting to number of signals 72 | - R² scores displayed for each signal 73 | - Easy identification of model improvements 74 | 75 | ### ⚡ Lightweight & Edge-Ready Architecture 76 | 77 | #### **Ultra-Lightweight Transformer Design** 78 | Despite being Transformer-based, our models are designed as **ultra-lightweight variants** that maintain exceptional performance while minimizing computational requirements: 79 | 80 | - **Edge Device Optimized**: Train and deploy on resource-constrained hardware 81 | - **Fast Inference**: Real-time predictions with minimal latency 82 | - **Low Memory Footprint**: Efficient model architecture for embedded systems 83 | - **Rapid Training**: Quick model convergence even on limited compute 84 | 85 | #### **Digital Twin Anything: Universal Edge Deployment** 🌐 86 | 87 | Our design philosophy enables **personalized digital twins for individual assets**: 88 | 89 | - **Per-Vehicle Digital Twins**: Dedicated models for each car or vehicle 90 | - **Per-Engine Monitoring**: Individual engine-specific predictive models 91 | - **Device-Level Customization**: Any system with sufficient testbench sensor data can have its own lightweight digital twin 92 | - **Automated Edge Pipeline**: Complete training and inference pipeline deployable on edge devices 93 | 94 | **Vision**: Create an automated, lightweight digital twin for **anything** - from individual machines to entire production lines, all running on edge hardware with continuous learning capabilities. 95 | 96 | #### **Future Potential: Simulation Model Surrogate** 🔬 97 | 98 | **Envisioned application for computational efficiency**: 99 | 100 | The lightweight nature of our Transformer architecture opens an exciting future possibility: 101 | - Treat each simulation mesh region as a virtual "sensor" 102 | - Potentially use lightweight Transformers to learn complex simulation behaviors 103 | - **Could reverse-engineer expensive simulations** with orders of magnitude less computational cost 104 | - May maintain high accuracy while enabling real-time simulation surrogate models 105 | - Promising for CFD, FEA, and other computationally intensive simulations 106 | 107 | This approach could unlock new possibilities: 108 | - Real-time simulation during design iterations 109 | - Democratizing access to high-fidelity simulations 110 | - Embedding complex physics models in edge devices 111 | - Accelerating digital twin development cycles 112 | 113 | *Note: This represents a theoretical framework and future research direction that has not yet been fully validated in production environments.* 114 | 115 | ### Additional Features 116 | 117 | - ✅ **Modular Design**: Easy to extend and customize 118 | - ✅ **Comprehensive Training Pipeline**: Built-in data preprocessing, training, and evaluation 119 | - ✅ **Interactive Gradio Interface**: User-friendly web interface for all training stages 120 | - ✅ **Jupyter Notebooks**: Complete tutorials and examples 121 | - ✅ **Production Ready**: Exportable models for deployment 122 | - ✅ **Extensive Documentation**: Clear API documentation and usage examples 123 | - ✅ **Automated Model Management**: Intelligent model saving and loading with configurations 124 | 125 | ## 📊 Use Cases 126 | 127 | This framework is ideal for: 128 | 129 | - **Manufacturing Digital Twins**: Predict equipment states from sensor arrays 130 | - **Chemical Process Monitoring**: Model complex sensor interactions in reactors 131 | - **Power Plant Optimization**: Forecast turbine and generator conditions 132 | - **HVAC Systems**: Predict temperature and pressure distributions 133 | - **Predictive Maintenance**: Early detection of anomalies from sensor patterns 134 | - **Quality Control**: Predict product quality from process sensors 135 | 136 | ## 🏗️ Architecture Overview 137 | ``` 138 | Please contact the author for detail information 139 | ``` 140 | 141 | ## 🔧 Installation 142 | 143 | ### Quick Start with Google Colab 144 | 145 | ```bash 146 | # Clone the repository 147 | !git clone https://github.com/FTF1990/Industrial-digital-twin-by-transformer.git 148 | %cd Industrial-digital-twin-by-transformer 149 | 150 | # Install dependencies 151 | !pip install -r requirements.txt 152 | ``` 153 | 154 | ### Local Installation 155 | 156 | ```bash 157 | # Clone the repository 158 | git clone https://github.com/FTF1990/Industrial-digital-twin-by-transformer.git 159 | cd Industrial-digital-twin-by-transformer 160 | 161 | # Create virtual environment (recommended) 162 | python -m venv venv 163 | source venv/bin/activate # On Windows: venv\Scripts\activate 164 | 165 | # Install dependencies 166 | pip install -r requirements.txt 167 | ``` 168 | 169 | ## 📚 Quick Start 170 | 171 | ### 1. Prepare Your Data 172 | 173 | Place your CSV sensor data file in the `data/raw/` folder. Your CSV should have: 174 | - Each row represents a timestep 175 | - Each column represents a sensor measurement 176 | - (Optional) First column can be a timestamp 177 | 178 | Example CSV structure: 179 | ```csv 180 | timestamp,sensor_1,sensor_2,sensor_3,...,sensor_n 181 | 2025-01-01 00:00:00,23.5,101.3,45.2,...,78.9 182 | 2025-01-01 00:00:01,23.6,101.4,45.1,...,79.0 183 | ... 184 | ``` 185 | 186 | ### 2. Train Stage1 Model Using Jupyter Notebook (Basic Training) 187 | 188 | This section demonstrates **basic Stage1 (SST) model training** for learning sensor prediction fundamentals. 189 | 190 | **Note**: The notebook provides a foundation for understanding the SST architecture and basic training process. For the complete Stage2 Boost training and ensemble model generation, please use the enhanced Gradio interface (Section 3). 191 | 192 | **Available Notebooks**: 193 | - `notebooks/Train and run model with demo data and your own data with gradio interface.ipynb` - Quick start tutorial for beginners 194 | - `notebooks/transformer_boost_Leap_final.ipynb` - Advanced example: Complete Stage1 + Stage2 training on LEAP dataset (Author's testing file, comments in Chinese) 195 | 196 | **Basic Training Example** (for your own data): 197 | 198 | ```python 199 | from models.static_transformer import StaticSensorTransformer 200 | from src.data_loader import SensorDataLoader 201 | from src.trainer import ModelTrainer 202 | 203 | # Load data 204 | data_loader = SensorDataLoader(data_path='data/raw/your_data.csv') 205 | 206 | # Configure signals 207 | boundary_signals = ['sensor_1', 'sensor_2', 'sensor_3'] # Inputs 208 | target_signals = ['sensor_4', 'sensor_5'] # Outputs to predict 209 | 210 | # Prepare data 211 | data_splits = data_loader.prepare_data(boundary_signals, target_signals) 212 | 213 | # Create and train Stage1 SST model 214 | model = StaticSensorTransformer( 215 | num_boundary_sensors=len(boundary_signals), 216 | num_target_sensors=len(target_signals) 217 | ) 218 | 219 | trainer = ModelTrainer(model, device='cuda') 220 | history = trainer.train(train_loader, val_loader) 221 | 222 | # Save trained model 223 | torch.save(model.state_dict(), 'saved_models/my_sst_model.pth') 224 | ``` 225 | 226 | **What you'll learn in Stage1**: 227 | - Loading and preprocessing sensor data 228 | - Configuring boundary and target sensors 229 | - Training the Static Sensor Transformer (SST) 230 | - Basic model evaluation and prediction 231 | 232 | **For complete functionality** (Stage2 Boost + Ensemble Models), proceed to Section 3. 233 | 234 | ### 3. Use Enhanced Gradio Interface (Complete Stage1 + Stage2 Training) 235 | 236 | **Gradio UI Demo Video**: Coming soon 237 | 238 | #### **Getting Started with Jupyter Notebook Tutorial** 239 | 240 | For a step-by-step guide, see: 241 | - `notebooks/Train and run model with demo data and your own data with gradio interface.ipynb` 242 | 243 | This notebook demonstrates: 244 | - Downloading demo data from Kaggle (power-gen-machine dataset) 245 | - Setting up the Gradio interface 246 | - Training with demo data or your own custom data 247 | 248 | Simply follow the notebook steps to get started with the complete workflow. 249 | 250 | #### **Complete Workflow** 251 | 252 | The enhanced interface provides the **complete end-to-end workflow**: 253 | - 📊 **Tab 1: Data Loading** - Refresh and select demo data (`data.csv`) or upload your own CSV 254 | - 🎯 **Tab 2: Signal Configuration & Stage1 Training** - Refresh, load signal configuration, select parameters, and train base SST models 255 | - 🔬 **Tab 3: Residual Extraction** - Extract and analyze prediction errors from Stage1 models 256 | - 🚀 **Tab 4: Stage2 Boost Training** - Train secondary models on residuals for error correction 257 | - 🎯 **Tab 5: Ensemble Model Generation** - Intelligent Delta R² threshold-based model combination 258 | - 📊 **Tab 6: Inference Comparison** - Compare Stage1 SST vs. ensemble model performance with visualizations 259 | - 💾 **Tab 7: Export** - Automatic model saving with complete configurations 260 | 261 | **This is the recommended way to experience the full capabilities of the framework**, including: 262 | - Automated multi-stage training pipeline using demo data 263 | - Intelligent signal-wise Stage2 selection 264 | - Comprehensive performance metrics and visualizations 265 | - Production-ready ensemble model generation 266 | 267 | **Using Your Own Data**: 268 | Simply place your CSV file in the `data/` folder, refresh in Tab 1, and select your file. Ensure your CSV follows the same format as the demo data (timesteps as rows, sensors as columns). Then configure your own input/output signals in Tab 2. 269 | 270 | **Quick Start Guide**: See `docs/QUICKSTART.md` for a 5-minute tutorial 271 | 272 | ## 📖 Documentation 273 | 274 | ``` 275 | Please contact the author for detail information 276 | ``` 277 | 278 | ## 🎯 Performance 279 | 280 | ### Benchmark Results 281 | 282 | #### 🏭 Industrial Rotating Machinery Case Study 283 | 284 | **Dataset**: [Power Generation Machine Sensor Data](https://www.kaggle.com/datasets/tianffan/power-gen-machine) 285 | 286 | **Application Domain**: Real-world advanced rotating machinery for power generation 287 | - Multi-sensor system monitoring for complex industrial equipment 288 | - High-frequency operational data from production environment 289 | - Representative of industrial digital twin applications 290 | 291 | **Dataset Characteristics**: 292 | - **Source**: Real industrial equipment sensor array 293 | - **Complexity**: Multi-sensor interdependencies in high-performance rotating systems 294 | - **Scale**: Full operational sensor suite covering critical parameters 295 | - **Quality**: Production-grade sensor measurements 296 | 297 | **Performance Results** (Test Set): 298 | 299 | | Metric | Stage1 (SST) | Stage1+Stage2 Ensemble | Improvement | 300 | |--------|--------------|------------------------|-------------| 301 | | **R²** | 0.8101 | **0.9014** | +11.3% | 302 | | **MAE** | 1.56 | **1.24** | -20.2% | 303 | | **RMSE** | 3.89 | **3.57** | -8.3% | 304 | 305 | **Configuration**: 306 | - **Dataset**: 89 target signals, 217K samples 307 | - **Stage1**: 50 epochs, default hyperparameters 308 | - **Stage2**: Selective boost on 36/89 signals (Delta R² threshold: 0.03) 309 | - **Hardware**: Single NVIDIA A100 GPU 310 | - **Training**: No data augmentation, no special tuning 311 | 312 | **Training Recommendations** (Based on Practical Experience): 313 | 314 | The above results were achieved with default hyperparameters. However, **better performance can typically be obtained** with the following parameter tuning strategy: 315 | - 📉 **Lower learning rate**: Smaller learning rates (e.g., 0.00003 vs. default 0.0001) often lead to better convergence 316 | - ⏱️ **Higher scheduler patience**: Increased learning rate scheduler patience (e.g., 8 vs. default 3) allows more stable training 317 | - 📊 **Higher decay factor**: Higher learning rate decay factors reduce aggressive learning rate reductions 318 | - 🔄 **More epochs**: Training for more epochs with the above settings generally improves final performance 319 | 320 | These adjustments help achieve smoother convergence and better generalization, especially for complex industrial sensor systems. 321 | 322 | **Stage2 Intelligent Selection**: 323 | - **36 signals** selected for Stage2 correction (significant improvement observed) 324 | - **53 signals** kept Stage1-only predictions (already performing well) 325 | - Adaptive strategy balances performance gains with computational efficiency 326 | 327 | **Example Signal Improvements** (Stage1 → Ensemble): 328 | - Vibration sensors: R² -0.13 → 0.26, -0.55 → 0.47 (challenging signals) 329 | - Temperature sensors: R² 0.35 → 0.59, 0.68 → 0.93 (moderate improvements) 330 | - Pressure sensors: R² 0.08 → 0.47, 0.42 → 0.63 (significant gains) 331 | 332 |
333 | 📊 Click to View Full Results Visualization (All Signals Prediction Performance) 334 | 335 |
336 | 337 | The following image shows the prediction performance of all 89 target signals on the test set after Stage1 + Stage2 Boost: 338 | 339 | ![All Signals Prediction Results Demo](saved_models/result_demo.webp) 340 | 341 | **Figure Description**: 342 | - Blue line: Ground Truth 343 | - Orange line: Model Prediction 344 | - Each subplot represents the prediction performance of one sensor signal 345 | - Most signals show predictions closely matching ground truth values 346 | 347 |
348 | 349 | **Practical Insights**: 350 | - ✅ **Strong out-of-box baseline**: Stage1 achieves R² = 0.81 with default settings 351 | - ✅ **Refinement when needed**: Stage2 boost provides targeted improvements for challenging signals 352 | - ✅ **Real-world sensor data**: Demonstrates effectiveness on production equipment measurements 353 | - ✅ **Efficient training**: Both stages train quickly on standard hardware 354 | 355 | **Trained Models**: [Available on Kaggle Models](https://www.kaggle.com/models/tianffan/industrial-digital-twin-by-transformer) 356 | 357 | **Model File Locations**: 358 | - **Stage1 Models**: Three files (`.pth`, `_config.json`, `_scaler.pkl`) are located in `saved_models/` 359 | - **Stage2 Models**: Located in `saved_models/stage2_boost/` 360 | 361 | **Note on Benchmarks**: 362 | These results are provided as reference examples on specific datasets. This project prioritizes **practical applicability and ease of deployment** over competitive benchmark scores. Performance will vary based on your specific industrial application, sensor characteristics, and data quality. We encourage users to evaluate the framework on their own use cases. 363 | 364 | --- 365 | 366 | #### 🌍 Atmospheric Physics Simulation Benchmark 367 | 368 | **Dataset**: LEAP atmospheric physics simulation dataset 369 | 370 | **Performance Results**: 371 | - **Hardware**: Single NVIDIA A100 GPU (Google Colab) 372 | - **Signals**: 164 output signals (excluding ptend_q family) 373 | - **Stage1 (SST)**: R² ≈ 0.56 374 | - **Stage2 Boost**: R² ≈ 0.58 375 | - **Training**: No data augmentation applied 376 | 377 | **Testing Notebook**: See `notebooks/transformer_boost_Leap_final.ipynb` (Author's testing file with comments in Chinese) 378 | 379 | --- 380 | 381 | ### 📌 Performance Notes 382 | 383 | **Variability Factors**: 384 | Results may vary based on: 385 | - Dataset characteristics (sensor correlation patterns, noise levels, signal complexity) 386 | - Physical system properties (sensor spatial relationships, temporal dynamics) 387 | - Model configuration (architecture size, training parameters) 388 | - Application domain (manufacturing, energy, chemical processes, etc.) 389 | 390 | **Best Results Observed**: 391 | - **Highly correlated sensor systems**: R² > 0.80 (e.g., rotating machinery) 392 | - **Complex multi-physics systems**: R² 0.55-0.65 (e.g., atmospheric simulation) 393 | 394 | The framework shows particularly strong performance when sensor outputs have **clear physical interdependencies and spatial relationships**, which aligns with its core design philosophy. 395 | 396 | --- 397 | 398 | ### 🤝 Community Contributions Welcome 399 | 400 | We warmly encourage users to share their benchmark results! If you have applied this framework to your domain, please contribute: 401 | - **Anonymized/desensitized datasets** from your industrial applications 402 | - **Performance metrics** (R², MAE, RMSE, etc.) and visualizations 403 | - **Use case descriptions** and domain insights 404 | 405 | Your contributions help build understanding of the framework's capabilities across diverse industrial scenarios. Please open an [issue](https://github.com/FTF1990/Industrial-digital-twin-by-transformer/issues) or submit a pull request! 406 | 407 | ## 🤝 Contributing 408 | 409 | Thank you for your interest in this project! We truly value community engagement and feedback. 410 | 411 | **Ways to Support This Project**: 412 | - ⭐ **Give us a star!** It helps others discover this work and motivates continued development 413 | - 🐛 **Bug reports or suggestions?** Please feel free to open an [issue](https://github.com/FTF1990/Industrial-digital-twin-by-transformer/issues) 414 | - 💬 **Ideas or questions?** We welcome discussions in issues or comments 415 | - 📊 **Performance results?** Share your anonymized data and results - these are especially valuable! 416 | 417 | **Current Status**: Due to time constraints, the author may not be able to immediately review and merge external pull requests. We sincerely appreciate your understanding. 418 | 419 | **For major changes**: We kindly ask that you open an issue first to discuss your proposed changes before investing significant effort. 420 | 421 | ⏱️ **Response time**: The author will respond as time permits. Your patience is greatly appreciated. 422 | 423 | Your understanding, patience, and contributions are greatly appreciated! 🙏 424 | 425 | ### Development Setup 426 | 427 | ```bash 428 | # Clone repository 429 | git clone https://github.com/FTF1990/Industrial-digital-twin-by-transformer.git 430 | cd Industrial-digital-twin-by-transformer 431 | 432 | # Install in development mode 433 | pip install -e . 434 | 435 | # Run tests (if available) 436 | python -m pytest tests/ 437 | ``` 438 | 439 | ## 📄 License 440 | 441 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. 442 | 443 | ## 🙏 Acknowledgments 444 | 445 | - Transformer architecture based on "Attention Is All You Need" (Vaswani et al., 2017) 446 | - Inspired by digital twin applications in industrial automation 447 | - Built with PyTorch, Gradio, and the amazing open-source community 448 | 449 | ## 📞 Contact 450 | 451 | For questions, issues, or collaborations: 452 | - **GitHub Issues**: [Create an issue](https://github.com/FTF1990/Industrial-digital-twin-by-transformer/issues) 453 | - **Email**: shvichenko11@gmail.com 454 | 455 | ## 🔗 Citation 456 | 457 | If you use this work in your research, please cite: 458 | 459 | ```bibtex 460 | @software{industrial_digital_twin_transformer, 461 | author = {FTF1990}, 462 | title = {Industrial Digital Twin by Transformer}, 463 | year = {2025}, 464 | url = {https://github.com/FTF1990/Industrial-digital-twin-by-transformer} 465 | } 466 | ``` 467 | 468 | ## 🗺️ Roadmap 469 | 470 | ### v1.0 (Current) ✅ 471 | - [x] Stage2 Boost training system 472 | - [x] Intelligent R² threshold selection 473 | - [x] Ensemble model generation 474 | - [x] Inference comparison tools 475 | - [x] Enhanced Gradio interface 476 | 477 | ### v2.0 (Upcoming) 🚀 478 | 479 | #### **Stage3 Temporal Oscillation Enhancement System** 🕐 480 | The next evolution targeting temporal oscillation signal reconstruction: 481 | 482 | - **Stage3 Temporal Oscillation Feature Extraction**: 483 | - Focus on signals with temporal oscillation characteristics (high-frequency pulsations, vibrations, etc.) 484 | - Current spatial-sequence Transformers can only capture mean features of temporal oscillations, unable to reconstruct oscillation patterns 485 | - Use temporal ML models or temporal Transformers for pure time-series feature extraction 486 | - Enhance and restore temporal oscillation characteristics inherent to the signals themselves 487 | 488 | - **Final Residual Future Prediction**: 489 | - After Stage1 + Stage2 + Stage3, the final residuals are primarily devoid of spatial features 490 | - Enable pure time-series forecasting on final residuals for future timestep prediction 491 | - Suitable for applications requiring forward prediction capabilities 492 | 493 | - **Signal Relationship Mask Editing** (Planned): 494 | - Maximize Transformer flexibility with input-output signal relationship masks 495 | - Apply engineering knowledge to mask non-directly-related factors 496 | - Better reconstruct real system behaviors by incorporating domain expertise 497 | - Enhance model accuracy through expert-guided feature relationships 498 | 499 | - **Complete Spatial-Temporal Decomposition Architecture**: 500 | - **Stage1 (SST)**: Spatial sensor relationships and cross-sensor dependencies 501 | - **Stage2 (Boost)**: Spatial residual correction and secondary spatial patterns 502 | - **Stage3 (Temporal)**: Pure temporal oscillation features and time-series dynamics 503 | - **Final Goal**: Separate spatial and temporal features into hierarchical layers, capturing all predictable patterns except irreducible noise for universal digital twin applications 504 | 505 | - **Hierarchical Feature Extraction Philosophy**: 506 | - Layer 1: Primary spatial sensor correlations (SST) 507 | - Layer 2: Residual spatial patterns (Stage2 Boost) 508 | - Layer 3: Temporal oscillation characteristics (Stage3 Temporal) 509 | - Final Residual: Irreducible stochastic noise + optional future prediction 510 | 511 | This design aims to achieve **universal digital twin modeling** by systematically decomposing and capturing all predictable features across different domains. 512 | 513 | --- 514 | 515 | **Made with ❤️ for the Industrial AI Community** 516 | -------------------------------------------------------------------------------- /models/residual_tft.py: -------------------------------------------------------------------------------- 1 | """ 2 | Residual TFT Module for Stage2 Boost Training 3 | 4 | This module provides utilities for residual extraction and Stage2 model training 5 | in the Industrial Digital Twin framework. 6 | 7 | Key Components: 8 | - ResidualExtractor: Extract residuals from trained SST models 9 | - GroupedMultiTargetTFT: TFT-style model for residual prediction 10 | - Utility functions for residual data preparation 11 | - Mixed precision inference utilities 12 | - Safe R² computation for multi-output scenarios 13 | - Selective boosting based on R² thresholds 14 | """ 15 | 16 | import torch 17 | import torch.nn as nn 18 | import numpy as np 19 | import pandas as pd 20 | import gc 21 | from typing import Dict, List, Tuple, Any, Optional 22 | from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score 23 | from torch.cuda.amp import autocast 24 | 25 | 26 | class GroupedMultiTargetTFT(nn.Module): 27 | """ 28 | Grouped Multi-Target Temporal Fusion Transformer 29 | 30 | A TFT-style model for predicting multiple target sensors with optional grouping. 31 | This model is compatible with the StaticSensorTransformer architecture but adds 32 | support for signal grouping and temporal fusion capabilities. 33 | 34 | Args: 35 | num_targets (int): Number of target sensors to predict 36 | num_external_factors (int): Number of external/boundary condition factors 37 | d_model (int): Model dimension 38 | nhead (int): Number of attention heads 39 | num_layers (int): Number of transformer layers 40 | dropout (float): Dropout rate 41 | use_grouping (bool): Whether to use signal grouping 42 | signal_groups (List[List[int]], optional): Groups of signal indices 43 | """ 44 | 45 | def __init__( 46 | self, 47 | num_targets: int, 48 | num_external_factors: int, 49 | d_model: int = 128, 50 | nhead: int = 8, 51 | num_layers: int = 3, 52 | dropout: float = 0.1, 53 | use_grouping: bool = False, 54 | signal_groups: Optional[List[List[int]]] = None 55 | ): 56 | super(GroupedMultiTargetTFT, self).__init__() 57 | 58 | self.num_targets = num_targets 59 | self.num_external_factors = num_external_factors 60 | self.d_model = d_model 61 | self.nhead = nhead 62 | self.num_layers = num_layers 63 | self.use_grouping = use_grouping 64 | self.signal_groups = signal_groups 65 | 66 | # Input embedding 67 | self.input_embedding = nn.Linear(1, d_model) 68 | self.position_encoding = nn.Parameter( 69 | torch.randn(num_external_factors, d_model) 70 | ) 71 | 72 | # Transformer encoder 73 | encoder_layer = nn.TransformerEncoderLayer( 74 | d_model=d_model, 75 | nhead=nhead, 76 | dim_feedforward=d_model * 2, 77 | dropout=dropout, 78 | batch_first=True 79 | ) 80 | self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) 81 | 82 | # Output layers 83 | self.global_pool = nn.AdaptiveAvgPool1d(1) 84 | self.output_projection = nn.Linear(d_model, num_targets) 85 | 86 | self._init_weights() 87 | 88 | def _init_weights(self): 89 | """Initialize model weights""" 90 | for p in self.parameters(): 91 | if p.dim() > 1: 92 | nn.init.xavier_uniform_(p) 93 | 94 | def forward(self, x: torch.Tensor) -> torch.Tensor: 95 | """ 96 | Forward pass 97 | 98 | Args: 99 | x (torch.Tensor): Input tensor of shape (batch_size, num_external_factors) 100 | 101 | Returns: 102 | torch.Tensor: Predictions of shape (batch_size, num_targets) 103 | """ 104 | batch_size = x.shape[0] 105 | 106 | # Embed inputs 107 | x = x.unsqueeze(-1) # (batch, factors, 1) 108 | x = self.input_embedding(x) + self.position_encoding.unsqueeze(0) 109 | 110 | # Transform 111 | x = self.transformer(x) # (batch, factors, d_model) 112 | 113 | # Pool and project 114 | x = x.permute(0, 2, 1) # (batch, d_model, factors) 115 | x = self.global_pool(x).squeeze(-1) # (batch, d_model) 116 | predictions = self.output_projection(x) # (batch, num_targets) 117 | 118 | return predictions 119 | 120 | 121 | class ResidualExtractor: 122 | """ 123 | Utility class for extracting residuals from trained SST models 124 | 125 | This class provides methods to extract prediction residuals from trained 126 | StaticSensorTransformer models, which can then be used for Stage2 boost training. 127 | """ 128 | 129 | @staticmethod 130 | def extract_residuals_from_trained_models( 131 | model_name: str, 132 | df: pd.DataFrame, 133 | global_state: Dict[str, Any], 134 | device: torch.device 135 | ) -> Tuple[pd.DataFrame, Dict[str, Any]]: 136 | """ 137 | Extract residuals from a trained SST model 138 | 139 | Args: 140 | model_name (str): Name of the trained model in global_state 141 | df (pd.DataFrame): Input dataframe with sensor data 142 | global_state (Dict): Global state containing trained models 143 | device (torch.device): Device to run inference on 144 | 145 | Returns: 146 | Tuple[pd.DataFrame, Dict]: 147 | - DataFrame containing residuals and predictions 148 | - Dictionary with extraction info 149 | """ 150 | try: 151 | # Get model from global state 152 | if model_name not in global_state.get('trained_models', {}): 153 | return pd.DataFrame(), { 154 | 'error': f"Model '{model_name}' not found in trained models" 155 | } 156 | 157 | model_info = global_state['trained_models'][model_name] 158 | model = model_info['model'] 159 | boundary_signals = model_info['boundary_signals'] 160 | target_signals = model_info['target_signals'] 161 | 162 | # Get scalers 163 | scalers = global_state.get('scalers', {}).get(model_name, {}) 164 | if not scalers: 165 | return pd.DataFrame(), { 166 | 'error': f"Scalers not found for model '{model_name}'" 167 | } 168 | 169 | scaler_X = scalers.get('X') 170 | scaler_y = scalers.get('y') 171 | 172 | # Prepare input data 173 | X = df[boundary_signals].values 174 | y_true = df[target_signals].values 175 | 176 | # Use improved residual computation with batch processing and mixed precision 177 | residuals = compute_residuals_correctly( 178 | X, y_true, model, scaler_X, scaler_y, device, batch_size=1024 179 | ) 180 | 181 | # Also get predictions for visualization 182 | X_scaled = scaler_X.transform(X) 183 | model.eval() 184 | with torch.no_grad(): 185 | X_tensor = torch.FloatTensor(X_scaled).to(device) 186 | y_pred_list = [] 187 | batch_size = 1024 188 | for i in range(0, len(X_tensor), batch_size): 189 | batch = X_tensor[i:i+batch_size] 190 | with autocast(): 191 | y_pred_batch = model(batch).cpu().numpy() 192 | y_pred_list.append(y_pred_batch) 193 | y_pred_scaled = np.vstack(y_pred_list) 194 | y_pred = scaler_y.inverse_transform(y_pred_scaled) 195 | 196 | # Create output dataframe 197 | residuals_df = pd.DataFrame() 198 | 199 | # Add boundary signals 200 | for sig in boundary_signals: 201 | residuals_df[sig] = df[sig].values 202 | 203 | # Add residuals for each target signal 204 | residual_signals = [] 205 | for i, sig in enumerate(target_signals): 206 | residual_col = f"{sig}_residual" 207 | residuals_df[residual_col] = residuals[:, i] 208 | residual_signals.append(residual_col) 209 | 210 | # Also add true and predicted values 211 | residuals_df[f"{sig}_true"] = y_true[:, i] 212 | residuals_df[f"{sig}_pred"] = y_pred[:, i] 213 | 214 | # Calculate per-signal metrics 215 | metrics = {} 216 | for i, sig in enumerate(target_signals): 217 | metrics[sig] = { 218 | 'mae': mean_absolute_error(y_true[:, i], y_pred[:, i]), 219 | 'rmse': np.sqrt(mean_squared_error(y_true[:, i], y_pred[:, i])), 220 | 'r2': r2_score(y_true[:, i], y_pred[:, i]) 221 | } 222 | 223 | # Create info dictionary 224 | info = { 225 | 'model_name': model_name, 226 | 'boundary_signals': boundary_signals, 227 | 'target_signals': target_signals, 228 | 'residual_signals': residual_signals, 229 | 'num_samples': len(residuals_df), 230 | 'metrics': metrics 231 | } 232 | 233 | return residuals_df, info 234 | 235 | except Exception as e: 236 | import traceback 237 | return pd.DataFrame(), { 238 | 'error': f"Failed to extract residuals: {str(e)}", 239 | 'traceback': traceback.format_exc() 240 | } 241 | 242 | 243 | def prepare_residual_sequence_data( 244 | residuals_df: pd.DataFrame, 245 | boundary_signals: List[str], 246 | residual_signals: List[str], 247 | sequence_length: int = 10, 248 | future_horizon: int = 1 249 | ) -> Tuple[np.ndarray, np.ndarray]: 250 | """ 251 | Prepare sequential residual data for TFT-style models 252 | 253 | Args: 254 | residuals_df (pd.DataFrame): DataFrame containing residuals 255 | boundary_signals (List[str]): List of boundary signal column names 256 | residual_signals (List[str]): List of residual signal column names 257 | sequence_length (int): Length of input sequences 258 | future_horizon (int): Number of steps to predict ahead 259 | 260 | Returns: 261 | Tuple[np.ndarray, np.ndarray]: X sequences and y targets 262 | """ 263 | X_sequences = [] 264 | y_sequences = [] 265 | 266 | # Get data 267 | boundary_data = residuals_df[boundary_signals].values 268 | residual_data = residuals_df[residual_signals].values 269 | 270 | # Create sequences 271 | for i in range(len(residuals_df) - sequence_length - future_horizon + 1): 272 | # Input: boundary conditions + past residuals 273 | X_seq = np.concatenate([ 274 | boundary_data[i:i + sequence_length], 275 | residual_data[i:i + sequence_length] 276 | ], axis=1) 277 | 278 | # Target: future residuals 279 | y_seq = residual_data[i + sequence_length:i + sequence_length + future_horizon] 280 | 281 | X_sequences.append(X_seq) 282 | y_sequences.append(y_seq) 283 | 284 | return np.array(X_sequences), np.array(y_sequences) 285 | 286 | 287 | def train_residual_tft( 288 | model: nn.Module, 289 | train_loader: torch.utils.data.DataLoader, 290 | val_loader: torch.utils.data.DataLoader, 291 | config: Dict[str, Any], 292 | device: torch.device 293 | ) -> Tuple[nn.Module, Dict[str, List[float]]]: 294 | """ 295 | Train a Residual TFT model 296 | 297 | Args: 298 | model (nn.Module): The TFT model to train 299 | train_loader (DataLoader): Training data loader 300 | val_loader (DataLoader): Validation data loader 301 | config (Dict): Training configuration 302 | device (torch.device): Device to train on 303 | 304 | Returns: 305 | Tuple[nn.Module, Dict]: Trained model and training history 306 | """ 307 | optimizer = torch.optim.AdamW( 308 | model.parameters(), 309 | lr=config.get('lr', 0.001), 310 | weight_decay=config.get('weight_decay', 1e-5) 311 | ) 312 | 313 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 314 | optimizer, 315 | mode='min', 316 | factor=config.get('scheduler_factor', 0.5), 317 | patience=config.get('scheduler_patience', 10) 318 | ) 319 | 320 | criterion = nn.MSELoss() 321 | 322 | history = { 323 | 'train_losses': [], 324 | 'val_losses': [], 325 | 'train_r2': [], 326 | 'val_r2': [] 327 | } 328 | 329 | best_val_loss = float('inf') 330 | best_model_state = None 331 | patience_counter = 0 332 | early_stop_patience = config.get('early_stop_patience', 25) 333 | 334 | for epoch in range(config.get('epochs', 100)): 335 | # Training 336 | model.train() 337 | train_loss = 0.0 338 | train_preds = [] 339 | train_targets = [] 340 | 341 | for batch_X, batch_y in train_loader: 342 | batch_X, batch_y = batch_X.to(device), batch_y.to(device) 343 | 344 | optimizer.zero_grad() 345 | outputs = model(batch_X) 346 | loss = criterion(outputs, batch_y) 347 | loss.backward() 348 | 349 | # Gradient clipping 350 | if config.get('grad_clip', 0) > 0: 351 | torch.nn.utils.clip_grad_norm_( 352 | model.parameters(), 353 | config['grad_clip'] 354 | ) 355 | 356 | optimizer.step() 357 | 358 | train_loss += loss.item() 359 | train_preds.append(outputs.detach().cpu().numpy()) 360 | train_targets.append(batch_y.detach().cpu().numpy()) 361 | 362 | train_loss /= len(train_loader) 363 | train_preds = np.vstack(train_preds) 364 | train_targets = np.vstack(train_targets) 365 | train_r2 = r2_score(train_targets, train_preds) 366 | 367 | # Validation 368 | model.eval() 369 | val_loss = 0.0 370 | val_preds = [] 371 | val_targets = [] 372 | 373 | with torch.no_grad(): 374 | for batch_X, batch_y in val_loader: 375 | batch_X, batch_y = batch_X.to(device), batch_y.to(device) 376 | outputs = model(batch_X) 377 | loss = criterion(outputs, batch_y) 378 | 379 | val_loss += loss.item() 380 | val_preds.append(outputs.cpu().numpy()) 381 | val_targets.append(batch_y.cpu().numpy()) 382 | 383 | val_loss /= len(val_loader) 384 | val_preds = np.vstack(val_preds) 385 | val_targets = np.vstack(val_targets) 386 | val_r2 = r2_score(val_targets, val_preds) 387 | 388 | # Record history 389 | history['train_losses'].append(train_loss) 390 | history['val_losses'].append(val_loss) 391 | history['train_r2'].append(train_r2) 392 | history['val_r2'].append(val_r2) 393 | 394 | # Learning rate scheduling 395 | scheduler.step(val_loss) 396 | 397 | # Early stopping check 398 | if val_loss < best_val_loss: 399 | best_val_loss = val_loss 400 | best_model_state = model.state_dict().copy() 401 | patience_counter = 0 402 | else: 403 | patience_counter += 1 404 | 405 | if patience_counter >= early_stop_patience: 406 | break 407 | 408 | # Load best model 409 | if best_model_state is not None: 410 | model.load_state_dict(best_model_state) 411 | 412 | return model, history 413 | 414 | 415 | def compute_r2_safe( 416 | y_true: np.ndarray, 417 | y_pred: np.ndarray, 418 | method: str = 'per_output_mean' 419 | ) -> Tuple[float, np.ndarray]: 420 | """ 421 | Safe R² calculation - avoid anomalies with multi-output 422 | 423 | This function computes R² scores per output signal and aggregates them 424 | to avoid anomalies that can occur with sklearn's default multi-output R² calculation. 425 | 426 | Args: 427 | y_true: Ground truth (n_samples, n_outputs) or (n_samples,) 428 | y_pred: Predictions (n_samples, n_outputs) or (n_samples,) 429 | method: Aggregation method 430 | - 'per_output_mean': Mean of per-output R² (filters out anomalies) 431 | - 'per_output_median': Median of per-output R² 432 | - 'sklearn_default': Use sklearn's default multioutput='uniform_average' 433 | - 'global': Treat all values as one global prediction 434 | 435 | Returns: 436 | r2: Overall R² score 437 | per_output_r2: R² for each output (for diagnostics) 438 | """ 439 | if y_true.ndim == 1: 440 | y_true = y_true.reshape(-1, 1) 441 | y_pred = y_pred.reshape(-1, 1) 442 | 443 | n_outputs = y_true.shape[1] 444 | per_output_r2 = np.zeros(n_outputs) 445 | 446 | # Compute R² for each output separately 447 | for i in range(n_outputs): 448 | y_t = y_true[:, i] 449 | y_p = y_pred[:, i] 450 | 451 | # Check variance 452 | var_true = np.var(y_t) 453 | if var_true < 1e-10: 454 | per_output_r2[i] = 0.0 455 | else: 456 | try: 457 | per_output_r2[i] = r2_score(y_t, y_p) 458 | except Exception: 459 | per_output_r2[i] = -1.0 460 | 461 | # Aggregate based on method 462 | if method == 'per_output_mean': 463 | # Filter out anomalies 464 | valid_r2 = per_output_r2[np.isfinite(per_output_r2) & (per_output_r2 > -10)] 465 | r2 = np.mean(valid_r2) if len(valid_r2) > 0 else -1.0 466 | elif method == 'per_output_median': 467 | valid_r2 = per_output_r2[np.isfinite(per_output_r2) & (per_output_r2 > -10)] 468 | r2 = np.median(valid_r2) if len(valid_r2) > 0 else -1.0 469 | elif method == 'sklearn_default': 470 | r2 = r2_score(y_true, y_pred, multioutput='uniform_average') 471 | elif method == 'global': 472 | # Flatten and treat as single output 473 | r2 = r2_score(y_true.flatten(), y_pred.flatten()) 474 | else: 475 | r2 = r2_score(y_true, y_pred) 476 | 477 | return r2, per_output_r2 478 | 479 | 480 | def compute_residuals_correctly( 481 | X_orig: np.ndarray, 482 | y_orig: np.ndarray, 483 | base_model: nn.Module, 484 | scaler_X: Any, 485 | scaler_y: Any, 486 | device: torch.device, 487 | batch_size: int = 1024 488 | ) -> np.ndarray: 489 | """ 490 | Correctly compute residuals in original scale. 491 | 492 | CRITICAL: Residuals must be computed in the original (non-standardized) space. 493 | This ensures that Stage2 model learns meaningful residual patterns. 494 | 495 | Steps: 496 | 1. Standardize input using scaler_X 497 | 2. Predict in standardized space 498 | 3. Inverse transform predictions to original space 499 | 4. Compute residuals = y_true - y_pred (both in original space) 500 | 501 | Args: 502 | X_orig: Original input data (n_samples, n_features) 503 | y_orig: Original target data (n_samples, n_targets) 504 | base_model: Trained Stage1 model 505 | scaler_X: Input scaler (from Stage1 training) 506 | scaler_y: Output scaler (from Stage1 training) 507 | device: torch.device for inference 508 | batch_size: Batch size for memory-efficient inference 509 | 510 | Returns: 511 | residuals: Residuals in original space (n_samples, n_targets) 512 | """ 513 | base_model.eval() 514 | 515 | # Step 1: Standardize input 516 | X_scaled = scaler_X.transform(X_orig) 517 | 518 | # Step 2: Predict in standardized space (with batching for memory efficiency) 519 | with torch.no_grad(): 520 | X_tensor = torch.FloatTensor(X_scaled).to(device) 521 | y_pred_scaled_list = [] 522 | 523 | for i in range(0, len(X_tensor), batch_size): 524 | batch = X_tensor[i:i+batch_size] 525 | with autocast(): 526 | y_pred_batch = base_model(batch).cpu().numpy() 527 | y_pred_scaled_list.append(y_pred_batch) 528 | 529 | y_pred_scaled = np.vstack(y_pred_scaled_list) 530 | 531 | # Step 3: Inverse transform to original space 532 | y_pred_original = scaler_y.inverse_transform(y_pred_scaled) 533 | 534 | # Step 4: Compute residuals in original space 535 | residuals = y_orig - y_pred_original 536 | 537 | return residuals 538 | 539 | 540 | def batch_inference( 541 | model: nn.Module, 542 | X_data: np.ndarray, 543 | scaler_X: Any, 544 | scaler_y: Any, 545 | device: torch.device, 546 | batch_size: int = 512, 547 | model_name: str = "Model" 548 | ) -> np.ndarray: 549 | """ 550 | Batch processing inference to avoid GPU OOM 551 | 552 | This function performs inference in batches with automatic memory management 553 | to handle large datasets without running out of GPU memory. 554 | 555 | Args: 556 | model: Trained model 557 | X_data: Input data in original space (n_samples, n_features) 558 | scaler_X: Input scaler 559 | scaler_y: Output scaler 560 | device: torch.device for inference 561 | batch_size: Batch size for processing 562 | model_name: Name for logging 563 | 564 | Returns: 565 | y_pred: Predictions in original space (n_samples, n_targets) 566 | """ 567 | model.eval() 568 | n_samples = X_data.shape[0] 569 | predictions_list = [] 570 | 571 | for i in range(0, n_samples, batch_size): 572 | end_idx = min(i + batch_size, n_samples) 573 | batch_X = X_data[i:end_idx] 574 | 575 | # Standardize 576 | batch_X_scaled = scaler_X.transform(batch_X) 577 | batch_X_tensor = torch.FloatTensor(batch_X_scaled).to(device) 578 | 579 | # Inference with mixed precision 580 | with torch.no_grad(): 581 | with autocast(): 582 | batch_pred_scaled = model(batch_X_tensor).cpu().numpy() 583 | 584 | # Inverse transform 585 | batch_pred = scaler_y.inverse_transform(batch_pred_scaled) 586 | predictions_list.append(batch_pred) 587 | 588 | # Clean up GPU memory 589 | del batch_X_tensor, batch_pred_scaled 590 | if (i // batch_size + 1) % 10 == 0: 591 | torch.cuda.empty_cache() 592 | gc.collect() 593 | 594 | y_pred = np.vstack(predictions_list) 595 | return y_pred 596 | 597 | 598 | def inference_with_boosting( 599 | X_test: np.ndarray, 600 | base_model: nn.Module, 601 | residual_model: nn.Module, 602 | scalers: Dict[str, Any], 603 | device: torch.device, 604 | signal_r2_scores: Optional[np.ndarray] = None, 605 | r2_threshold: float = 0.4, 606 | batch_size: int = 512, 607 | use_selective_boosting: bool = True 608 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 609 | """ 610 | Apply boosting selectively based on R² scores 611 | 612 | This function implements intelligent model selection: 613 | - For signals with high R² (≥ threshold): Use Stage1 predictions only 614 | - For signals with low R² (< threshold): Apply Stage2 boost correction 615 | 616 | Args: 617 | X_test: Test data in original space 618 | base_model: Stage 1 model 619 | residual_model: Stage 2 residual model 620 | scalers: Dictionary containing: 621 | - 'stage1_scaler_X': Stage1 input scaler 622 | - 'stage1_scaler_y': Stage1 output scaler 623 | - 'stage2_scaler_X': Stage2 input scaler 624 | - 'stage2_scaler_residual': Stage2 residual scaler 625 | device: torch.device 626 | signal_r2_scores: R² score for each signal from validation (optional) 627 | r2_threshold: Threshold for determining weak signals 628 | batch_size: Batch size for inference 629 | use_selective_boosting: If True, only boost weak signals; if False, boost all 630 | 631 | Returns: 632 | y_pred_stage1: Stage 1 predictions only 633 | y_pred_boosted: Final predictions (with selective boosting) 634 | boosting_mask: Boolean mask indicating which signals were boosted 635 | """ 636 | 637 | # Stage 1 prediction 638 | y_pred_stage1 = batch_inference( 639 | base_model, X_test, 640 | scalers['stage1_scaler_X'], 641 | scalers['stage1_scaler_y'], 642 | device, 643 | batch_size, 644 | "Stage 1" 645 | ) 646 | 647 | # Stage 2 residual prediction 648 | residual_pred = batch_inference( 649 | residual_model, X_test, 650 | scalers['stage2_scaler_X'], 651 | scalers['stage2_scaler_residual'], 652 | device, 653 | batch_size, 654 | "Stage 2" 655 | ) 656 | 657 | # Full boosting (all signals) 658 | y_pred_full_boosted = y_pred_stage1 + residual_pred 659 | 660 | # Selective boosting (only weak signals) 661 | if use_selective_boosting and signal_r2_scores is not None: 662 | # Identify weak signals (R² < threshold) 663 | boosting_mask = signal_r2_scores < r2_threshold 664 | num_boosted = np.sum(boosting_mask) 665 | 666 | print(f"Selective Boosting: {num_boosted}/{len(boosting_mask)} signals boosted (R² < {r2_threshold})") 667 | 668 | # Apply boosting only to weak signals 669 | y_pred_boosted = y_pred_stage1.copy() 670 | y_pred_boosted[:, boosting_mask] = y_pred_full_boosted[:, boosting_mask] 671 | else: 672 | # Use full boosting for all signals 673 | y_pred_boosted = y_pred_full_boosted 674 | boosting_mask = np.ones(y_pred_stage1.shape[1], dtype=bool) 675 | 676 | return y_pred_stage1, y_pred_boosted, boosting_mask 677 | 678 | 679 | def compute_per_signal_metrics( 680 | y_true: np.ndarray, 681 | y_pred: np.ndarray 682 | ) -> Dict[str, np.ndarray]: 683 | """ 684 | Compute detailed metrics for each signal 685 | 686 | Args: 687 | y_true: Ground truth (n_samples, n_signals) 688 | y_pred: Predictions (n_samples, n_signals) 689 | 690 | Returns: 691 | metrics: Dictionary with per-signal metrics 692 | - 'mae': Mean Absolute Error per signal 693 | - 'rmse': Root Mean Squared Error per signal 694 | - 'r2': R² score per signal 695 | - 'mape': Mean Absolute Percentage Error per signal 696 | - 'true_mean': Mean of true values per signal 697 | - 'true_std': Std of true values per signal 698 | - 'pred_mean': Mean of predictions per signal 699 | - 'pred_std': Std of predictions per signal 700 | """ 701 | n_signals = y_true.shape[1] 702 | 703 | metrics = { 704 | 'mae': np.zeros(n_signals), 705 | 'rmse': np.zeros(n_signals), 706 | 'r2': np.zeros(n_signals), 707 | 'mape': np.zeros(n_signals), 708 | 'true_mean': np.zeros(n_signals), 709 | 'true_std': np.zeros(n_signals), 710 | 'pred_mean': np.zeros(n_signals), 711 | 'pred_std': np.zeros(n_signals) 712 | } 713 | 714 | for i in range(n_signals): 715 | y_t = y_true[:, i] 716 | y_p = y_pred[:, i] 717 | 718 | metrics['mae'][i] = mean_absolute_error(y_t, y_p) 719 | metrics['rmse'][i] = np.sqrt(mean_squared_error(y_t, y_p)) 720 | 721 | var_true = np.var(y_t) 722 | if var_true < 1e-10: 723 | metrics['r2'][i] = 0.0 724 | else: 725 | try: 726 | metrics['r2'][i] = r2_score(y_t, y_p) 727 | except Exception: 728 | metrics['r2'][i] = -1.0 729 | 730 | # MAPE for non-zero values 731 | non_zero_mask = np.abs(y_t) > 1e-6 732 | if np.sum(non_zero_mask) > 0: 733 | mape = np.mean(np.abs((y_t[non_zero_mask] - y_p[non_zero_mask]) / 734 | y_t[non_zero_mask])) * 100 735 | metrics['mape'][i] = mape 736 | else: 737 | metrics['mape'][i] = np.nan 738 | 739 | metrics['true_mean'][i] = np.mean(y_t) 740 | metrics['true_std'][i] = np.std(y_t) 741 | metrics['pred_mean'][i] = np.mean(y_p) 742 | metrics['pred_std'][i] = np.std(y_p) 743 | 744 | return metrics 745 | 746 | 747 | def clear_gpu_memory(): 748 | """Clean up GPU memory""" 749 | if torch.cuda.is_available(): 750 | torch.cuda.empty_cache() 751 | gc.collect() 752 | 753 | 754 | def print_gpu_memory(): 755 | """Print GPU memory usage""" 756 | if torch.cuda.is_available(): 757 | allocated = torch.cuda.memory_allocated() / 1024**3 758 | reserved = torch.cuda.memory_reserved() / 1024**3 759 | print(f"GPU Memory: Allocated {allocated:.2f} GB | Reserved {reserved:.2f} GB") 760 | --------------------------------------------------------------------------------