├── .idea ├── .gitignore ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml ├── modules.xml ├── sensorllm_code.iml └── vcs.xml ├── asset ├── logo_left.jpg ├── logo_right.png └── sensorllm_model.png ├── mhealth_stage1.ipynb ├── mhealth_stage2.ipynb ├── readme.md └── sensorllm ├── __init__.py ├── data ├── __init__.py ├── stage1_dataset.py ├── stage2_dataset.py └── utils.py ├── eval └── eval.py ├── model ├── __init__.py ├── chronos_model │ ├── __init__.py │ └── chronos_model.py ├── stage1_sensorllm.py ├── stage2_sensorllm.py ├── ts_backbone.yaml └── utils.py ├── train ├── __init__.py ├── llama_flash_attn_monkey_patch.py ├── sensorllm_trainer.py ├── train.py └── train_mem.py └── utils.py /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 24 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/sensorllm_code.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /asset/logo_left.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zechenli03/SensorLLM/1f4142a30f452721e943771190fe1dade3337249/asset/logo_left.jpg -------------------------------------------------------------------------------- /asset/logo_right.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zechenli03/SensorLLM/1f4142a30f452721e943771190fe1dade3337249/asset/logo_right.png -------------------------------------------------------------------------------- /asset/sensorllm_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zechenli03/SensorLLM/1f4142a30f452721e943771190fe1dade3337249/asset/sensorllm_model.png -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 |

SensorLLM

2 |

Human-Intuitive Alignment of Multivariate Sensor Data with LLMs for Activity Recognition

3 |

4 | Zechen Li1   5 | Shohreh Deldari1   6 | Linyao Chen2   7 | Hao Xue1   8 | Flora D. Salim1 9 |

10 | 1 University of New South Wales, Sydney
11 | 2 University of Tokyo 12 |

13 | 14 | arXiv 15 | 16 |

17 | 18 | 19 | ## 🌟 Overview 20 | 21 | **SensorLLM** is a two-stage framework that aligns sensor time series with human-intuitive text, enabling LLMs to interpret complex numerical data and achieve SOTA human activity recognition across varying sensor types, counts, and sequence lengths. 22 |
23 | sensorllm_model 24 |
25 | 26 | ### 🔑 Key Features 27 | - Aligns sensor time-series with ***human-intuitive, annotation-free*** textual trend descriptions and summaries via a QA-based framework. 28 | - ***Sensor–Language Alignment Stage*** operates on single-channel, variable-length segments for fine-grained trend-text alignment. 29 | - ***Task-Aware Tuning Stage*** handles multi-channel, multi-sensor data for downstream human activity recognition (HAR). 30 | 31 | ### 📂 Datasets 32 | The current implementation supports five HAR datasets: [`USC-HAD`](https://sipi.usc.edu/had/), [`UCI-HAR`](https://archive.ics.uci.edu/dataset/240/human+activity+recognition+using+smartphones), [`MHealth`](https://archive.ics.uci.edu/dataset/319/mhealth+dataset), [`Capture-24`](https://ora.ox.ac.uk/objects/uuid:99d7c092-d865-4a19-b096-cc16440cd001), and [`PAMAP2`](https://archive.ics.uci.edu/dataset/231/pamap2+physical+activity+monitoring). 33 | 34 | To apply SensorLLM to other datasets, please refer to the code and configuration examples provided for the supported datasets. In particular, you may need to modify the corresponding entries in [`ts_backbone.yaml`](./sensorllm/model/ts_backbone.yaml) and adapt the data loading logic in the [`./sensorllm/data`](./sensorllm/data) folder to match your dataset’s format. 35 | 36 | 37 | ## 🚀 Getting started 38 | 39 | > Currently supported pretrained models: 40 | > - Time-series models: [Chronos](https://arxiv.org/abs/2403.07815) 41 | > - Language models: [LLaMA](https://arxiv.org/abs/2407.21783) 42 | > 43 | > Other pretrained models **can be used with minor modifications to the SensorLLM framework**. 44 | 45 | 46 | ### Sensor-Language QA Pairs Generation 47 | We provide two example notebooks to generate QA pairs for aligning sensor time-series data with human-intuitive text: 48 | - [`mhealth_stage1.ipynb`](./mhealth_stage1.ipynb): Generates QA pairs for Stage 1 by aligning single-channel sensor segments with trend-based natural language descriptions. 49 | - [`mhealth_stage2.ipynb`](./mhealth_stage1.ipynb): Generates statistical information text for Stage 2, performing HAR classification using multi-channel sensor data. 50 | 51 | You can also customize or extend the QA templates in these notebooks to generate more diverse types of sensor–language QA pairs for your own use cases. 52 | 53 | ### Sensor–Language Alignment 54 | To align sensor time-series data with text, run the following command: 55 | 56 | ```bash 57 | torchrun --nproc_per_node=[NUM_GPUS] sensorllm/train/train_mem.py \ 58 | --model_name_or_path [LLM_PATH] \ 59 | --pt_encoder_backbone_ckpt [TS_EMBEDDER_PATH] \ 60 | --tokenize_method 'StanNormalizeUniformBins' \ 61 | --dataset [DATASET_NAME] \ 62 | --data_path [TS_TRAIN_PATH] \ 63 | --eval_data_path [TS_EVAL_PATH] \ 64 | --qa_path [QA_TRAIN_PATH] \ 65 | --eval_qa_path [QA_EVAL_PATH] \ 66 | --output_dir [OUTPUT_PATH] \ 67 | --model_max_length [MAX_LEN] \ 68 | --num_train_epochs [EPOCH] \ 69 | --per_device_train_batch_size [TRAIN_BATCH] \ 70 | --per_device_eval_batch_size [EVAL_BATCH] \ 71 | --evaluation_strategy "steps" \ 72 | --save_strategy "steps" \ 73 | --save_steps [SAVE_STEPS] \ 74 | --eval_steps [EVAL_STEPS] \ 75 | --learning_rate 2e-3 \ 76 | --weight_decay 0.0 \ 77 | --warmup_ratio 0.03 \ 78 | --lr_scheduler_type "cosine" \ 79 | --logging_steps 1 \ 80 | --gradient_checkpointing True \ 81 | --save_total_limit 1 \ 82 | --bf16 True \ 83 | --fix_llm True \ 84 | --fix_ts_encoder True \ 85 | --model_type CasualLM \ 86 | --load_best_model_at_end True 87 | ``` 88 | 89 | ### Evaluation or Inference 90 | To perform evaluation or inference for the Sensor–Language Alignment stage, run the following command: 91 | 92 | ```bash 93 | python sensorllm/eval/eval.py \ 94 | --model_name_or_path [STAGE1_MODEL_PATH] \ 95 | --pt_encoder_backbone_ckpt [TS_EMBEDDER_PATH] \ 96 | --torch_dtype bfloat16 \ 97 | --tokenize_method 'StanNormalizeUniformBins' \ 98 | --dataset [DATASET_NAME] \ 99 | --data_path [TS_DATASET_PATH] \ 100 | --qa_path [QA_DATASET_PATH] \ 101 | --output_file_name [OUTPUT_FILE_NAME] \ 102 | --model_max_length [MAX_LEN] \ 103 | --shuffle False 104 | ``` 105 | 106 | ### Task-Aware Tuning 107 | To perform a HAR task, use the following command: 108 | ```bash 109 | torchrun --nproc_per_node=[NUM_GPUS] sensorllm/train/train_mem.py \ 110 | --model_name_or_path [STAGE1_MODEL_PATH] \ 111 | --pt_encoder_backbone_ckpt [TS_EMBEDDER_PATH] \ 112 | --model_type "SequenceClassification" \ 113 | --num_labels [ACTIVITY_NUM] \ 114 | --use_weighted_loss True \ 115 | --tokenize_method 'StanNormalizeUniformBins' \ 116 | --dataset [DATASET_NAME] \ 117 | --data_path [TS_TRAIN_PATH] \ 118 | --eval_data_path [TS_EVAL_PATH] \ 119 | --qa_path [QA_TRAIN_PATH] \ 120 | --eval_qa_path [QA_EVAL_PATH] \ 121 | --output_dir [OUTPUT_PATH] \ 122 | --model_max_length [MAX_LEN] \ 123 | --num_train_epochs [EPOCH] \ 124 | --num_train_epochs [EPOCH] \ 125 | --per_device_train_batch_size [TRAIN_BATCH] \ 126 | --per_device_eval_batch_size [EVAL_BATCH] \ 127 | --evaluation_strategy "steps" \ 128 | --save_strategy "steps" \ 129 | --save_steps [SAVE_STEPS] \ 130 | --eval_steps [EVAL_STEPS] \ 131 | --save_total_limit 1 \ 132 | --load_best_model_at_end True \ 133 | --learning_rate 2e-3 \ 134 | --weight_decay 0.0 \ 135 | --warmup_ratio 0.03 \ 136 | --lr_scheduler_type "cosine" \ 137 | --logging_steps 1 \ 138 | --bf16 True \ 139 | --fix_llm True \ 140 | --fix_cls_head False \ 141 | --fix_ts_encoder True \ 142 | --gradient_checkpointing True \ 143 | --metric_for_best_model "f1_macro" \ 144 | --preprocess_type "smry+Q" \ 145 | --greater_is_better True \ 146 | --stage_2 True \ 147 | --shuffle True 148 | ``` 149 | See [`./sensorllm/data/utils.py`](./sensorllm/data/utils.py) for all available preprocess_type options or to make edits. 150 | 151 | ## 🌍 Citation 152 | 153 | If you find this repository useful for your research, please cite our paper: 154 | 155 | ``` 156 | @misc{li2025sensorllm, 157 | title={SensorLLM: Human-Intuitive Alignment of Multivariate Sensor Data with LLMs for Activity Recognition}, 158 | author={Zechen Li and Shohreh Deldari and Linyao Chen and Hao Xue and Flora D. Salim}, 159 | year={2025}, 160 | eprint={2410.10624}, 161 | archivePrefix={arXiv}, 162 | primaryClass={cs.CL}, 163 | url={https://arxiv.org/abs/2410.10624}, 164 | } 165 | ``` 166 | 167 | ## 📄 License 168 | 169 | Creative Commons License 170 |
171 | This work is under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License. 172 | 173 | 174 | ## 📩 Contact 175 | 176 | If you have any questions or suggestions, feel free to contact Zechen at `zechen.li(at)unsw(dot)edu(dot)au`. 177 | -------------------------------------------------------------------------------- /sensorllm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zechenli03/SensorLLM/1f4142a30f452721e943771190fe1dade3337249/sensorllm/__init__.py -------------------------------------------------------------------------------- /sensorllm/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .stage1_dataset import make_ts_text_data_module, UniChannelTimeSeriesDataset 2 | from .stage2_dataset import make_ts_text_data_module_stage2, MultiChannelTimeSeriesDatasetStage2, make_ts_classification_data_module_stage2 -------------------------------------------------------------------------------- /sensorllm/data/stage1_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Sequence 2 | from torch.utils.data import Dataset 3 | import numpy as np 4 | import random 5 | import copy 6 | import json 7 | import pickle 8 | import logging 9 | from dataclasses import dataclass 10 | 11 | from sensorllm.data.utils import generate_chat_template, preprocess, get_token_dict 12 | from sensorllm.model.chronos_model import * 13 | 14 | import transformers 15 | 16 | import torch 17 | 18 | IGNORE_INDEX = -100 19 | RDM_SEED = 42 20 | 21 | 22 | def preprocess_time_series2( 23 | sources: Sequence[Dict[str, str]], # [{"Q": "...", "A": "...", "type": ...}] 24 | channel_names: Sequence[str], 25 | ts_list: list, 26 | dataset: str, 27 | data_args: dict, 28 | ) -> Sequence[Dict[str, str]]: 29 | # assert len(sources) == 6 30 | ts_token = data_args["default_ts_token"] 31 | modified_sources = [] 32 | start_tokens_dict, end_tokens_dict = get_token_dict(dataset, data_args) 33 | 34 | for source, channel_name, ts in zip(sources, channel_names, ts_list): 35 | if data_args["last_token"]: 36 | added_token = ts_token * (len(ts) + 1) 37 | else: 38 | added_token = ts_token * len(ts) 39 | assert channel_name in list(start_tokens_dict.keys()), f"Start token {channel_name} not found" 40 | assert channel_name in list(end_tokens_dict.keys()), f"End token {channel_name} not found" 41 | start_token = start_tokens_dict[channel_name] 42 | end_token = end_tokens_dict[channel_name] 43 | modified_q = start_token + added_token + end_token + source["Q"] 44 | modified_a = source["A"]+"\n\n"+source["summary"]["A"] 45 | modified_sources.append({"Q": modified_q, "A": modified_a, "type": source["type"]}) 46 | return modified_sources 47 | 48 | 49 | class UniChannelTimeSeriesDataset(Dataset): 50 | def __init__(self, data_path=None, qa_path=None, tokenizer=None, chronos_tokenizer=None, split=None, data_args=None): 51 | """ 52 | data_path: a tensor of shape (N, C, L) where N is the number of multichannel time-series samples, 53 | C is the number of channels (6), and L is the sequence length (200). 54 | qa_path: a list of QA texts corresponding to each channel of each sample. 55 | """ 56 | 57 | super(UniChannelTimeSeriesDataset, self).__init__() 58 | self.data_path = data_path 59 | self.qa_path = qa_path 60 | self.tokenizer = tokenizer 61 | self.chronos_tokenizer = chronos_tokenizer 62 | self.split = split 63 | 64 | ignore_qa_types = data_args.ignore_qa_types 65 | self.dataset = data_args.dataset 66 | 67 | shuffle = data_args.shuffle 68 | self.data_args = data_args.ts_backbone_config[self.dataset] 69 | self.data_args["default_ts_token"] = data_args.ts_backbone_config["default_ts_token"] 70 | self.data_args["last_token"] = data_args.ts_backbone_config["chronos_model"]["last_token"] 71 | 72 | self.SYS_INST = f"A dialogue between a curious researcher and an AI assistant. The AI analyzes a sensor time-series dataset (N points, {self.data_args['sample_rate']}Hz sampling rate) to answer specific questions. This interaction demonstrates the AI's data analysis skills and the potential of human-AI collaboration in interpreting complex data." 73 | print(f"INSTRUCTION Template: {self.SYS_INST}") 74 | self.ts_data, self.list_data_dict, self.channel_list = self._flatten_data(ignore_qa_types, shuffle) 75 | 76 | print( 77 | f"The dataset size is: {len(self.list_data_dict)}." 78 | ) 79 | 80 | def _flatten_data(self, ignore_qa_types: list, shuffle: bool): 81 | logging.warning("Loading data...") 82 | with open(self.data_path, "rb") as f: 83 | data_file = pickle.load(f) 84 | with open(self.qa_path, "r") as file: 85 | qa_file = json.load(file) 86 | qa_dict = [] 87 | ts_data = [] 88 | channel_list = [] 89 | for d in qa_file["dataset"]: 90 | data_idx = d["index"] 91 | data = data_file[int(data_idx)] 92 | if self.dataset in ["usc-had", "uci"]: 93 | for x_acc, y_acc, z_acc, x_g, y_g, z_g in zip( 94 | d["qa_pairs"]["x-axis accelerometer"], 95 | d["qa_pairs"]["y-axis accelerometer"], 96 | d["qa_pairs"]["z-axis accelerometer"], 97 | d["qa_pairs"]["x-axis gyroscope"], 98 | d["qa_pairs"]["y-axis gyroscope"], 99 | d["qa_pairs"]["z-axis gyroscope"], 100 | ): 101 | assert x_acc["type"] == y_acc["type"] == z_acc["type"] == x_g["type"] == y_g["type"] == z_g[ 102 | "type"], "QA type values error" 103 | # if x_acc["type"] not in ["sub_trend_no_val", "trend_table"]: 104 | if x_acc["type"] not in ignore_qa_types: 105 | x_acc["summary"] = d['summaries']["x-axis accelerometer"] 106 | y_acc["summary"] = d['summaries']["y-axis accelerometer"] 107 | z_acc["summary"] = d['summaries']["z-axis accelerometer"] 108 | x_g["summary"] = d['summaries']["x-axis gyroscope"] 109 | y_g["summary"] = d['summaries']["y-axis gyroscope"] 110 | z_g["summary"] = d['summaries']["z-axis gyroscope"] 111 | qa_dict.append([x_acc, y_acc, z_acc, x_g, y_g, z_g]) 112 | ts_data.append( 113 | [torch.from_numpy(data[:, i]).to(torch.float64) for i in range(data.shape[1])] 114 | ) 115 | channel_list.extend(["x_acc", "y_acc", "z_acc", "x_g", "y_g", "z_g"]) 116 | elif self.dataset == "capture24": 117 | for x_acc, y_acc, z_acc in zip( 118 | d["qa_pairs"]["x-axis accelerometer"], 119 | d["qa_pairs"]["y-axis accelerometer"], 120 | d["qa_pairs"]["z-axis accelerometer"] 121 | ): 122 | assert x_acc["type"] == y_acc["type"] == z_acc["type"], "QA type values error" 123 | if x_acc["type"] not in ignore_qa_types: 124 | x_acc["summary"] = d['summaries']["x-axis accelerometer"] 125 | y_acc["summary"] = d['summaries']["y-axis accelerometer"] 126 | z_acc["summary"] = d['summaries']["z-axis accelerometer"] 127 | qa_dict.append([x_acc, y_acc, z_acc]) 128 | ts_data.append( 129 | [torch.from_numpy(data[:, i]).to(torch.float64) for i in range(data.shape[1])] 130 | ) 131 | channel_list.extend(["x_acc", "y_acc", "z_acc"]) 132 | elif self.dataset == "mhealth": 133 | for c_acc_x, c_acc_y, c_acc_z, la_acc_x, la_acc_y, la_acc_z, la_gs_x, la_gs_y, la_gs_z, rla_acc_x, rla_acc_y, rla_acc_z, rla_gs_x, rla_gs_y, rla_gs_z in zip( 134 | d["qa_pairs"]["chest x-axis accelerometer"], d["qa_pairs"]["chest y-axis accelerometer"], d["qa_pairs"]["chest z-axis accelerometer"], 135 | d["qa_pairs"]["left-ankle x-axis accelerometer"], d["qa_pairs"]["left-ankle y-axis accelerometer"], d["qa_pairs"]["left-ankle z-axis accelerometer"], 136 | d["qa_pairs"]["left-ankle x-axis gyroscope"], d["qa_pairs"]["left-ankle y-axis gyroscope"], d["qa_pairs"]["left-ankle z-axis gyroscope"], 137 | d["qa_pairs"]["right-lower-arm x-axis accelerometer"], d["qa_pairs"]["right-lower-arm y-axis accelerometer"], d["qa_pairs"]["right-lower-arm z-axis accelerometer"], 138 | d["qa_pairs"]["right-lower-arm x-axis gyroscope"], d["qa_pairs"]["right-lower-arm y-axis gyroscope"], d["qa_pairs"]["right-lower-arm z-axis gyroscope"] 139 | ): 140 | assert c_acc_x["type"] == c_acc_y["type"] == c_acc_z["type"] == la_acc_x["type"] == la_acc_y[ 141 | "type"] == la_acc_z["type"] == la_gs_x["type"] == la_gs_y["type"] == la_gs_z[ 142 | "type"] == rla_acc_x["type"] == rla_acc_y["type"] == rla_acc_z["type"] == rla_gs_x[ 143 | "type"] == rla_gs_y["type"] == rla_gs_z["type"], "QA type values error" 144 | 145 | if c_acc_x["type"] not in ignore_qa_types: 146 | c_acc_x["summary"] = d['summaries']["chest x-axis accelerometer"] 147 | c_acc_y["summary"] = d['summaries']["chest y-axis accelerometer"] 148 | c_acc_z["summary"] = d['summaries']["chest z-axis accelerometer"] 149 | la_acc_x["summary"] = d['summaries']["left-ankle x-axis accelerometer"] 150 | la_acc_y["summary"] = d['summaries']["left-ankle y-axis accelerometer"] 151 | la_acc_z["summary"] = d['summaries']["left-ankle x-axis accelerometer"] 152 | la_gs_x["summary"] = d['summaries']["left-ankle x-axis gyroscope"] 153 | la_gs_y["summary"] = d['summaries']["left-ankle y-axis gyroscope"] 154 | la_gs_z["summary"] = d['summaries']["left-ankle z-axis gyroscope"] 155 | rla_acc_x["summary"] = d['summaries']["right-lower-arm x-axis accelerometer"] 156 | rla_acc_y["summary"] = d['summaries']["right-lower-arm y-axis accelerometer"] 157 | rla_acc_z["summary"] = d['summaries']["right-lower-arm z-axis accelerometer"] 158 | rla_gs_x["summary"] = d['summaries']["right-lower-arm x-axis gyroscope"] 159 | rla_gs_y["summary"] = d['summaries']["right-lower-arm y-axis gyroscope"] 160 | rla_gs_z["summary"] = d['summaries']["right-lower-arm z-axis gyroscope"] 161 | qa_dict.append([c_acc_x, c_acc_y, c_acc_z, la_acc_x, la_acc_y, la_acc_z, la_gs_x, la_gs_y, la_gs_z, rla_acc_x, rla_acc_y, rla_acc_z, rla_gs_x, rla_gs_y, rla_gs_z]) 162 | ts_data.append( 163 | [torch.from_numpy(data[:, i]).to(torch.float64) for i in range(data.shape[1])] 164 | ) 165 | channel_list.extend(["c_acc_x", "c_acc_y", "c_acc_z", "la_acc_x", "la_acc_y", "la_acc_z", "la_gs_x", "la_gs_y", "la_gs_z", "rla_acc_x", "rla_acc_y", "rla_acc_z", "rla_gs_x", "rla_gs_y", "rla_gs_z"]) 166 | elif self.dataset == "pamap" or self.dataset == "pamap50": 167 | for acc_hand_x, acc_hand_y, acc_hand_z, gyr_hand_x, gyr_hand_y, gyr_hand_z, mag_hand_x, mag_hand_y, \ 168 | mag_hand_z, acc_chest_x, acc_chest_y, acc_chest_z, gyr_chest_x, gyr_chest_y, gyr_chest_z, \ 169 | mag_chest_x, mag_chest_y, mag_chest_z, acc_ankle_x, acc_ankle_y, acc_ankle_z, gyr_ankle_x, \ 170 | gyr_ankle_y, gyr_ankle_z, mag_ankle_x, mag_ankle_y, mag_ankle_z in zip( 171 | d["qa_pairs"]["hand x-axis accelerometer"], d["qa_pairs"]["hand y-axis accelerometer"], d["qa_pairs"]["hand z-axis accelerometer"], d["qa_pairs"]["hand x-axis gyroscope"], d["qa_pairs"]["hand y-axis gyroscope"], 172 | d["qa_pairs"]["hand z-axis gyroscope"],d["qa_pairs"]["hand x-axis magnetometer"], d["qa_pairs"]["hand y-axis magnetometer"], d["qa_pairs"]["hand z-axis magnetometer"],d["qa_pairs"]["chest x-axis accelerometer"], 173 | d["qa_pairs"]["chest y-axis accelerometer"], d["qa_pairs"]["chest z-axis accelerometer"],d["qa_pairs"]["chest x-axis gyroscope"], d["qa_pairs"]["chest y-axis gyroscope"], d["qa_pairs"]["chest z-axis gyroscope"], 174 | d["qa_pairs"]["chest x-axis magnetometer"], d["qa_pairs"]["chest y-axis magnetometer"], d["qa_pairs"]["chest z-axis magnetometer"],d["qa_pairs"]["ankle x-axis accelerometer"], d["qa_pairs"]["ankle y-axis accelerometer"], 175 | d["qa_pairs"]["ankle z-axis accelerometer"],d["qa_pairs"]["ankle x-axis gyroscope"], d["qa_pairs"]["ankle y-axis gyroscope"], d["qa_pairs"]["ankle z-axis gyroscope"],d["qa_pairs"]["ankle x-axis magnetometer"], 176 | d["qa_pairs"]["ankle y-axis magnetometer"], d["qa_pairs"]["ankle z-axis magnetometer"] 177 | 178 | ): 179 | assert acc_hand_x["type"] == acc_hand_y["type"] == acc_hand_z["type"] == gyr_hand_x["type"] == \ 180 | gyr_hand_y[ 181 | "type"] == gyr_hand_z["type"] == mag_hand_x["type"] == mag_hand_y["type"] == mag_hand_z[ 182 | "type"] == acc_chest_x["type"] == acc_chest_y["type"] == acc_chest_z["type"] == \ 183 | gyr_chest_x[ 184 | "type"] == gyr_chest_y["type"] == gyr_chest_z["type"] == mag_chest_x["type"] == \ 185 | mag_chest_y[ 186 | "type"] == mag_chest_z["type"] == acc_ankle_x["type"] == acc_ankle_y["type"] == \ 187 | acc_ankle_z[ 188 | "type"] == gyr_ankle_x["type"] == gyr_ankle_y["type"] == gyr_ankle_z["type"] == \ 189 | mag_ankle_x[ 190 | "type"] == mag_ankle_y["type"] == mag_ankle_z["type"], "QA type values error" 191 | 192 | if acc_hand_x["type"] not in ignore_qa_types: 193 | acc_hand_x["summary"] = d['summaries']["hand x-axis accelerometer"] 194 | acc_hand_y["summary"] = d['summaries']["hand y-axis accelerometer"] 195 | acc_hand_z["summary"] = d['summaries']["hand z-axis accelerometer"] 196 | gyr_hand_x["summary"] = d['summaries']["hand x-axis gyroscope"] 197 | gyr_hand_y["summary"] = d['summaries']["hand y-axis gyroscope"] 198 | gyr_hand_z["summary"] = d['summaries']["hand z-axis gyroscope"] 199 | mag_hand_x["summary"] = d['summaries']["hand x-axis magnetometer"] 200 | mag_hand_y["summary"] = d['summaries']["hand y-axis magnetometer"] 201 | mag_hand_z["summary"] = d['summaries']["hand z-axis magnetometer"] 202 | 203 | acc_chest_x["summary"] = d['summaries']["chest x-axis accelerometer"] 204 | acc_chest_y["summary"] = d['summaries']["chest y-axis accelerometer"] 205 | acc_chest_z["summary"] = d['summaries']["chest z-axis accelerometer"] 206 | gyr_chest_x["summary"] = d['summaries']["chest x-axis gyroscope"] 207 | gyr_chest_y["summary"] = d['summaries']["chest y-axis gyroscope"] 208 | gyr_chest_z["summary"] = d['summaries']["chest z-axis gyroscope"] 209 | mag_chest_x["summary"] = d['summaries']["chest x-axis magnetometer"] 210 | mag_chest_y["summary"] = d['summaries']["chest y-axis magnetometer"] 211 | mag_chest_z["summary"] = d['summaries']["chest z-axis magnetometer"] 212 | 213 | acc_ankle_x["summary"] = d['summaries']["ankle x-axis accelerometer"] 214 | acc_ankle_y["summary"] = d['summaries']["ankle y-axis accelerometer"] 215 | acc_ankle_z["summary"] = d['summaries']["ankle z-axis accelerometer"] 216 | gyr_ankle_x["summary"] = d['summaries']["ankle x-axis gyroscope"] 217 | gyr_ankle_y["summary"] = d['summaries']["ankle y-axis gyroscope"] 218 | gyr_ankle_z["summary"] = d['summaries']["ankle z-axis gyroscope"] 219 | mag_ankle_x["summary"] = d['summaries']["ankle x-axis magnetometer"] 220 | mag_ankle_y["summary"] = d['summaries']["ankle y-axis magnetometer"] 221 | mag_ankle_z["summary"] = d['summaries']["ankle z-axis magnetometer"] 222 | 223 | qa_dict.append( 224 | [acc_hand_x, acc_hand_y, acc_hand_z, gyr_hand_x, gyr_hand_y, gyr_hand_z, mag_hand_x, 225 | mag_hand_y, mag_hand_z, 226 | acc_chest_x, acc_chest_y, acc_chest_z, gyr_chest_x, gyr_chest_y, gyr_chest_z, mag_chest_x, 227 | mag_chest_y, mag_chest_z, 228 | acc_ankle_x, acc_ankle_y, acc_ankle_z, gyr_ankle_x, gyr_ankle_y, gyr_ankle_z, mag_ankle_x, 229 | mag_ankle_y, mag_ankle_z]) 230 | ts_data.append( 231 | [torch.from_numpy(data[:, i]).to(torch.float64) for i in range(data.shape[1])] 232 | ) 233 | channel_list.extend([ 234 | "acc_hand_x", "acc_hand_y", "acc_hand_z", 235 | "gyr_hand_x", "gyr_hand_y", "gyr_hand_z", 236 | "mag_hand_x", "mag_hand_y", "mag_hand_z", 237 | "acc_chest_x", "acc_chest_y", "acc_chest_z", 238 | "gyr_chest_x", "gyr_chest_y", "gyr_chest_z", 239 | "mag_chest_x", "mag_chest_y", "mag_chest_z", 240 | "acc_ankle_x", "acc_ankle_y", "acc_ankle_z", 241 | "gyr_ankle_x", "gyr_ankle_y", "gyr_ankle_z", 242 | "mag_ankle_x", "mag_ankle_y", "mag_ankle_z" 243 | ]) 244 | else: 245 | raise ValueError(f"Wrong dataset name in _flatten_data: {self.dataset}") 246 | 247 | 248 | assert len(ts_data) == len(qa_dict), "ts_data, qa_dict shape mismatched" 249 | 250 | if shuffle: 251 | print("Shuffling data...") 252 | random.seed(RDM_SEED) 253 | indexes = list(range(len(qa_dict))) 254 | random.shuffle(indexes) 255 | 256 | qa_dict = [qa_dict[i] for i in indexes] 257 | ts_data = [ts_data[i] for i in indexes] 258 | 259 | qa_dict = [item for sublist in qa_dict for item in sublist] 260 | ts_data = [item for sublist in ts_data for item in sublist] 261 | 262 | return ts_data, qa_dict, channel_list 263 | 264 | def __len__(self): 265 | return len(self.list_data_dict) 266 | 267 | def __getitem__(self, index): 268 | sources = self.list_data_dict[index] 269 | channels = self.channel_list[index] 270 | ts = self.ts_data[index] # 1 * L 271 | if isinstance(index, int): 272 | sources = [sources] 273 | channels = [channels] 274 | ts = [ts] 275 | assert len(sources) == 1, "sources should be a list" 276 | 277 | sources = preprocess_time_series2( 278 | copy.deepcopy(sources), copy.deepcopy(channels), copy.deepcopy(ts), self.dataset, self.data_args 279 | ) 280 | 281 | ts_token_ids_list = [] 282 | ts_attention_mask_list = [] 283 | ts_tokenizer_state_list = [] 284 | for context in ts: 285 | if isinstance(context, list): 286 | context = left_pad_and_stack_1D(context) 287 | assert isinstance(context, torch.Tensor) 288 | if context.ndim == 1: 289 | context = context.unsqueeze(0) 290 | assert context.ndim == 2 291 | 292 | ts_token_ids, ts_attention_mask, ts_tokenizer_state = ( 293 | self.chronos_tokenizer.context_input_transform(context) 294 | ) 295 | ts_token_ids_list.append(ts_token_ids) 296 | ts_attention_mask_list.append(ts_attention_mask) 297 | ts_tokenizer_state_list.append(ts_tokenizer_state) 298 | 299 | if self.tokenizer is None: 300 | data_dict = dict( 301 | question=sources[0]["Q"], 302 | ground_truth=sources[0]["A"], 303 | type=sources[0]["type"], 304 | ts_token_ids=ts_token_ids_list[0], 305 | ts_attention_mask=ts_attention_mask_list[0], 306 | ts_tokenizer_state=ts_tokenizer_state_list[0] 307 | ) 308 | return data_dict 309 | 310 | data_dict = preprocess(sources, self.tokenizer, self.SYS_INST, self.split, "Q", "Q") 311 | 312 | data_dict = dict(input_ids=data_dict["input_ids"][0], 313 | input_texts=sources[0]["Q"], 314 | answer=sources[0]["A"], 315 | labels=data_dict["labels"][0], 316 | ts_token_ids=ts_token_ids_list[0], 317 | ts_attention_mask=ts_attention_mask_list[0], 318 | ts_tokenizer_state=ts_tokenizer_state_list[0]) 319 | 320 | return data_dict 321 | 322 | 323 | @dataclass 324 | class DataCollatorForTsTextDataset(object): 325 | """Collate examples for supervised fine-tuning.""" 326 | 327 | tokenizer: transformers.PreTrainedTokenizer 328 | 329 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 330 | input_ids, labels, ts_token_ids, ts_attention_mask, ts_tokenizer_state = tuple( 331 | [instance[key] for instance in instances] 332 | for key in ("input_ids", "labels", "ts_token_ids", "ts_attention_mask", "ts_tokenizer_state") 333 | ) 334 | 335 | input_ids = torch.nn.utils.rnn.pad_sequence( 336 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id 337 | ) 338 | labels = torch.nn.utils.rnn.pad_sequence( 339 | labels, batch_first=True, padding_value=IGNORE_INDEX 340 | ) 341 | 342 | return dict( 343 | input_ids=input_ids, 344 | labels=labels, 345 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id), 346 | ts_token_ids=ts_token_ids, # return as list 347 | ts_attention_mask=ts_attention_mask, 348 | ts_tokenizer_state=ts_tokenizer_state 349 | ) 350 | 351 | 352 | def make_ts_text_data_module( 353 | tokenizer: transformers.PreTrainedTokenizer, chronos_tokenizer, data_args 354 | ) -> Dict: 355 | """Make dataset and collator for supervised fine-tuning.""" 356 | data_collator = DataCollatorForTsTextDataset(tokenizer=tokenizer) 357 | train_dataset = UniChannelTimeSeriesDataset( 358 | data_path=data_args.data_path, qa_path=data_args.qa_path, tokenizer=tokenizer, chronos_tokenizer=chronos_tokenizer, split="train", data_args=data_args 359 | ) 360 | eval_dataset = UniChannelTimeSeriesDataset( 361 | data_path=data_args.eval_data_path, qa_path=data_args.eval_qa_path, tokenizer=tokenizer, chronos_tokenizer=chronos_tokenizer, split="train", data_args=data_args 362 | ) 363 | for i in range(2): 364 | print(f"data example {i}:\nInput Text: {train_dataset[i]['input_texts']}\nAnswer: {train_dataset[i]['answer']}\nInput ids: {train_dataset[i]['input_ids']}\nTS token ids: {train_dataset[i]['ts_token_ids']}\n") 365 | 366 | return dict( 367 | train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=data_collator 368 | ) 369 | -------------------------------------------------------------------------------- /sensorllm/data/stage2_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Sequence 2 | from torch.utils.data import Dataset 3 | import numpy as np 4 | import pandas as pd 5 | import random 6 | import copy 7 | import json 8 | import re 9 | import pickle 10 | import logging 11 | from dataclasses import dataclass 12 | 13 | from sensorllm.data.utils import generate_chat_template, preprocess, preprocess_cls, get_token_list 14 | from sensorllm.model.chronos_model import * 15 | import transformers 16 | 17 | import torch 18 | 19 | IGNORE_INDEX = -100 20 | RDM_SEED = 42 21 | 22 | 23 | def preprocess_time_series_stage2( 24 | sources: Sequence[Dict[str, str]], # [{"Q": "...", "A": "..."}] 25 | added_str: str, 26 | ) -> Sequence[Dict[str, str]]: 27 | modified_sources = [] 28 | pattern = r'\b\d+\.\s+[A-Za-z\s]+\.' 29 | for index, source in enumerate(sources): 30 | modified_q = added_str + source["Q"] 31 | 32 | matches = re.findall(pattern, source["A"]) 33 | assert len(matches) == 1 34 | cot = source["A"].replace(matches[-1], '') 35 | 36 | modified_sources.append({"Q": modified_q, "A": source["A"], "cot": cot.strip(), "ground_truth": matches[-1]}) 37 | return modified_sources 38 | 39 | 40 | def preprocess_time_series_CLS_stage2( 41 | sources: Sequence[Dict[str, str]], # [{"Q": "...", "A": "..."}] 42 | ) -> Sequence[Dict[str, str]]: 43 | modified_sources = [] 44 | for index, source in enumerate(sources): 45 | modified_sources.append({ 46 | "Q": source.get("Q", ""), 47 | "smry": source.get("smry", ""), 48 | "trend_text": source.get("trend_text", ""), 49 | "corr_text": source.get("corr_text", ""), 50 | "info_text": source.get("info_text", ""), 51 | "answer": source.get("A", ""), 52 | "label": source.get("label", "") 53 | }) 54 | return modified_sources 55 | 56 | 57 | class MultiChannelTimeSeriesDatasetStage2(Dataset): 58 | def __init__(self, data_path=None, qa_path=None, tokenizer=None, chronos_tokenizer=None, split=None, data_args=None): 59 | super(MultiChannelTimeSeriesDatasetStage2, self).__init__() 60 | self.data_path = data_path 61 | self.qa_path = qa_path 62 | self.tokenizer = tokenizer 63 | self.chronos_tokenizer = chronos_tokenizer 64 | self.split = split 65 | self.preprocess_type = data_args.preprocess_type 66 | self.preprocess_type_eval = None if tokenizer is None else data_args.preprocess_type_eval 67 | 68 | shuffle = data_args.shuffle 69 | dataset = data_args.dataset 70 | 71 | if dataset == 'usc-had': 72 | self.SYS_INST = "The assistant is provided with time-series readings of six sensor channels, including three accelerometer channels (in g) and three gyroscope channels (in dps). Each channel contains 200 data representing information extracted from the same 2-second time window at a sampling rate of 100Hz. Please analyze the trends and patterns in each channel to identify the correct activity type from the following twelve activity options:\n\n1. Walking Forward\n2. Walking Left\n3. Walking Right\n4. Walking Upstairs\n5. Walking Downstairs\n6. Running Forward\n7. Jumping\n8. Sitting\n9. Standing\n10. Sleeping\n11. Elevator Up\n12. Elevator Down\n\nProvide the predicted activity as both the number and the name at the end." 73 | elif dataset == 'capture24': 74 | self.SYS_INST = "The assistant is provided with time-series readings of three accelerometer channels (in g). Each channel contains 500 data representing information extracted from the same 10-second time window at a sampling rate of 50Hz. Please analyze the trends and patterns in each channel to identify the correct activity type from the following 10 activity options:\n\n1. sleep\n2. sitting\n3. household-chores\n4. walking\n5. vehicle\n6. bicycling\n7. mixed-activity\n8. standing\n9. manual-work\n10. sports\n\nProvide the predicted activity as both the number and the name at the end." 75 | elif dataset == 'mhealth': 76 | self.SYS_INST = "The assistant is provided with time-series readings of 15 sensor channels, including acceleration sensors (in m/s^2) and gyroscope sensors (in deg/s). Each channel contains 100 data representing information extracted from the same 2-second time window at a sampling rate of 50Hz. Please analyze the trends and patterns in each channel to identify the correct activity type from the following twelve activity options:\n\n1. Standing still\n2. Sitting and relaxing\n3. Lying down\n4. Walking\n5. Climbing stairs\n6. Waist bends forward\n7. Frontal elevation of arms\n8. Knees bending (crouching)\n9. Cycling\n10. Jogging\n11. Running\n12. Jump front & back\n\nProvide the predicted activity as both the number and the name at the end." 77 | elif dataset == 'pamap50': 78 | self.SYS_INST = "The assistant is provided with time-series readings of 27 sensor channels, including acceleration sensors (in m/s^2) and gyroscope sensors (in rad/s) and magnetometer sensors (in μT). Each channel contains 100 data representing information extracted from the same 2-second time window at a sampling rate of 50Hz. Please analyze the trends and patterns in each channel to identify the correct activity type from the following twelve activity options:\n\n1. lying\n2. sitting\n3. standing\n4. walking\n5. running\n6. cycling\n7. Nordic walking\n8. ascending stairs\n9. descending stairs\n10. vacuum cleaning\n11. ironing\n12. rope jumping\n\nProvide the predicted activity as both the number and the name at the end." 79 | elif dataset == 'pamap': 80 | self.SYS_INST = "The assistant is provided with time-series readings of 27 sensor channels, including acceleration sensors (in m/s^2) and gyroscope sensors (in rad/s) and magnetometer sensors (in μT). Each channel contains 100 data representing information extracted from the same 2-second time window at a sampling rate of 100Hz. Please analyze the trends and patterns in each channel to identify the correct activity type from the following twelve activity options:\n\n1. lying\n2. sitting\n3. standing\n4. walking\n5. running\n6. cycling\n7. Nordic walking\n8. ascending stairs\n9. descending stairs\n10. vacuum cleaning\n11. ironing\n12. rope jumping\n\nProvide the predicted activity as both the number and the name at the end." 81 | elif dataset == 'uci': 82 | self.SYS_INST = "The assistant is provided with time-series readings of six sensor channels, including three accelerometer channels (in g) and three gyroscope channels (in dps). Each channel contains 128 data representing information extracted from the same 2.56-second time window at a sampling rate of 50Hz. Please analyze the trends and patterns in each channel to identify the correct activity type from the following twelve activity options:\n\n1. Walking Forward\n2. Walking Upstairs\n3. Walking Downstairs\n4. Sitting\n5. Standing\n6. Laying\n\nProvide the predicted activity as both the number and the name at the end." 83 | else: 84 | raise ValueError(f"Wrong dataset name in __init__: {dataset}") 85 | 86 | self.ts_data, self.list_data_dict = self._flatten_data(shuffle) 87 | 88 | self.data_args = data_args.ts_backbone_config 89 | 90 | self.window_length = len(self.ts_data[0][0]) 91 | self.channel_num = len(self.ts_data[0]) 92 | assert self.channel_num == self.data_args[dataset]["channel_num"], "channel_num, data_args.channel_num shape mismatched" 93 | 94 | print( 95 | f"The dataset size is: {len(self.ts_data)}. Window size: {self.window_length}. Channel num: {self.channel_num}." 96 | ) 97 | 98 | if self.data_args["chronos_model"]["last_token"]: 99 | added_token = self.data_args["default_ts_token"] * (self.window_length + 1) 100 | else: 101 | added_token = self.data_args["default_ts_token"] * self.window_length 102 | 103 | start_tokens_list, end_tokens_list = get_token_list(dataset, self.data_args[dataset], data_args.add_ts_special_token_text) 104 | 105 | added_str = '' 106 | for start_token, end_token in zip(start_tokens_list, end_tokens_list): 107 | added_str += start_token + added_token + end_token + '\n' 108 | self.added_str = added_str 109 | 110 | def _flatten_data(self, shuffle: bool): 111 | logging.warning(f"Loading {self.split} data...") 112 | with open(self.data_path, "rb") as f: 113 | data_file = pickle.load(f) 114 | with open(self.qa_path, "r") as file: 115 | qa_file = json.load(file) 116 | data_file = np.array(data_file, dtype=np.float64) 117 | ts_data = [] 118 | qa_dict = [] 119 | assert len(data_file) == len(qa_file["dataset"]) 120 | for q in qa_file["dataset"]: 121 | data_idx = q["index"] 122 | data = data_file[int(data_idx)] 123 | ts_data.append([torch.from_numpy(data[:, i]).to(torch.float64) for i in range(data.shape[1])]) 124 | qa_dict.append(q["qa_pair"]) 125 | assert len(ts_data) == len(qa_dict), "ts_data, qa_dict, length not matched" 126 | 127 | if shuffle: 128 | print("Shuffling data...") 129 | random.seed(RDM_SEED) 130 | indexes = list(range(len(ts_data))) 131 | random.shuffle(indexes) 132 | ts_data = [ts_data[i] for i in indexes] 133 | qa_dict = [qa_dict[i] for i in indexes] 134 | # 135 | # if self.split == "eval": 136 | # return ts_data[:100], qa_dict[:100] 137 | 138 | return ts_data, qa_dict 139 | 140 | def __len__(self): 141 | return len(self.ts_data) 142 | 143 | def __getitem__(self, index): 144 | sources = self.list_data_dict[index] # {"Q": ..., "A": ...} 145 | multichannel_ts = self.ts_data[index] # C * L, 6 * 200 146 | 147 | if isinstance(index, int): 148 | sources = [sources] 149 | multichannel_ts = [multichannel_ts] 150 | 151 | assert ( 152 | len(sources) == 1 153 | ), "sources should be a list" 154 | 155 | sources = preprocess_time_series_stage2( 156 | copy.deepcopy(sources), self.added_str 157 | ) 158 | 159 | mts_token_ids_list = [] 160 | mts_attention_mask_list = [] 161 | mts_tokenizer_state_list = [] 162 | for ts in multichannel_ts: 163 | context = torch.stack(ts) 164 | if isinstance(context, list): 165 | context = left_pad_and_stack_1D(context) 166 | assert isinstance(context, torch.Tensor) 167 | if context.ndim == 1: 168 | context = context.unsqueeze(0) 169 | assert context.ndim == 2 170 | 171 | mts_token_ids, mts_attention_mask, mts_tokenizer_state = ( 172 | self.chronos_tokenizer.context_input_transform(context) 173 | ) 174 | mts_token_ids_list.append(mts_token_ids) 175 | mts_attention_mask_list.append(mts_attention_mask) 176 | mts_tokenizer_state_list.append(mts_tokenizer_state) 177 | 178 | if self.tokenizer is None: 179 | data_dict = dict( 180 | question=sources[0]["Q"], 181 | answer=sources[0]["A"], 182 | cot=sources[0]["cot"], 183 | ground_truth=sources[0]["ground_truth"], 184 | mts_token_ids=mts_token_ids_list[0], 185 | mts_attention_mask=mts_attention_mask_list[0], 186 | mts_tokenizer_state=mts_tokenizer_state_list[0] 187 | ) 188 | return data_dict 189 | 190 | data_dict = preprocess(sources, self.tokenizer, self.SYS_INST, self.split, self.preprocess_type, self.preprocess_type_eval) 191 | 192 | data_dict = dict(input_ids=data_dict["input_ids"][0], 193 | labels=data_dict["labels"][0], 194 | mts_token_ids=mts_token_ids_list[0], 195 | mts_attention_mask=mts_attention_mask_list[0], 196 | mts_tokenizer_state=mts_tokenizer_state_list[0]) 197 | return data_dict 198 | 199 | 200 | @dataclass 201 | class DataCollatorForTsTextDatasetStage2(object): 202 | """Collate examples for supervised fine-tuning.""" 203 | 204 | tokenizer: transformers.PreTrainedTokenizer 205 | 206 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 207 | input_ids, labels, mts_token_ids, mts_attention_mask, mts_tokenizer_state = tuple( 208 | [instance[key] for instance in instances] 209 | for key in ("input_ids", "labels", "mts_token_ids", "mts_attention_mask", "mts_tokenizer_state") 210 | ) 211 | 212 | input_ids = torch.nn.utils.rnn.pad_sequence( 213 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id 214 | ) 215 | labels = torch.nn.utils.rnn.pad_sequence( 216 | labels, batch_first=True, padding_value=IGNORE_INDEX 217 | ) 218 | 219 | return dict( 220 | input_ids=input_ids, 221 | labels=labels, 222 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id), 223 | mts_token_ids=torch.stack(mts_token_ids), 224 | mts_attention_mask=torch.stack(mts_attention_mask), 225 | mts_tokenizer_state=mts_tokenizer_state 226 | ) 227 | 228 | 229 | def make_ts_text_data_module_stage2( 230 | tokenizer: transformers.PreTrainedTokenizer, chronos_tokenizer, data_args 231 | ) -> Dict: 232 | """Make dataset and collator for supervised fine-tuning.""" 233 | data_collator = DataCollatorForTsTextDatasetStage2(tokenizer=tokenizer) 234 | train_dataset = MultiChannelTimeSeriesDatasetStage2( 235 | data_path=data_args.data_path, qa_path=data_args.qa_path, tokenizer=tokenizer, chronos_tokenizer=chronos_tokenizer, split="train", data_args=data_args 236 | ) 237 | eval_dataset = MultiChannelTimeSeriesDatasetStage2( 238 | data_path=data_args.eval_data_path, qa_path=data_args.eval_qa_path, tokenizer=tokenizer, chronos_tokenizer=chronos_tokenizer, split="eval", data_args=data_args 239 | ) 240 | return dict( 241 | train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=data_collator 242 | ) 243 | 244 | 245 | @dataclass 246 | class DataCollatorForTsCLSDatasetStage2(object): 247 | """Collate examples for supervised fine-tuning.""" 248 | 249 | tokenizer: transformers.PreTrainedTokenizer 250 | 251 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 252 | input_ids, labels, mts_token_ids, mts_attention_mask, mts_tokenizer_state = tuple( 253 | [instance[key] for instance in instances] 254 | for key in ("input_ids", "labels", "mts_token_ids", "mts_attention_mask", "mts_tokenizer_state") 255 | ) 256 | 257 | input_ids = torch.nn.utils.rnn.pad_sequence( 258 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id 259 | ) 260 | 261 | return dict( 262 | input_ids=input_ids, 263 | labels=torch.tensor(labels), 264 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id), 265 | mts_token_ids=torch.stack(mts_token_ids), 266 | mts_attention_mask=torch.stack(mts_attention_mask), 267 | mts_tokenizer_state=mts_tokenizer_state 268 | ) 269 | 270 | 271 | class MultiChannelTimeSeriesCLSDatasetStage2(Dataset): 272 | def __init__(self, data_path=None, qa_path=None, tokenizer=None, chronos_tokenizer=None, split=None, label2id=None, data_args=None): 273 | super(MultiChannelTimeSeriesCLSDatasetStage2, self).__init__() 274 | self.data_path = data_path 275 | self.qa_path = qa_path 276 | self.tokenizer = tokenizer 277 | self.chronos_tokenizer = chronos_tokenizer 278 | self.split = split 279 | self.label2id = label2id 280 | self.preprocess_type = data_args.preprocess_type 281 | 282 | shuffle = data_args.shuffle 283 | dataset = data_args.dataset 284 | 285 | 286 | self.ts_data, self.list_data_dict, self.class_weights = self._flatten_data(shuffle) 287 | 288 | self.data_args = data_args.ts_backbone_config 289 | 290 | self.window_length = len(self.ts_data[0][0]) 291 | self.channel_num = len(self.ts_data[0]) 292 | assert self.channel_num == self.data_args[dataset][ 293 | "channel_num"], "channel_num, data_args.channel_num shape mismatched" 294 | 295 | print( 296 | f"The dataset size is: {len(self.ts_data)}. Window size: {self.window_length}. Channel num: {self.channel_num}." 297 | ) 298 | 299 | if self.data_args["chronos_model"]["last_token"]: 300 | added_token = self.data_args["default_ts_token"] * (self.window_length + 1) 301 | else: 302 | added_token = self.data_args["default_ts_token"] * self.window_length 303 | 304 | start_tokens_list, end_tokens_list = get_token_list(dataset, self.data_args[dataset], data_args.add_ts_special_token_text) 305 | 306 | added_str = '' 307 | for start_token, end_token in zip(start_tokens_list, end_tokens_list): 308 | added_str += start_token + added_token + end_token + '\n' 309 | self.added_str = added_str 310 | 311 | def _flatten_data(self, shuffle: bool): 312 | logging.warning(f"Loading {self.split} data...") 313 | with open(self.data_path, "rb") as f: 314 | data_file = pickle.load(f) 315 | with open(self.qa_path, "r") as file: 316 | qa_file = json.load(file) 317 | ts_data = [] 318 | qa_dict = [] 319 | label_list = [] 320 | assert len(data_file) == len(qa_file["dataset"]) 321 | for q in qa_file["dataset"]: 322 | data_idx = q["index"] 323 | data = data_file[int(data_idx)] 324 | ts_data.append([torch.from_numpy(data[:, i]).to(torch.float64) for i in range(data.shape[1])]) 325 | answer = q["qa_pair"]["A"] 326 | try: 327 | label = int(self.label2id[answer]) 328 | except KeyError: 329 | raise ValueError(f"Text '{answer}' not found in label2id dictionary") 330 | label_list.append(label) 331 | q["qa_pair"]['label'] = label 332 | qa_dict.append(q["qa_pair"]) 333 | # assert len(ts_data[0]) == 6, "ts_data channel length error" 334 | # assert len(ts_data[0][0]) == 200, "ts_data length error" 335 | # assert len(ts_data[0][1]) == 200, "ts_data length error" 336 | # assert len(ts_data[0][2]) == 200, "ts_data length error" 337 | assert len(ts_data) == len(qa_dict) == len(label_list), "ts_data, qa_dict, label_list, length not matched" 338 | 339 | class_weights = None 340 | if self.split == 'train': 341 | label_series = pd.Series(label_list) 342 | value_counts = label_series.value_counts(normalize=True) 343 | class_weights = (1 / value_counts.sort_index()).tolist() 344 | class_weights = torch.tensor(class_weights) 345 | class_weights = class_weights / class_weights.sum() 346 | if shuffle: 347 | print("Shuffling data...") 348 | random.seed(RDM_SEED) 349 | indexes = list(range(len(ts_data))) 350 | random.shuffle(indexes) 351 | ts_data = [ts_data[i] for i in indexes] 352 | qa_dict = [qa_dict[i] for i in indexes] 353 | 354 | # if self.split == "eval": 355 | # return ts_data[:100], qa_dict[:100] 356 | 357 | return ts_data, qa_dict, class_weights 358 | 359 | def __len__(self): 360 | return len(self.ts_data) 361 | 362 | def get_class_weights(self): 363 | return self.class_weights 364 | 365 | def __getitem__(self, index): 366 | sources = self.list_data_dict[index] # {"Q": ..., "A": ...} 367 | multichannel_ts = self.ts_data[index] # C * L, 6 * 200 368 | 369 | if isinstance(index, int): 370 | sources = [sources] 371 | multichannel_ts = [multichannel_ts] 372 | 373 | assert ( 374 | len(sources) == 1 375 | ), "sources should be a list" 376 | 377 | sources = preprocess_time_series_CLS_stage2( 378 | copy.deepcopy(sources) 379 | ) 380 | 381 | mts_token_ids_list = [] 382 | mts_attention_mask_list = [] 383 | mts_tokenizer_state_list = [] 384 | for ts in multichannel_ts: 385 | context = torch.stack(ts) 386 | if isinstance(context, list): 387 | context = left_pad_and_stack_1D(context) 388 | assert isinstance(context, torch.Tensor) 389 | if context.ndim == 1: 390 | context = context.unsqueeze(0) 391 | assert context.ndim == 2 392 | 393 | mts_token_ids, mts_attention_mask, mts_tokenizer_state = ( 394 | self.chronos_tokenizer.context_input_transform(context) 395 | ) 396 | mts_token_ids_list.append(mts_token_ids) 397 | mts_attention_mask_list.append(mts_attention_mask) 398 | mts_tokenizer_state_list.append(mts_tokenizer_state) 399 | 400 | if self.tokenizer is None: 401 | data_dict = dict( 402 | added_str=self.added_str, 403 | question=sources[0]["Q"], 404 | smry=sources[0]["smry"], 405 | trend_text=sources[0]["trend_text"], 406 | corr_text=sources[0]["corr_text"], 407 | info_text=sources[0]["info_text"], 408 | answer=sources[0]["answer"], 409 | label=sources[0]["label"], 410 | mts_token_ids=mts_token_ids_list[0], 411 | mts_attention_mask=mts_attention_mask_list[0], 412 | mts_tokenizer_state=mts_tokenizer_state_list[0] 413 | ) 414 | return data_dict 415 | 416 | data_dict = preprocess_cls(sources, self.tokenizer, self.added_str, self.preprocess_type) 417 | 418 | data_dict = dict(input_ids=data_dict["input_ids"][0], 419 | input_texts=data_dict["input_texts"][0], 420 | labels=sources[0]["label"], 421 | answer=sources[0]["answer"], 422 | mts_token_ids=mts_token_ids_list[0], 423 | mts_attention_mask=mts_attention_mask_list[0], 424 | mts_tokenizer_state=mts_tokenizer_state_list[0]) 425 | 426 | return data_dict 427 | 428 | 429 | def make_ts_classification_data_module_stage2( 430 | tokenizer: transformers.PreTrainedTokenizer, chronos_tokenizer, label2id, data_args 431 | ) -> Dict: 432 | """Make dataset and collator for supervised fine-tuning.""" 433 | data_collator = DataCollatorForTsCLSDatasetStage2(tokenizer=tokenizer) 434 | train_dataset = MultiChannelTimeSeriesCLSDatasetStage2( 435 | data_path=data_args.data_path, qa_path=data_args.qa_path, tokenizer=tokenizer, chronos_tokenizer=chronos_tokenizer, split="train", label2id=label2id, data_args=data_args 436 | ) 437 | class_weights = train_dataset.get_class_weights() 438 | assert class_weights is not None, "class_weights should not be None" 439 | 440 | eval_dataset = MultiChannelTimeSeriesCLSDatasetStage2( 441 | data_path=data_args.eval_data_path, qa_path=data_args.eval_qa_path, tokenizer=tokenizer, chronos_tokenizer=chronos_tokenizer, split="eval", label2id=label2id, data_args=data_args 442 | ) 443 | 444 | for i in range(2): 445 | print(f"data example {i}:\nInput Text: {train_dataset[i]['input_texts']}\nLabel: {train_dataset[i]['labels']}\nAnswer: {train_dataset[i]['answer']}\nInput ids: {train_dataset[i]['input_ids']}\n") 446 | 447 | return dict( 448 | train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=data_collator, class_weights=class_weights 449 | ) -------------------------------------------------------------------------------- /sensorllm/data/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Sequence 2 | import transformers 3 | import copy 4 | 5 | IGNORE_INDEX = -100 6 | 7 | 8 | def generate_chat_template(messages, bos_token, eos_token, add_generation_prompt=False): 9 | LLAMA_3_CHAT_TEMPLATE = ( 10 | "{% set loop_messages = messages %}" 11 | "{% for message in loop_messages %}" 12 | "{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}" 13 | "{% if loop.index0 == 0 %}" 14 | "{% set content = bos_token + content %}" 15 | "{% endif %}" 16 | "{{ content }}" 17 | "{% endfor %}" 18 | "{% if add_generation_prompt %}" 19 | "{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}" 20 | "{% endif %}" 21 | ) 22 | 23 | from jinja2 import Template 24 | template = Template(LLAMA_3_CHAT_TEMPLATE) 25 | return template.render(messages=messages, bos_token=bos_token, eos_token=eos_token, add_generation_prompt=add_generation_prompt) 26 | 27 | 28 | def generate_chat_template2(messages, bos_token, eos_token, add_generation_prompt=False): 29 | LLAMA_3_CHAT_TEMPLATE = ( 30 | "{% set loop_messages = messages %}" 31 | "{% for message in loop_messages %}" 32 | "{% if loop.last %}" 33 | "{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim %}" 34 | "{% else %}" 35 | "{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}" 36 | "{% endif %}" 37 | "{% if loop.index0 == 0 %}" 38 | "{% set content = bos_token + content %}" 39 | "{% endif %}" 40 | "{{ content }}" 41 | "{% endfor %}" 42 | "{% if add_generation_prompt %}" 43 | "{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}" 44 | "{% endif %}" 45 | ) 46 | 47 | from jinja2 import Template 48 | template = Template(LLAMA_3_CHAT_TEMPLATE) 49 | return template.render(messages=messages, bos_token=bos_token, eos_token=eos_token, add_generation_prompt=add_generation_prompt) 50 | 51 | 52 | def _tokenize_fn( 53 | conversations: Sequence[str], tokenizer: transformers.PreTrainedTokenizer 54 | ) -> Dict: 55 | """Tokenize a list of strings.""" 56 | tokenized_list = [ 57 | tokenizer( 58 | conv, 59 | return_tensors="pt", 60 | padding="longest", 61 | max_length=tokenizer.model_max_length, 62 | truncation=True, 63 | ) 64 | for conv in conversations 65 | ] 66 | input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] 67 | input_ids_lens = labels_lens = [ 68 | tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() 69 | for tokenized in tokenized_list 70 | ] 71 | return dict( 72 | input_ids=input_ids, 73 | labels=labels, 74 | input_ids_lens=input_ids_lens, 75 | labels_lens=labels_lens, 76 | ) 77 | 78 | 79 | def get_token_dict(dataset: str, data_args: dict): 80 | if dataset in ["usc-had", "uci"]: 81 | start_tokens_dict = { 82 | "x_acc": data_args["default_x_acc_start_token"], 83 | "y_acc": data_args["default_y_acc_start_token"], 84 | "z_acc": data_args["default_z_acc_start_token"], 85 | "x_g": data_args["default_x_gyro_start_token"], 86 | "y_g": data_args["default_y_gyro_start_token"], 87 | "z_g": data_args["default_z_gyro_start_token"] 88 | } 89 | 90 | end_tokens_dict = { 91 | "x_acc": data_args["default_x_acc_end_token"], 92 | "y_acc": data_args["default_y_acc_end_token"], 93 | "z_acc": data_args["default_z_acc_end_token"], 94 | "x_g": data_args["default_x_gyro_end_token"], 95 | "y_g": data_args["default_y_gyro_end_token"], 96 | "z_g": data_args["default_z_gyro_end_token"] 97 | } 98 | elif dataset == "capture24": 99 | start_tokens_dict = { 100 | "x_acc": data_args["default_x_acc_start_token"], 101 | "y_acc": data_args["default_y_acc_start_token"], 102 | "z_acc": data_args["default_z_acc_start_token"], 103 | } 104 | 105 | end_tokens_dict = { 106 | "x_acc": data_args["default_x_acc_end_token"], 107 | "y_acc": data_args["default_y_acc_end_token"], 108 | "z_acc": data_args["default_z_acc_end_token"], 109 | } 110 | elif dataset == "mhealth": 111 | start_tokens_dict = { 112 | "c_acc_x": data_args["default_chest_x_acc_start_token"], 113 | "c_acc_y": data_args["default_chest_y_acc_start_token"], 114 | "c_acc_z": data_args["default_chest_z_acc_start_token"], 115 | "la_acc_x": data_args["default_left_ankle_x_acc_start_token"], 116 | "la_acc_y": data_args["default_left_ankle_y_acc_start_token"], 117 | "la_acc_z": data_args["default_left_ankle_z_acc_start_token"], 118 | "la_gs_x": data_args["default_left_ankle_x_gyro_start_token"], 119 | "la_gs_y": data_args["default_left_ankle_y_gyro_start_token"], 120 | "la_gs_z": data_args["default_left_ankle_z_gyro_start_token"], 121 | "rla_acc_x": data_args["default_right_lower_arm_x_acc_start_token"], 122 | "rla_acc_y": data_args["default_right_lower_arm_y_acc_start_token"], 123 | "rla_acc_z": data_args["default_right_lower_arm_z_acc_start_token"], 124 | "rla_gs_x": data_args["default_right_lower_arm_x_gyro_start_token"], 125 | "rla_gs_y": data_args["default_right_lower_arm_y_gyro_start_token"], 126 | "rla_gs_z": data_args["default_right_lower_arm_z_gyro_start_token"] 127 | } 128 | end_tokens_dict = { 129 | "c_acc_x": data_args["default_chest_x_acc_end_token"], 130 | "c_acc_y": data_args["default_chest_y_acc_end_token"], 131 | "c_acc_z": data_args["default_chest_z_acc_end_token"], 132 | "la_acc_x": data_args["default_left_ankle_x_acc_end_token"], 133 | "la_acc_y": data_args["default_left_ankle_y_acc_end_token"], 134 | "la_acc_z": data_args["default_left_ankle_z_acc_end_token"], 135 | "la_gs_x": data_args["default_left_ankle_x_gyro_end_token"], 136 | "la_gs_y": data_args["default_left_ankle_y_gyro_end_token"], 137 | "la_gs_z": data_args["default_left_ankle_z_gyro_end_token"], 138 | "rla_acc_x": data_args["default_right_lower_arm_x_acc_end_token"], 139 | "rla_acc_y": data_args["default_right_lower_arm_y_acc_end_token"], 140 | "rla_acc_z": data_args["default_right_lower_arm_z_acc_end_token"], 141 | "rla_gs_x": data_args["default_right_lower_arm_x_gyro_end_token"], 142 | "rla_gs_y": data_args["default_right_lower_arm_y_gyro_end_token"], 143 | "rla_gs_z": data_args["default_right_lower_arm_z_gyro_end_token"] 144 | } 145 | elif dataset == "pamap" or dataset == "pamap50": 146 | start_tokens_dict = { 147 | "acc_hand_x": data_args["default_hand_x_acc_start_token"], 148 | "acc_hand_y": data_args["default_hand_y_acc_start_token"], 149 | "acc_hand_z": data_args["default_hand_z_acc_start_token"], 150 | "gyr_hand_x": data_args["default_hand_x_gyro_start_token"], 151 | "gyr_hand_y": data_args["default_hand_y_gyro_start_token"], 152 | "gyr_hand_z": data_args["default_hand_z_gyro_start_token"], 153 | "mag_hand_x": data_args["default_hand_x_mag_start_token"], 154 | "mag_hand_y": data_args["default_hand_y_mag_start_token"], 155 | "mag_hand_z": data_args["default_hand_z_mag_start_token"], 156 | "acc_chest_x": data_args["default_chest_x_acc_start_token"], 157 | "acc_chest_y": data_args["default_chest_y_acc_start_token"], 158 | "acc_chest_z": data_args["default_chest_z_acc_start_token"], 159 | "gyr_chest_x": data_args["default_chest_x_gyro_start_token"], 160 | "gyr_chest_y": data_args["default_chest_y_gyro_start_token"], 161 | "gyr_chest_z": data_args["default_chest_z_gyro_start_token"], 162 | "mag_chest_x": data_args["default_chest_x_mag_start_token"], 163 | "mag_chest_y": data_args["default_chest_y_mag_start_token"], 164 | "mag_chest_z": data_args["default_chest_z_mag_start_token"], 165 | "acc_ankle_x": data_args["default_ankle_x_acc_start_token"], 166 | "acc_ankle_y": data_args["default_ankle_y_acc_start_token"], 167 | "acc_ankle_z": data_args["default_ankle_z_acc_start_token"], 168 | "gyr_ankle_x": data_args["default_ankle_x_gyro_start_token"], 169 | "gyr_ankle_y": data_args["default_ankle_y_gyro_start_token"], 170 | "gyr_ankle_z": data_args["default_ankle_z_gyro_start_token"], 171 | "mag_ankle_x": data_args["default_ankle_x_mag_start_token"], 172 | "mag_ankle_y": data_args["default_ankle_y_mag_start_token"], 173 | "mag_ankle_z": data_args["default_ankle_z_mag_start_token"], 174 | } 175 | end_tokens_dict = { 176 | "acc_hand_x": data_args["default_hand_x_acc_end_token"], 177 | "acc_hand_y": data_args["default_hand_y_acc_end_token"], 178 | "acc_hand_z": data_args["default_hand_z_acc_end_token"], 179 | "gyr_hand_x": data_args["default_hand_x_gyro_end_token"], 180 | "gyr_hand_y": data_args["default_hand_y_gyro_end_token"], 181 | "gyr_hand_z": data_args["default_hand_z_gyro_end_token"], 182 | "mag_hand_x": data_args["default_hand_x_mag_end_token"], 183 | "mag_hand_y": data_args["default_hand_y_mag_end_token"], 184 | "mag_hand_z": data_args["default_hand_z_mag_end_token"], 185 | "acc_chest_x": data_args["default_chest_x_acc_end_token"], 186 | "acc_chest_y": data_args["default_chest_y_acc_end_token"], 187 | "acc_chest_z": data_args["default_chest_z_acc_end_token"], 188 | "gyr_chest_x": data_args["default_chest_x_gyro_end_token"], 189 | "gyr_chest_y": data_args["default_chest_y_gyro_end_token"], 190 | "gyr_chest_z": data_args["default_chest_z_gyro_end_token"], 191 | "mag_chest_x": data_args["default_chest_x_mag_end_token"], 192 | "mag_chest_y": data_args["default_chest_y_mag_end_token"], 193 | "mag_chest_z": data_args["default_chest_z_mag_end_token"], 194 | "acc_ankle_x": data_args["default_ankle_x_acc_end_token"], 195 | "acc_ankle_y": data_args["default_ankle_y_acc_end_token"], 196 | "acc_ankle_z": data_args["default_ankle_z_acc_end_token"], 197 | "gyr_ankle_x": data_args["default_ankle_x_gyro_end_token"], 198 | "gyr_ankle_y": data_args["default_ankle_y_gyro_end_token"], 199 | "gyr_ankle_z": data_args["default_ankle_z_gyro_end_token"], 200 | "mag_ankle_x": data_args["default_ankle_x_mag_end_token"], 201 | "mag_ankle_y": data_args["default_ankle_y_mag_end_token"], 202 | "mag_ankle_z": data_args["default_ankle_z_mag_end_token"], 203 | } 204 | else: 205 | raise ValueError(f"Wrong dataset name in preprocess_time_series2: {dataset}") 206 | return start_tokens_dict, end_tokens_dict 207 | 208 | 209 | def get_token_list(dataset: str, data_args: dict, add_ts_special_token_text: bool): 210 | if dataset in ["usc-had", "uci"]: 211 | if add_ts_special_token_text: 212 | start_tokens_list = [ 213 | "x-axis accelerometer readings: " + data_args["default_x_acc_start_token"], 214 | "y-axis accelerometer readings: " + data_args["default_y_acc_start_token"], 215 | "z-axis accelerometer readings: " + data_args["default_z_acc_start_token"], 216 | "x-axis gyroscope readings: " + data_args["default_x_gyro_start_token"], 217 | "y-axis gyroscope readings: " + data_args["default_y_gyro_start_token"], 218 | "z-axis gyroscope readings: " + data_args["default_z_gyro_start_token"] 219 | ] 220 | else: 221 | start_tokens_list = [ 222 | data_args["default_x_acc_start_token"], 223 | data_args["default_y_acc_start_token"], 224 | data_args["default_z_acc_start_token"], 225 | data_args["default_x_gyro_start_token"], 226 | data_args["default_y_gyro_start_token"], 227 | data_args["default_z_gyro_start_token"] 228 | ] 229 | 230 | end_tokens_list = [ 231 | data_args["default_x_acc_end_token"], 232 | data_args["default_y_acc_end_token"], 233 | data_args["default_z_acc_end_token"], 234 | data_args["default_x_gyro_end_token"], 235 | data_args["default_y_gyro_end_token"], 236 | data_args["default_z_gyro_end_token"] 237 | ] 238 | elif dataset == "capture24": 239 | if add_ts_special_token_text: 240 | start_tokens_list = [ 241 | "x-axis accelerometer readings: " + data_args["default_x_acc_start_token"], 242 | "y-axis accelerometer readings: " + data_args["default_y_acc_start_token"], 243 | "z-axis accelerometer readings: " + data_args["default_z_acc_start_token"] 244 | ] 245 | else: 246 | start_tokens_list = [ 247 | data_args["default_x_acc_start_token"], 248 | data_args["default_y_acc_start_token"], 249 | data_args["default_z_acc_start_token"], 250 | ] 251 | 252 | end_tokens_list = [ 253 | data_args["default_x_acc_end_token"], 254 | data_args["default_y_acc_end_token"], 255 | data_args["default_z_acc_end_token"], 256 | ] 257 | elif dataset == "mhealth": 258 | if add_ts_special_token_text: 259 | start_tokens_list = [ 260 | "Chest x-axis accelerometer: " + data_args["default_chest_x_acc_start_token"], 261 | "Chest y-axis accelerometer: " + data_args["default_chest_y_acc_start_token"], 262 | "Chest z-axis accelerometer: " + data_args["default_chest_z_acc_start_token"], 263 | "left-ankle x-axis accelerometer: " + data_args["default_left_ankle_x_acc_start_token"], 264 | "left-ankle y-axis accelerometer: " + data_args["default_left_ankle_y_acc_start_token"], 265 | "left-ankle z-axis accelerometer: " + data_args["default_left_ankle_z_acc_start_token"], 266 | "left-ankle x-axis gyroscope: " + data_args["default_left_ankle_x_gyro_start_token"], 267 | "left-ankle y-axis gyroscope: " + data_args["default_left_ankle_y_gyro_start_token"], 268 | "left-ankle z-axis gyroscope: " + data_args["default_left_ankle_z_gyro_start_token"], 269 | "right-lower-arm x-axis accelerometer: " + data_args["default_right_lower_arm_x_acc_start_token"], 270 | "right-lower-arm y-axis accelerometer: " + data_args["default_right_lower_arm_y_acc_start_token"], 271 | "right-lower-arm z-axis accelerometer: " + data_args["default_right_lower_arm_z_acc_start_token"], 272 | "right-lower-arm x-axis gyroscope: " + data_args["default_right_lower_arm_x_gyro_start_token"], 273 | "right-lower-arm y-axis gyroscope: " + data_args["default_right_lower_arm_y_gyro_start_token"], 274 | "right-lower-arm z-axis gyroscope: " + data_args["default_right_lower_arm_z_gyro_start_token"] 275 | ] 276 | else: 277 | start_tokens_list = [ 278 | data_args["default_chest_x_acc_start_token"], 279 | data_args["default_chest_y_acc_start_token"], 280 | data_args["default_chest_z_acc_start_token"], 281 | data_args["default_left_ankle_x_acc_start_token"], 282 | data_args["default_left_ankle_y_acc_start_token"], 283 | data_args["default_left_ankle_z_acc_start_token"], 284 | data_args["default_left_ankle_x_gyro_start_token"], 285 | data_args["default_left_ankle_y_gyro_start_token"], 286 | data_args["default_left_ankle_z_gyro_start_token"], 287 | data_args["default_right_lower_arm_x_acc_start_token"], 288 | data_args["default_right_lower_arm_y_acc_start_token"], 289 | data_args["default_right_lower_arm_z_acc_start_token"], 290 | data_args["default_right_lower_arm_x_gyro_start_token"], 291 | data_args["default_right_lower_arm_y_gyro_start_token"], 292 | data_args["default_right_lower_arm_z_gyro_start_token"] 293 | ] 294 | end_tokens_list = [ 295 | data_args["default_chest_x_acc_end_token"], 296 | data_args["default_chest_y_acc_end_token"], 297 | data_args["default_chest_z_acc_end_token"], 298 | data_args["default_left_ankle_x_acc_end_token"], 299 | data_args["default_left_ankle_y_acc_end_token"], 300 | data_args["default_left_ankle_z_acc_end_token"], 301 | data_args["default_left_ankle_x_gyro_end_token"], 302 | data_args["default_left_ankle_y_gyro_end_token"], 303 | data_args["default_left_ankle_z_gyro_end_token"], 304 | data_args["default_right_lower_arm_x_acc_end_token"], 305 | data_args["default_right_lower_arm_y_acc_end_token"], 306 | data_args["default_right_lower_arm_z_acc_end_token"], 307 | data_args["default_right_lower_arm_x_gyro_end_token"], 308 | data_args["default_right_lower_arm_y_gyro_end_token"], 309 | data_args["default_right_lower_arm_z_gyro_end_token"] 310 | ] 311 | elif dataset == "pamap" or dataset == "pamap50": 312 | if add_ts_special_token_text: 313 | start_tokens_list = [ 314 | "Hand x-axis accelerometer: " + data_args["default_hand_x_acc_start_token"], 315 | "Hand x-axis accelerometer: " + data_args["default_hand_y_acc_start_token"], 316 | "Hand x-axis accelerometer: " + data_args["default_hand_z_acc_start_token"], 317 | "Hand x-axis accelerometer: " + data_args["default_hand_x_gyro_start_token"], 318 | "Hand x-axis accelerometer: " + data_args["default_hand_y_gyro_start_token"], 319 | "Hand x-axis accelerometer: " + data_args["default_hand_z_gyro_start_token"], 320 | "Hand x-axis accelerometer: " + data_args["default_hand_x_mag_start_token"], 321 | "Hand x-axis accelerometer: " + data_args["default_hand_y_mag_start_token"], 322 | "Hand x-axis accelerometer: " + data_args["default_hand_z_mag_start_token"], 323 | "Chest x-axis accelerometer: " + data_args["default_chest_x_acc_start_token"], 324 | "Chest x-axis accelerometer: " + data_args["default_chest_y_acc_start_token"], 325 | "Chest x-axis accelerometer: " + data_args["default_chest_z_acc_start_token"], 326 | "Chest x-axis accelerometer: " + data_args["default_chest_x_gyro_start_token"], 327 | "Chest x-axis accelerometer: " + data_args["default_chest_y_gyro_start_token"], 328 | "Chest x-axis accelerometer: " + data_args["default_chest_z_gyro_start_token"], 329 | "Chest x-axis accelerometer: " + data_args["default_chest_x_mag_start_token"], 330 | "Chest x-axis accelerometer: " + data_args["default_chest_y_mag_start_token"], 331 | "Chest x-axis accelerometer: " + data_args["default_chest_z_mag_start_token"], 332 | "Ankle x-axis accelerometer: " + data_args["default_ankle_x_acc_start_token"], 333 | "Ankle x-axis accelerometer: " + data_args["default_ankle_y_acc_start_token"], 334 | "Ankle x-axis accelerometer: " + data_args["default_ankle_z_acc_start_token"], 335 | "Ankle x-axis accelerometer: " + data_args["default_ankle_x_gyro_start_token"], 336 | "Ankle x-axis accelerometer: " + data_args["default_ankle_y_gyro_start_token"], 337 | "Ankle x-axis accelerometer: " + data_args["default_ankle_z_gyro_start_token"], 338 | "Ankle x-axis accelerometer: " + data_args["default_ankle_x_mag_start_token"], 339 | "Ankle x-axis accelerometer: " + data_args["default_ankle_y_mag_start_token"], 340 | "Ankle x-axis accelerometer: " + data_args["default_ankle_z_mag_start_token"] 341 | ] 342 | else: 343 | start_tokens_list = [ 344 | data_args["default_hand_x_acc_start_token"], 345 | data_args["default_hand_y_acc_start_token"], 346 | data_args["default_hand_z_acc_start_token"], 347 | data_args["default_hand_x_gyro_start_token"], 348 | data_args["default_hand_y_gyro_start_token"], 349 | data_args["default_hand_z_gyro_start_token"], 350 | data_args["default_hand_x_mag_start_token"], 351 | data_args["default_hand_y_mag_start_token"], 352 | data_args["default_hand_z_mag_start_token"], 353 | data_args["default_chest_x_acc_start_token"], 354 | data_args["default_chest_y_acc_start_token"], 355 | data_args["default_chest_z_acc_start_token"], 356 | data_args["default_chest_x_gyro_start_token"], 357 | data_args["default_chest_y_gyro_start_token"], 358 | data_args["default_chest_z_gyro_start_token"], 359 | data_args["default_chest_x_mag_start_token"], 360 | data_args["default_chest_y_mag_start_token"], 361 | data_args["default_chest_z_mag_start_token"], 362 | data_args["default_ankle_x_acc_start_token"], 363 | data_args["default_ankle_y_acc_start_token"], 364 | data_args["default_ankle_z_acc_start_token"], 365 | data_args["default_ankle_x_gyro_start_token"], 366 | data_args["default_ankle_y_gyro_start_token"], 367 | data_args["default_ankle_z_gyro_start_token"], 368 | data_args["default_ankle_x_mag_start_token"], 369 | data_args["default_ankle_y_mag_start_token"], 370 | data_args["default_ankle_z_mag_start_token"] 371 | ] 372 | end_tokens_list = [ 373 | data_args["default_hand_x_acc_end_token"], 374 | data_args["default_hand_y_acc_end_token"], 375 | data_args["default_hand_z_acc_end_token"], 376 | data_args["default_hand_x_gyro_end_token"], 377 | data_args["default_hand_y_gyro_end_token"], 378 | data_args["default_hand_z_gyro_end_token"], 379 | data_args["default_hand_x_mag_end_token"], 380 | data_args["default_hand_y_mag_end_token"], 381 | data_args["default_hand_z_mag_end_token"], 382 | data_args["default_chest_x_acc_end_token"], 383 | data_args["default_chest_y_acc_end_token"], 384 | data_args["default_chest_z_acc_end_token"], 385 | data_args["default_chest_x_gyro_end_token"], 386 | data_args["default_chest_y_gyro_end_token"], 387 | data_args["default_chest_z_gyro_end_token"], 388 | data_args["default_chest_x_mag_end_token"], 389 | data_args["default_chest_y_mag_end_token"], 390 | data_args["default_chest_z_mag_end_token"], 391 | data_args["default_ankle_x_acc_end_token"], 392 | data_args["default_ankle_y_acc_end_token"], 393 | data_args["default_ankle_z_acc_end_token"], 394 | data_args["default_ankle_x_gyro_end_token"], 395 | data_args["default_ankle_y_gyro_end_token"], 396 | data_args["default_ankle_z_gyro_end_token"], 397 | data_args["default_ankle_x_mag_end_token"], 398 | data_args["default_ankle_y_mag_end_token"], 399 | data_args["default_ankle_z_mag_end_token"] 400 | ] 401 | else: 402 | raise ValueError(f"Wrong dataset name in preprocess_time_series2: {dataset}") 403 | return start_tokens_list, end_tokens_list 404 | 405 | 406 | def preprocess( 407 | sources: Sequence[Dict[str, str]], 408 | tokenizer: transformers.PreTrainedTokenizer, 409 | SYS_INST: str, 410 | split: str, 411 | preprocess_type: str, 412 | preprocess_type_eval: str, 413 | ) -> Dict: 414 | """Preprocess the data by tokenizing.""" 415 | examples = [generate_chat_template([ 416 | {"role": "system", "content": SYS_INST}, 417 | {"role": "user", "content": s["Q"]}, 418 | {"role": "assistant", "content": s["A"]}], bos_token=tokenizer.bos_token, eos_token=tokenizer.eos_token, add_generation_prompt=False) for s in 419 | sources] 420 | 421 | if split == 'train': 422 | if preprocess_type == "Q": 423 | sources_q = [generate_chat_template([ 424 | {"role": "system", "content": SYS_INST}, 425 | {"role": "user", "content": s["Q"]}], bos_token=tokenizer.bos_token, eos_token=tokenizer.eos_token, add_generation_prompt=True) for s in 426 | sources] 427 | else: 428 | assert preprocess_type == "Q+cot" 429 | sources_q = [generate_chat_template2([ 430 | {"role": "system", "content": SYS_INST}, 431 | {"role": "user", "content": s["Q"]}, 432 | {"role": "assistant", "content": s["cot"]}], bos_token=tokenizer.bos_token, 433 | eos_token=tokenizer.eos_token, 434 | add_generation_prompt=False) for s in 435 | sources] 436 | else: 437 | assert split == 'eval' 438 | if preprocess_type_eval == "Q": 439 | sources_q = [generate_chat_template([ 440 | {"role": "system", "content": SYS_INST}, 441 | {"role": "user", "content": s["Q"]}], bos_token=tokenizer.bos_token, eos_token=tokenizer.eos_token, add_generation_prompt=True) for s in 442 | sources] 443 | else: 444 | assert preprocess_type_eval == "Q+cot" 445 | sources_q = [generate_chat_template2([ 446 | {"role": "system", "content": SYS_INST}, 447 | {"role": "user", "content": s["Q"]}, 448 | {"role": "assistant", "content": s["cot"]}], bos_token=tokenizer.bos_token, 449 | eos_token=tokenizer.eos_token, 450 | add_generation_prompt=False) for s in 451 | sources] 452 | 453 | examples_tokenized = _tokenize_fn(examples, tokenizer) 454 | sources_tokenized = _tokenize_fn(sources_q, tokenizer) 455 | 456 | input_ids = examples_tokenized["input_ids"] 457 | labels = copy.deepcopy(input_ids) 458 | for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]): 459 | label[:source_len] = IGNORE_INDEX 460 | return dict(input_ids=input_ids, labels=labels) 461 | 462 | 463 | def preprocess_cls( 464 | sources: Sequence[Dict[str, str]], 465 | tokenizer: transformers.PreTrainedTokenizer, 466 | added_str: str, 467 | preprocess_type: str 468 | ) -> Dict: 469 | """Preprocess the data by tokenizing.""" 470 | 471 | if preprocess_type == "smry": 472 | inputs = [added_str + '\n' + s["smry"] for s in sources] 473 | elif preprocess_type == "trend": 474 | inputs = [added_str + '\n' + s["trend_text"] for s in sources] 475 | elif preprocess_type == "corr": 476 | inputs = [added_str + '\n' + s["corr_text"] for s in sources] 477 | elif preprocess_type == "none": 478 | inputs = [added_str for _ in sources] 479 | elif preprocess_type == "smry+Q": 480 | inputs = [added_str + '\n' + s["smry"] + '\n' + s["Q"] for s in sources] 481 | elif preprocess_type == "smry+meta": 482 | inputs = [added_str + '\n' + s["smry"] + '\n' + s["info_text"] for s in sources] 483 | elif preprocess_type == "smry+meta+Q": 484 | inputs = [added_str + '\n' + s["smry"] + '\n' + s["info_text"] + '\n' + s["Q"] for s in sources] 485 | elif preprocess_type == "smry+corr": 486 | inputs = [added_str + '\n' + s["smry"] + '\n' + s["corr_text"] for s in sources] 487 | elif preprocess_type == "smry+corr+Q": 488 | inputs = [added_str + '\n' + s["smry"] + '\n' + s["corr_text"] + '\n' + s["Q"] for s in sources] 489 | elif preprocess_type == "smry+trend+corr": 490 | inputs = [added_str + '\n' + s["smry"] + '\n' + s["trend_text"] + '\n' + s["corr_text"] for s in sources] 491 | elif preprocess_type == "smry+trend+corr+Q": 492 | inputs = [added_str + '\n' + s["smry"] + '\n' + s["trend_text"] + '\n' + s["corr_text"] + '\n' + s["Q"] for s in sources] 493 | elif preprocess_type == "smry+trend+Q": 494 | inputs = [added_str + '\n' + s["smry"] + '\n' + s["trend_text"] + '\n' + s["Q"] for s in sources] 495 | else: 496 | assert preprocess_type == "smry+trend", f"Undefined preprocess_type {preprocess_type}" 497 | inputs = [added_str + '\n' + s["smry"] + '\n' + s["trend_text"] for s in sources] 498 | 499 | inputs_tokenized = _tokenize_fn(inputs, tokenizer) 500 | 501 | input_ids = inputs_tokenized["input_ids"] 502 | return dict(input_ids=input_ids, input_texts=inputs) -------------------------------------------------------------------------------- /sensorllm/eval/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import re 4 | from transformers import AutoTokenizer, AutoConfig 5 | from sensorllm.model.chronos_model import * 6 | import torch 7 | from torch.utils.data import DataLoader 8 | from tqdm import tqdm 9 | from sensorllm.model import * 10 | from sensorllm.data import UniChannelTimeSeriesDataset 11 | from sensorllm.data.utils import generate_chat_template 12 | from sensorllm.utils import disable_torch_init 13 | import warnings 14 | import argparse 15 | 16 | 17 | warnings.filterwarnings("ignore") 18 | 19 | SYS_INST = "A chat between a curious human and an AI assistant. The assistant is given a sequence of N features that represent information extracted from sensor (time-series) readings. The original readings consisted of N data points collected at a sample rate of 100Hz. The assistant's task is to analyze the trends and patterns in the sensor readings by leveraging the encoded information within the features to answer the following specific questions provided by the human." 20 | 21 | 22 | def parse_config(): 23 | parser = argparse.ArgumentParser(description='arg parser') 24 | parser.add_argument('--model_name_or_path', type=str, 25 | default="") 26 | parser.add_argument('--pt_encoder_backbone_ckpt', type=str, 27 | default="") 28 | parser.add_argument('--tokenize_method', type=str, default="MeanScaleUniformBins") 29 | parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float32", "float16", "bfloat16"]) 30 | 31 | parser.add_argument('--dataset', type=str, default="usc-had") 32 | parser.add_argument('--output_file_name', type=str, default="eval.json") 33 | parser.add_argument('--model_max_length', type=int, default=8192, help='context length during evaluation') 34 | parser.add_argument('--data_path', type=str, default="", 35 | help="Path to the testing data.") 36 | parser.add_argument('--qa_path', type=str, default="", 37 | help="Path to the testing QA data.") 38 | parser.add_argument('--ignore_qa_types', type=str, nargs='*', default=["sub_trend_no_val"]) 39 | 40 | # * data loader, batch_size, shuffle, num_workers 41 | parser.add_argument("--batch_size", type=int, default=6) 42 | parser.add_argument("--shuffle", type=bool, default=False) 43 | parser.add_argument("--num_workers", type=int, default=2) 44 | 45 | args = parser.parse_args() 46 | return args 47 | 48 | 49 | def load_dataset(data_path, qa_path, chronos_tokenizer): 50 | print("Loading validation datasets.") 51 | dataset = UniChannelTimeSeriesDataset( 52 | data_path=data_path, 53 | qa_path=qa_path, 54 | tokenizer=None, # * load ts and QA 55 | chronos_tokenizer=chronos_tokenizer, 56 | data_args=args 57 | ) 58 | print(f"Example data: {dataset[5]}") 59 | print("Done!") 60 | print(dataset) 61 | return dataset 62 | 63 | 64 | def custom_collate_fn(batch): 65 | batch_dict = { 66 | 'question': [], 67 | 'ground_truth': [], 68 | 'type': [], 69 | 'ts_token_ids': [], 70 | 'ts_attention_mask': [] 71 | } 72 | 73 | for item in batch: 74 | for key in batch_dict: 75 | batch_dict[key].append(item[key]) 76 | 77 | return batch_dict 78 | 79 | 80 | def get_dataloader(dataset, batch_size, num_workers=2): 81 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, 82 | collate_fn=custom_collate_fn) 83 | return dataloader 84 | 85 | 86 | def init_model(args): 87 | # Model 88 | disable_torch_init() 89 | model_name = os.path.expanduser(args.model_name_or_path) 90 | 91 | # * print the model_name (get the basename) 92 | print(f'[INFO] Model name: {os.path.basename(model_name)}') 93 | 94 | tokenizer = AutoTokenizer.from_pretrained( 95 | model_name, 96 | padding_side="left" 97 | ) 98 | model = SensorLLMStage1LlamaForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=False, use_cache=False, 99 | torch_dtype=args.torch_dtype).cuda() 100 | model.get_model().load_pt_encoder_backbone_checkpoint(args.pt_encoder_backbone_ckpt, 101 | tc=args.tokenize_method, 102 | torch_dtype=args.torch_dtype) 103 | pt_backbone_config = AutoConfig.from_pretrained(args.pt_encoder_backbone_ckpt) 104 | 105 | assert hasattr(pt_backbone_config, "chronos_config"), "Not a Chronos config file" 106 | 107 | chronos_config = ChronosConfig(**pt_backbone_config.chronos_config) 108 | chronos_config.tokenizer_class = args.tokenize_method 109 | chronos_tokenizer = chronos_config.create_tokenizer() 110 | 111 | model.initialize_tokenizer_ts_backbone_config_wo_embedding(tokenizer, dataset=args.dataset) 112 | model.get_model().load_start_end_tokens(dataset=args.dataset) 113 | 114 | return model, tokenizer, chronos_tokenizer 115 | 116 | 117 | def generate_outputs(model, tokenizer, inputs, ts_token_ids, ts_attention_mask, do_sample=True, temperature=0.6, 118 | top_k=50, max_length=8192, top_p=0.9): 119 | model.eval() 120 | model.get_model().pt_encoder_backbone.eval() 121 | terminators = [ 122 | tokenizer.eos_token_id, 123 | tokenizer.convert_tokens_to_ids("<|eot_id|>") 124 | ] 125 | with torch.inference_mode(): 126 | outputs = model.generate( 127 | **inputs, 128 | ts_token_ids=ts_token_ids, 129 | ts_attention_mask=ts_attention_mask, 130 | do_sample=do_sample, 131 | use_cache=False, 132 | temperature=temperature, 133 | top_k=top_k, 134 | max_new_tokens=max_length, 135 | top_p=top_p, 136 | eos_token_id=terminators, 137 | pad_token_id=tokenizer.pad_token_id 138 | ) # * B, L' 139 | input_token_len = inputs.input_ids.shape[1] 140 | n_diff_input_output = (inputs.input_ids != outputs[:, :input_token_len]).sum().item() 141 | 142 | if n_diff_input_output > 0: 143 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 144 | outputs = tokenizer.batch_decode(outputs[:, input_token_len:], skip_special_tokens=True) 145 | outputs = [output.strip() for output in outputs] 146 | 147 | return outputs 148 | 149 | 150 | def start_generation(model, tokenizer, dataloader, output_dir, output_file_name): 151 | os.makedirs(output_dir, exist_ok=True) 152 | output_file = os.path.join(output_dir, output_file_name) 153 | 154 | results = {"prompt": SYS_INST, "results": []} 155 | if os.path.exists(output_file): 156 | # 如果文件已存在,加载现有结果 157 | with open(output_file, 'r') as f: 158 | results = json.load(f) 159 | 160 | processed_count = len(results["results"]) 161 | 162 | o_i = 0 163 | for batch_idx, batch in enumerate(tqdm(dataloader)): 164 | if batch_idx * dataloader.batch_size < processed_count: 165 | continue 166 | print(f"start from id {batch_idx}...") 167 | 168 | ts_token_ids = [ts_tensor.cuda() for ts_tensor in batch["ts_token_ids"]] # * tensor of B, N. 169 | ts_attention_mask = [ts_tensor.cuda() for ts_tensor in batch["ts_attention_mask"]] 170 | 171 | ground_truths = batch["ground_truth"] # * list of string 172 | types = batch["type"] 173 | questions = batch["question"] # * list of string 174 | 175 | templated_questions = [generate_chat_template([ 176 | {"role": "system", "content": SYS_INST}, 177 | {"role": "user", "content": q}], bos_token=tokenizer.bos_token, eos_token=tokenizer.eos_token, 178 | add_generation_prompt=True) for q in 179 | questions] 180 | 181 | inputs = tokenizer(templated_questions, padding=True, return_tensors="pt").to(model.device) 182 | outputs = generate_outputs(model, tokenizer, inputs, ts_token_ids, 183 | ts_attention_mask) # List of str, length is B 184 | 185 | # saving results 186 | batch_results = [] 187 | for q, gt, output, tp, ts in zip(questions, ground_truths, outputs, types, ts_token_ids): 188 | result = { 189 | "questions": q, 190 | "ground_truth": gt, 191 | "model_output": output, 192 | "model_len": len(ts[0]), 193 | "type": tp 194 | } 195 | batch_results.append(result) 196 | if o_i < 10: 197 | tqdm.write(f"Type: {tp}\nOutput: {output}\nGround-truth: {gt}\n\n") 198 | tqdm.write("---------" * 30) 199 | o_i += 1 200 | results["results"].extend(batch_results) 201 | 202 | if batch_idx % 10 == 0: # 每10个批次保存一次 203 | save_results(results, output_file) 204 | 205 | save_results(results, output_file) 206 | return results 207 | 208 | def save_results(results, output_file): 209 | temp_file = output_file + '.tmp' 210 | with open(temp_file, 'w') as f: 211 | json.dump(results, f, indent=2) 212 | 213 | if os.path.exists(output_file): 214 | os.replace(temp_file, output_file) 215 | else: 216 | os.rename(temp_file, output_file) 217 | 218 | 219 | def eval(args): 220 | # * ouptut 221 | args.output_dir = os.path.join(args.model_name_or_path, "evaluation") 222 | output_file_path = os.path.join(args.output_dir, args.output_file_name) 223 | 224 | if not os.path.exists(output_file_path): 225 | # * need inferencing 226 | model, tokenizer, chronos_tokenizer = init_model(args) 227 | ts_backbone_config = model.get_model().ts_backbone_config 228 | args.ts_backbone_config = ts_backbone_config 229 | 230 | dataset = load_dataset(args.data_path, args.qa_path, chronos_tokenizer) 231 | dataloader = get_dataloader(dataset, args.batch_size, args.num_workers) 232 | 233 | print(f'[INFO] Start generating results for {args.output_file_name}.') 234 | results = start_generation(model, tokenizer, dataloader, args.output_dir, args.output_file_name) 235 | 236 | del model 237 | del tokenizer 238 | torch.cuda.empty_cache() 239 | else: 240 | # * directly load the results 241 | print(f'[INFO] {output_file_path} already exists, directly loading...') 242 | with open(output_file_path, 'r') as fp: 243 | results = json.load(fp) 244 | print(results["results"][:10]) 245 | 246 | 247 | if __name__ == "__main__": 248 | args = parse_config() 249 | dtype_mapping = { 250 | "float32": torch.float32, 251 | "float16": torch.float16, 252 | "bfloat16": torch.bfloat16, 253 | } 254 | 255 | args.torch_dtype = dtype_mapping[args.torch_dtype] 256 | 257 | eval(args) 258 | -------------------------------------------------------------------------------- /sensorllm/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .stage1_sensorllm import SensorLLMStage1Config, SensorLLMStage1LlamaForCausalLM 2 | from .stage2_sensorllm import SensorLLMStage2Config, SensorLLMStage2LlamaForCausalLM, SensorLLMStage2LlamaForSequenceClassification -------------------------------------------------------------------------------- /sensorllm/model/chronos_model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | 5 | from .chronos_model import ( 6 | ChronosConfig, 7 | ChronosModel, 8 | ChronosPipeline, 9 | ChronosTokenizer, 10 | MeanScaleUniformBins, 11 | StanNormalizeUniformBins, 12 | left_pad_and_stack_1D 13 | ) 14 | 15 | __all__ = [ 16 | "ChronosConfig", 17 | "ChronosModel", 18 | "ChronosPipeline", 19 | "ChronosTokenizer", 20 | "MeanScaleUniformBins", 21 | "StanNormalizeUniformBins", 22 | "left_pad_and_stack_1D" 23 | ] -------------------------------------------------------------------------------- /sensorllm/model/chronos_model/chronos_model.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from dataclasses import dataclass 3 | from typing import Any, Dict, List, Literal, Optional, Tuple, Union 4 | 5 | from . import chronos_model 6 | import torch 7 | import torch.nn as nn 8 | from transformers import ( 9 | AutoConfig, 10 | AutoModelForCausalLM, 11 | AutoModelForSeq2SeqLM, 12 | GenerationConfig, 13 | PreTrainedModel, 14 | ) 15 | 16 | 17 | @dataclass 18 | class ChronosConfig: 19 | """ 20 | This class holds all the configuration parameters to be used 21 | by ``ChronosTokenizer`` and ``ChronosModel``. 22 | """ 23 | 24 | tokenizer_class: str 25 | tokenizer_kwargs: Dict[str, Any] 26 | context_length: int 27 | prediction_length: int 28 | n_tokens: int 29 | n_special_tokens: int 30 | pad_token_id: int 31 | eos_token_id: int 32 | use_eos_token: bool 33 | model_type: Literal["causal", "seq2seq"] 34 | num_samples: int 35 | temperature: float 36 | top_k: int 37 | top_p: float 38 | 39 | def __post_init__(self): 40 | assert ( 41 | self.pad_token_id < self.n_special_tokens 42 | and self.eos_token_id < self.n_special_tokens 43 | ), f"Special token id's must be smaller than {self.n_special_tokens=}" 44 | 45 | def create_tokenizer(self) -> "ChronosTokenizer": 46 | class_ = getattr(chronos_model, self.tokenizer_class) 47 | return class_(**self.tokenizer_kwargs, config=self) 48 | 49 | 50 | class ChronosTokenizer: 51 | """ 52 | A ``ChronosTokenizer`` definines how time series are mapped into token IDs 53 | and back. 54 | 55 | For details, see the ``input_transform`` and ``output_transform`` methods, 56 | which concrete classes must implement. 57 | """ 58 | 59 | def context_input_transform( 60 | self, 61 | context: torch.Tensor, 62 | ) -> Tuple: 63 | """ 64 | Turn a batch of time series into token IDs, attention map, and tokenizer_state. 65 | 66 | Parameters 67 | ---------- 68 | context 69 | A tensor shaped (batch_size, time_length), containing the 70 | timeseries to forecast. Use left-padding with ``torch.nan`` 71 | to align time series of different lengths. 72 | 73 | Returns 74 | ------- 75 | token_ids 76 | A tensor of integers, shaped (batch_size, time_length + 1) 77 | if ``config.use_eos_token`` and (batch_size, time_length) 78 | otherwise, containing token IDs for the input series. 79 | attention_mask 80 | A boolean tensor, same shape as ``token_ids``, indicating 81 | which input observations are not ``torch.nan`` (i.e. not 82 | missing nor padding). 83 | tokenizer_state 84 | An object that can be passed to ``label_input_transform`` 85 | and ``output_transform``. Contains the relevant information 86 | to decode output samples into real values, 87 | such as location and scale parameters. 88 | """ 89 | raise NotImplementedError() 90 | 91 | def label_input_transform(self, label: torch.Tensor, tokenizer_state: Any) -> Tuple: 92 | """ 93 | Turn a batch of label slices of time series into token IDs and attention map 94 | using the ``tokenizer_state`` provided by ``context_input_transform``. 95 | 96 | Parameters 97 | ---------- 98 | context 99 | A tensor shaped (batch_size, time_length), containing the 100 | timeseries to forecast. Use left-padding with ``torch.nan`` 101 | to align time series of different lengths. 102 | tokenizer_state 103 | An object returned by ``context_input_transform`` containing 104 | relevant information to preprocess data, such as location and 105 | scale. The nature of this depends on the specific tokenizer. 106 | This is used for tokenizing the label, in order to use the same 107 | scaling used to tokenize the context. 108 | 109 | Returns 110 | ------- 111 | token_ids 112 | A tensor of integers, shaped (batch_size, time_length + 1) 113 | if ``config.use_eos_token`` and (batch_size, time_length) 114 | otherwise, containing token IDs for the input series. 115 | attention_mask 116 | A boolean tensor, same shape as ``token_ids``, indicating 117 | which input observations are not ``torch.nan`` (i.e. not 118 | missing nor padding). 119 | """ 120 | raise NotImplementedError() 121 | 122 | def output_transform( 123 | self, samples: torch.Tensor, tokenizer_state: Any 124 | ) -> torch.Tensor: 125 | """ 126 | Turn a batch of sample token IDs into real values. 127 | 128 | Parameters 129 | ---------- 130 | samples 131 | A tensor of integers, shaped (batch_size, num_samples, time_length), 132 | containing token IDs of sample trajectories. 133 | tokenizer_state 134 | An object returned by ``input_transform`` containing 135 | relevant context to decode samples, such as location and scale. 136 | The nature of this depends on the specific tokenizer. 137 | 138 | Returns 139 | ------- 140 | forecasts 141 | A real tensor, shaped (batch_size, num_samples, time_length), 142 | containing forecasted sample paths. 143 | """ 144 | raise NotImplementedError() 145 | 146 | 147 | class MeanScaleUniformBins(ChronosTokenizer): 148 | def __init__( 149 | self, low_limit: float, high_limit: float, config: ChronosConfig 150 | ) -> None: 151 | self.config = config 152 | self.centers = torch.linspace( 153 | low_limit, 154 | high_limit, 155 | config.n_tokens - config.n_special_tokens - 1, 156 | ) 157 | self.boundaries = torch.concat( 158 | ( 159 | torch.tensor([-1e20], device=self.centers.device), 160 | (self.centers[1:] + self.centers[:-1]) / 2, 161 | torch.tensor([1e20], device=self.centers.device), 162 | ) 163 | ) 164 | 165 | def _input_transform( 166 | self, context: torch.Tensor, scale: Optional[torch.Tensor] = None 167 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 168 | attention_mask = ~torch.isnan(context) 169 | 170 | if scale is None: 171 | scale = torch.nansum( 172 | torch.abs(context) * attention_mask, dim=-1 173 | ) / torch.nansum(attention_mask, dim=-1) 174 | scale[~(scale > 0)] = 1.0 175 | 176 | scaled_context = context / scale.unsqueeze(dim=-1) 177 | token_ids = ( 178 | torch.bucketize( 179 | input=scaled_context, 180 | boundaries=self.boundaries, 181 | # buckets are open to the right, see: 182 | # https://pytorch.org/docs/2.1/generated/torch.bucketize.html#torch-bucketize 183 | right=True, 184 | ) 185 | + self.config.n_special_tokens 186 | ) 187 | token_ids[~attention_mask] = self.config.pad_token_id 188 | 189 | return token_ids, attention_mask, scale 190 | 191 | def _append_eos_token( 192 | self, token_ids: torch.Tensor, attention_mask: torch.Tensor 193 | ) -> Tuple[torch.Tensor, torch.Tensor]: 194 | batch_size = token_ids.shape[0] 195 | eos_tokens = torch.full((batch_size, 1), fill_value=self.config.eos_token_id) 196 | token_ids = torch.concat((token_ids, eos_tokens), dim=1) 197 | eos_mask = torch.full((batch_size, 1), fill_value=True) 198 | attention_mask = torch.concat((attention_mask, eos_mask), dim=1) 199 | 200 | return token_ids, attention_mask 201 | 202 | def context_input_transform( 203 | self, context: torch.Tensor 204 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 205 | length = context.shape[-1] 206 | 207 | if length > self.config.context_length: 208 | context = context[..., -self.config.context_length :] 209 | 210 | token_ids, attention_mask, scale = self._input_transform(context=context) 211 | 212 | if self.config.use_eos_token and self.config.model_type == "seq2seq": 213 | token_ids, attention_mask = self._append_eos_token( 214 | token_ids=token_ids, attention_mask=attention_mask 215 | ) 216 | 217 | return token_ids, attention_mask, scale 218 | 219 | def label_input_transform( 220 | self, label: torch.Tensor, scale: torch.Tensor 221 | ) -> Tuple[torch.Tensor, torch.Tensor]: 222 | length = label.shape[-1] 223 | 224 | assert length == self.config.prediction_length 225 | token_ids, attention_mask, _ = self._input_transform(context=label, scale=scale) 226 | 227 | if self.config.use_eos_token: 228 | token_ids, attention_mask = self._append_eos_token( 229 | token_ids=token_ids, attention_mask=attention_mask 230 | ) 231 | 232 | return token_ids, attention_mask 233 | 234 | def output_transform( 235 | self, samples: torch.Tensor, scale: torch.Tensor 236 | ) -> torch.Tensor: 237 | scale_unsqueezed = scale.unsqueeze(-1).unsqueeze(-1) 238 | indices = torch.clamp( 239 | samples - self.config.n_special_tokens - 1, 240 | min=0, 241 | max=len(self.centers) - 1, 242 | ) 243 | return self.centers[indices] * scale_unsqueezed 244 | 245 | 246 | class StanNormalizeUniformBins(ChronosTokenizer): 247 | def __init__( 248 | self, low_limit: float, high_limit: float, config: ChronosConfig 249 | ) -> None: 250 | self.config = config 251 | self.centers = torch.linspace( 252 | low_limit, 253 | high_limit, 254 | config.n_tokens - config.n_special_tokens - 1 255 | ) 256 | # print("self.centers.device:", self.centers.device) 257 | self.boundaries = torch.concat( 258 | ( 259 | torch.tensor([-1e20], device=self.centers.device), 260 | (self.centers[1:] + self.centers[:-1]) / 2, 261 | torch.tensor([1e20], device=self.centers.device), 262 | ) 263 | ) 264 | 265 | def _input_transform( 266 | self, context: torch.Tensor, scale: Tuple[torch.Tensor, torch.Tensor] = None, eps: float = 1e-8 267 | ) -> Tuple[torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: 268 | attention_mask = ~torch.isnan(context) 269 | 270 | if scale is None: 271 | mean = torch.nansum(context * attention_mask, dim=-1) / torch.nansum(attention_mask, dim=-1) 272 | variance = torch.nansum((context - mean.unsqueeze(dim=-1)) ** 2 * attention_mask, dim=-1) / torch.nansum( 273 | attention_mask, dim=-1) 274 | std = torch.sqrt(variance + eps) 275 | scale = (mean, std) # 添加 eps 以避免除以零 276 | 277 | scaled_context = (context - scale[0].unsqueeze(dim=-1)) / scale[1].unsqueeze(dim=-1) 278 | # print("scaled_context.device:", scaled_context.device) 279 | token_ids = ( 280 | torch.bucketize( 281 | input=scaled_context, 282 | boundaries=self.boundaries, 283 | # buckets are open to the right, see: 284 | # https://pytorch.org/docs/2.1/generated/torch.bucketize.html#torch-bucketize 285 | right=True, 286 | ) 287 | + self.config.n_special_tokens 288 | ) 289 | token_ids[~attention_mask] = self.config.pad_token_id 290 | 291 | return token_ids, attention_mask, scale 292 | 293 | def _append_eos_token( 294 | self, token_ids: torch.Tensor, attention_mask: torch.Tensor 295 | ) -> Tuple[torch.Tensor, torch.Tensor]: 296 | batch_size = token_ids.shape[0] 297 | eos_tokens = torch.full((batch_size, 1), fill_value=self.config.eos_token_id) 298 | token_ids = torch.concat((token_ids, eos_tokens), dim=1) 299 | eos_mask = torch.full((batch_size, 1), fill_value=True) 300 | attention_mask = torch.concat((attention_mask, eos_mask), dim=1) 301 | 302 | return token_ids, attention_mask 303 | 304 | def context_input_transform( 305 | self, context: torch.Tensor 306 | ) -> Tuple[torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: 307 | length = context.shape[-1] 308 | 309 | if length > self.config.context_length: 310 | context = context[..., -self.config.context_length:] 311 | 312 | token_ids, attention_mask, scale = self._input_transform(context=context) 313 | 314 | if self.config.use_eos_token and self.config.model_type == "seq2seq": 315 | token_ids, attention_mask = self._append_eos_token( 316 | token_ids=token_ids, attention_mask=attention_mask 317 | ) 318 | 319 | return token_ids, attention_mask, scale 320 | 321 | def label_input_transform( 322 | self, label: torch.Tensor, scale: Tuple[torch.Tensor, torch.Tensor] 323 | ) -> Tuple[torch.Tensor, torch.Tensor]: 324 | length = label.shape[-1] 325 | 326 | assert length == self.config.prediction_length 327 | token_ids, attention_mask, _ = self._input_transform(context=label, scale=scale) 328 | 329 | if self.config.use_eos_token: 330 | token_ids, attention_mask = self._append_eos_token( 331 | token_ids=token_ids, attention_mask=attention_mask 332 | ) 333 | 334 | return token_ids, attention_mask 335 | 336 | def output_transform( 337 | self, samples: torch.Tensor, scale: Tuple[torch.Tensor, torch.Tensor] 338 | ) -> torch.Tensor: 339 | mean_unsqueezed = scale[0].unsqueeze(-1).unsqueeze(-1) 340 | std_unsqueezed = scale[1].unsqueeze(-1).unsqueeze(-1) 341 | indices = torch.clamp( 342 | samples - self.config.n_special_tokens - 1, 343 | min=0, 344 | max=len(self.centers) - 1, 345 | ) 346 | return self.centers[indices] * std_unsqueezed + mean_unsqueezed 347 | 348 | 349 | class ChronosModel(nn.Module): 350 | """ 351 | A ``ChronosModel`` wraps a ``PreTrainedModel`` object from ``transformers`` 352 | and uses it to predict sample paths for time series tokens. 353 | 354 | Parameters 355 | ---------- 356 | config 357 | The configuration to use. 358 | model 359 | The pretrained model to use. 360 | """ 361 | 362 | def __init__(self, config: ChronosConfig, model: PreTrainedModel) -> None: 363 | super().__init__() 364 | self.config = config 365 | self.model = model 366 | 367 | @property 368 | def device(self): 369 | return self.model.device 370 | 371 | def encode( 372 | self, 373 | input_ids: torch.Tensor, 374 | attention_mask: torch.Tensor, 375 | ): 376 | """ 377 | Extract the encoder embedding for the given token sequences. 378 | 379 | Parameters 380 | ---------- 381 | input_ids 382 | Tensor of indices of input sequence tokens in the vocabulary 383 | with shape (batch_size, sequence_length). 384 | attention_mask 385 | A mask tensor of the same shape as input_ids to avoid attending 386 | on padding or missing tokens. 387 | 388 | Returns 389 | ------- 390 | embedding 391 | A tensor of encoder embeddings with shape 392 | (batch_size, sequence_length, d_model). 393 | """ 394 | assert ( 395 | self.config.model_type == "seq2seq" 396 | ), "Encoder embeddings are only supported for encoder-decoder models" 397 | return self.model.encoder( 398 | input_ids=input_ids, attention_mask=attention_mask 399 | ).last_hidden_state 400 | 401 | def forward( 402 | self, 403 | input_ids: torch.Tensor, 404 | attention_mask: torch.Tensor, 405 | prediction_length: Optional[int] = None, 406 | num_samples: Optional[int] = None, 407 | temperature: Optional[float] = None, 408 | top_k: Optional[int] = None, 409 | top_p: Optional[float] = None, 410 | ) -> torch.Tensor: 411 | """ 412 | Predict future sample tokens for the given token sequences. 413 | 414 | Arguments ``prediction_length``, ``num_samples``, ``temperature``, 415 | ``top_k``, ``top_p`` can be used to customize the model inference, 416 | and default to the corresponding attributes in ``self.config`` if 417 | not provided. 418 | 419 | Returns 420 | ------- 421 | samples 422 | A tensor of integers, shaped (batch_size, num_samples, time_length), 423 | containing forecasted sample paths. 424 | """ 425 | if prediction_length is None: 426 | prediction_length = self.config.prediction_length 427 | if num_samples is None: 428 | num_samples = self.config.num_samples 429 | if temperature is None: 430 | temperature = self.config.temperature 431 | if top_k is None: 432 | top_k = self.config.top_k 433 | if top_p is None: 434 | top_p = self.config.top_p 435 | 436 | preds = self.model.generate( 437 | input_ids=input_ids, 438 | attention_mask=attention_mask, 439 | generation_config=GenerationConfig( 440 | min_new_tokens=prediction_length, 441 | max_new_tokens=prediction_length, 442 | do_sample=True, 443 | num_return_sequences=num_samples, 444 | eos_token_id=self.config.eos_token_id, 445 | pad_token_id=self.config.pad_token_id, 446 | temperature=temperature, 447 | top_k=top_k, 448 | top_p=top_p, 449 | ), 450 | ) 451 | 452 | if self.config.model_type == "seq2seq": 453 | preds = preds[..., 1:] # remove the decoder start token 454 | else: 455 | assert self.config.model_type == "causal" 456 | assert preds.size(-1) == input_ids.size(-1) + prediction_length 457 | preds = preds[..., -prediction_length:] 458 | 459 | return preds.reshape(input_ids.size(0), num_samples, -1) 460 | 461 | 462 | def left_pad_and_stack_1D(tensors: List[torch.Tensor]) -> torch.Tensor: 463 | max_len = max(len(c) for c in tensors) 464 | padded = [] 465 | for c in tensors: 466 | assert isinstance(c, torch.Tensor) 467 | assert c.ndim == 1 468 | padding = torch.full( 469 | size=(max_len - len(c),), fill_value=torch.nan, device=c.device 470 | ) 471 | padded.append(torch.concat((padding, c), dim=-1)) 472 | return torch.stack(padded) 473 | 474 | 475 | @dataclass 476 | class ChronosPipeline: 477 | """ 478 | A ``ChronosPipeline`` uses the given tokenizer and model to forecast 479 | input time series. 480 | 481 | Use the ``from_pretrained`` class method to load serialized models. 482 | Use the ``predict`` method to get forecasts. 483 | 484 | Parameters 485 | ---------- 486 | tokenizer 487 | The tokenizer object to use. 488 | model 489 | The model to use. 490 | """ 491 | 492 | tokenizer: ChronosTokenizer 493 | model: ChronosModel 494 | 495 | def _prepare_and_validate_context( 496 | self, context: Union[torch.Tensor, List[torch.Tensor]] 497 | ): 498 | if isinstance(context, list): 499 | context = left_pad_and_stack_1D(context) 500 | assert isinstance(context, torch.Tensor) 501 | if context.ndim == 1: 502 | context = context.unsqueeze(0) 503 | assert context.ndim == 2 504 | 505 | return context 506 | 507 | @torch.no_grad() 508 | def embed( 509 | self, context: Union[torch.Tensor, List[torch.Tensor]] 510 | ) -> Tuple[torch.Tensor, Any]: 511 | """ 512 | Get encoder embeddings for the given time series. 513 | 514 | Parameters 515 | ---------- 516 | context 517 | Input series. This is either a 1D tensor, or a list 518 | of 1D tensors, or a 2D tensor whose first dimension 519 | is batch. In the latter case, use left-padding with 520 | ``torch.nan`` to align series of different lengths. 521 | 522 | Returns 523 | ------- 524 | embeddings, tokenizer_state 525 | A tuple of two tensors: the encoder embeddings and the tokenizer_state, 526 | e.g., the scale of the time series in the case of mean scaling. 527 | The encoder embeddings are shaped (batch_size, context_length, d_model) 528 | or (batch_size, context_length + 1, d_model), where context_length 529 | is the size of the context along the time axis if a 2D tensor was provided 530 | or the length of the longest time series, if a list of 1D tensors was 531 | provided, and the extra 1 is for EOS. 532 | """ 533 | context_tensor = self._prepare_and_validate_context(context=context) 534 | token_ids, attention_mask, tokenizer_state = ( 535 | self.tokenizer.context_input_transform(context_tensor) 536 | ) 537 | embeddings = self.model.encode( 538 | input_ids=token_ids.to(self.model.device), 539 | attention_mask=attention_mask.to(self.model.device), 540 | ).cpu() 541 | return embeddings, tokenizer_state 542 | 543 | def predict( 544 | self, 545 | context: Union[torch.Tensor, List[torch.Tensor]], 546 | prediction_length: Optional[int] = None, 547 | num_samples: Optional[int] = None, 548 | temperature: Optional[float] = None, 549 | top_k: Optional[int] = None, 550 | top_p: Optional[float] = None, 551 | limit_prediction_length: bool = True, 552 | ) -> torch.Tensor: 553 | """ 554 | Get forecasts for the given time series. 555 | 556 | Parameters 557 | ---------- 558 | context 559 | Input series. This is either a 1D tensor, or a list 560 | of 1D tensors, or a 2D tensor whose first dimension 561 | is batch. In the latter case, use left-padding with 562 | ``torch.nan`` to align series of different lengths. 563 | prediction_length 564 | Time steps to predict. Defaults to what specified 565 | in ``self.model.config``. 566 | num_samples 567 | Number of sample paths to predict. Defaults to what 568 | specified in ``self.model.config``. 569 | temperature 570 | Temperature to use for generating sample tokens. 571 | Defaults to what specified in ``self.model.config``. 572 | top_k 573 | Top-k parameter to use for generating sample tokens. 574 | Defaults to what specified in ``self.model.config``. 575 | top_p 576 | Top-p parameter to use for generating sample tokens. 577 | Defaults to what specified in ``self.model.config``. 578 | limit_prediction_length 579 | Force prediction length smaller or equal than the 580 | built-in prediction length from the model. True by 581 | default. When true, fail loudly if longer predictions 582 | are requested, otherwise longer predictions are allowed. 583 | 584 | Returns 585 | ------- 586 | samples 587 | Tensor of sample forecasts, of shape 588 | (batch_size, num_samples, prediction_length). 589 | """ 590 | context_tensor = self._prepare_and_validate_context(context=context) 591 | 592 | if prediction_length is None: 593 | prediction_length = self.model.config.prediction_length 594 | 595 | if prediction_length > self.model.config.prediction_length: 596 | msg = ( 597 | f"We recommend keeping prediction length <= {self.model.config.prediction_length}. " 598 | "The quality of longer predictions may degrade since the model is not optimized for it. " 599 | ) 600 | if limit_prediction_length: 601 | msg += "You can turn off this check by setting `limit_prediction_length=False`." 602 | raise ValueError(msg) 603 | warnings.warn(msg) 604 | 605 | predictions = [] 606 | remaining = prediction_length 607 | 608 | while remaining > 0: 609 | token_ids, attention_mask, scale = self.tokenizer.context_input_transform( 610 | context_tensor 611 | ) 612 | samples = self.model( 613 | token_ids.to(self.model.device), 614 | attention_mask.to(self.model.device), 615 | min(remaining, self.model.config.prediction_length), 616 | num_samples, 617 | temperature, 618 | top_k, 619 | top_p, 620 | ) 621 | prediction = self.tokenizer.output_transform( 622 | samples.to(scale.device), scale 623 | ) 624 | 625 | predictions.append(prediction) 626 | remaining -= prediction.shape[-1] 627 | 628 | if remaining <= 0: 629 | break 630 | 631 | context_tensor = torch.cat( 632 | [context_tensor, prediction.median(dim=1).values], dim=-1 633 | ) 634 | 635 | return torch.cat(predictions, dim=-1) 636 | 637 | @classmethod 638 | def from_pretrained(cls, *args, tc='MeanScaleUniformBins', **kwargs): 639 | """ 640 | Load the model, either from a local path or from the HuggingFace Hub. 641 | Supports the same arguments as ``AutoConfig`` and ``AutoModel`` 642 | from ``transformers``. 643 | """ 644 | 645 | config = AutoConfig.from_pretrained(*args, **kwargs) 646 | 647 | assert hasattr(config, "chronos_config"), "Not a Chronos config file" 648 | 649 | chronos_config = ChronosConfig(**config.chronos_config) 650 | chronos_config.tokenizer_class = tc 651 | 652 | if chronos_config.model_type == "seq2seq": 653 | inner_model = AutoModelForSeq2SeqLM.from_pretrained(*args, **kwargs) 654 | else: 655 | assert chronos_config.model_type == "causal" 656 | inner_model = AutoModelForCausalLM.from_pretrained(*args, **kwargs) 657 | 658 | return cls( 659 | tokenizer=chronos_config.create_tokenizer(), 660 | model=ChronosModel(config=chronos_config, model=inner_model), 661 | ) -------------------------------------------------------------------------------- /sensorllm/model/stage1_sensorllm.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple, Union 2 | from torch.nn import CrossEntropyLoss 3 | from .utils import * 4 | from contextlib import nullcontext 5 | 6 | from transformers import ( 7 | AutoConfig, 8 | AutoModelForCausalLM, 9 | LlamaConfig, 10 | LlamaForCausalLM, 11 | ) 12 | 13 | from transformers.modeling_outputs import ( 14 | BaseModelOutputWithPast, 15 | CausalLMOutputWithPast, 16 | ) 17 | 18 | import logging 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | class SensorLLMStage1Config(LlamaConfig): 24 | model_type = "sensorllmstage1" 25 | 26 | 27 | class SensorLLMStage1LlamaModel(BaseSensorLLMModel): 28 | config_class = SensorLLMStage1Config 29 | 30 | def __init__(self, config: LlamaConfig): 31 | super(SensorLLMStage1LlamaModel, self).__init__(config) 32 | 33 | def forward( 34 | self, 35 | input_ids: torch.LongTensor = None, # B, L 36 | attention_mask: Optional[torch.Tensor] = None, 37 | past_key_values: Optional[List[torch.FloatTensor]] = None, 38 | inputs_embeds: Optional[torch.FloatTensor] = None, 39 | use_cache: Optional[bool] = None, 40 | output_attentions: Optional[bool] = None, 41 | output_hidden_states: Optional[bool] = None, 42 | ts_token_ids: Optional[List[torch.Tensor]] = None, # B, L_ts 43 | ts_attention_mask: Optional[List[torch.Tensor]] = None, 44 | ts_tokenizer_state: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, 45 | return_dict: Optional[bool] = None, 46 | cache_position: Optional[torch.LongTensor] = None, 47 | ) -> Union[Tuple, BaseModelOutputWithPast]: 48 | 49 | # Check the dimensions of input_ids 50 | if input_ids.dim() != 2: 51 | raise ValueError(f"Expected input_ids to be a 2D tensor, but got {input_ids.dim()}D tensor") 52 | 53 | orig_embeds_params = getattr(self, "orig_embeds_params", None) 54 | 55 | if inputs_embeds is None: 56 | inputs_embeds = self.embed_tokens(input_ids) 57 | 58 | pt_encoder_backbone = getattr(self, 'pt_encoder_backbone', None) 59 | 60 | if pt_encoder_backbone is not None and (input_ids.shape[1] != 1 or self.training) and ts_token_ids is not None: 61 | assert type(ts_token_ids) is list 62 | 63 | with torch.no_grad() if self.fix_ts_encoder else nullcontext(): 64 | if self.fix_ts_encoder: 65 | self.pt_encoder_backbone.eval() 66 | ts_features = [] 67 | for ti, am in zip(ts_token_ids, ts_attention_mask): # * iterate over batch 68 | ts_feature = self.ts_embed(ti, am) 69 | if torch.any(torch.isnan(ts_feature)) or torch.any(torch.isinf(ts_feature)): 70 | raise ValueError("ts_feature has NaN values") 71 | ts_features.append(ts_feature[0]) 72 | summed_ts_embeddings = [self.ts_proj(ts_feature) for ts_feature in ts_features] 73 | 74 | new_input_embeds = [] 75 | for cur_input_ids, cur_input_embeds, cur_ts_embeds in zip( 76 | input_ids, inputs_embeds, summed_ts_embeddings 77 | ): # * input_ids: B, L; input_embeds: B, L, C; summed_ts_embeddings: B, L_ts, C 78 | cur_ts_embeds = cur_ts_embeds.to( 79 | device=cur_input_embeds.device 80 | ) 81 | 82 | num_ts_tokens = cur_ts_embeds.shape[0] # * number of ts tokens 83 | total_ts_token_count = (cur_input_ids == self.ts_backbone_config["ts_token_id"]).sum() 84 | 85 | if num_ts_tokens != total_ts_token_count: 86 | raise ValueError( 87 | f"The window size of time-series tokens ({num_ts_tokens}) and input template ts tokens ({total_ts_token_count}) should be the same.") 88 | 89 | for start_token_id, end_token_id in self.start_end_tokens.items(): 90 | start_token_count = (cur_input_ids == start_token_id).sum() 91 | end_token_count = (cur_input_ids == end_token_id).sum() 92 | 93 | if start_token_count != end_token_count: 94 | raise ValueError( 95 | f"The number of {start_token_id} tokens ({start_token_count}) and {end_token_id} tokens ({end_token_count}) should be the same.") 96 | 97 | start_token_positions = torch.where(cur_input_ids == start_token_id)[0] 98 | 99 | for start_token_pos in start_token_positions: 100 | end_token_pos = start_token_pos + num_ts_tokens + 1 101 | total_ts_token_count -= num_ts_tokens 102 | 103 | if end_token_pos >= len(cur_input_ids) or cur_input_ids[end_token_pos] != end_token_id: 104 | raise ValueError( 105 | f"The end token '{end_token_id}' should follow the start token '{start_token_id}' after {num_ts_tokens} positions." 106 | ) 107 | 108 | if orig_embeds_params is not None: # * will not update the original embeddings except for TS_START_TOKEN and TS_END_TOKEN 109 | # print("Will not update the original embeddings except for TS_START_TOKEN and TS_END_TOKEN") 110 | cur_input_embeds = torch.cat( 111 | ( 112 | cur_input_embeds[:start_token_pos].detach(), 113 | cur_input_embeds[start_token_pos: start_token_pos + 1], 114 | cur_ts_embeds, 115 | cur_input_embeds[end_token_pos: end_token_pos + 1], 116 | cur_input_embeds[end_token_pos + 1:].detach(), 117 | ), 118 | dim=0, 119 | ) 120 | else: 121 | # print("Will update the original embeddings") 122 | cur_input_embeds = torch.cat( 123 | ( 124 | cur_input_embeds[:start_token_pos + 1], 125 | cur_ts_embeds, 126 | cur_input_embeds[end_token_pos:], 127 | ), 128 | dim=0, 129 | ) 130 | 131 | if total_ts_token_count != 0: 132 | raise ValueError( 133 | f"The value of total_ts_token_count ({total_ts_token_count}) should be the 0.") 134 | new_input_embeds.append(cur_input_embeds) 135 | inputs_embeds = torch.stack(new_input_embeds, dim=0) 136 | 137 | return super(SensorLLMStage1LlamaModel, self).forward( 138 | input_ids=None, 139 | attention_mask=attention_mask, 140 | past_key_values=past_key_values, 141 | inputs_embeds=inputs_embeds, 142 | use_cache=use_cache, 143 | output_attentions=output_attentions, 144 | output_hidden_states=output_hidden_states, 145 | return_dict=return_dict, 146 | cache_position=cache_position, 147 | ) 148 | 149 | 150 | class SensorLLMStage1LlamaForCausalLM(BaseSensorLLM, LlamaForCausalLM): 151 | config_class = SensorLLMStage1Config 152 | 153 | def __init__(self, config): 154 | super(LlamaForCausalLM, self).__init__(config) 155 | self.model = SensorLLMStage1LlamaModel(config) 156 | 157 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 158 | 159 | # Initialize weights and apply final processing 160 | self.post_init() 161 | 162 | def get_model(self): 163 | return self.model 164 | 165 | def forward( 166 | self, 167 | input_ids: torch.LongTensor = None, 168 | attention_mask: Optional[torch.Tensor] = None, 169 | past_key_values: Optional[List[torch.FloatTensor]] = None, 170 | inputs_embeds: Optional[torch.FloatTensor] = None, 171 | labels: Optional[torch.LongTensor] = None, 172 | use_cache: Optional[bool] = None, # * control whether to return past_key_values 173 | output_attentions: Optional[bool] = None, 174 | output_hidden_states: Optional[bool] = None, 175 | ts_token_ids: Optional[List[torch.Tensor]] = None, # B, L_ts 176 | ts_attention_mask: Optional[List[torch.Tensor]] = None, 177 | ts_tokenizer_state: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, 178 | return_dict: Optional[bool] = None, 179 | cache_position: Optional[torch.LongTensor] = None, 180 | ) -> Union[Tuple, CausalLMOutputWithPast]: 181 | output_attentions = ( 182 | output_attentions 183 | if output_attentions is not None 184 | else self.config.output_attentions 185 | ) 186 | output_hidden_states = ( 187 | output_hidden_states 188 | if output_hidden_states is not None 189 | else self.config.output_hidden_states 190 | ) 191 | return_dict = ( 192 | return_dict if return_dict is not None else self.config.use_return_dict 193 | ) 194 | 195 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 196 | outputs = self.model( 197 | input_ids=input_ids, 198 | attention_mask=attention_mask, 199 | past_key_values=past_key_values, 200 | inputs_embeds=inputs_embeds, 201 | use_cache=use_cache, 202 | output_attentions=output_attentions, 203 | output_hidden_states=output_hidden_states, 204 | return_dict=return_dict, 205 | ts_token_ids=ts_token_ids, 206 | ts_attention_mask=ts_attention_mask, 207 | ts_tokenizer_state=ts_tokenizer_state, 208 | cache_position=cache_position, 209 | ) 210 | 211 | hidden_states = outputs[0] 212 | logits = self.lm_head(hidden_states) 213 | 214 | loss = None 215 | if labels is not None: 216 | # Shift so that tokens < n predict n 217 | shift_logits = logits[..., :-1, :].contiguous() # * B, L, V 218 | shift_labels = labels[..., 1:].contiguous() # * B, L 219 | # Flatten the tokens 220 | loss_fct = CrossEntropyLoss() 221 | shift_logits = shift_logits.view(-1, self.config.vocab_size) 222 | shift_labels = shift_labels.view(-1) 223 | # Enable model/pipeline parallelism 224 | shift_labels = shift_labels.to(shift_logits.device) 225 | loss = loss_fct(shift_logits, shift_labels) 226 | 227 | if not return_dict: 228 | output = (logits,) + outputs[1:] 229 | return (loss,) + output if loss is not None else output 230 | 231 | return CausalLMOutputWithPast( 232 | loss=loss, 233 | logits=logits, 234 | past_key_values=outputs.past_key_values, 235 | hidden_states=outputs.hidden_states, 236 | attentions=outputs.attentions, 237 | ) 238 | 239 | def prepare_inputs_for_generation( 240 | self, 241 | input_ids, 242 | past_key_values=None, 243 | attention_mask=None, 244 | inputs_embeds=None, 245 | **kwargs, 246 | ): 247 | if past_key_values: 248 | input_ids = input_ids[:, -1:] 249 | 250 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 251 | if inputs_embeds is not None and past_key_values is None: 252 | model_inputs = {"inputs_embeds": inputs_embeds} 253 | else: 254 | model_inputs = {"input_ids": input_ids} 255 | model_inputs.update( 256 | { 257 | "past_key_values": past_key_values, 258 | "use_cache": kwargs.get("use_cache"), 259 | "attention_mask": attention_mask, 260 | "ts_token_ids": kwargs.get("ts_token_ids", None), 261 | "ts_attention_mask": kwargs.get("ts_attention_mask", None), 262 | "ts_tokenizer_state": kwargs.get("ts_tokenizer_state", None), 263 | "cache_position": kwargs.get("cache_position", None), 264 | } 265 | ) 266 | return model_inputs 267 | 268 | 269 | AutoConfig.register("sensorllmstage1", SensorLLMStage1Config) 270 | AutoModelForCausalLM.register(SensorLLMStage1Config, SensorLLMStage1LlamaForCausalLM) 271 | -------------------------------------------------------------------------------- /sensorllm/model/stage2_sensorllm.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple, Union 2 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 3 | from .utils import * 4 | from contextlib import nullcontext 5 | from transformers import ( 6 | AutoConfig, 7 | AutoModelForCausalLM, 8 | AutoModelForSequenceClassification, 9 | LlamaConfig, 10 | LlamaForCausalLM, 11 | LlamaForSequenceClassification 12 | ) 13 | 14 | from transformers.modeling_outputs import ( 15 | BaseModelOutputWithPast, 16 | CausalLMOutputWithPast, 17 | SequenceClassifierOutputWithPast 18 | ) 19 | 20 | import logging 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | class SensorLLMStage2Config(LlamaConfig): 26 | model_type = "sensorllmstage2" 27 | 28 | 29 | class SensorLLMStage2LlamaModel(BaseSensorLLMModel): 30 | config_class = SensorLLMStage2Config 31 | 32 | def __init__(self, config: LlamaConfig): 33 | super(SensorLLMStage2LlamaModel, self).__init__(config) 34 | 35 | def forward( 36 | self, 37 | input_ids: torch.LongTensor = None, # B, L 38 | attention_mask: Optional[torch.Tensor] = None, 39 | position_ids: Optional[torch.LongTensor] = None, 40 | past_key_values: Optional[List[torch.FloatTensor]] = None, 41 | inputs_embeds: Optional[torch.FloatTensor] = None, 42 | use_cache: Optional[bool] = None, 43 | output_attentions: Optional[bool] = None, 44 | output_hidden_states: Optional[bool] = None, 45 | mts_token_ids: Optional[torch.Tensor] = None, 46 | mts_attention_mask: Optional[torch.Tensor] = None, 47 | mts_tokenizer_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 48 | return_dict: Optional[bool] = None, 49 | ) -> Union[Tuple, BaseModelOutputWithPast]: 50 | # Check the dimensions of input_ids 51 | if input_ids.dim() != 2: 52 | raise ValueError(f"Expected input_ids to be a 2D tensor, but got {input_ids.dim()}D tensor") 53 | 54 | if mts_token_ids.dim() != 3: 55 | raise ValueError(f"Expected multichannel_ts to be a 3D tensor, but got {mts_token_ids.dim()}D tensor") 56 | 57 | orig_embeds_params = getattr(self, "orig_embeds_params", None) 58 | 59 | if inputs_embeds is None: 60 | inputs_embeds = self.embed_tokens(input_ids) 61 | 62 | pt_encoder_backbone = getattr(self, 'pt_encoder_backbone', None) 63 | 64 | if pt_encoder_backbone is not None and (input_ids.shape[1] != 1 or self.training) and mts_token_ids is not None: 65 | 66 | channel_num = mts_token_ids.size(1) 67 | with torch.no_grad() if self.fix_ts_encoder else nullcontext(): 68 | if self.fix_ts_encoder: 69 | self.pt_encoder_backbone.eval() 70 | ts_features = [] 71 | for ts_token_ids, ts_attention_mask in zip(mts_token_ids, mts_attention_mask): # * iterate over batch 72 | ts_feature = self.ts_embed(ts_token_ids, ts_attention_mask) 73 | if torch.any(torch.isnan(ts_feature)) or torch.any(torch.isinf(ts_feature)): 74 | raise ValueError("ts_feature has NaN values") 75 | ts_features.append(ts_feature) 76 | 77 | summed_ts_embeddings = [self.ts_proj(ts_feature) for ts_feature in ts_features] 78 | summed_ts_embeddings = torch.stack(summed_ts_embeddings) 79 | 80 | new_input_embeds = [] 81 | for cur_input_ids, cur_input_embeds, cur_ts_embeds in zip( 82 | input_ids, inputs_embeds, summed_ts_embeddings 83 | ): # * input_ids: B, L; input_embeds: B, L, D; summed_ts_embeddings: B, C, L_ts, D 84 | cur_ts_embeds = cur_ts_embeds.to( 85 | device=cur_input_embeds.device 86 | ) 87 | num_ts_tokens = cur_ts_embeds.shape[1] # * number of ts tokens 88 | 89 | if len(self.start_end_tokens) != channel_num: 90 | raise ValueError( 91 | f"The length of start_end_tokens ({len(self.start_end_tokens)}) and channel_num ({channel_num}) should be the same.") 92 | if len(self.start_end_tokens) != cur_ts_embeds.size(0): 93 | raise ValueError( 94 | f"The length of start_end_tokens ({len(self.start_end_tokens)}) and cur_ts_embeds ({cur_ts_embeds.size(0)}) should be the same.") 95 | 96 | total_ts_token_count = (cur_input_ids == self.ts_backbone_config["ts_token_id"]).sum() 97 | for (start_token_id, end_token_id), channel_ebd in zip(self.start_end_tokens.items(), cur_ts_embeds): 98 | start_token_count = (cur_input_ids == start_token_id).sum() 99 | end_token_count = (cur_input_ids == end_token_id).sum() 100 | 101 | if start_token_count != end_token_count: 102 | raise ValueError( 103 | f"The number of {start_token_id} tokens ({start_token_count}) and {end_token_id} tokens ({end_token_count}) should be the same.") 104 | 105 | start_token_positions = torch.where(cur_input_ids == start_token_id)[0] 106 | 107 | for start_token_pos in start_token_positions: 108 | end_token_pos = start_token_pos + num_ts_tokens + 1 109 | total_ts_token_count -= num_ts_tokens 110 | 111 | if end_token_pos >= len(cur_input_ids) or cur_input_ids[end_token_pos] != end_token_id: 112 | raise ValueError( 113 | f"The end token '{end_token_id}' should follow the start token '{start_token_id}' after {num_ts_tokens} positions." 114 | ) 115 | 116 | if orig_embeds_params is not None: # * will not update the original embeddings except for TS_START_TOKEN and TS_END_TOKEN 117 | cur_input_embeds = torch.cat( 118 | ( 119 | cur_input_embeds[:start_token_pos].detach(), 120 | cur_input_embeds[start_token_pos: start_token_pos + 1], 121 | channel_ebd, 122 | cur_input_embeds[end_token_pos: end_token_pos + 1], 123 | cur_input_embeds[end_token_pos + 1:].detach(), 124 | ), 125 | dim=0, 126 | ) 127 | else: 128 | cur_input_embeds = torch.cat( 129 | ( 130 | cur_input_embeds[:start_token_pos + 1], 131 | channel_ebd, 132 | cur_input_embeds[end_token_pos:], 133 | ), 134 | dim=0, 135 | ) 136 | if total_ts_token_count != 0: 137 | raise ValueError( 138 | f"The value of total_ts_token_count ({total_ts_token_count}) should be the 0.") 139 | new_input_embeds.append(cur_input_embeds) 140 | inputs_embeds = torch.stack(new_input_embeds, dim=0) 141 | 142 | return super(SensorLLMStage2LlamaModel, self).forward( 143 | input_ids=None, 144 | attention_mask=attention_mask, 145 | position_ids=position_ids, 146 | past_key_values=past_key_values, 147 | inputs_embeds=inputs_embeds, 148 | use_cache=use_cache, 149 | output_attentions=output_attentions, 150 | output_hidden_states=output_hidden_states, 151 | return_dict=return_dict, 152 | ) 153 | 154 | 155 | class SensorLLMStage2LlamaForCausalLM(BaseSensorLLM, LlamaForCausalLM): 156 | config_class = SensorLLMStage2Config 157 | 158 | def __init__(self, config): 159 | super(LlamaForCausalLM, self).__init__(config) 160 | self.model = SensorLLMStage2LlamaModel(config) 161 | 162 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 163 | 164 | # Initialize weights and apply final processing 165 | self.post_init() 166 | 167 | def get_model(self): 168 | return self.model 169 | 170 | def forward( 171 | self, 172 | input_ids: torch.LongTensor = None, 173 | attention_mask: Optional[torch.Tensor] = None, 174 | past_key_values: Optional[List[torch.FloatTensor]] = None, 175 | inputs_embeds: Optional[torch.FloatTensor] = None, 176 | labels: Optional[torch.LongTensor] = None, 177 | use_cache: Optional[bool] = None, # * control whether to return past_key_values 178 | output_attentions: Optional[bool] = None, 179 | output_hidden_states: Optional[bool] = None, 180 | mts_token_ids: Optional[torch.Tensor] = None, 181 | mts_attention_mask: Optional[torch.Tensor] = None, 182 | mts_tokenizer_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 183 | return_dict: Optional[bool] = None, 184 | ) -> Union[Tuple, CausalLMOutputWithPast]: 185 | output_attentions = ( 186 | output_attentions 187 | if output_attentions is not None 188 | else self.config.output_attentions 189 | ) 190 | output_hidden_states = ( 191 | output_hidden_states 192 | if output_hidden_states is not None 193 | else self.config.output_hidden_states 194 | ) 195 | return_dict = ( 196 | return_dict if return_dict is not None else self.config.use_return_dict 197 | ) 198 | 199 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 200 | outputs = self.model( 201 | input_ids=input_ids, 202 | attention_mask=attention_mask, 203 | past_key_values=past_key_values, 204 | inputs_embeds=inputs_embeds, 205 | use_cache=use_cache, 206 | output_attentions=output_attentions, 207 | output_hidden_states=output_hidden_states, 208 | return_dict=return_dict, 209 | mts_token_ids=mts_token_ids, 210 | mts_attention_mask=mts_attention_mask, 211 | mts_tokenizer_state=mts_tokenizer_state 212 | ) 213 | 214 | hidden_states = outputs[0] 215 | logits = self.lm_head(hidden_states) 216 | 217 | loss = None 218 | if labels is not None: 219 | # Shift so that tokens < n predict n 220 | shift_logits = logits[..., :-1, :].contiguous() # * B, L, V 221 | shift_labels = labels[..., 1:].contiguous() # * B, L 222 | # Flatten the tokens 223 | loss_fct = CrossEntropyLoss() 224 | shift_logits = shift_logits.view(-1, self.config.vocab_size) 225 | shift_labels = shift_labels.view(-1) 226 | # Enable model/pipeline parallelism 227 | shift_labels = shift_labels.to(shift_logits.device) 228 | loss = loss_fct(shift_logits, shift_labels) 229 | 230 | if not return_dict: 231 | output = (logits,) + outputs[1:] 232 | return (loss,) + output if loss is not None else output 233 | 234 | return CausalLMOutputWithPast( 235 | loss=loss, 236 | logits=logits, 237 | past_key_values=outputs.past_key_values, 238 | hidden_states=outputs.hidden_states, 239 | attentions=outputs.attentions, 240 | ) 241 | 242 | def prepare_inputs_for_generation( 243 | self, 244 | input_ids, 245 | past_key_values=None, 246 | attention_mask=None, 247 | inputs_embeds=None, 248 | **kwargs, 249 | ): 250 | if past_key_values: 251 | input_ids = input_ids[:, -1:] 252 | 253 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 254 | if inputs_embeds is not None and past_key_values is None: 255 | model_inputs = {"inputs_embeds": inputs_embeds} 256 | else: 257 | model_inputs = {"input_ids": input_ids} 258 | 259 | model_inputs.update( 260 | { 261 | "past_key_values": past_key_values, 262 | "use_cache": kwargs.get("use_cache"), 263 | "attention_mask": attention_mask, 264 | "mts_token_ids": kwargs.get("mts_token_ids", None), 265 | "mts_attention_mask": kwargs.get("mts_attention_mask", None), 266 | "mts_tokenizer_state": kwargs.get("mts_tokenizer_state", None), 267 | } 268 | ) 269 | model_inputs.pop("cache_position") 270 | return model_inputs 271 | 272 | 273 | class SensorLLMStage2LlamaForSequenceClassification(BaseSensorLLM, LlamaForSequenceClassification): 274 | config_class = SensorLLMStage2Config 275 | 276 | def __init__(self, config): 277 | super(LlamaForSequenceClassification, self).__init__(config) 278 | self.num_labels = config.num_labels 279 | self.model = SensorLLMStage2LlamaModel(config) 280 | self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) 281 | 282 | # Initialize weights and apply final processing 283 | self.post_init() 284 | 285 | def get_model(self): 286 | return self.model 287 | 288 | def forward( 289 | self, 290 | input_ids: torch.LongTensor = None, 291 | attention_mask: Optional[torch.Tensor] = None, 292 | position_ids: Optional[torch.LongTensor] = None, 293 | past_key_values: Optional[List[torch.FloatTensor]] = None, 294 | inputs_embeds: Optional[torch.FloatTensor] = None, 295 | labels: Optional[torch.LongTensor] = None, 296 | use_cache: Optional[bool] = None, # * control whether to return past_key_values 297 | output_attentions: Optional[bool] = None, 298 | output_hidden_states: Optional[bool] = None, 299 | mts_token_ids: Optional[torch.Tensor] = None, 300 | mts_attention_mask: Optional[torch.Tensor] = None, 301 | mts_tokenizer_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 302 | return_dict: Optional[bool] = None, 303 | ) -> Union[Tuple, SequenceClassifierOutputWithPast]: 304 | r""" 305 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 306 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., 307 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If 308 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 309 | """ 310 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 311 | 312 | transformer_outputs = self.model( 313 | input_ids=input_ids, 314 | attention_mask=attention_mask, 315 | position_ids=position_ids, 316 | past_key_values=past_key_values, 317 | inputs_embeds=inputs_embeds, 318 | use_cache=use_cache, 319 | output_attentions=output_attentions, 320 | output_hidden_states=output_hidden_states, 321 | return_dict=return_dict, 322 | mts_token_ids=mts_token_ids, 323 | mts_attention_mask=mts_attention_mask, 324 | mts_tokenizer_state=mts_tokenizer_state 325 | ) 326 | hidden_states = transformer_outputs[0] 327 | logits = self.score(hidden_states) 328 | 329 | if input_ids is not None: 330 | batch_size = input_ids.shape[0] 331 | else: 332 | batch_size = inputs_embeds.shape[0] 333 | 334 | if self.config.pad_token_id is None and batch_size != 1: 335 | raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") 336 | if self.config.pad_token_id is None: 337 | sequence_lengths = -1 338 | else: 339 | if input_ids is not None: 340 | # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility 341 | sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 342 | sequence_lengths = sequence_lengths % input_ids.shape[-1] 343 | sequence_lengths = sequence_lengths.to(logits.device) 344 | else: 345 | sequence_lengths = -1 346 | 347 | pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] 348 | 349 | loss = None 350 | if labels is not None: 351 | labels = labels.to(logits.device) 352 | if self.config.problem_type is None: 353 | if self.num_labels == 1: 354 | self.config.problem_type = "regression" 355 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): 356 | self.config.problem_type = "single_label_classification" 357 | else: 358 | self.config.problem_type = "multi_label_classification" 359 | 360 | if self.config.problem_type == "regression": 361 | loss_fct = MSELoss() 362 | if self.num_labels == 1: 363 | loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) 364 | else: 365 | loss = loss_fct(pooled_logits, labels) 366 | elif self.config.problem_type == "single_label_classification": 367 | loss_fct = CrossEntropyLoss() 368 | loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) 369 | elif self.config.problem_type == "multi_label_classification": 370 | loss_fct = BCEWithLogitsLoss() 371 | loss = loss_fct(pooled_logits, labels) 372 | if not return_dict: 373 | output = (pooled_logits,) + transformer_outputs[1:] 374 | return ((loss,) + output) if loss is not None else output 375 | 376 | return SequenceClassifierOutputWithPast( 377 | loss=loss, 378 | logits=pooled_logits, 379 | past_key_values=transformer_outputs.past_key_values, 380 | hidden_states=transformer_outputs.hidden_states, 381 | attentions=transformer_outputs.attentions, 382 | ) 383 | 384 | 385 | AutoConfig.register("sensorllmstage2", SensorLLMStage2Config) 386 | AutoModelForCausalLM.register(SensorLLMStage2Config, SensorLLMStage2LlamaForCausalLM) 387 | AutoModelForSequenceClassification.register(SensorLLMStage2Config, SensorLLMStage2LlamaForSequenceClassification) -------------------------------------------------------------------------------- /sensorllm/model/ts_backbone.yaml: -------------------------------------------------------------------------------- 1 | ts_backbone_type: "chronos" 2 | 3 | chronos_model: 4 | name: "chronos-t5-large" 5 | encoder_output_dim: 1024 6 | projection_hidden_layer: 2 7 | projection_hidden_dim: [2048, 3072] 8 | last_token: true 9 | 10 | dropout_rate: 0.3 11 | default_ts_token: "" 12 | 13 | usc-had: 14 | id2label: 15 | 0: "1. Walking Forward" 16 | 1: "2. Walking Left" 17 | 2: "3. Walking Right" 18 | 3: "4. Walking Upstairs" 19 | 4: "5. Walking Downstairs" 20 | 5: "6. Running Forward" 21 | 6: "7. Jumping" 22 | 7: "8. Sitting" 23 | 8: "9. Standing" 24 | 9: "10. Sleeping" 25 | 10: "11. Elevator Up" 26 | 11: "12. Elevator Down" 27 | num_labels: 12 28 | channel_num: 6 29 | sample_rate: 100 30 | default_x_acc_start_token: "" 31 | default_x_acc_end_token: "" 32 | default_y_acc_start_token: "" 33 | default_y_acc_end_token: "" 34 | default_z_acc_start_token: "" 35 | default_z_acc_end_token: "" 36 | default_x_gyro_start_token: "" 37 | default_x_gyro_end_token: "" 38 | default_y_gyro_start_token: "" 39 | default_y_gyro_end_token: "" 40 | default_z_gyro_start_token: "" 41 | default_z_gyro_end_token: "" 42 | 43 | uci: 44 | id2label: 45 | 0: "Walking" 46 | 1: "Walking upstairs" 47 | 2: "Walking downstairs" 48 | 3: "Sitting" 49 | 4: "Standing" 50 | 5: "Laying" 51 | num_labels: 6 52 | channel_num: 6 53 | sample_rate: 50 54 | default_x_acc_start_token: "" 55 | default_x_acc_end_token: "" 56 | default_y_acc_start_token: "" 57 | default_y_acc_end_token: "" 58 | default_z_acc_start_token: "" 59 | default_z_acc_end_token: "" 60 | default_x_gyro_start_token: "" 61 | default_x_gyro_end_token: "" 62 | default_y_gyro_start_token: "" 63 | default_y_gyro_end_token: "" 64 | default_z_gyro_start_token: "" 65 | default_z_gyro_end_token: "" 66 | 67 | mhealth: 68 | id2label: 69 | 0: 'Standing still (1 min)' 70 | 1: 'Sitting and relaxing (1 min)' 71 | 2: 'Lying down (1 min)' 72 | 3: 'Walking (1 min)' 73 | 4: 'Climbing stairs (1 min)' 74 | 5: 'Waist bends forward (20x)' 75 | 6: 'Frontal elevation of arms (20x)' 76 | 7: 'Knees bending (crouching) (20x)' 77 | 8: 'Cycling (1 min)' 78 | 9: 'Jogging (1 min)' 79 | 10: 'Running (1 min)' 80 | 11: 'Jump front & back (20x)' 81 | num_labels: 12 82 | channel_num: 15 83 | sample_rate: 50 84 | default_chest_x_acc_start_token: "" 85 | default_chest_x_acc_end_token: "" 86 | default_chest_y_acc_start_token: "" 87 | default_chest_y_acc_end_token: "" 88 | default_chest_z_acc_start_token: "" 89 | default_chest_z_acc_end_token: "" 90 | default_left_ankle_x_acc_start_token: "" 91 | default_left_ankle_x_acc_end_token: "" 92 | default_left_ankle_y_acc_start_token: "" 93 | default_left_ankle_y_acc_end_token: "" 94 | default_left_ankle_z_acc_start_token: "" 95 | default_left_ankle_z_acc_end_token: "" 96 | default_left_ankle_x_gyro_start_token: "" 97 | default_left_ankle_x_gyro_end_token: "" 98 | default_left_ankle_y_gyro_start_token: "" 99 | default_left_ankle_y_gyro_end_token: "" 100 | default_left_ankle_z_gyro_start_token: "" 101 | default_left_ankle_z_gyro_end_token: "" 102 | default_right_lower_arm_x_acc_start_token: "" 103 | default_right_lower_arm_x_acc_end_token: "" 104 | default_right_lower_arm_y_acc_start_token: "" 105 | default_right_lower_arm_y_acc_end_token: "" 106 | default_right_lower_arm_z_acc_start_token: "" 107 | default_right_lower_arm_z_acc_end_token: "" 108 | default_right_lower_arm_x_gyro_start_token: "" 109 | default_right_lower_arm_x_gyro_end_token: "" 110 | default_right_lower_arm_y_gyro_start_token: "" 111 | default_right_lower_arm_y_gyro_end_token: "" 112 | default_right_lower_arm_z_gyro_start_token: "" 113 | default_right_lower_arm_z_gyro_end_token: "" 114 | 115 | pamap50: 116 | id2label: 117 | 0: 'lying' 118 | 1: 'sitting' 119 | 2: 'standing' 120 | 3: 'walking' 121 | 4: 'running' 122 | 5: 'cycling' 123 | 6: 'Nordic walking' 124 | 7: 'ascending stairs' 125 | 8: 'descending stairs' 126 | 9: 'vacuum cleaning' 127 | 10: 'ironing' 128 | 11: 'rope jumping' 129 | num_labels: 12 130 | channel_num: 27 131 | sample_rate: 50 132 | default_chest_x_acc_start_token: "" 133 | default_chest_x_acc_end_token: "" 134 | default_chest_y_acc_start_token: "" 135 | default_chest_y_acc_end_token: "" 136 | default_chest_z_acc_start_token: "" 137 | default_chest_z_acc_end_token: "" 138 | default_chest_x_gyro_start_token: "" 139 | default_chest_x_gyro_end_token: "" 140 | default_chest_y_gyro_start_token: "" 141 | default_chest_y_gyro_end_token: "" 142 | default_chest_z_gyro_start_token: "" 143 | default_chest_z_gyro_end_token: "" 144 | default_chest_x_mag_start_token: "" 145 | default_chest_x_mag_end_token: "" 146 | default_chest_y_mag_start_token: "" 147 | default_chest_y_mag_end_token: "" 148 | default_chest_z_mag_start_token: "" 149 | default_chest_z_mag_end_token: "" 150 | default_ankle_x_acc_start_token: "" 151 | default_ankle_x_acc_end_token: "" 152 | default_ankle_y_acc_start_token: "" 153 | default_ankle_y_acc_end_token: "" 154 | default_ankle_z_acc_start_token: "" 155 | default_ankle_z_acc_end_token: "" 156 | default_ankle_x_gyro_start_token: "" 157 | default_ankle_x_gyro_end_token: "" 158 | default_ankle_y_gyro_start_token: "" 159 | default_ankle_y_gyro_end_token: "" 160 | default_ankle_z_gyro_start_token: "" 161 | default_ankle_z_gyro_end_token: "" 162 | default_ankle_x_mag_start_token: "" 163 | default_ankle_x_mag_end_token: "" 164 | default_ankle_y_mag_start_token: "" 165 | default_ankle_y_mag_end_token: "" 166 | default_ankle_z_mag_start_token: "" 167 | default_ankle_z_mag_end_token: "" 168 | default_hand_x_acc_start_token: "" 169 | default_hand_x_acc_end_token: "" 170 | default_hand_y_acc_start_token: "" 171 | default_hand_y_acc_end_token: "" 172 | default_hand_z_acc_start_token: "" 173 | default_hand_z_acc_end_token: "" 174 | default_hand_x_gyro_start_token: "" 175 | default_hand_x_gyro_end_token: "" 176 | default_hand_y_gyro_start_token: "" 177 | default_hand_y_gyro_end_token: "" 178 | default_hand_z_gyro_start_token: "" 179 | default_hand_z_gyro_end_token: "" 180 | default_hand_x_mag_start_token: "" 181 | default_hand_x_mag_end_token: "" 182 | default_hand_y_mag_start_token: "" 183 | default_hand_y_mag_end_token: "" 184 | default_hand_z_mag_start_token: "" 185 | default_hand_z_mag_end_token: "" 186 | 187 | pamap: 188 | id2label: 189 | 0: 'lying' 190 | 1: 'sitting' 191 | 2: 'standing' 192 | 3: 'walking' 193 | 4: 'running' 194 | 5: 'cycling' 195 | 6: 'Nordic walking' 196 | 7: 'ascending stairs' 197 | 8: 'descending stairs' 198 | 9: 'vacuum cleaning' 199 | 10: 'ironing' 200 | 11: 'rope jumping' 201 | num_labels: 12 202 | channel_num: 27 203 | sample_rate: 100 204 | default_chest_x_acc_start_token: "" 205 | default_chest_x_acc_end_token: "" 206 | default_chest_y_acc_start_token: "" 207 | default_chest_y_acc_end_token: "" 208 | default_chest_z_acc_start_token: "" 209 | default_chest_z_acc_end_token: "" 210 | default_chest_x_gyro_start_token: "" 211 | default_chest_x_gyro_end_token: "" 212 | default_chest_y_gyro_start_token: "" 213 | default_chest_y_gyro_end_token: "" 214 | default_chest_z_gyro_start_token: "" 215 | default_chest_z_gyro_end_token: "" 216 | default_chest_x_mag_start_token: "" 217 | default_chest_x_mag_end_token: "" 218 | default_chest_y_mag_start_token: "" 219 | default_chest_y_mag_end_token: "" 220 | default_chest_z_mag_start_token: "" 221 | default_chest_z_mag_end_token: "" 222 | default_ankle_x_acc_start_token: "" 223 | default_ankle_x_acc_end_token: "" 224 | default_ankle_y_acc_start_token: "" 225 | default_ankle_y_acc_end_token: "" 226 | default_ankle_z_acc_start_token: "" 227 | default_ankle_z_acc_end_token: "" 228 | default_ankle_x_gyro_start_token: "" 229 | default_ankle_x_gyro_end_token: "" 230 | default_ankle_y_gyro_start_token: "" 231 | default_ankle_y_gyro_end_token: "" 232 | default_ankle_z_gyro_start_token: "" 233 | default_ankle_z_gyro_end_token: "" 234 | default_ankle_x_mag_start_token: "" 235 | default_ankle_x_mag_end_token: "" 236 | default_ankle_y_mag_start_token: "" 237 | default_ankle_y_mag_end_token: "" 238 | default_ankle_z_mag_start_token: "" 239 | default_ankle_z_mag_end_token: "" 240 | default_hand_x_acc_start_token: "" 241 | default_hand_x_acc_end_token: "" 242 | default_hand_y_acc_start_token: "" 243 | default_hand_y_acc_end_token: "" 244 | default_hand_z_acc_start_token: "" 245 | default_hand_z_acc_end_token: "" 246 | default_hand_x_gyro_start_token: "" 247 | default_hand_x_gyro_end_token: "" 248 | default_hand_y_gyro_start_token: "" 249 | default_hand_y_gyro_end_token: "" 250 | default_hand_z_gyro_start_token: "" 251 | default_hand_z_gyro_end_token: "" 252 | default_hand_x_mag_start_token: "" 253 | default_hand_x_mag_end_token: "" 254 | default_hand_y_mag_start_token: "" 255 | default_hand_y_mag_end_token: "" 256 | default_hand_z_mag_start_token: "" 257 | default_hand_z_mag_end_token: "" 258 | 259 | capture24: 260 | id2label: 261 | 0: 'sleep' 262 | 1: 'sitting' 263 | 2: 'household-chores' 264 | 3: 'walking' 265 | 4: 'vehicle' 266 | 5: 'bicycling' 267 | 6: 'mixed-activity' 268 | 7: 'standing' 269 | 8: 'manual-work' 270 | 9: 'sports' 271 | num_labels: 10 272 | channel_num: 3 273 | sample_rate: 50 274 | default_x_acc_start_token: "" 275 | default_x_acc_end_token: "" 276 | default_y_acc_start_token: "" 277 | default_y_acc_end_token: "" 278 | default_z_acc_start_token: "" 279 | default_z_acc_end_token: "" -------------------------------------------------------------------------------- /sensorllm/model/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import StoppingCriteria 3 | import torch.nn as nn 4 | import math 5 | from sensorllm.utils import * 6 | from .chronos_model import * 7 | from transformers import ( 8 | LlamaConfig, 9 | LlamaModel, 10 | PreTrainedModel 11 | ) 12 | import logging 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | DEFAULT_PAD_TOKEN = "[PAD]" 17 | DEFAULT_EOS_TOKEN = "" 18 | DEFAULT_BOS_TOKEN = "" 19 | DEFAULT_UNK_TOKEN = "" 20 | 21 | 22 | class BaseSensorLLMModel(LlamaModel): 23 | def __init__(self, config: LlamaConfig): 24 | super(BaseSensorLLMModel, self).__init__(config) 25 | 26 | current_dir = os.path.dirname(os.path.abspath(__file__)) 27 | print(f"current dir: {current_dir}") 28 | ts_config_addr = os.path.join(current_dir, "ts_backbone.yaml") 29 | self.ts_backbone_config = cfg_from_yaml_file(ts_config_addr) 30 | 31 | # self.ts_backbone_config["ts_backbone_output_dimension"] = self.config.hidden_size 32 | 33 | logger.warning( 34 | f"The hidden size of LLM is {self.config.hidden_size}." 35 | ) 36 | 37 | backbone_output_dim = self.ts_backbone_config['chronos_model']['encoder_output_dim'] 38 | logger.warning( 39 | f"{self.ts_backbone_config['chronos_model']['name']} output dim: {self.ts_backbone_config['chronos_model']['encoder_output_dim']}.") 40 | logger.warning( 41 | f"Use {self.ts_backbone_config['chronos_model']['projection_hidden_layer']} projection hidden layers.") 42 | if self.ts_backbone_config['chronos_model']['projection_hidden_layer'] > 0: 43 | # Add projection layer with linear layers and GELU activation 44 | projection_layers = [] 45 | last_dim = backbone_output_dim 46 | for i in range(self.ts_backbone_config['chronos_model']['projection_hidden_layer']): 47 | projection_layers.append( 48 | nn.Linear(last_dim, self.ts_backbone_config['chronos_model']["projection_hidden_dim"][i])) 49 | projection_layers.append(nn.GELU()) 50 | last_dim = self.ts_backbone_config['chronos_model']["projection_hidden_dim"][i] 51 | 52 | projection_layers.append(nn.Linear(last_dim, self.config.hidden_size)) 53 | self.ts_proj = nn.Sequential(*projection_layers) 54 | logger.warning( 55 | f"Each layer with {self.ts_backbone_config['chronos_model']['projection_hidden_dim']} hidden units.") 56 | else: 57 | # Single layer 58 | self.ts_proj = nn.Linear(backbone_output_dim, self.config.hidden_size) 59 | logger.warning(f"TS projector output dim: {self.config.hidden_size}.") 60 | 61 | self.fix_llm = False 62 | self.fix_ts_encoder = False 63 | 64 | def load_pt_encoder_backbone_checkpoint(self, checkpoint_path=None, tc=None, torch_dtype=None): 65 | logger.warning(f"Loading default pt_encoder_backbone_ckpt ...") 66 | pipeline = ChronosPipeline.from_pretrained( 67 | self.config.pt_encoder_backbone_ckpt if checkpoint_path is None else checkpoint_path, 68 | device_map=self.device, 69 | tc="MeanScaleUniformBins" if tc is None else tc, 70 | torch_dtype=torch.float32 if torch_dtype is None else torch_dtype, 71 | ) 72 | self.pt_encoder_backbone = pipeline.model 73 | self.pt_encoder_backbone.to(self.device) 74 | 75 | def load_start_end_tokens(self, dataset=None): 76 | logger.warning(f"Loading start_end_tokens dict for {dataset} dataset ...") 77 | dataset_config = self.ts_backbone_config[dataset] 78 | self.start_end_tokens = {} 79 | for key in dataset_config: 80 | if key.endswith('_start_token_id'): 81 | end_token_key = key.replace('_start_token_id', '_end_token_id') 82 | 83 | assert end_token_key in dataset_config 84 | self.start_end_tokens[dataset_config[key]] = dataset_config[end_token_key] 85 | 86 | @torch.no_grad() 87 | def ts_embed( 88 | self, 89 | token_ids: torch.Tensor, 90 | attention_mask: torch.Tensor 91 | ) -> torch.Tensor: 92 | """ 93 | Get encoder embeddings for the given time series. 94 | 95 | Parameters 96 | ---------- 97 | context 98 | Input series. This is either a 1D tensor, or a list 99 | of 1D tensors, or a 2D tensor whose first dimension 100 | is batch. In the latter case, use left-padding with 101 | ``torch.nan`` to align series of different lengths. 102 | 103 | Returns 104 | ------- 105 | embeddings, tokenizer_state 106 | A tuple of two tensors: the encoder embeddings and the tokenizer_state, 107 | e.g., the scale of the time series in the case of mean scaling. 108 | The encoder embeddings are shaped (batch_size, context_length, d_model) 109 | or (batch_size, context_length + 1, d_model), where context_length 110 | is the size of the context along the time axis if a 2D tensor was provided 111 | or the length of the longest time series, if a list of 1D tensors was 112 | provided, and the extra 1 is for EOS. 113 | """ 114 | 115 | embeddings = self.pt_encoder_backbone.encode( 116 | input_ids=token_ids, 117 | attention_mask=attention_mask, 118 | ) 119 | # if str(self.model.device) == 'cuda:1': 120 | # print("3", embeddings) 121 | 122 | return embeddings 123 | 124 | 125 | 126 | class BaseSensorLLM(PreTrainedModel): 127 | def initialize_tokenizer_ts_backbone_config_wo_embedding(self, tokenizer, dataset): 128 | # * called when stage2 or inference or inference without pre-training, assume tokenizer has time-series tokens 129 | ts_backbone_config = self.get_model().ts_backbone_config 130 | 131 | default_ts_token = ts_backbone_config["default_ts_token"] # 132 | # print(tokenizer.convert_tokens_to_ids([default_ts_token, "", ""])) 133 | 134 | tokenizer.add_tokens([default_ts_token], special_tokens=True) 135 | 136 | # * assert tokenizer has the default_ts_token 137 | ts_backbone_config["ts_token_id"] = tokenizer.convert_tokens_to_ids([default_ts_token])[0] 138 | 139 | if dataset not in ts_backbone_config: 140 | raise ValueError(f"Cannot find {dataset} in ts_backbone.yaml file.") 141 | 142 | dataset_config = ts_backbone_config[dataset] 143 | 144 | token_keys = [key for key in dataset_config.keys() if key.startswith('default_') and key.endswith('_token')] 145 | assert len(token_keys) == dataset_config["channel_num"]*2, f"len(token_keys) ! channel_num*2" 146 | tokenizer.add_tokens([dataset_config[token_key] for token_key in token_keys], special_tokens=True) 147 | 148 | for token_key in token_keys: 149 | token_id_key = token_key.replace('default_', '').replace('_token', '_token_id') 150 | dataset_config[token_id_key] = tokenizer.convert_tokens_to_ids([dataset_config[token_key]])[0] 151 | 152 | special_tokens_dict = dict() 153 | if tokenizer.pad_token is None: 154 | special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN 155 | if tokenizer.eos_token is None: 156 | special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN 157 | if tokenizer.bos_token is None: 158 | special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN 159 | if tokenizer.unk_token is None: 160 | special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN 161 | 162 | tokenizer.add_special_tokens(special_tokens_dict) 163 | 164 | def initialize_tokenizer_ts_backbone_config( 165 | self, tokenizer, device, fix_llm=True, dataset='usc-had' 166 | ): 167 | 168 | ts_backbone_config = self.get_model().ts_backbone_config 169 | 170 | default_ts_token = ts_backbone_config["default_ts_token"] # 171 | 172 | tokenizer.add_tokens( 173 | [default_ts_token], special_tokens=True 174 | ) # * no need to update embed since it will be replaced 175 | 176 | ts_backbone_config["ts_token_id"] = tokenizer.convert_tokens_to_ids([default_ts_token])[0] 177 | 178 | if dataset not in ts_backbone_config: 179 | raise ValueError(f"Cannot find {dataset} in ts_backbone.yaml file.") 180 | 181 | dataset_config = ts_backbone_config[dataset] 182 | 183 | token_keys = [key for key in dataset_config.keys() if key.startswith('default_') and key.endswith('_token')] 184 | assert len(token_keys) == dataset_config["channel_num"]*2, f"len(token_keys) ! channel_num*2" 185 | 186 | num_new_tokens = tokenizer.add_tokens([dataset_config[token_key] for token_key in token_keys], special_tokens=True) 187 | 188 | special_tokens_dict = dict() 189 | if tokenizer.pad_token is None: 190 | special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN 191 | if tokenizer.eos_token is None: 192 | special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN 193 | if tokenizer.bos_token is None: 194 | special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN 195 | if tokenizer.unk_token is None: 196 | special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN 197 | 198 | num_new_tokens += tokenizer.add_special_tokens(special_tokens_dict) 199 | 200 | self.resize_token_embeddings( 201 | len(tokenizer) 202 | ) # ! resize_token_embeddings will make the tokens trainable again 203 | 204 | for token_key in token_keys: 205 | token_id_key = token_key.replace('default_', '').replace('_token', '_token_id') 206 | dataset_config[token_id_key] = tokenizer.convert_tokens_to_ids([dataset_config[token_key]])[0] 207 | 208 | if num_new_tokens > 0: 209 | # Get the input embedding and output embedding of the model 210 | print("Calculate the average of the input embedding as the initialization value of the new token...") 211 | input_embeddings = self.get_input_embeddings().weight.data 212 | 213 | # Calculate the average of the input embedding and output embedding as the initialization value of the new token 214 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( 215 | dim=0, keepdim=True 216 | ) 217 | 218 | # Initialize the embedding of the new token to the average of the previous tokens 219 | input_embeddings[-num_new_tokens:] = input_embeddings_avg 220 | 221 | if hasattr(self, 'get_output_embeddings') and callable(getattr(self, 'get_output_embeddings')): 222 | output_embeddings = self.get_output_embeddings() 223 | if output_embeddings is not None: 224 | print("Calculate the average of the output embedding as the initialization value of the new token...") 225 | output_embeddings = self.get_output_embeddings().weight.data 226 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( 227 | dim=0, keepdim=True 228 | ) 229 | output_embeddings[-num_new_tokens:] = output_embeddings_avg 230 | else: 231 | print( 232 | f"Output embeddings not available.") 233 | 234 | # Set resize_token_embeddings to a multiple of 8 to improve performance 235 | T, E = input_embeddings = self.get_input_embeddings().weight.shape 236 | self.resize_token_embeddings(int(8 * math.ceil(T / 8.0))) 237 | 238 | # need to update the input embedding, but no need to update the output embedding 239 | for p in self.get_input_embeddings().parameters(): 240 | p.requires_grad = True 241 | 242 | if fix_llm: 243 | # Save original input embeddings 244 | self.get_model().orig_embeds_params = [ 245 | self.get_input_embeddings().weight.data.clone().to(device=device) 246 | ] # * only tuning the new embeddings 247 | 248 | # Try to fix output embeddings if the method exists 249 | if hasattr(self, 'get_output_embeddings') and callable(getattr(self, 'get_output_embeddings')): 250 | output_embeddings = self.get_output_embeddings() 251 | if output_embeddings is not None: 252 | for p in output_embeddings.parameters(): 253 | p.requires_grad = False 254 | print("Setting output embeddings fixed.") 255 | else: 256 | print("Output embeddings not available.") 257 | print(f"Setting {num_new_tokens} new tokens' input embeddings trainable.") 258 | else: 259 | self.get_model().orig_embeds_params = None 260 | 261 | # Try to make output embeddings trainable if the method exists 262 | if hasattr(self, 'get_output_embeddings') and callable(getattr(self, 'get_output_embeddings')): 263 | output_embeddings = self.get_output_embeddings() 264 | if output_embeddings is not None: 265 | for p in output_embeddings.parameters(): 266 | p.requires_grad = True 267 | print("Setting output embeddings and all input embeddings trainable.") 268 | else: 269 | print("Setting all input embeddings trainable.") 270 | else: 271 | print("Output embeddings not available. Setting all input embeddings trainable.") 272 | 273 | 274 | class KeywordsStoppingCriteria(StoppingCriteria): 275 | def __init__(self, keywords, tokenizer, input_ids): 276 | self.keywords = keywords 277 | self.keyword_ids = [tokenizer(keyword).input_ids for keyword in keywords] 278 | self.keyword_ids = [keyword_id[0] for keyword_id in self.keyword_ids if 279 | type(keyword_id) is list and len(keyword_id) == 1] 280 | self.tokenizer = tokenizer 281 | self.start_len = None 282 | self.input_ids = input_ids 283 | 284 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 285 | if self.start_len is None: 286 | self.start_len = self.input_ids.shape[1] 287 | else: 288 | for keyword_id in self.keyword_ids: 289 | if output_ids[0, -1] == keyword_id: 290 | return True 291 | outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0] 292 | for keyword in self.keywords: 293 | if keyword in outputs: 294 | return True 295 | return False 296 | -------------------------------------------------------------------------------- /sensorllm/train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zechenli03/SensorLLM/1f4142a30f452721e943771190fe1dade3337249/sensorllm/train/__init__.py -------------------------------------------------------------------------------- /sensorllm/train/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: 2 | import warnings 3 | from typing import Optional, Tuple 4 | 5 | import torch 6 | from flash_attn import __version__ as flash_attn_version 7 | from flash_attn.bert_padding import pad_input, unpad_input 8 | from flash_attn.flash_attn_interface import ( 9 | flash_attn_func, 10 | flash_attn_varlen_kvpacked_func, 11 | ) 12 | from transformers.models.llama.modeling_llama import ( 13 | LlamaAttention, 14 | LlamaModel, 15 | rotate_half, 16 | ) 17 | 18 | 19 | def apply_rotary_pos_emb(q, k, cos_sin, position_ids): 20 | gather_indices = position_ids[:, :, None, None] # [bsz, seq_len, 1, 1] 21 | gather_indices = gather_indices.repeat( 22 | 1, 1, cos_sin[0].shape[1], cos_sin[0].shape[3] 23 | ) 24 | bsz = gather_indices.shape[0] 25 | cos, sin = ( 26 | torch.gather(x.transpose(1, 2).repeat(bsz, 1, 1, 1), 1, gather_indices) 27 | for x in cos_sin 28 | ) 29 | q, k = ((x * cos) + (rotate_half(x) * sin) for x in (q, k)) 30 | return q, k 31 | 32 | 33 | def forward( 34 | self, 35 | hidden_states: torch.Tensor, 36 | attention_mask: Optional[torch.Tensor] = None, 37 | position_ids: Optional[torch.Tensor] = None, 38 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 39 | output_attentions: bool = False, 40 | use_cache: bool = False, 41 | padding_mask: Optional[torch.Tensor] = None, 42 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 43 | if output_attentions: 44 | warnings.warn( 45 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." 46 | ) 47 | 48 | bsz, q_len, _ = hidden_states.size() 49 | kv_heads = getattr(self, "num_key_value_heads", self.num_heads) 50 | 51 | q, k, v = ( 52 | op(hidden_states).view(bsz, q_len, nh, self.head_dim) 53 | for op, nh in ( 54 | (self.q_proj, self.num_heads), 55 | (self.k_proj, kv_heads), 56 | (self.v_proj, kv_heads), 57 | ) 58 | ) 59 | # shape: (b, s, num_heads, head_dim) 60 | 61 | kv_seq_len = k.shape[1] 62 | past_kv_len = 0 63 | if past_key_value is not None: 64 | past_kv_len = past_key_value[0].shape[2] 65 | kv_seq_len += past_kv_len 66 | 67 | cos_sin = self.rotary_emb(v, seq_len=kv_seq_len) 68 | q, k = apply_rotary_pos_emb(q, k, cos_sin, position_ids) 69 | 70 | if past_key_value is not None: 71 | assert ( 72 | flash_attn_version >= "2.1.0" 73 | ), "past_key_value support requires flash-attn >= 2.1.0" 74 | # reuse k, v 75 | k = torch.cat([past_key_value[0].transpose(1, 2), k], dim=1) 76 | v = torch.cat([past_key_value[1].transpose(1, 2), v], dim=1) 77 | 78 | past_key_value = (k.transpose(1, 2), v.transpose(1, 2)) if use_cache else None 79 | 80 | if attention_mask is None: 81 | output = flash_attn_func(q, k, v, 0.0, softmax_scale=None, causal=True).view( 82 | bsz, q_len, -1 83 | ) 84 | else: 85 | q, indices, cu_q_lens, max_s = unpad_input(q, attention_mask[:, -q_len:]) 86 | # We can skip concat and call unpad twice but seems better to call unpad only once. 87 | kv, _, cu_k_lens, max_k = unpad_input( 88 | torch.stack((k, v), dim=2), attention_mask 89 | ) 90 | output_unpad = flash_attn_varlen_kvpacked_func( 91 | q, 92 | kv, 93 | cu_q_lens, 94 | cu_k_lens, 95 | max_s, 96 | max_k, 97 | 0.0, 98 | softmax_scale=None, 99 | causal=True, 100 | ) 101 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) 102 | output = pad_input(output_unpad, indices, bsz, q_len) 103 | 104 | return self.o_proj(output), None, past_key_value 105 | 106 | 107 | # Disable the transformation of the attention mask in LlamaModel as flash attention 108 | # takes a boolean key_padding_mask. Fills in the past kv length for use in forward. 109 | def _prepare_decoder_attention_mask( 110 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 111 | ): 112 | # [bsz, seq_len] 113 | if past_key_values_length > 0 and attention_mask is not None: 114 | attention_mask = torch.cat( 115 | ( 116 | torch.full( 117 | (input_shape[0], past_key_values_length), 118 | True, 119 | dtype=attention_mask.dtype, 120 | device=attention_mask.device, 121 | ), 122 | attention_mask, 123 | ), 124 | dim=-1, 125 | ) 126 | 127 | if attention_mask is not None and torch.all(attention_mask): 128 | return None # This uses the faster call when training with full samples 129 | 130 | return attention_mask 131 | 132 | 133 | def replace_llama_attn_with_flash_attn(): 134 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 135 | if cuda_major < 8: 136 | warnings.warn( 137 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." 138 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" 139 | ) 140 | 141 | LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask 142 | LlamaAttention.forward = forward 143 | 144 | 145 | def test(): 146 | from fastchat.train.llama_flash_attn_monkey_patch import forward as fastchat_forward 147 | from transformers.models.llama.configuration_llama import LlamaConfig 148 | 149 | config = LlamaConfig( 150 | hidden_size=1024, 151 | intermediate_size=128, 152 | num_hidden_layers=1, 153 | num_attention_heads=8, 154 | max_position_embeddings=16, 155 | ) 156 | device = torch.device("cuda") 157 | model = LlamaModel(config) 158 | attn = LlamaAttention(config).to(device).half() 159 | bsz, hs, seqlen = 2, config.hidden_size, config.max_position_embeddings 160 | position_ids = torch.arange(seqlen, dtype=torch.long, device=device).view( 161 | -1, seqlen 162 | ) 163 | 164 | mask = torch.full((bsz, seqlen), True, dtype=torch.bool, device=device) 165 | for i in range(4): 166 | hidden = torch.rand((bsz, seqlen, hs), dtype=torch.float16, device=device) 167 | if i: 168 | mask[0, -i:] = False 169 | mask[1, :i] = False 170 | 171 | lmask = model._prepare_decoder_attention_mask(mask, hidden.shape[:2], hidden, 0) 172 | ref, _, _ = attn.forward( 173 | hidden, attention_mask=lmask, position_ids=position_ids 174 | ) 175 | 176 | fast, _, _ = fastchat_forward( 177 | attn, hidden, attention_mask=mask, position_ids=position_ids 178 | ) 179 | 180 | lmask = _prepare_decoder_attention_mask( 181 | model, mask, hidden.shape[:2], hidden, 0 182 | ) 183 | test, _, _ = forward( 184 | attn, hidden, attention_mask=lmask, position_ids=position_ids 185 | ) 186 | 187 | print(f"Mean(abs(ref)) = {torch.mean(torch.abs(ref))}") 188 | print(f"Mean(abs(ref - fast)) = {torch.mean(torch.abs(ref - fast))}") 189 | print(f"Mean(abs(ref - test)) = {torch.mean(torch.abs(ref - test))}") 190 | print(f"Mean(abs(fast - test)) = {torch.mean(torch.abs(fast - test))}") 191 | print(f"allclose(fast, test) = {torch.allclose(fast, test)}") 192 | 193 | with torch.no_grad(): 194 | # Also check that past_kv is handled properly 195 | hidden = torch.rand((bsz, seqlen, hs), dtype=torch.float16, device=device) 196 | part_len = seqlen // 4 197 | assert part_len * 4 == seqlen 198 | mask = torch.full((bsz, seqlen), True, dtype=torch.bool, device=device) 199 | mask[0, -2:] = False 200 | lmask = _prepare_decoder_attention_mask( 201 | model, mask, hidden.shape[:2], hidden, 0 202 | ) 203 | oneshot, _, _ = forward( 204 | attn, hidden, attention_mask=lmask, position_ids=position_ids 205 | ) 206 | parts = [] 207 | past_kv, past_kv_len = None, 0 208 | for i in range(4): 209 | start = part_len * i 210 | end = start + part_len 211 | hidden_part = hidden[:, start:end, ...] 212 | lmask = _prepare_decoder_attention_mask( 213 | model, 214 | mask[:, start:end], 215 | hidden_part.shape[:2], 216 | hidden_part, 217 | past_kv_len, 218 | ) 219 | part, _, past_kv = forward( 220 | attn, 221 | hidden_part.clone(), 222 | attention_mask=lmask, 223 | position_ids=position_ids[:, start:end], 224 | past_key_value=past_kv, 225 | use_cache=True, 226 | ) 227 | parts.append(part) 228 | past_kv_len = past_kv[0].shape[2] 229 | 230 | print( 231 | f"allclose(oneshot[:, 0], parts[0]) = {torch.allclose(oneshot[:, :part_len], parts[0])}" 232 | ) 233 | print( 234 | f"allclose(oneshot, parts) = {torch.allclose(oneshot, torch.cat(parts, dim=1))}" 235 | ) 236 | 237 | 238 | if __name__ == "__main__": 239 | test() -------------------------------------------------------------------------------- /sensorllm/train/sensorllm_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | 5 | from transformers import Trainer 6 | from typing import Optional 7 | 8 | 9 | def unwrap_model(model: nn.Module) -> nn.Module: 10 | """ 11 | Recursively unwraps a model from potential containers (as used in distributed training). 12 | 13 | Args: 14 | model (`torch.nn.Module`): The model to unwrap. 15 | """ 16 | # since there could be multiple levels of wrapping, unwrap recursively 17 | if hasattr(model, "module"): 18 | return unwrap_model(model.module) 19 | else: 20 | return model 21 | 22 | 23 | class SensorLLMTrainer(Trainer): 24 | 25 | def _save(self, output_dir: Optional[str] = None, state_dict=None): 26 | # Save the model 27 | _state_dict = state_dict 28 | if _state_dict is None: 29 | # Only save the model itself if we are using distributed training 30 | model_to_save = unwrap_model(self.model) 31 | _state_dict = model_to_save.state_dict() 32 | 33 | keys_to_match = ['pt_encoder_backbone'] 34 | filtered_state_dict = {k: v for k, v in _state_dict.items() if 35 | not any(key_match in k for key_match in keys_to_match)} 36 | 37 | super(SensorLLMTrainer, self)._save(output_dir, filtered_state_dict) 38 | 39 | 40 | class SensorLLMWeightedCELossTrainer(SensorLLMTrainer): 41 | def __init__(self, *args, class_weights=None, **kwargs): 42 | super().__init__(*args, **kwargs) 43 | # Ensure label_weights is a tensor 44 | assert class_weights is not None, "class_weights for SensorLLMWeightedCELossTrainer is None" 45 | print(f"class_weights: {class_weights}") 46 | self.class_weights = class_weights 47 | 48 | def compute_loss(self, model, inputs, return_outputs=False): 49 | # Extract labels and convert them to long type for cross_entropy 50 | labels = inputs.pop("labels") 51 | 52 | # Forward pass 53 | outputs = model(**inputs) 54 | 55 | # Extract logits assuming they are directly outputted by the model 56 | logits = outputs.get('logits') 57 | 58 | # Compute custom loss with class weights for imbalanced data handling 59 | assert self.class_weights is not None, "self.class_weights is None" 60 | loss_fct = torch.nn.CrossEntropyLoss( 61 | weight=torch.tensor(self.class_weights, device=model.device, dtype=logits.dtype)) 62 | 63 | loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1)) 64 | 65 | return (loss, outputs) if return_outputs else loss 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /sensorllm/train/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pathlib 4 | import yaml 5 | from dataclasses import dataclass, field 6 | from typing import Optional, List 7 | import transformers 8 | from transformers import AutoConfig 9 | 10 | import nltk 11 | from nltk.translate.bleu_score import sentence_bleu 12 | from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score 13 | import numpy as np 14 | 15 | from sensorllm.model import * 16 | from sensorllm.model.chronos_model import * 17 | from sensorllm.train.sensorllm_trainer import SensorLLMTrainer, SensorLLMWeightedCELossTrainer 18 | import logging as logger 19 | from sensorllm.data import make_ts_text_data_module, make_ts_text_data_module_stage2, make_ts_classification_data_module_stage2 20 | import warnings 21 | 22 | import evaluate 23 | 24 | warnings.filterwarnings("ignore") 25 | 26 | 27 | @dataclass 28 | class ModelArguments: 29 | model_name_or_path: Optional[str] = field(default="") 30 | pt_encoder_backbone_ckpt: str = field(default=None) 31 | tokenize_method: str = field(default='MeanScaleUniformBins', 32 | metadata={"help": "MeanScaleUniformBins or StanNormalizeUniformBins."}) 33 | model_type: Optional[str] = field(default="CasualLM", 34 | metadata={"help": "CasualLM or SequenceClassification."}) 35 | 36 | 37 | @dataclass 38 | class DataArguments: 39 | dataset: str = field( 40 | default="usc-had", 41 | metadata={"help": "usc-had, mhealth, pamap, pamap50, uci, capture24"}, 42 | ) 43 | data_path: str = field( 44 | default="", 45 | metadata={"help": "Path to the training data."}, 46 | ) 47 | qa_path: str = field( 48 | default="", 49 | metadata={"help": "Path to the training QA data."}, 50 | ) 51 | eval_data_path: str = field( 52 | default="", 53 | metadata={"help": "Path to the eval data."}, 54 | ) 55 | eval_qa_path: str = field( 56 | default="", 57 | metadata={"help": "Path to the eval QA data."}, 58 | ) 59 | shuffle: bool = field(default=True, metadata={"help": "Whether to shuffle data."}) 60 | ignore_qa_types: List[str] = field(default_factory=lambda: ["trend"]) 61 | preprocess_type: str = field(default='Q', 62 | metadata={"help": "Q or Q+cot."}) 63 | preprocess_type_eval: str = field(default='Q+cot', 64 | metadata={"help": "Q or Q+cot."}) 65 | add_ts_special_token_text: bool = field(default=False) 66 | 67 | 68 | @dataclass 69 | class TrainingArguments(transformers.TrainingArguments): 70 | cache_dir: Optional[str] = field(default=None) 71 | optim: str = field(default="adamw_torch") 72 | model_max_length: int = field( 73 | default=8192, 74 | metadata={ 75 | "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." 76 | }, 77 | ) 78 | 79 | num_labels: int = field( 80 | default=12, 81 | metadata={ 82 | "help": "Number of output labels." 83 | }, 84 | ) 85 | use_weighted_loss: bool = field(default=True, metadata={"help": "Use weighted loss for classification model."}) 86 | 87 | fix_llm: bool = field(default=True, metadata={"help": "Whether to fix the LLM."}) 88 | fix_ts_encoder: bool = field(default=True, metadata={"help": "Whether to fix the pretrained ts encoder."}) 89 | fix_cls_head: bool = field(default=False, metadata={"help": "Whether to fix the cls head of LLM."}) 90 | stage_2: bool = field(default=False) # * set True when fine-tuning 91 | only_stage2: bool = field(default=False) 92 | # * ts backbone ckpt path 93 | tune_mm_mlp_adapter: bool = field(default=True) 94 | metric_for_best_model: str = field(default='eval_loss') 95 | 96 | 97 | def print_trainable_parameters(model): 98 | all_param = 0 99 | trainable_params = 0 100 | for name, param in model.named_parameters(): 101 | all_param += param.numel() 102 | if param.requires_grad: 103 | print(f"Layer: {name}, Trainable: {param.requires_grad}") 104 | trainable_params += param.numel() 105 | 106 | for name, param in model.get_model().pt_encoder_backbone.model.named_parameters(): 107 | all_param += param.numel() 108 | if param.requires_grad: 109 | print(f"Layer: {name}, Trainable: {param.requires_grad}") 110 | trainable_params += param.numel() 111 | print( 112 | f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}" 113 | ) 114 | 115 | 116 | def check_model_parameters(model_or_params, print_details=False): 117 | device_map = {} 118 | 119 | if hasattr(model_or_params, 'named_parameters'): 120 | param_iterator = model_or_params.named_parameters() 121 | elif isinstance(model_or_params, dict): 122 | param_iterator = model_or_params.items() 123 | else: 124 | raise ValueError("Input must be a model or a dictionary of parameters") 125 | 126 | for name, param in param_iterator: 127 | device = param.device 128 | shape = tuple(param.shape) 129 | device_str = str(device) 130 | 131 | if device_str not in device_map: 132 | device_map[device_str] = [] 133 | 134 | device_map[device_str].append((name, shape)) 135 | 136 | if print_details: 137 | print(f"Parameter: {name}") 138 | print(f" Shape: {shape}") 139 | print(f" Device: {device}") 140 | print("-" * 50) 141 | 142 | print("\nSummary:") 143 | for device, params in device_map.items(): 144 | print(f"\nDevice: {device}") 145 | print(f"Number of parameters: {len(params)}") 146 | if print_details: 147 | for name, shape in params: 148 | print(f" {name}: {shape}") 149 | 150 | return device_map 151 | 152 | 153 | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): 154 | state_dict = trainer.model.state_dict() 155 | cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} 156 | if trainer.args.should_save: 157 | trainer._save(output_dir, state_dict=cpu_state_dict) 158 | 159 | 160 | def train(): 161 | parser = transformers.HfArgumentParser( 162 | (ModelArguments, DataArguments, TrainingArguments)) 163 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 164 | 165 | training_args.log_level = "info" # * default is passive(warning) 166 | # * build logger 167 | # logger = build_logger(__name__, training_args.output_dir + '/train.log') 168 | 169 | logger.warning(f"Using device: {training_args.device}") 170 | 171 | if not training_args.stage_2: 172 | # stage 1 173 | logger.warning("Using model of Stage 1") 174 | model = SensorLLMStage1LlamaForCausalLM.from_pretrained( 175 | model_args.model_name_or_path, 176 | cache_dir=training_args.cache_dir, 177 | ) 178 | else: 179 | if model_args.model_type == "CasualLM": 180 | model = SensorLLMStage2LlamaForCausalLM.from_pretrained( 181 | model_args.model_name_or_path, 182 | cache_dir=training_args.cache_dir, 183 | ) 184 | logger.warning("Loaded CausalLM model.") 185 | else: 186 | logger.warning(f"Loading {data_args.dataset} dataset configs ...") 187 | with open('./sensorllm/model/ts_backbone.yaml', 'r') as file: 188 | dataset_configs = yaml.safe_load(file) 189 | 190 | dataset_config = dataset_configs[data_args.dataset] 191 | 192 | id2label = dataset_config["id2label"] 193 | print(f"Dataset id2label:\n{id2label}") 194 | label2id = {v: k for k, v in id2label.items()} 195 | assert training_args.num_labels == len(id2label) 196 | assert model_args.model_type == "SequenceClassification", f"Undefined model_type {model_args.model_type}" 197 | model = SensorLLMStage2LlamaForSequenceClassification.from_pretrained( 198 | model_args.model_name_or_path, 199 | num_labels=training_args.num_labels, 200 | id2label=id2label, 201 | label2id=label2id, 202 | cache_dir=training_args.cache_dir, 203 | ) 204 | logger.warning("Loaded SequenceClassification model.") 205 | 206 | model.config.use_cache = False 207 | 208 | print(f"Default pt_encoder_backbone_ckpt is {model_args.pt_encoder_backbone_ckpt}.") 209 | model.get_model().load_pt_encoder_backbone_checkpoint(model_args.pt_encoder_backbone_ckpt, 210 | tc=model_args.tokenize_method) 211 | pt_backbone_config = AutoConfig.from_pretrained(model_args.pt_encoder_backbone_ckpt) 212 | 213 | assert hasattr(pt_backbone_config, "chronos_config"), "Not a Chronos config file" 214 | 215 | chronos_config = ChronosConfig(**pt_backbone_config.chronos_config) 216 | chronos_config.tokenizer_class = model_args.tokenize_method 217 | chronos_tokenizer = chronos_config.create_tokenizer() 218 | 219 | if training_args.fix_llm: 220 | # * This will fix all the parameters 221 | model.requires_grad_(False) 222 | # * fix llama, lm_head 223 | model.get_model().fix_llm = True 224 | logger.warning("LLM is fixed. Fix_llm flag is set to True") 225 | model.get_model().ts_proj.requires_grad_(True) 226 | model.get_model().pt_encoder_backbone.requires_grad_(True) 227 | else: 228 | model.get_model().fix_llm = False 229 | logger.warning("LLM is trainable. Fix_llm flag is set to False") 230 | 231 | tokenizer = transformers.AutoTokenizer.from_pretrained( 232 | model_args.model_name_or_path, 233 | cache_dir=training_args.cache_dir, 234 | model_max_length=training_args.model_max_length, 235 | padding_side="right", 236 | use_fast=False, 237 | ) 238 | 239 | if not training_args.fix_ts_encoder: 240 | # * not fix pretrained ts encoder 241 | model.get_model().fix_ts_encoder = False 242 | logger.warning( 243 | "Pretrained TS backbone is trainable. fix_ts_encoder flag is set to False, ts net grad will be recorded.") 244 | else: 245 | model.get_model().fix_ts_encoder = True # * use with torch.inference_mode to control 246 | logger.warning( 247 | "Pretrained TS backbone is fixed. fix_ts_encoder flag is set to True, ts net grad will not be recorded.") 248 | 249 | logger.warning("Set requires_grad of Pretrained TS backbone to False") 250 | model.get_model().pt_encoder_backbone.requires_grad_( 251 | False) 252 | 253 | if training_args.tune_mm_mlp_adapter: 254 | # * not fix the projection layer 255 | # * may need to set the embed_tokens to require_grad = True if added new tokens 256 | # * this is done in initialize_tokenizer_ts_backbone_config 257 | logger.warning("Time-series Projector is trainable. ") 258 | else: 259 | model.get_model().ts_proj.requires_grad_(False) 260 | logger.warning("Time-series Projector is fixed.") 261 | 262 | if model_args.model_type == "SequenceClassification": 263 | if not training_args.fix_cls_head: 264 | model.score.requires_grad_(True) 265 | logger.warning("LLM classification head is trainable. ") 266 | else: 267 | model.score.requires_grad_(False) 268 | logger.warning("LLM classification head is fixed.") 269 | 270 | if not training_args.stage_2: 271 | # * we assume in stage2, llm, and time-series embedder (and projection layer) can be loaded from the model checkpoint 272 | model.initialize_tokenizer_ts_backbone_config(tokenizer=tokenizer, device=training_args.device, 273 | fix_llm=training_args.fix_llm, dataset=data_args.dataset) 274 | else: 275 | # * stage2 276 | if training_args.only_stage2: 277 | logger.warning("The loaded model haven't been trained in Stage 1. Initializing the LLM tokenizer with new tokens now...") 278 | model.initialize_tokenizer_ts_backbone_config(tokenizer=tokenizer, device=training_args.device, 279 | fix_llm=training_args.fix_llm, dataset=data_args.dataset) 280 | else: 281 | model.initialize_tokenizer_ts_backbone_config_wo_embedding(tokenizer=tokenizer, dataset=data_args.dataset) 282 | 283 | model.get_model().load_start_end_tokens(dataset=data_args.dataset) 284 | 285 | ts_backbone_config = model.get_model().ts_backbone_config 286 | data_args.ts_backbone_config = ts_backbone_config 287 | 288 | if not training_args.stage_2: 289 | logger.warning("Stage 1") 290 | data_module = make_ts_text_data_module(tokenizer=tokenizer, chronos_tokenizer=chronos_tokenizer, 291 | data_args=data_args) 292 | else: 293 | # * stage2 294 | if model_args.model_type == "CasualLM": 295 | data_module = make_ts_text_data_module_stage2(tokenizer=tokenizer, chronos_tokenizer=chronos_tokenizer, 296 | data_args=data_args) 297 | else: 298 | assert model_args.model_type == "SequenceClassification", f"Undefined model_type {model_args.model_type} for data_module" 299 | data_module = make_ts_classification_data_module_stage2(tokenizer=tokenizer, 300 | chronos_tokenizer=chronos_tokenizer, 301 | label2id=label2id, 302 | data_args=data_args) 303 | 304 | if model_args.model_type == "CasualLM": 305 | trainer = SensorLLMTrainer(model=model, 306 | args=training_args, 307 | tokenizer=tokenizer, 308 | **data_module) 309 | else: 310 | assert model_args.model_type == "SequenceClassification", f"Undefined model_type {model_args.model_type} for Trainer" 311 | 312 | metric_f1 = evaluate.load("../metrics/f1") 313 | metric_acc = evaluate.load("../metrics/accuracy") 314 | metric_precision = evaluate.load("../metrics/precision") 315 | metric_recall = evaluate.load("../metrics/recall") 316 | 317 | def compute_metrics(eval_pred): 318 | predictions, labels = eval_pred 319 | predictions = np.argmax(predictions, axis=1) 320 | 321 | accuracy = metric_acc.compute(predictions=predictions, references=labels)["accuracy"] 322 | 323 | precision_macro = metric_precision.compute(predictions=predictions, references=labels, average='macro')[ 324 | "precision"] 325 | recall_macro = metric_recall.compute(predictions=predictions, references=labels, average='macro')["recall"] 326 | f1_macro = metric_f1.compute(predictions=predictions, references=labels, average='macro')["f1"] 327 | 328 | f1_micro = metric_f1.compute(predictions=predictions, references=labels, average='micro')["f1"] 329 | 330 | precision_per_class = metric_precision.compute(predictions=predictions, references=labels, average=None)[ 331 | "precision"] 332 | recall_per_class = metric_recall.compute(predictions=predictions, references=labels, average=None)["recall"] 333 | f1_per_class = metric_f1.compute(predictions=predictions, references=labels, average=None)["f1"] 334 | 335 | results = { 336 | "f1_macro": f1_macro, 337 | "f1_micro": f1_micro, 338 | "accuracy": accuracy, 339 | "precision_macro": precision_macro, 340 | "recall_macro": recall_macro 341 | } 342 | 343 | for i, (p, r, f) in enumerate(zip(precision_per_class, recall_per_class, f1_per_class)): 344 | results[f"precision_class_{i}"] = p 345 | results[f"recall_class_{i}"] = r 346 | results[f"f1_class_{i}"] = f 347 | 348 | return results 349 | 350 | model.config.pad_token_id = tokenizer.pad_token_id 351 | if training_args.use_weighted_loss: 352 | logger.warning("Using weighted_loss trainer") 353 | trainer = SensorLLMWeightedCELossTrainer(model=model, 354 | args=training_args, 355 | tokenizer=tokenizer, 356 | compute_metrics=compute_metrics, 357 | **data_module) 358 | else: 359 | del data_module['class_weights'] 360 | print(data_module.keys) 361 | trainer = SensorLLMTrainer(model=model, 362 | args=training_args, 363 | tokenizer=tokenizer, 364 | compute_metrics=compute_metrics, 365 | **data_module) 366 | 367 | print_trainable_parameters(model) 368 | 369 | if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): 370 | trainer.train(resume_from_checkpoint=True) 371 | else: 372 | trainer.train() 373 | trainer.save_state() 374 | 375 | safe_save_model_for_hf_trainer(trainer=trainer, 376 | output_dir=training_args.output_dir) 377 | 378 | # eval_results = trainer.evaluate() 379 | # print("Evaluation results:", eval_results) 380 | 381 | 382 | if __name__ == "__main__": 383 | train() 384 | -------------------------------------------------------------------------------- /sensorllm/train/train_mem.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: 2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: 3 | # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. 4 | 5 | # Need to call this before importing transformers. 6 | from sensorllm.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn 7 | 8 | replace_llama_attn_with_flash_attn() 9 | 10 | from sensorllm.train.train import train 11 | 12 | if __name__ == "__main__": 13 | train() -------------------------------------------------------------------------------- /sensorllm/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import logging.handlers 3 | import os 4 | import sys 5 | 6 | import requests 7 | 8 | import yaml 9 | 10 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 11 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." 12 | 13 | handler = None 14 | 15 | 16 | class EasyDict(dict): 17 | def __init__(self, d=None, **kwargs): 18 | if d is None: 19 | d = {} 20 | if kwargs: 21 | d.update(**kwargs) 22 | for k, v in d.items(): 23 | setattr(self, k, v) 24 | # Class attributes 25 | for k in self.__class__.__dict__.keys(): 26 | if not (k.startswith('__') and k.endswith('__')) and not k in ('update', 'pop'): 27 | setattr(self, k, getattr(self, k)) 28 | 29 | def __setattr__(self, name, value): 30 | if isinstance(name, str): 31 | if isinstance(value, (list, tuple)): 32 | value = [self.__class__(x) 33 | if isinstance(x, dict) else x for x in value] 34 | elif isinstance(value, dict) and not isinstance(value, self.__class__): 35 | value = self.__class__(value) 36 | super(EasyDict, self).__setattr__(name, value) 37 | super(EasyDict, self).__setitem__(name, value) 38 | else: 39 | self.__setitem__(name, value) 40 | 41 | def __setitem__(self, name, value): 42 | 43 | if isinstance(value, (list, tuple)): 44 | value = [self.__class__(x) 45 | if isinstance(x, dict) else x for x in value] 46 | elif isinstance(value, dict) and not isinstance(value, self.__class__): 47 | value = self.__class__(value) 48 | super(EasyDict, self).__setitem__(name, value) 49 | 50 | def update(self, e=None, **f): 51 | d = e or dict() 52 | d.update(f) 53 | for k in d: 54 | setattr(self, k, d[k]) 55 | 56 | def pop(self, k, d=None): 57 | delattr(self, k) 58 | return super(EasyDict, self).pop(k, d) 59 | 60 | def merge_new_config(config, new_config): 61 | for key, val in new_config.items(): 62 | if not isinstance(val, dict): 63 | if key == '_base_': 64 | with open(new_config['_base_'], 'r') as f: 65 | try: 66 | val = yaml.load(f, Loader=yaml.FullLoader) 67 | except: 68 | val = yaml.load(f) 69 | config[key] = EasyDict() 70 | merge_new_config(config[key], val) 71 | else: 72 | config[key] = val 73 | continue 74 | if key not in config: 75 | config[key] = EasyDict() 76 | merge_new_config(config[key], val) 77 | return config 78 | 79 | def cfg_from_yaml_file(cfg_file): 80 | config = EasyDict() 81 | with open(cfg_file, 'r') as f: 82 | new_config = yaml.load(f, Loader=yaml.FullLoader) 83 | merge_new_config(config=config, new_config=new_config) 84 | return config 85 | 86 | def build_logger(logger_name, logger_filepath): 87 | global handler 88 | 89 | formatter = logging.Formatter( 90 | fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 91 | datefmt="%Y-%m-%d %H:%M:%S", 92 | ) 93 | 94 | # Set the format of root handlers 95 | if not logging.getLogger().handlers: 96 | logging.basicConfig(level=logging.INFO) 97 | else: 98 | logging.getLogger().handlers[0].setFormatter(formatter) 99 | 100 | # Redirect stdout and stderr to loggers 101 | stdout_logger = logging.getLogger("stdout") 102 | stdout_logger.setLevel(logging.INFO) 103 | sl_out = StreamToLogger(stdout_logger, logging.INFO) 104 | sys.stdout = sl_out 105 | 106 | stderr_logger = logging.getLogger("stderr") 107 | stderr_logger.setLevel(logging.ERROR) 108 | sl_err = StreamToLogger(stderr_logger, logging.ERROR) 109 | sys.stderr = sl_err 110 | 111 | # Get logger 112 | logger = logging.getLogger(logger_name) 113 | logger.setLevel(logging.INFO) 114 | 115 | # Add a file handler for all loggers 116 | if handler is None: 117 | # Get the logger_file's directory, and create it if it does not exist 118 | logger_filedir = os.path.dirname(logger_filepath) 119 | os.makedirs(logger_filedir, exist_ok=True) 120 | handler = logging.handlers.TimedRotatingFileHandler( 121 | logger_filepath, when='D', utc=True) 122 | handler.setFormatter(formatter) 123 | 124 | # Attach the handler to all existing loggers 125 | for name, item in logging.root.manager.loggerDict.items(): 126 | if isinstance(item, logging.Logger): 127 | item.addHandler(handler) 128 | 129 | return logger 130 | 131 | 132 | class StreamToLogger(object): 133 | """ 134 | Fake file-like stream object that redirects writes to a logger instance. 135 | """ 136 | def __init__(self, logger, log_level=logging.INFO): 137 | self.terminal = sys.stdout if log_level == logging.INFO else sys.stderr 138 | self.logger = logger 139 | self.log_level = log_level 140 | self.linebuf = '' 141 | 142 | def __getattr__(self, attr): 143 | return getattr(self.terminal, attr) 144 | 145 | def write(self, buf): 146 | temp_linebuf = self.linebuf + buf 147 | self.linebuf = '' 148 | for line in temp_linebuf.splitlines(True): 149 | if line.endswith('\n'): 150 | self.logger.log(self.log_level, line.rstrip()) 151 | else: 152 | self.linebuf += line 153 | 154 | def flush(self): 155 | if self.linebuf: 156 | self.logger.log(self.log_level, self.linebuf.rstrip()) 157 | self.linebuf = '' 158 | 159 | 160 | def disable_torch_init(): 161 | """ 162 | Disable the redundant torch default initialization to accelerate model creation. 163 | """ 164 | import torch 165 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 166 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 167 | 168 | 169 | def violates_moderation(text): 170 | """ 171 | Check whether the text violates OpenAI moderation API. 172 | """ 173 | url = "https://api.openai.com/v1/moderations" 174 | headers = {"Content-Type": "application/json", 175 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} 176 | text = text.replace("\n", "") 177 | data = "{" + '"input": ' + f'"{text}"' + "}" 178 | data = data.encode("utf-8") 179 | try: 180 | ret = requests.post(url, headers=headers, data=data, timeout=5) 181 | flagged = ret.json()["results"][0]["flagged"] 182 | except requests.exceptions.RequestException as e: 183 | flagged = False 184 | except KeyError as e: 185 | flagged = False 186 | 187 | return flagged 188 | 189 | 190 | def pretty_print_semaphore(semaphore): 191 | if semaphore is None: 192 | return "None" 193 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" --------------------------------------------------------------------------------