17 |
18 |
19 | ## 🌟 Overview
20 |
21 | **SensorLLM** is a two-stage framework that aligns sensor time series with human-intuitive text, enabling LLMs to interpret complex numerical data and achieve SOTA human activity recognition across varying sensor types, counts, and sequence lengths.
22 |
23 |
24 |
25 |
26 | ### 🔑 Key Features
27 | - Aligns sensor time-series with ***human-intuitive, annotation-free*** textual trend descriptions and summaries via a QA-based framework.
28 | - ***Sensor–Language Alignment Stage*** operates on single-channel, variable-length segments for fine-grained trend-text alignment.
29 | - ***Task-Aware Tuning Stage*** handles multi-channel, multi-sensor data for downstream human activity recognition (HAR).
30 |
31 | ### 📂 Datasets
32 | The current implementation supports five HAR datasets: [`USC-HAD`](https://sipi.usc.edu/had/), [`UCI-HAR`](https://archive.ics.uci.edu/dataset/240/human+activity+recognition+using+smartphones), [`MHealth`](https://archive.ics.uci.edu/dataset/319/mhealth+dataset), [`Capture-24`](https://ora.ox.ac.uk/objects/uuid:99d7c092-d865-4a19-b096-cc16440cd001), and [`PAMAP2`](https://archive.ics.uci.edu/dataset/231/pamap2+physical+activity+monitoring).
33 |
34 | To apply SensorLLM to other datasets, please refer to the code and configuration examples provided for the supported datasets. In particular, you may need to modify the corresponding entries in [`ts_backbone.yaml`](./sensorllm/model/ts_backbone.yaml) and adapt the data loading logic in the [`./sensorllm/data`](./sensorllm/data) folder to match your dataset’s format.
35 |
36 |
37 | ## 🚀 Getting started
38 |
39 | > Currently supported pretrained models:
40 | > - Time-series models: [Chronos](https://arxiv.org/abs/2403.07815)
41 | > - Language models: [LLaMA](https://arxiv.org/abs/2407.21783)
42 | >
43 | > Other pretrained models **can be used with minor modifications to the SensorLLM framework**.
44 |
45 |
46 | ### Sensor-Language QA Pairs Generation
47 | We provide two example notebooks to generate QA pairs for aligning sensor time-series data with human-intuitive text:
48 | - [`mhealth_stage1.ipynb`](./mhealth_stage1.ipynb): Generates QA pairs for Stage 1 by aligning single-channel sensor segments with trend-based natural language descriptions.
49 | - [`mhealth_stage2.ipynb`](./mhealth_stage1.ipynb): Generates statistical information text for Stage 2, performing HAR classification using multi-channel sensor data.
50 |
51 | You can also customize or extend the QA templates in these notebooks to generate more diverse types of sensor–language QA pairs for your own use cases.
52 |
53 | ### Sensor–Language Alignment
54 | To align sensor time-series data with text, run the following command:
55 |
56 | ```bash
57 | torchrun --nproc_per_node=[NUM_GPUS] sensorllm/train/train_mem.py \
58 | --model_name_or_path [LLM_PATH] \
59 | --pt_encoder_backbone_ckpt [TS_EMBEDDER_PATH] \
60 | --tokenize_method 'StanNormalizeUniformBins' \
61 | --dataset [DATASET_NAME] \
62 | --data_path [TS_TRAIN_PATH] \
63 | --eval_data_path [TS_EVAL_PATH] \
64 | --qa_path [QA_TRAIN_PATH] \
65 | --eval_qa_path [QA_EVAL_PATH] \
66 | --output_dir [OUTPUT_PATH] \
67 | --model_max_length [MAX_LEN] \
68 | --num_train_epochs [EPOCH] \
69 | --per_device_train_batch_size [TRAIN_BATCH] \
70 | --per_device_eval_batch_size [EVAL_BATCH] \
71 | --evaluation_strategy "steps" \
72 | --save_strategy "steps" \
73 | --save_steps [SAVE_STEPS] \
74 | --eval_steps [EVAL_STEPS] \
75 | --learning_rate 2e-3 \
76 | --weight_decay 0.0 \
77 | --warmup_ratio 0.03 \
78 | --lr_scheduler_type "cosine" \
79 | --logging_steps 1 \
80 | --gradient_checkpointing True \
81 | --save_total_limit 1 \
82 | --bf16 True \
83 | --fix_llm True \
84 | --fix_ts_encoder True \
85 | --model_type CasualLM \
86 | --load_best_model_at_end True
87 | ```
88 |
89 | ### Evaluation or Inference
90 | To perform evaluation or inference for the Sensor–Language Alignment stage, run the following command:
91 |
92 | ```bash
93 | python sensorllm/eval/eval.py \
94 | --model_name_or_path [STAGE1_MODEL_PATH] \
95 | --pt_encoder_backbone_ckpt [TS_EMBEDDER_PATH] \
96 | --torch_dtype bfloat16 \
97 | --tokenize_method 'StanNormalizeUniformBins' \
98 | --dataset [DATASET_NAME] \
99 | --data_path [TS_DATASET_PATH] \
100 | --qa_path [QA_DATASET_PATH] \
101 | --output_file_name [OUTPUT_FILE_NAME] \
102 | --model_max_length [MAX_LEN] \
103 | --shuffle False
104 | ```
105 |
106 | ### Task-Aware Tuning
107 | To perform a HAR task, use the following command:
108 | ```bash
109 | torchrun --nproc_per_node=[NUM_GPUS] sensorllm/train/train_mem.py \
110 | --model_name_or_path [STAGE1_MODEL_PATH] \
111 | --pt_encoder_backbone_ckpt [TS_EMBEDDER_PATH] \
112 | --model_type "SequenceClassification" \
113 | --num_labels [ACTIVITY_NUM] \
114 | --use_weighted_loss True \
115 | --tokenize_method 'StanNormalizeUniformBins' \
116 | --dataset [DATASET_NAME] \
117 | --data_path [TS_TRAIN_PATH] \
118 | --eval_data_path [TS_EVAL_PATH] \
119 | --qa_path [QA_TRAIN_PATH] \
120 | --eval_qa_path [QA_EVAL_PATH] \
121 | --output_dir [OUTPUT_PATH] \
122 | --model_max_length [MAX_LEN] \
123 | --num_train_epochs [EPOCH] \
124 | --num_train_epochs [EPOCH] \
125 | --per_device_train_batch_size [TRAIN_BATCH] \
126 | --per_device_eval_batch_size [EVAL_BATCH] \
127 | --evaluation_strategy "steps" \
128 | --save_strategy "steps" \
129 | --save_steps [SAVE_STEPS] \
130 | --eval_steps [EVAL_STEPS] \
131 | --save_total_limit 1 \
132 | --load_best_model_at_end True \
133 | --learning_rate 2e-3 \
134 | --weight_decay 0.0 \
135 | --warmup_ratio 0.03 \
136 | --lr_scheduler_type "cosine" \
137 | --logging_steps 1 \
138 | --bf16 True \
139 | --fix_llm True \
140 | --fix_cls_head False \
141 | --fix_ts_encoder True \
142 | --gradient_checkpointing True \
143 | --metric_for_best_model "f1_macro" \
144 | --preprocess_type "smry+Q" \
145 | --greater_is_better True \
146 | --stage_2 True \
147 | --shuffle True
148 | ```
149 | See [`./sensorllm/data/utils.py`](./sensorllm/data/utils.py) for all available preprocess_type options or to make edits.
150 |
151 | ## 🌍 Citation
152 |
153 | If you find this repository useful for your research, please cite our paper:
154 |
155 | ```
156 | @misc{li2025sensorllm,
157 | title={SensorLLM: Human-Intuitive Alignment of Multivariate Sensor Data with LLMs for Activity Recognition},
158 | author={Zechen Li and Shohreh Deldari and Linyao Chen and Hao Xue and Flora D. Salim},
159 | year={2025},
160 | eprint={2410.10624},
161 | archivePrefix={arXiv},
162 | primaryClass={cs.CL},
163 | url={https://arxiv.org/abs/2410.10624},
164 | }
165 | ```
166 |
167 | ## 📄 License
168 |
169 |
170 |
171 | This work is under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.
172 |
173 |
174 | ## 📩 Contact
175 |
176 | If you have any questions or suggestions, feel free to contact Zechen at `zechen.li(at)unsw(dot)edu(dot)au`.
177 |
--------------------------------------------------------------------------------
/sensorllm/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zechenli03/SensorLLM/1f4142a30f452721e943771190fe1dade3337249/sensorllm/__init__.py
--------------------------------------------------------------------------------
/sensorllm/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .stage1_dataset import make_ts_text_data_module, UniChannelTimeSeriesDataset
2 | from .stage2_dataset import make_ts_text_data_module_stage2, MultiChannelTimeSeriesDatasetStage2, make_ts_classification_data_module_stage2
--------------------------------------------------------------------------------
/sensorllm/data/stage1_dataset.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Sequence
2 | from torch.utils.data import Dataset
3 | import numpy as np
4 | import random
5 | import copy
6 | import json
7 | import pickle
8 | import logging
9 | from dataclasses import dataclass
10 |
11 | from sensorllm.data.utils import generate_chat_template, preprocess, get_token_dict
12 | from sensorllm.model.chronos_model import *
13 |
14 | import transformers
15 |
16 | import torch
17 |
18 | IGNORE_INDEX = -100
19 | RDM_SEED = 42
20 |
21 |
22 | def preprocess_time_series2(
23 | sources: Sequence[Dict[str, str]], # [{"Q": "...", "A": "...", "type": ...}]
24 | channel_names: Sequence[str],
25 | ts_list: list,
26 | dataset: str,
27 | data_args: dict,
28 | ) -> Sequence[Dict[str, str]]:
29 | # assert len(sources) == 6
30 | ts_token = data_args["default_ts_token"]
31 | modified_sources = []
32 | start_tokens_dict, end_tokens_dict = get_token_dict(dataset, data_args)
33 |
34 | for source, channel_name, ts in zip(sources, channel_names, ts_list):
35 | if data_args["last_token"]:
36 | added_token = ts_token * (len(ts) + 1)
37 | else:
38 | added_token = ts_token * len(ts)
39 | assert channel_name in list(start_tokens_dict.keys()), f"Start token {channel_name} not found"
40 | assert channel_name in list(end_tokens_dict.keys()), f"End token {channel_name} not found"
41 | start_token = start_tokens_dict[channel_name]
42 | end_token = end_tokens_dict[channel_name]
43 | modified_q = start_token + added_token + end_token + source["Q"]
44 | modified_a = source["A"]+"\n\n"+source["summary"]["A"]
45 | modified_sources.append({"Q": modified_q, "A": modified_a, "type": source["type"]})
46 | return modified_sources
47 |
48 |
49 | class UniChannelTimeSeriesDataset(Dataset):
50 | def __init__(self, data_path=None, qa_path=None, tokenizer=None, chronos_tokenizer=None, split=None, data_args=None):
51 | """
52 | data_path: a tensor of shape (N, C, L) where N is the number of multichannel time-series samples,
53 | C is the number of channels (6), and L is the sequence length (200).
54 | qa_path: a list of QA texts corresponding to each channel of each sample.
55 | """
56 |
57 | super(UniChannelTimeSeriesDataset, self).__init__()
58 | self.data_path = data_path
59 | self.qa_path = qa_path
60 | self.tokenizer = tokenizer
61 | self.chronos_tokenizer = chronos_tokenizer
62 | self.split = split
63 |
64 | ignore_qa_types = data_args.ignore_qa_types
65 | self.dataset = data_args.dataset
66 |
67 | shuffle = data_args.shuffle
68 | self.data_args = data_args.ts_backbone_config[self.dataset]
69 | self.data_args["default_ts_token"] = data_args.ts_backbone_config["default_ts_token"]
70 | self.data_args["last_token"] = data_args.ts_backbone_config["chronos_model"]["last_token"]
71 |
72 | self.SYS_INST = f"A dialogue between a curious researcher and an AI assistant. The AI analyzes a sensor time-series dataset (N points, {self.data_args['sample_rate']}Hz sampling rate) to answer specific questions. This interaction demonstrates the AI's data analysis skills and the potential of human-AI collaboration in interpreting complex data."
73 | print(f"INSTRUCTION Template: {self.SYS_INST}")
74 | self.ts_data, self.list_data_dict, self.channel_list = self._flatten_data(ignore_qa_types, shuffle)
75 |
76 | print(
77 | f"The dataset size is: {len(self.list_data_dict)}."
78 | )
79 |
80 | def _flatten_data(self, ignore_qa_types: list, shuffle: bool):
81 | logging.warning("Loading data...")
82 | with open(self.data_path, "rb") as f:
83 | data_file = pickle.load(f)
84 | with open(self.qa_path, "r") as file:
85 | qa_file = json.load(file)
86 | qa_dict = []
87 | ts_data = []
88 | channel_list = []
89 | for d in qa_file["dataset"]:
90 | data_idx = d["index"]
91 | data = data_file[int(data_idx)]
92 | if self.dataset in ["usc-had", "uci"]:
93 | for x_acc, y_acc, z_acc, x_g, y_g, z_g in zip(
94 | d["qa_pairs"]["x-axis accelerometer"],
95 | d["qa_pairs"]["y-axis accelerometer"],
96 | d["qa_pairs"]["z-axis accelerometer"],
97 | d["qa_pairs"]["x-axis gyroscope"],
98 | d["qa_pairs"]["y-axis gyroscope"],
99 | d["qa_pairs"]["z-axis gyroscope"],
100 | ):
101 | assert x_acc["type"] == y_acc["type"] == z_acc["type"] == x_g["type"] == y_g["type"] == z_g[
102 | "type"], "QA type values error"
103 | # if x_acc["type"] not in ["sub_trend_no_val", "trend_table"]:
104 | if x_acc["type"] not in ignore_qa_types:
105 | x_acc["summary"] = d['summaries']["x-axis accelerometer"]
106 | y_acc["summary"] = d['summaries']["y-axis accelerometer"]
107 | z_acc["summary"] = d['summaries']["z-axis accelerometer"]
108 | x_g["summary"] = d['summaries']["x-axis gyroscope"]
109 | y_g["summary"] = d['summaries']["y-axis gyroscope"]
110 | z_g["summary"] = d['summaries']["z-axis gyroscope"]
111 | qa_dict.append([x_acc, y_acc, z_acc, x_g, y_g, z_g])
112 | ts_data.append(
113 | [torch.from_numpy(data[:, i]).to(torch.float64) for i in range(data.shape[1])]
114 | )
115 | channel_list.extend(["x_acc", "y_acc", "z_acc", "x_g", "y_g", "z_g"])
116 | elif self.dataset == "capture24":
117 | for x_acc, y_acc, z_acc in zip(
118 | d["qa_pairs"]["x-axis accelerometer"],
119 | d["qa_pairs"]["y-axis accelerometer"],
120 | d["qa_pairs"]["z-axis accelerometer"]
121 | ):
122 | assert x_acc["type"] == y_acc["type"] == z_acc["type"], "QA type values error"
123 | if x_acc["type"] not in ignore_qa_types:
124 | x_acc["summary"] = d['summaries']["x-axis accelerometer"]
125 | y_acc["summary"] = d['summaries']["y-axis accelerometer"]
126 | z_acc["summary"] = d['summaries']["z-axis accelerometer"]
127 | qa_dict.append([x_acc, y_acc, z_acc])
128 | ts_data.append(
129 | [torch.from_numpy(data[:, i]).to(torch.float64) for i in range(data.shape[1])]
130 | )
131 | channel_list.extend(["x_acc", "y_acc", "z_acc"])
132 | elif self.dataset == "mhealth":
133 | for c_acc_x, c_acc_y, c_acc_z, la_acc_x, la_acc_y, la_acc_z, la_gs_x, la_gs_y, la_gs_z, rla_acc_x, rla_acc_y, rla_acc_z, rla_gs_x, rla_gs_y, rla_gs_z in zip(
134 | d["qa_pairs"]["chest x-axis accelerometer"], d["qa_pairs"]["chest y-axis accelerometer"], d["qa_pairs"]["chest z-axis accelerometer"],
135 | d["qa_pairs"]["left-ankle x-axis accelerometer"], d["qa_pairs"]["left-ankle y-axis accelerometer"], d["qa_pairs"]["left-ankle z-axis accelerometer"],
136 | d["qa_pairs"]["left-ankle x-axis gyroscope"], d["qa_pairs"]["left-ankle y-axis gyroscope"], d["qa_pairs"]["left-ankle z-axis gyroscope"],
137 | d["qa_pairs"]["right-lower-arm x-axis accelerometer"], d["qa_pairs"]["right-lower-arm y-axis accelerometer"], d["qa_pairs"]["right-lower-arm z-axis accelerometer"],
138 | d["qa_pairs"]["right-lower-arm x-axis gyroscope"], d["qa_pairs"]["right-lower-arm y-axis gyroscope"], d["qa_pairs"]["right-lower-arm z-axis gyroscope"]
139 | ):
140 | assert c_acc_x["type"] == c_acc_y["type"] == c_acc_z["type"] == la_acc_x["type"] == la_acc_y[
141 | "type"] == la_acc_z["type"] == la_gs_x["type"] == la_gs_y["type"] == la_gs_z[
142 | "type"] == rla_acc_x["type"] == rla_acc_y["type"] == rla_acc_z["type"] == rla_gs_x[
143 | "type"] == rla_gs_y["type"] == rla_gs_z["type"], "QA type values error"
144 |
145 | if c_acc_x["type"] not in ignore_qa_types:
146 | c_acc_x["summary"] = d['summaries']["chest x-axis accelerometer"]
147 | c_acc_y["summary"] = d['summaries']["chest y-axis accelerometer"]
148 | c_acc_z["summary"] = d['summaries']["chest z-axis accelerometer"]
149 | la_acc_x["summary"] = d['summaries']["left-ankle x-axis accelerometer"]
150 | la_acc_y["summary"] = d['summaries']["left-ankle y-axis accelerometer"]
151 | la_acc_z["summary"] = d['summaries']["left-ankle x-axis accelerometer"]
152 | la_gs_x["summary"] = d['summaries']["left-ankle x-axis gyroscope"]
153 | la_gs_y["summary"] = d['summaries']["left-ankle y-axis gyroscope"]
154 | la_gs_z["summary"] = d['summaries']["left-ankle z-axis gyroscope"]
155 | rla_acc_x["summary"] = d['summaries']["right-lower-arm x-axis accelerometer"]
156 | rla_acc_y["summary"] = d['summaries']["right-lower-arm y-axis accelerometer"]
157 | rla_acc_z["summary"] = d['summaries']["right-lower-arm z-axis accelerometer"]
158 | rla_gs_x["summary"] = d['summaries']["right-lower-arm x-axis gyroscope"]
159 | rla_gs_y["summary"] = d['summaries']["right-lower-arm y-axis gyroscope"]
160 | rla_gs_z["summary"] = d['summaries']["right-lower-arm z-axis gyroscope"]
161 | qa_dict.append([c_acc_x, c_acc_y, c_acc_z, la_acc_x, la_acc_y, la_acc_z, la_gs_x, la_gs_y, la_gs_z, rla_acc_x, rla_acc_y, rla_acc_z, rla_gs_x, rla_gs_y, rla_gs_z])
162 | ts_data.append(
163 | [torch.from_numpy(data[:, i]).to(torch.float64) for i in range(data.shape[1])]
164 | )
165 | channel_list.extend(["c_acc_x", "c_acc_y", "c_acc_z", "la_acc_x", "la_acc_y", "la_acc_z", "la_gs_x", "la_gs_y", "la_gs_z", "rla_acc_x", "rla_acc_y", "rla_acc_z", "rla_gs_x", "rla_gs_y", "rla_gs_z"])
166 | elif self.dataset == "pamap" or self.dataset == "pamap50":
167 | for acc_hand_x, acc_hand_y, acc_hand_z, gyr_hand_x, gyr_hand_y, gyr_hand_z, mag_hand_x, mag_hand_y, \
168 | mag_hand_z, acc_chest_x, acc_chest_y, acc_chest_z, gyr_chest_x, gyr_chest_y, gyr_chest_z, \
169 | mag_chest_x, mag_chest_y, mag_chest_z, acc_ankle_x, acc_ankle_y, acc_ankle_z, gyr_ankle_x, \
170 | gyr_ankle_y, gyr_ankle_z, mag_ankle_x, mag_ankle_y, mag_ankle_z in zip(
171 | d["qa_pairs"]["hand x-axis accelerometer"], d["qa_pairs"]["hand y-axis accelerometer"], d["qa_pairs"]["hand z-axis accelerometer"], d["qa_pairs"]["hand x-axis gyroscope"], d["qa_pairs"]["hand y-axis gyroscope"],
172 | d["qa_pairs"]["hand z-axis gyroscope"],d["qa_pairs"]["hand x-axis magnetometer"], d["qa_pairs"]["hand y-axis magnetometer"], d["qa_pairs"]["hand z-axis magnetometer"],d["qa_pairs"]["chest x-axis accelerometer"],
173 | d["qa_pairs"]["chest y-axis accelerometer"], d["qa_pairs"]["chest z-axis accelerometer"],d["qa_pairs"]["chest x-axis gyroscope"], d["qa_pairs"]["chest y-axis gyroscope"], d["qa_pairs"]["chest z-axis gyroscope"],
174 | d["qa_pairs"]["chest x-axis magnetometer"], d["qa_pairs"]["chest y-axis magnetometer"], d["qa_pairs"]["chest z-axis magnetometer"],d["qa_pairs"]["ankle x-axis accelerometer"], d["qa_pairs"]["ankle y-axis accelerometer"],
175 | d["qa_pairs"]["ankle z-axis accelerometer"],d["qa_pairs"]["ankle x-axis gyroscope"], d["qa_pairs"]["ankle y-axis gyroscope"], d["qa_pairs"]["ankle z-axis gyroscope"],d["qa_pairs"]["ankle x-axis magnetometer"],
176 | d["qa_pairs"]["ankle y-axis magnetometer"], d["qa_pairs"]["ankle z-axis magnetometer"]
177 |
178 | ):
179 | assert acc_hand_x["type"] == acc_hand_y["type"] == acc_hand_z["type"] == gyr_hand_x["type"] == \
180 | gyr_hand_y[
181 | "type"] == gyr_hand_z["type"] == mag_hand_x["type"] == mag_hand_y["type"] == mag_hand_z[
182 | "type"] == acc_chest_x["type"] == acc_chest_y["type"] == acc_chest_z["type"] == \
183 | gyr_chest_x[
184 | "type"] == gyr_chest_y["type"] == gyr_chest_z["type"] == mag_chest_x["type"] == \
185 | mag_chest_y[
186 | "type"] == mag_chest_z["type"] == acc_ankle_x["type"] == acc_ankle_y["type"] == \
187 | acc_ankle_z[
188 | "type"] == gyr_ankle_x["type"] == gyr_ankle_y["type"] == gyr_ankle_z["type"] == \
189 | mag_ankle_x[
190 | "type"] == mag_ankle_y["type"] == mag_ankle_z["type"], "QA type values error"
191 |
192 | if acc_hand_x["type"] not in ignore_qa_types:
193 | acc_hand_x["summary"] = d['summaries']["hand x-axis accelerometer"]
194 | acc_hand_y["summary"] = d['summaries']["hand y-axis accelerometer"]
195 | acc_hand_z["summary"] = d['summaries']["hand z-axis accelerometer"]
196 | gyr_hand_x["summary"] = d['summaries']["hand x-axis gyroscope"]
197 | gyr_hand_y["summary"] = d['summaries']["hand y-axis gyroscope"]
198 | gyr_hand_z["summary"] = d['summaries']["hand z-axis gyroscope"]
199 | mag_hand_x["summary"] = d['summaries']["hand x-axis magnetometer"]
200 | mag_hand_y["summary"] = d['summaries']["hand y-axis magnetometer"]
201 | mag_hand_z["summary"] = d['summaries']["hand z-axis magnetometer"]
202 |
203 | acc_chest_x["summary"] = d['summaries']["chest x-axis accelerometer"]
204 | acc_chest_y["summary"] = d['summaries']["chest y-axis accelerometer"]
205 | acc_chest_z["summary"] = d['summaries']["chest z-axis accelerometer"]
206 | gyr_chest_x["summary"] = d['summaries']["chest x-axis gyroscope"]
207 | gyr_chest_y["summary"] = d['summaries']["chest y-axis gyroscope"]
208 | gyr_chest_z["summary"] = d['summaries']["chest z-axis gyroscope"]
209 | mag_chest_x["summary"] = d['summaries']["chest x-axis magnetometer"]
210 | mag_chest_y["summary"] = d['summaries']["chest y-axis magnetometer"]
211 | mag_chest_z["summary"] = d['summaries']["chest z-axis magnetometer"]
212 |
213 | acc_ankle_x["summary"] = d['summaries']["ankle x-axis accelerometer"]
214 | acc_ankle_y["summary"] = d['summaries']["ankle y-axis accelerometer"]
215 | acc_ankle_z["summary"] = d['summaries']["ankle z-axis accelerometer"]
216 | gyr_ankle_x["summary"] = d['summaries']["ankle x-axis gyroscope"]
217 | gyr_ankle_y["summary"] = d['summaries']["ankle y-axis gyroscope"]
218 | gyr_ankle_z["summary"] = d['summaries']["ankle z-axis gyroscope"]
219 | mag_ankle_x["summary"] = d['summaries']["ankle x-axis magnetometer"]
220 | mag_ankle_y["summary"] = d['summaries']["ankle y-axis magnetometer"]
221 | mag_ankle_z["summary"] = d['summaries']["ankle z-axis magnetometer"]
222 |
223 | qa_dict.append(
224 | [acc_hand_x, acc_hand_y, acc_hand_z, gyr_hand_x, gyr_hand_y, gyr_hand_z, mag_hand_x,
225 | mag_hand_y, mag_hand_z,
226 | acc_chest_x, acc_chest_y, acc_chest_z, gyr_chest_x, gyr_chest_y, gyr_chest_z, mag_chest_x,
227 | mag_chest_y, mag_chest_z,
228 | acc_ankle_x, acc_ankle_y, acc_ankle_z, gyr_ankle_x, gyr_ankle_y, gyr_ankle_z, mag_ankle_x,
229 | mag_ankle_y, mag_ankle_z])
230 | ts_data.append(
231 | [torch.from_numpy(data[:, i]).to(torch.float64) for i in range(data.shape[1])]
232 | )
233 | channel_list.extend([
234 | "acc_hand_x", "acc_hand_y", "acc_hand_z",
235 | "gyr_hand_x", "gyr_hand_y", "gyr_hand_z",
236 | "mag_hand_x", "mag_hand_y", "mag_hand_z",
237 | "acc_chest_x", "acc_chest_y", "acc_chest_z",
238 | "gyr_chest_x", "gyr_chest_y", "gyr_chest_z",
239 | "mag_chest_x", "mag_chest_y", "mag_chest_z",
240 | "acc_ankle_x", "acc_ankle_y", "acc_ankle_z",
241 | "gyr_ankle_x", "gyr_ankle_y", "gyr_ankle_z",
242 | "mag_ankle_x", "mag_ankle_y", "mag_ankle_z"
243 | ])
244 | else:
245 | raise ValueError(f"Wrong dataset name in _flatten_data: {self.dataset}")
246 |
247 |
248 | assert len(ts_data) == len(qa_dict), "ts_data, qa_dict shape mismatched"
249 |
250 | if shuffle:
251 | print("Shuffling data...")
252 | random.seed(RDM_SEED)
253 | indexes = list(range(len(qa_dict)))
254 | random.shuffle(indexes)
255 |
256 | qa_dict = [qa_dict[i] for i in indexes]
257 | ts_data = [ts_data[i] for i in indexes]
258 |
259 | qa_dict = [item for sublist in qa_dict for item in sublist]
260 | ts_data = [item for sublist in ts_data for item in sublist]
261 |
262 | return ts_data, qa_dict, channel_list
263 |
264 | def __len__(self):
265 | return len(self.list_data_dict)
266 |
267 | def __getitem__(self, index):
268 | sources = self.list_data_dict[index]
269 | channels = self.channel_list[index]
270 | ts = self.ts_data[index] # 1 * L
271 | if isinstance(index, int):
272 | sources = [sources]
273 | channels = [channels]
274 | ts = [ts]
275 | assert len(sources) == 1, "sources should be a list"
276 |
277 | sources = preprocess_time_series2(
278 | copy.deepcopy(sources), copy.deepcopy(channels), copy.deepcopy(ts), self.dataset, self.data_args
279 | )
280 |
281 | ts_token_ids_list = []
282 | ts_attention_mask_list = []
283 | ts_tokenizer_state_list = []
284 | for context in ts:
285 | if isinstance(context, list):
286 | context = left_pad_and_stack_1D(context)
287 | assert isinstance(context, torch.Tensor)
288 | if context.ndim == 1:
289 | context = context.unsqueeze(0)
290 | assert context.ndim == 2
291 |
292 | ts_token_ids, ts_attention_mask, ts_tokenizer_state = (
293 | self.chronos_tokenizer.context_input_transform(context)
294 | )
295 | ts_token_ids_list.append(ts_token_ids)
296 | ts_attention_mask_list.append(ts_attention_mask)
297 | ts_tokenizer_state_list.append(ts_tokenizer_state)
298 |
299 | if self.tokenizer is None:
300 | data_dict = dict(
301 | question=sources[0]["Q"],
302 | ground_truth=sources[0]["A"],
303 | type=sources[0]["type"],
304 | ts_token_ids=ts_token_ids_list[0],
305 | ts_attention_mask=ts_attention_mask_list[0],
306 | ts_tokenizer_state=ts_tokenizer_state_list[0]
307 | )
308 | return data_dict
309 |
310 | data_dict = preprocess(sources, self.tokenizer, self.SYS_INST, self.split, "Q", "Q")
311 |
312 | data_dict = dict(input_ids=data_dict["input_ids"][0],
313 | input_texts=sources[0]["Q"],
314 | answer=sources[0]["A"],
315 | labels=data_dict["labels"][0],
316 | ts_token_ids=ts_token_ids_list[0],
317 | ts_attention_mask=ts_attention_mask_list[0],
318 | ts_tokenizer_state=ts_tokenizer_state_list[0])
319 |
320 | return data_dict
321 |
322 |
323 | @dataclass
324 | class DataCollatorForTsTextDataset(object):
325 | """Collate examples for supervised fine-tuning."""
326 |
327 | tokenizer: transformers.PreTrainedTokenizer
328 |
329 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
330 | input_ids, labels, ts_token_ids, ts_attention_mask, ts_tokenizer_state = tuple(
331 | [instance[key] for instance in instances]
332 | for key in ("input_ids", "labels", "ts_token_ids", "ts_attention_mask", "ts_tokenizer_state")
333 | )
334 |
335 | input_ids = torch.nn.utils.rnn.pad_sequence(
336 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
337 | )
338 | labels = torch.nn.utils.rnn.pad_sequence(
339 | labels, batch_first=True, padding_value=IGNORE_INDEX
340 | )
341 |
342 | return dict(
343 | input_ids=input_ids,
344 | labels=labels,
345 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
346 | ts_token_ids=ts_token_ids, # return as list
347 | ts_attention_mask=ts_attention_mask,
348 | ts_tokenizer_state=ts_tokenizer_state
349 | )
350 |
351 |
352 | def make_ts_text_data_module(
353 | tokenizer: transformers.PreTrainedTokenizer, chronos_tokenizer, data_args
354 | ) -> Dict:
355 | """Make dataset and collator for supervised fine-tuning."""
356 | data_collator = DataCollatorForTsTextDataset(tokenizer=tokenizer)
357 | train_dataset = UniChannelTimeSeriesDataset(
358 | data_path=data_args.data_path, qa_path=data_args.qa_path, tokenizer=tokenizer, chronos_tokenizer=chronos_tokenizer, split="train", data_args=data_args
359 | )
360 | eval_dataset = UniChannelTimeSeriesDataset(
361 | data_path=data_args.eval_data_path, qa_path=data_args.eval_qa_path, tokenizer=tokenizer, chronos_tokenizer=chronos_tokenizer, split="train", data_args=data_args
362 | )
363 | for i in range(2):
364 | print(f"data example {i}:\nInput Text: {train_dataset[i]['input_texts']}\nAnswer: {train_dataset[i]['answer']}\nInput ids: {train_dataset[i]['input_ids']}\nTS token ids: {train_dataset[i]['ts_token_ids']}\n")
365 |
366 | return dict(
367 | train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=data_collator
368 | )
369 |
--------------------------------------------------------------------------------
/sensorllm/data/stage2_dataset.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Sequence
2 | from torch.utils.data import Dataset
3 | import numpy as np
4 | import pandas as pd
5 | import random
6 | import copy
7 | import json
8 | import re
9 | import pickle
10 | import logging
11 | from dataclasses import dataclass
12 |
13 | from sensorllm.data.utils import generate_chat_template, preprocess, preprocess_cls, get_token_list
14 | from sensorllm.model.chronos_model import *
15 | import transformers
16 |
17 | import torch
18 |
19 | IGNORE_INDEX = -100
20 | RDM_SEED = 42
21 |
22 |
23 | def preprocess_time_series_stage2(
24 | sources: Sequence[Dict[str, str]], # [{"Q": "...", "A": "..."}]
25 | added_str: str,
26 | ) -> Sequence[Dict[str, str]]:
27 | modified_sources = []
28 | pattern = r'\b\d+\.\s+[A-Za-z\s]+\.'
29 | for index, source in enumerate(sources):
30 | modified_q = added_str + source["Q"]
31 |
32 | matches = re.findall(pattern, source["A"])
33 | assert len(matches) == 1
34 | cot = source["A"].replace(matches[-1], '')
35 |
36 | modified_sources.append({"Q": modified_q, "A": source["A"], "cot": cot.strip(), "ground_truth": matches[-1]})
37 | return modified_sources
38 |
39 |
40 | def preprocess_time_series_CLS_stage2(
41 | sources: Sequence[Dict[str, str]], # [{"Q": "...", "A": "..."}]
42 | ) -> Sequence[Dict[str, str]]:
43 | modified_sources = []
44 | for index, source in enumerate(sources):
45 | modified_sources.append({
46 | "Q": source.get("Q", ""),
47 | "smry": source.get("smry", ""),
48 | "trend_text": source.get("trend_text", ""),
49 | "corr_text": source.get("corr_text", ""),
50 | "info_text": source.get("info_text", ""),
51 | "answer": source.get("A", ""),
52 | "label": source.get("label", "")
53 | })
54 | return modified_sources
55 |
56 |
57 | class MultiChannelTimeSeriesDatasetStage2(Dataset):
58 | def __init__(self, data_path=None, qa_path=None, tokenizer=None, chronos_tokenizer=None, split=None, data_args=None):
59 | super(MultiChannelTimeSeriesDatasetStage2, self).__init__()
60 | self.data_path = data_path
61 | self.qa_path = qa_path
62 | self.tokenizer = tokenizer
63 | self.chronos_tokenizer = chronos_tokenizer
64 | self.split = split
65 | self.preprocess_type = data_args.preprocess_type
66 | self.preprocess_type_eval = None if tokenizer is None else data_args.preprocess_type_eval
67 |
68 | shuffle = data_args.shuffle
69 | dataset = data_args.dataset
70 |
71 | if dataset == 'usc-had':
72 | self.SYS_INST = "The assistant is provided with time-series readings of six sensor channels, including three accelerometer channels (in g) and three gyroscope channels (in dps). Each channel contains 200 data representing information extracted from the same 2-second time window at a sampling rate of 100Hz. Please analyze the trends and patterns in each channel to identify the correct activity type from the following twelve activity options:\n\n1. Walking Forward\n2. Walking Left\n3. Walking Right\n4. Walking Upstairs\n5. Walking Downstairs\n6. Running Forward\n7. Jumping\n8. Sitting\n9. Standing\n10. Sleeping\n11. Elevator Up\n12. Elevator Down\n\nProvide the predicted activity as both the number and the name at the end."
73 | elif dataset == 'capture24':
74 | self.SYS_INST = "The assistant is provided with time-series readings of three accelerometer channels (in g). Each channel contains 500 data representing information extracted from the same 10-second time window at a sampling rate of 50Hz. Please analyze the trends and patterns in each channel to identify the correct activity type from the following 10 activity options:\n\n1. sleep\n2. sitting\n3. household-chores\n4. walking\n5. vehicle\n6. bicycling\n7. mixed-activity\n8. standing\n9. manual-work\n10. sports\n\nProvide the predicted activity as both the number and the name at the end."
75 | elif dataset == 'mhealth':
76 | self.SYS_INST = "The assistant is provided with time-series readings of 15 sensor channels, including acceleration sensors (in m/s^2) and gyroscope sensors (in deg/s). Each channel contains 100 data representing information extracted from the same 2-second time window at a sampling rate of 50Hz. Please analyze the trends and patterns in each channel to identify the correct activity type from the following twelve activity options:\n\n1. Standing still\n2. Sitting and relaxing\n3. Lying down\n4. Walking\n5. Climbing stairs\n6. Waist bends forward\n7. Frontal elevation of arms\n8. Knees bending (crouching)\n9. Cycling\n10. Jogging\n11. Running\n12. Jump front & back\n\nProvide the predicted activity as both the number and the name at the end."
77 | elif dataset == 'pamap50':
78 | self.SYS_INST = "The assistant is provided with time-series readings of 27 sensor channels, including acceleration sensors (in m/s^2) and gyroscope sensors (in rad/s) and magnetometer sensors (in μT). Each channel contains 100 data representing information extracted from the same 2-second time window at a sampling rate of 50Hz. Please analyze the trends and patterns in each channel to identify the correct activity type from the following twelve activity options:\n\n1. lying\n2. sitting\n3. standing\n4. walking\n5. running\n6. cycling\n7. Nordic walking\n8. ascending stairs\n9. descending stairs\n10. vacuum cleaning\n11. ironing\n12. rope jumping\n\nProvide the predicted activity as both the number and the name at the end."
79 | elif dataset == 'pamap':
80 | self.SYS_INST = "The assistant is provided with time-series readings of 27 sensor channels, including acceleration sensors (in m/s^2) and gyroscope sensors (in rad/s) and magnetometer sensors (in μT). Each channel contains 100 data representing information extracted from the same 2-second time window at a sampling rate of 100Hz. Please analyze the trends and patterns in each channel to identify the correct activity type from the following twelve activity options:\n\n1. lying\n2. sitting\n3. standing\n4. walking\n5. running\n6. cycling\n7. Nordic walking\n8. ascending stairs\n9. descending stairs\n10. vacuum cleaning\n11. ironing\n12. rope jumping\n\nProvide the predicted activity as both the number and the name at the end."
81 | elif dataset == 'uci':
82 | self.SYS_INST = "The assistant is provided with time-series readings of six sensor channels, including three accelerometer channels (in g) and three gyroscope channels (in dps). Each channel contains 128 data representing information extracted from the same 2.56-second time window at a sampling rate of 50Hz. Please analyze the trends and patterns in each channel to identify the correct activity type from the following twelve activity options:\n\n1. Walking Forward\n2. Walking Upstairs\n3. Walking Downstairs\n4. Sitting\n5. Standing\n6. Laying\n\nProvide the predicted activity as both the number and the name at the end."
83 | else:
84 | raise ValueError(f"Wrong dataset name in __init__: {dataset}")
85 |
86 | self.ts_data, self.list_data_dict = self._flatten_data(shuffle)
87 |
88 | self.data_args = data_args.ts_backbone_config
89 |
90 | self.window_length = len(self.ts_data[0][0])
91 | self.channel_num = len(self.ts_data[0])
92 | assert self.channel_num == self.data_args[dataset]["channel_num"], "channel_num, data_args.channel_num shape mismatched"
93 |
94 | print(
95 | f"The dataset size is: {len(self.ts_data)}. Window size: {self.window_length}. Channel num: {self.channel_num}."
96 | )
97 |
98 | if self.data_args["chronos_model"]["last_token"]:
99 | added_token = self.data_args["default_ts_token"] * (self.window_length + 1)
100 | else:
101 | added_token = self.data_args["default_ts_token"] * self.window_length
102 |
103 | start_tokens_list, end_tokens_list = get_token_list(dataset, self.data_args[dataset], data_args.add_ts_special_token_text)
104 |
105 | added_str = ''
106 | for start_token, end_token in zip(start_tokens_list, end_tokens_list):
107 | added_str += start_token + added_token + end_token + '\n'
108 | self.added_str = added_str
109 |
110 | def _flatten_data(self, shuffle: bool):
111 | logging.warning(f"Loading {self.split} data...")
112 | with open(self.data_path, "rb") as f:
113 | data_file = pickle.load(f)
114 | with open(self.qa_path, "r") as file:
115 | qa_file = json.load(file)
116 | data_file = np.array(data_file, dtype=np.float64)
117 | ts_data = []
118 | qa_dict = []
119 | assert len(data_file) == len(qa_file["dataset"])
120 | for q in qa_file["dataset"]:
121 | data_idx = q["index"]
122 | data = data_file[int(data_idx)]
123 | ts_data.append([torch.from_numpy(data[:, i]).to(torch.float64) for i in range(data.shape[1])])
124 | qa_dict.append(q["qa_pair"])
125 | assert len(ts_data) == len(qa_dict), "ts_data, qa_dict, length not matched"
126 |
127 | if shuffle:
128 | print("Shuffling data...")
129 | random.seed(RDM_SEED)
130 | indexes = list(range(len(ts_data)))
131 | random.shuffle(indexes)
132 | ts_data = [ts_data[i] for i in indexes]
133 | qa_dict = [qa_dict[i] for i in indexes]
134 | #
135 | # if self.split == "eval":
136 | # return ts_data[:100], qa_dict[:100]
137 |
138 | return ts_data, qa_dict
139 |
140 | def __len__(self):
141 | return len(self.ts_data)
142 |
143 | def __getitem__(self, index):
144 | sources = self.list_data_dict[index] # {"Q": ..., "A": ...}
145 | multichannel_ts = self.ts_data[index] # C * L, 6 * 200
146 |
147 | if isinstance(index, int):
148 | sources = [sources]
149 | multichannel_ts = [multichannel_ts]
150 |
151 | assert (
152 | len(sources) == 1
153 | ), "sources should be a list"
154 |
155 | sources = preprocess_time_series_stage2(
156 | copy.deepcopy(sources), self.added_str
157 | )
158 |
159 | mts_token_ids_list = []
160 | mts_attention_mask_list = []
161 | mts_tokenizer_state_list = []
162 | for ts in multichannel_ts:
163 | context = torch.stack(ts)
164 | if isinstance(context, list):
165 | context = left_pad_and_stack_1D(context)
166 | assert isinstance(context, torch.Tensor)
167 | if context.ndim == 1:
168 | context = context.unsqueeze(0)
169 | assert context.ndim == 2
170 |
171 | mts_token_ids, mts_attention_mask, mts_tokenizer_state = (
172 | self.chronos_tokenizer.context_input_transform(context)
173 | )
174 | mts_token_ids_list.append(mts_token_ids)
175 | mts_attention_mask_list.append(mts_attention_mask)
176 | mts_tokenizer_state_list.append(mts_tokenizer_state)
177 |
178 | if self.tokenizer is None:
179 | data_dict = dict(
180 | question=sources[0]["Q"],
181 | answer=sources[0]["A"],
182 | cot=sources[0]["cot"],
183 | ground_truth=sources[0]["ground_truth"],
184 | mts_token_ids=mts_token_ids_list[0],
185 | mts_attention_mask=mts_attention_mask_list[0],
186 | mts_tokenizer_state=mts_tokenizer_state_list[0]
187 | )
188 | return data_dict
189 |
190 | data_dict = preprocess(sources, self.tokenizer, self.SYS_INST, self.split, self.preprocess_type, self.preprocess_type_eval)
191 |
192 | data_dict = dict(input_ids=data_dict["input_ids"][0],
193 | labels=data_dict["labels"][0],
194 | mts_token_ids=mts_token_ids_list[0],
195 | mts_attention_mask=mts_attention_mask_list[0],
196 | mts_tokenizer_state=mts_tokenizer_state_list[0])
197 | return data_dict
198 |
199 |
200 | @dataclass
201 | class DataCollatorForTsTextDatasetStage2(object):
202 | """Collate examples for supervised fine-tuning."""
203 |
204 | tokenizer: transformers.PreTrainedTokenizer
205 |
206 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
207 | input_ids, labels, mts_token_ids, mts_attention_mask, mts_tokenizer_state = tuple(
208 | [instance[key] for instance in instances]
209 | for key in ("input_ids", "labels", "mts_token_ids", "mts_attention_mask", "mts_tokenizer_state")
210 | )
211 |
212 | input_ids = torch.nn.utils.rnn.pad_sequence(
213 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
214 | )
215 | labels = torch.nn.utils.rnn.pad_sequence(
216 | labels, batch_first=True, padding_value=IGNORE_INDEX
217 | )
218 |
219 | return dict(
220 | input_ids=input_ids,
221 | labels=labels,
222 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
223 | mts_token_ids=torch.stack(mts_token_ids),
224 | mts_attention_mask=torch.stack(mts_attention_mask),
225 | mts_tokenizer_state=mts_tokenizer_state
226 | )
227 |
228 |
229 | def make_ts_text_data_module_stage2(
230 | tokenizer: transformers.PreTrainedTokenizer, chronos_tokenizer, data_args
231 | ) -> Dict:
232 | """Make dataset and collator for supervised fine-tuning."""
233 | data_collator = DataCollatorForTsTextDatasetStage2(tokenizer=tokenizer)
234 | train_dataset = MultiChannelTimeSeriesDatasetStage2(
235 | data_path=data_args.data_path, qa_path=data_args.qa_path, tokenizer=tokenizer, chronos_tokenizer=chronos_tokenizer, split="train", data_args=data_args
236 | )
237 | eval_dataset = MultiChannelTimeSeriesDatasetStage2(
238 | data_path=data_args.eval_data_path, qa_path=data_args.eval_qa_path, tokenizer=tokenizer, chronos_tokenizer=chronos_tokenizer, split="eval", data_args=data_args
239 | )
240 | return dict(
241 | train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=data_collator
242 | )
243 |
244 |
245 | @dataclass
246 | class DataCollatorForTsCLSDatasetStage2(object):
247 | """Collate examples for supervised fine-tuning."""
248 |
249 | tokenizer: transformers.PreTrainedTokenizer
250 |
251 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
252 | input_ids, labels, mts_token_ids, mts_attention_mask, mts_tokenizer_state = tuple(
253 | [instance[key] for instance in instances]
254 | for key in ("input_ids", "labels", "mts_token_ids", "mts_attention_mask", "mts_tokenizer_state")
255 | )
256 |
257 | input_ids = torch.nn.utils.rnn.pad_sequence(
258 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
259 | )
260 |
261 | return dict(
262 | input_ids=input_ids,
263 | labels=torch.tensor(labels),
264 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
265 | mts_token_ids=torch.stack(mts_token_ids),
266 | mts_attention_mask=torch.stack(mts_attention_mask),
267 | mts_tokenizer_state=mts_tokenizer_state
268 | )
269 |
270 |
271 | class MultiChannelTimeSeriesCLSDatasetStage2(Dataset):
272 | def __init__(self, data_path=None, qa_path=None, tokenizer=None, chronos_tokenizer=None, split=None, label2id=None, data_args=None):
273 | super(MultiChannelTimeSeriesCLSDatasetStage2, self).__init__()
274 | self.data_path = data_path
275 | self.qa_path = qa_path
276 | self.tokenizer = tokenizer
277 | self.chronos_tokenizer = chronos_tokenizer
278 | self.split = split
279 | self.label2id = label2id
280 | self.preprocess_type = data_args.preprocess_type
281 |
282 | shuffle = data_args.shuffle
283 | dataset = data_args.dataset
284 |
285 |
286 | self.ts_data, self.list_data_dict, self.class_weights = self._flatten_data(shuffle)
287 |
288 | self.data_args = data_args.ts_backbone_config
289 |
290 | self.window_length = len(self.ts_data[0][0])
291 | self.channel_num = len(self.ts_data[0])
292 | assert self.channel_num == self.data_args[dataset][
293 | "channel_num"], "channel_num, data_args.channel_num shape mismatched"
294 |
295 | print(
296 | f"The dataset size is: {len(self.ts_data)}. Window size: {self.window_length}. Channel num: {self.channel_num}."
297 | )
298 |
299 | if self.data_args["chronos_model"]["last_token"]:
300 | added_token = self.data_args["default_ts_token"] * (self.window_length + 1)
301 | else:
302 | added_token = self.data_args["default_ts_token"] * self.window_length
303 |
304 | start_tokens_list, end_tokens_list = get_token_list(dataset, self.data_args[dataset], data_args.add_ts_special_token_text)
305 |
306 | added_str = ''
307 | for start_token, end_token in zip(start_tokens_list, end_tokens_list):
308 | added_str += start_token + added_token + end_token + '\n'
309 | self.added_str = added_str
310 |
311 | def _flatten_data(self, shuffle: bool):
312 | logging.warning(f"Loading {self.split} data...")
313 | with open(self.data_path, "rb") as f:
314 | data_file = pickle.load(f)
315 | with open(self.qa_path, "r") as file:
316 | qa_file = json.load(file)
317 | ts_data = []
318 | qa_dict = []
319 | label_list = []
320 | assert len(data_file) == len(qa_file["dataset"])
321 | for q in qa_file["dataset"]:
322 | data_idx = q["index"]
323 | data = data_file[int(data_idx)]
324 | ts_data.append([torch.from_numpy(data[:, i]).to(torch.float64) for i in range(data.shape[1])])
325 | answer = q["qa_pair"]["A"]
326 | try:
327 | label = int(self.label2id[answer])
328 | except KeyError:
329 | raise ValueError(f"Text '{answer}' not found in label2id dictionary")
330 | label_list.append(label)
331 | q["qa_pair"]['label'] = label
332 | qa_dict.append(q["qa_pair"])
333 | # assert len(ts_data[0]) == 6, "ts_data channel length error"
334 | # assert len(ts_data[0][0]) == 200, "ts_data length error"
335 | # assert len(ts_data[0][1]) == 200, "ts_data length error"
336 | # assert len(ts_data[0][2]) == 200, "ts_data length error"
337 | assert len(ts_data) == len(qa_dict) == len(label_list), "ts_data, qa_dict, label_list, length not matched"
338 |
339 | class_weights = None
340 | if self.split == 'train':
341 | label_series = pd.Series(label_list)
342 | value_counts = label_series.value_counts(normalize=True)
343 | class_weights = (1 / value_counts.sort_index()).tolist()
344 | class_weights = torch.tensor(class_weights)
345 | class_weights = class_weights / class_weights.sum()
346 | if shuffle:
347 | print("Shuffling data...")
348 | random.seed(RDM_SEED)
349 | indexes = list(range(len(ts_data)))
350 | random.shuffle(indexes)
351 | ts_data = [ts_data[i] for i in indexes]
352 | qa_dict = [qa_dict[i] for i in indexes]
353 |
354 | # if self.split == "eval":
355 | # return ts_data[:100], qa_dict[:100]
356 |
357 | return ts_data, qa_dict, class_weights
358 |
359 | def __len__(self):
360 | return len(self.ts_data)
361 |
362 | def get_class_weights(self):
363 | return self.class_weights
364 |
365 | def __getitem__(self, index):
366 | sources = self.list_data_dict[index] # {"Q": ..., "A": ...}
367 | multichannel_ts = self.ts_data[index] # C * L, 6 * 200
368 |
369 | if isinstance(index, int):
370 | sources = [sources]
371 | multichannel_ts = [multichannel_ts]
372 |
373 | assert (
374 | len(sources) == 1
375 | ), "sources should be a list"
376 |
377 | sources = preprocess_time_series_CLS_stage2(
378 | copy.deepcopy(sources)
379 | )
380 |
381 | mts_token_ids_list = []
382 | mts_attention_mask_list = []
383 | mts_tokenizer_state_list = []
384 | for ts in multichannel_ts:
385 | context = torch.stack(ts)
386 | if isinstance(context, list):
387 | context = left_pad_and_stack_1D(context)
388 | assert isinstance(context, torch.Tensor)
389 | if context.ndim == 1:
390 | context = context.unsqueeze(0)
391 | assert context.ndim == 2
392 |
393 | mts_token_ids, mts_attention_mask, mts_tokenizer_state = (
394 | self.chronos_tokenizer.context_input_transform(context)
395 | )
396 | mts_token_ids_list.append(mts_token_ids)
397 | mts_attention_mask_list.append(mts_attention_mask)
398 | mts_tokenizer_state_list.append(mts_tokenizer_state)
399 |
400 | if self.tokenizer is None:
401 | data_dict = dict(
402 | added_str=self.added_str,
403 | question=sources[0]["Q"],
404 | smry=sources[0]["smry"],
405 | trend_text=sources[0]["trend_text"],
406 | corr_text=sources[0]["corr_text"],
407 | info_text=sources[0]["info_text"],
408 | answer=sources[0]["answer"],
409 | label=sources[0]["label"],
410 | mts_token_ids=mts_token_ids_list[0],
411 | mts_attention_mask=mts_attention_mask_list[0],
412 | mts_tokenizer_state=mts_tokenizer_state_list[0]
413 | )
414 | return data_dict
415 |
416 | data_dict = preprocess_cls(sources, self.tokenizer, self.added_str, self.preprocess_type)
417 |
418 | data_dict = dict(input_ids=data_dict["input_ids"][0],
419 | input_texts=data_dict["input_texts"][0],
420 | labels=sources[0]["label"],
421 | answer=sources[0]["answer"],
422 | mts_token_ids=mts_token_ids_list[0],
423 | mts_attention_mask=mts_attention_mask_list[0],
424 | mts_tokenizer_state=mts_tokenizer_state_list[0])
425 |
426 | return data_dict
427 |
428 |
429 | def make_ts_classification_data_module_stage2(
430 | tokenizer: transformers.PreTrainedTokenizer, chronos_tokenizer, label2id, data_args
431 | ) -> Dict:
432 | """Make dataset and collator for supervised fine-tuning."""
433 | data_collator = DataCollatorForTsCLSDatasetStage2(tokenizer=tokenizer)
434 | train_dataset = MultiChannelTimeSeriesCLSDatasetStage2(
435 | data_path=data_args.data_path, qa_path=data_args.qa_path, tokenizer=tokenizer, chronos_tokenizer=chronos_tokenizer, split="train", label2id=label2id, data_args=data_args
436 | )
437 | class_weights = train_dataset.get_class_weights()
438 | assert class_weights is not None, "class_weights should not be None"
439 |
440 | eval_dataset = MultiChannelTimeSeriesCLSDatasetStage2(
441 | data_path=data_args.eval_data_path, qa_path=data_args.eval_qa_path, tokenizer=tokenizer, chronos_tokenizer=chronos_tokenizer, split="eval", label2id=label2id, data_args=data_args
442 | )
443 |
444 | for i in range(2):
445 | print(f"data example {i}:\nInput Text: {train_dataset[i]['input_texts']}\nLabel: {train_dataset[i]['labels']}\nAnswer: {train_dataset[i]['answer']}\nInput ids: {train_dataset[i]['input_ids']}\n")
446 |
447 | return dict(
448 | train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=data_collator, class_weights=class_weights
449 | )
--------------------------------------------------------------------------------
/sensorllm/data/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Sequence
2 | import transformers
3 | import copy
4 |
5 | IGNORE_INDEX = -100
6 |
7 |
8 | def generate_chat_template(messages, bos_token, eos_token, add_generation_prompt=False):
9 | LLAMA_3_CHAT_TEMPLATE = (
10 | "{% set loop_messages = messages %}"
11 | "{% for message in loop_messages %}"
12 | "{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}"
13 | "{% if loop.index0 == 0 %}"
14 | "{% set content = bos_token + content %}"
15 | "{% endif %}"
16 | "{{ content }}"
17 | "{% endfor %}"
18 | "{% if add_generation_prompt %}"
19 | "{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}"
20 | "{% endif %}"
21 | )
22 |
23 | from jinja2 import Template
24 | template = Template(LLAMA_3_CHAT_TEMPLATE)
25 | return template.render(messages=messages, bos_token=bos_token, eos_token=eos_token, add_generation_prompt=add_generation_prompt)
26 |
27 |
28 | def generate_chat_template2(messages, bos_token, eos_token, add_generation_prompt=False):
29 | LLAMA_3_CHAT_TEMPLATE = (
30 | "{% set loop_messages = messages %}"
31 | "{% for message in loop_messages %}"
32 | "{% if loop.last %}"
33 | "{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim %}"
34 | "{% else %}"
35 | "{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}"
36 | "{% endif %}"
37 | "{% if loop.index0 == 0 %}"
38 | "{% set content = bos_token + content %}"
39 | "{% endif %}"
40 | "{{ content }}"
41 | "{% endfor %}"
42 | "{% if add_generation_prompt %}"
43 | "{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}"
44 | "{% endif %}"
45 | )
46 |
47 | from jinja2 import Template
48 | template = Template(LLAMA_3_CHAT_TEMPLATE)
49 | return template.render(messages=messages, bos_token=bos_token, eos_token=eos_token, add_generation_prompt=add_generation_prompt)
50 |
51 |
52 | def _tokenize_fn(
53 | conversations: Sequence[str], tokenizer: transformers.PreTrainedTokenizer
54 | ) -> Dict:
55 | """Tokenize a list of strings."""
56 | tokenized_list = [
57 | tokenizer(
58 | conv,
59 | return_tensors="pt",
60 | padding="longest",
61 | max_length=tokenizer.model_max_length,
62 | truncation=True,
63 | )
64 | for conv in conversations
65 | ]
66 | input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
67 | input_ids_lens = labels_lens = [
68 | tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
69 | for tokenized in tokenized_list
70 | ]
71 | return dict(
72 | input_ids=input_ids,
73 | labels=labels,
74 | input_ids_lens=input_ids_lens,
75 | labels_lens=labels_lens,
76 | )
77 |
78 |
79 | def get_token_dict(dataset: str, data_args: dict):
80 | if dataset in ["usc-had", "uci"]:
81 | start_tokens_dict = {
82 | "x_acc": data_args["default_x_acc_start_token"],
83 | "y_acc": data_args["default_y_acc_start_token"],
84 | "z_acc": data_args["default_z_acc_start_token"],
85 | "x_g": data_args["default_x_gyro_start_token"],
86 | "y_g": data_args["default_y_gyro_start_token"],
87 | "z_g": data_args["default_z_gyro_start_token"]
88 | }
89 |
90 | end_tokens_dict = {
91 | "x_acc": data_args["default_x_acc_end_token"],
92 | "y_acc": data_args["default_y_acc_end_token"],
93 | "z_acc": data_args["default_z_acc_end_token"],
94 | "x_g": data_args["default_x_gyro_end_token"],
95 | "y_g": data_args["default_y_gyro_end_token"],
96 | "z_g": data_args["default_z_gyro_end_token"]
97 | }
98 | elif dataset == "capture24":
99 | start_tokens_dict = {
100 | "x_acc": data_args["default_x_acc_start_token"],
101 | "y_acc": data_args["default_y_acc_start_token"],
102 | "z_acc": data_args["default_z_acc_start_token"],
103 | }
104 |
105 | end_tokens_dict = {
106 | "x_acc": data_args["default_x_acc_end_token"],
107 | "y_acc": data_args["default_y_acc_end_token"],
108 | "z_acc": data_args["default_z_acc_end_token"],
109 | }
110 | elif dataset == "mhealth":
111 | start_tokens_dict = {
112 | "c_acc_x": data_args["default_chest_x_acc_start_token"],
113 | "c_acc_y": data_args["default_chest_y_acc_start_token"],
114 | "c_acc_z": data_args["default_chest_z_acc_start_token"],
115 | "la_acc_x": data_args["default_left_ankle_x_acc_start_token"],
116 | "la_acc_y": data_args["default_left_ankle_y_acc_start_token"],
117 | "la_acc_z": data_args["default_left_ankle_z_acc_start_token"],
118 | "la_gs_x": data_args["default_left_ankle_x_gyro_start_token"],
119 | "la_gs_y": data_args["default_left_ankle_y_gyro_start_token"],
120 | "la_gs_z": data_args["default_left_ankle_z_gyro_start_token"],
121 | "rla_acc_x": data_args["default_right_lower_arm_x_acc_start_token"],
122 | "rla_acc_y": data_args["default_right_lower_arm_y_acc_start_token"],
123 | "rla_acc_z": data_args["default_right_lower_arm_z_acc_start_token"],
124 | "rla_gs_x": data_args["default_right_lower_arm_x_gyro_start_token"],
125 | "rla_gs_y": data_args["default_right_lower_arm_y_gyro_start_token"],
126 | "rla_gs_z": data_args["default_right_lower_arm_z_gyro_start_token"]
127 | }
128 | end_tokens_dict = {
129 | "c_acc_x": data_args["default_chest_x_acc_end_token"],
130 | "c_acc_y": data_args["default_chest_y_acc_end_token"],
131 | "c_acc_z": data_args["default_chest_z_acc_end_token"],
132 | "la_acc_x": data_args["default_left_ankle_x_acc_end_token"],
133 | "la_acc_y": data_args["default_left_ankle_y_acc_end_token"],
134 | "la_acc_z": data_args["default_left_ankle_z_acc_end_token"],
135 | "la_gs_x": data_args["default_left_ankle_x_gyro_end_token"],
136 | "la_gs_y": data_args["default_left_ankle_y_gyro_end_token"],
137 | "la_gs_z": data_args["default_left_ankle_z_gyro_end_token"],
138 | "rla_acc_x": data_args["default_right_lower_arm_x_acc_end_token"],
139 | "rla_acc_y": data_args["default_right_lower_arm_y_acc_end_token"],
140 | "rla_acc_z": data_args["default_right_lower_arm_z_acc_end_token"],
141 | "rla_gs_x": data_args["default_right_lower_arm_x_gyro_end_token"],
142 | "rla_gs_y": data_args["default_right_lower_arm_y_gyro_end_token"],
143 | "rla_gs_z": data_args["default_right_lower_arm_z_gyro_end_token"]
144 | }
145 | elif dataset == "pamap" or dataset == "pamap50":
146 | start_tokens_dict = {
147 | "acc_hand_x": data_args["default_hand_x_acc_start_token"],
148 | "acc_hand_y": data_args["default_hand_y_acc_start_token"],
149 | "acc_hand_z": data_args["default_hand_z_acc_start_token"],
150 | "gyr_hand_x": data_args["default_hand_x_gyro_start_token"],
151 | "gyr_hand_y": data_args["default_hand_y_gyro_start_token"],
152 | "gyr_hand_z": data_args["default_hand_z_gyro_start_token"],
153 | "mag_hand_x": data_args["default_hand_x_mag_start_token"],
154 | "mag_hand_y": data_args["default_hand_y_mag_start_token"],
155 | "mag_hand_z": data_args["default_hand_z_mag_start_token"],
156 | "acc_chest_x": data_args["default_chest_x_acc_start_token"],
157 | "acc_chest_y": data_args["default_chest_y_acc_start_token"],
158 | "acc_chest_z": data_args["default_chest_z_acc_start_token"],
159 | "gyr_chest_x": data_args["default_chest_x_gyro_start_token"],
160 | "gyr_chest_y": data_args["default_chest_y_gyro_start_token"],
161 | "gyr_chest_z": data_args["default_chest_z_gyro_start_token"],
162 | "mag_chest_x": data_args["default_chest_x_mag_start_token"],
163 | "mag_chest_y": data_args["default_chest_y_mag_start_token"],
164 | "mag_chest_z": data_args["default_chest_z_mag_start_token"],
165 | "acc_ankle_x": data_args["default_ankle_x_acc_start_token"],
166 | "acc_ankle_y": data_args["default_ankle_y_acc_start_token"],
167 | "acc_ankle_z": data_args["default_ankle_z_acc_start_token"],
168 | "gyr_ankle_x": data_args["default_ankle_x_gyro_start_token"],
169 | "gyr_ankle_y": data_args["default_ankle_y_gyro_start_token"],
170 | "gyr_ankle_z": data_args["default_ankle_z_gyro_start_token"],
171 | "mag_ankle_x": data_args["default_ankle_x_mag_start_token"],
172 | "mag_ankle_y": data_args["default_ankle_y_mag_start_token"],
173 | "mag_ankle_z": data_args["default_ankle_z_mag_start_token"],
174 | }
175 | end_tokens_dict = {
176 | "acc_hand_x": data_args["default_hand_x_acc_end_token"],
177 | "acc_hand_y": data_args["default_hand_y_acc_end_token"],
178 | "acc_hand_z": data_args["default_hand_z_acc_end_token"],
179 | "gyr_hand_x": data_args["default_hand_x_gyro_end_token"],
180 | "gyr_hand_y": data_args["default_hand_y_gyro_end_token"],
181 | "gyr_hand_z": data_args["default_hand_z_gyro_end_token"],
182 | "mag_hand_x": data_args["default_hand_x_mag_end_token"],
183 | "mag_hand_y": data_args["default_hand_y_mag_end_token"],
184 | "mag_hand_z": data_args["default_hand_z_mag_end_token"],
185 | "acc_chest_x": data_args["default_chest_x_acc_end_token"],
186 | "acc_chest_y": data_args["default_chest_y_acc_end_token"],
187 | "acc_chest_z": data_args["default_chest_z_acc_end_token"],
188 | "gyr_chest_x": data_args["default_chest_x_gyro_end_token"],
189 | "gyr_chest_y": data_args["default_chest_y_gyro_end_token"],
190 | "gyr_chest_z": data_args["default_chest_z_gyro_end_token"],
191 | "mag_chest_x": data_args["default_chest_x_mag_end_token"],
192 | "mag_chest_y": data_args["default_chest_y_mag_end_token"],
193 | "mag_chest_z": data_args["default_chest_z_mag_end_token"],
194 | "acc_ankle_x": data_args["default_ankle_x_acc_end_token"],
195 | "acc_ankle_y": data_args["default_ankle_y_acc_end_token"],
196 | "acc_ankle_z": data_args["default_ankle_z_acc_end_token"],
197 | "gyr_ankle_x": data_args["default_ankle_x_gyro_end_token"],
198 | "gyr_ankle_y": data_args["default_ankle_y_gyro_end_token"],
199 | "gyr_ankle_z": data_args["default_ankle_z_gyro_end_token"],
200 | "mag_ankle_x": data_args["default_ankle_x_mag_end_token"],
201 | "mag_ankle_y": data_args["default_ankle_y_mag_end_token"],
202 | "mag_ankle_z": data_args["default_ankle_z_mag_end_token"],
203 | }
204 | else:
205 | raise ValueError(f"Wrong dataset name in preprocess_time_series2: {dataset}")
206 | return start_tokens_dict, end_tokens_dict
207 |
208 |
209 | def get_token_list(dataset: str, data_args: dict, add_ts_special_token_text: bool):
210 | if dataset in ["usc-had", "uci"]:
211 | if add_ts_special_token_text:
212 | start_tokens_list = [
213 | "x-axis accelerometer readings: " + data_args["default_x_acc_start_token"],
214 | "y-axis accelerometer readings: " + data_args["default_y_acc_start_token"],
215 | "z-axis accelerometer readings: " + data_args["default_z_acc_start_token"],
216 | "x-axis gyroscope readings: " + data_args["default_x_gyro_start_token"],
217 | "y-axis gyroscope readings: " + data_args["default_y_gyro_start_token"],
218 | "z-axis gyroscope readings: " + data_args["default_z_gyro_start_token"]
219 | ]
220 | else:
221 | start_tokens_list = [
222 | data_args["default_x_acc_start_token"],
223 | data_args["default_y_acc_start_token"],
224 | data_args["default_z_acc_start_token"],
225 | data_args["default_x_gyro_start_token"],
226 | data_args["default_y_gyro_start_token"],
227 | data_args["default_z_gyro_start_token"]
228 | ]
229 |
230 | end_tokens_list = [
231 | data_args["default_x_acc_end_token"],
232 | data_args["default_y_acc_end_token"],
233 | data_args["default_z_acc_end_token"],
234 | data_args["default_x_gyro_end_token"],
235 | data_args["default_y_gyro_end_token"],
236 | data_args["default_z_gyro_end_token"]
237 | ]
238 | elif dataset == "capture24":
239 | if add_ts_special_token_text:
240 | start_tokens_list = [
241 | "x-axis accelerometer readings: " + data_args["default_x_acc_start_token"],
242 | "y-axis accelerometer readings: " + data_args["default_y_acc_start_token"],
243 | "z-axis accelerometer readings: " + data_args["default_z_acc_start_token"]
244 | ]
245 | else:
246 | start_tokens_list = [
247 | data_args["default_x_acc_start_token"],
248 | data_args["default_y_acc_start_token"],
249 | data_args["default_z_acc_start_token"],
250 | ]
251 |
252 | end_tokens_list = [
253 | data_args["default_x_acc_end_token"],
254 | data_args["default_y_acc_end_token"],
255 | data_args["default_z_acc_end_token"],
256 | ]
257 | elif dataset == "mhealth":
258 | if add_ts_special_token_text:
259 | start_tokens_list = [
260 | "Chest x-axis accelerometer: " + data_args["default_chest_x_acc_start_token"],
261 | "Chest y-axis accelerometer: " + data_args["default_chest_y_acc_start_token"],
262 | "Chest z-axis accelerometer: " + data_args["default_chest_z_acc_start_token"],
263 | "left-ankle x-axis accelerometer: " + data_args["default_left_ankle_x_acc_start_token"],
264 | "left-ankle y-axis accelerometer: " + data_args["default_left_ankle_y_acc_start_token"],
265 | "left-ankle z-axis accelerometer: " + data_args["default_left_ankle_z_acc_start_token"],
266 | "left-ankle x-axis gyroscope: " + data_args["default_left_ankle_x_gyro_start_token"],
267 | "left-ankle y-axis gyroscope: " + data_args["default_left_ankle_y_gyro_start_token"],
268 | "left-ankle z-axis gyroscope: " + data_args["default_left_ankle_z_gyro_start_token"],
269 | "right-lower-arm x-axis accelerometer: " + data_args["default_right_lower_arm_x_acc_start_token"],
270 | "right-lower-arm y-axis accelerometer: " + data_args["default_right_lower_arm_y_acc_start_token"],
271 | "right-lower-arm z-axis accelerometer: " + data_args["default_right_lower_arm_z_acc_start_token"],
272 | "right-lower-arm x-axis gyroscope: " + data_args["default_right_lower_arm_x_gyro_start_token"],
273 | "right-lower-arm y-axis gyroscope: " + data_args["default_right_lower_arm_y_gyro_start_token"],
274 | "right-lower-arm z-axis gyroscope: " + data_args["default_right_lower_arm_z_gyro_start_token"]
275 | ]
276 | else:
277 | start_tokens_list = [
278 | data_args["default_chest_x_acc_start_token"],
279 | data_args["default_chest_y_acc_start_token"],
280 | data_args["default_chest_z_acc_start_token"],
281 | data_args["default_left_ankle_x_acc_start_token"],
282 | data_args["default_left_ankle_y_acc_start_token"],
283 | data_args["default_left_ankle_z_acc_start_token"],
284 | data_args["default_left_ankle_x_gyro_start_token"],
285 | data_args["default_left_ankle_y_gyro_start_token"],
286 | data_args["default_left_ankle_z_gyro_start_token"],
287 | data_args["default_right_lower_arm_x_acc_start_token"],
288 | data_args["default_right_lower_arm_y_acc_start_token"],
289 | data_args["default_right_lower_arm_z_acc_start_token"],
290 | data_args["default_right_lower_arm_x_gyro_start_token"],
291 | data_args["default_right_lower_arm_y_gyro_start_token"],
292 | data_args["default_right_lower_arm_z_gyro_start_token"]
293 | ]
294 | end_tokens_list = [
295 | data_args["default_chest_x_acc_end_token"],
296 | data_args["default_chest_y_acc_end_token"],
297 | data_args["default_chest_z_acc_end_token"],
298 | data_args["default_left_ankle_x_acc_end_token"],
299 | data_args["default_left_ankle_y_acc_end_token"],
300 | data_args["default_left_ankle_z_acc_end_token"],
301 | data_args["default_left_ankle_x_gyro_end_token"],
302 | data_args["default_left_ankle_y_gyro_end_token"],
303 | data_args["default_left_ankle_z_gyro_end_token"],
304 | data_args["default_right_lower_arm_x_acc_end_token"],
305 | data_args["default_right_lower_arm_y_acc_end_token"],
306 | data_args["default_right_lower_arm_z_acc_end_token"],
307 | data_args["default_right_lower_arm_x_gyro_end_token"],
308 | data_args["default_right_lower_arm_y_gyro_end_token"],
309 | data_args["default_right_lower_arm_z_gyro_end_token"]
310 | ]
311 | elif dataset == "pamap" or dataset == "pamap50":
312 | if add_ts_special_token_text:
313 | start_tokens_list = [
314 | "Hand x-axis accelerometer: " + data_args["default_hand_x_acc_start_token"],
315 | "Hand x-axis accelerometer: " + data_args["default_hand_y_acc_start_token"],
316 | "Hand x-axis accelerometer: " + data_args["default_hand_z_acc_start_token"],
317 | "Hand x-axis accelerometer: " + data_args["default_hand_x_gyro_start_token"],
318 | "Hand x-axis accelerometer: " + data_args["default_hand_y_gyro_start_token"],
319 | "Hand x-axis accelerometer: " + data_args["default_hand_z_gyro_start_token"],
320 | "Hand x-axis accelerometer: " + data_args["default_hand_x_mag_start_token"],
321 | "Hand x-axis accelerometer: " + data_args["default_hand_y_mag_start_token"],
322 | "Hand x-axis accelerometer: " + data_args["default_hand_z_mag_start_token"],
323 | "Chest x-axis accelerometer: " + data_args["default_chest_x_acc_start_token"],
324 | "Chest x-axis accelerometer: " + data_args["default_chest_y_acc_start_token"],
325 | "Chest x-axis accelerometer: " + data_args["default_chest_z_acc_start_token"],
326 | "Chest x-axis accelerometer: " + data_args["default_chest_x_gyro_start_token"],
327 | "Chest x-axis accelerometer: " + data_args["default_chest_y_gyro_start_token"],
328 | "Chest x-axis accelerometer: " + data_args["default_chest_z_gyro_start_token"],
329 | "Chest x-axis accelerometer: " + data_args["default_chest_x_mag_start_token"],
330 | "Chest x-axis accelerometer: " + data_args["default_chest_y_mag_start_token"],
331 | "Chest x-axis accelerometer: " + data_args["default_chest_z_mag_start_token"],
332 | "Ankle x-axis accelerometer: " + data_args["default_ankle_x_acc_start_token"],
333 | "Ankle x-axis accelerometer: " + data_args["default_ankle_y_acc_start_token"],
334 | "Ankle x-axis accelerometer: " + data_args["default_ankle_z_acc_start_token"],
335 | "Ankle x-axis accelerometer: " + data_args["default_ankle_x_gyro_start_token"],
336 | "Ankle x-axis accelerometer: " + data_args["default_ankle_y_gyro_start_token"],
337 | "Ankle x-axis accelerometer: " + data_args["default_ankle_z_gyro_start_token"],
338 | "Ankle x-axis accelerometer: " + data_args["default_ankle_x_mag_start_token"],
339 | "Ankle x-axis accelerometer: " + data_args["default_ankle_y_mag_start_token"],
340 | "Ankle x-axis accelerometer: " + data_args["default_ankle_z_mag_start_token"]
341 | ]
342 | else:
343 | start_tokens_list = [
344 | data_args["default_hand_x_acc_start_token"],
345 | data_args["default_hand_y_acc_start_token"],
346 | data_args["default_hand_z_acc_start_token"],
347 | data_args["default_hand_x_gyro_start_token"],
348 | data_args["default_hand_y_gyro_start_token"],
349 | data_args["default_hand_z_gyro_start_token"],
350 | data_args["default_hand_x_mag_start_token"],
351 | data_args["default_hand_y_mag_start_token"],
352 | data_args["default_hand_z_mag_start_token"],
353 | data_args["default_chest_x_acc_start_token"],
354 | data_args["default_chest_y_acc_start_token"],
355 | data_args["default_chest_z_acc_start_token"],
356 | data_args["default_chest_x_gyro_start_token"],
357 | data_args["default_chest_y_gyro_start_token"],
358 | data_args["default_chest_z_gyro_start_token"],
359 | data_args["default_chest_x_mag_start_token"],
360 | data_args["default_chest_y_mag_start_token"],
361 | data_args["default_chest_z_mag_start_token"],
362 | data_args["default_ankle_x_acc_start_token"],
363 | data_args["default_ankle_y_acc_start_token"],
364 | data_args["default_ankle_z_acc_start_token"],
365 | data_args["default_ankle_x_gyro_start_token"],
366 | data_args["default_ankle_y_gyro_start_token"],
367 | data_args["default_ankle_z_gyro_start_token"],
368 | data_args["default_ankle_x_mag_start_token"],
369 | data_args["default_ankle_y_mag_start_token"],
370 | data_args["default_ankle_z_mag_start_token"]
371 | ]
372 | end_tokens_list = [
373 | data_args["default_hand_x_acc_end_token"],
374 | data_args["default_hand_y_acc_end_token"],
375 | data_args["default_hand_z_acc_end_token"],
376 | data_args["default_hand_x_gyro_end_token"],
377 | data_args["default_hand_y_gyro_end_token"],
378 | data_args["default_hand_z_gyro_end_token"],
379 | data_args["default_hand_x_mag_end_token"],
380 | data_args["default_hand_y_mag_end_token"],
381 | data_args["default_hand_z_mag_end_token"],
382 | data_args["default_chest_x_acc_end_token"],
383 | data_args["default_chest_y_acc_end_token"],
384 | data_args["default_chest_z_acc_end_token"],
385 | data_args["default_chest_x_gyro_end_token"],
386 | data_args["default_chest_y_gyro_end_token"],
387 | data_args["default_chest_z_gyro_end_token"],
388 | data_args["default_chest_x_mag_end_token"],
389 | data_args["default_chest_y_mag_end_token"],
390 | data_args["default_chest_z_mag_end_token"],
391 | data_args["default_ankle_x_acc_end_token"],
392 | data_args["default_ankle_y_acc_end_token"],
393 | data_args["default_ankle_z_acc_end_token"],
394 | data_args["default_ankle_x_gyro_end_token"],
395 | data_args["default_ankle_y_gyro_end_token"],
396 | data_args["default_ankle_z_gyro_end_token"],
397 | data_args["default_ankle_x_mag_end_token"],
398 | data_args["default_ankle_y_mag_end_token"],
399 | data_args["default_ankle_z_mag_end_token"]
400 | ]
401 | else:
402 | raise ValueError(f"Wrong dataset name in preprocess_time_series2: {dataset}")
403 | return start_tokens_list, end_tokens_list
404 |
405 |
406 | def preprocess(
407 | sources: Sequence[Dict[str, str]],
408 | tokenizer: transformers.PreTrainedTokenizer,
409 | SYS_INST: str,
410 | split: str,
411 | preprocess_type: str,
412 | preprocess_type_eval: str,
413 | ) -> Dict:
414 | """Preprocess the data by tokenizing."""
415 | examples = [generate_chat_template([
416 | {"role": "system", "content": SYS_INST},
417 | {"role": "user", "content": s["Q"]},
418 | {"role": "assistant", "content": s["A"]}], bos_token=tokenizer.bos_token, eos_token=tokenizer.eos_token, add_generation_prompt=False) for s in
419 | sources]
420 |
421 | if split == 'train':
422 | if preprocess_type == "Q":
423 | sources_q = [generate_chat_template([
424 | {"role": "system", "content": SYS_INST},
425 | {"role": "user", "content": s["Q"]}], bos_token=tokenizer.bos_token, eos_token=tokenizer.eos_token, add_generation_prompt=True) for s in
426 | sources]
427 | else:
428 | assert preprocess_type == "Q+cot"
429 | sources_q = [generate_chat_template2([
430 | {"role": "system", "content": SYS_INST},
431 | {"role": "user", "content": s["Q"]},
432 | {"role": "assistant", "content": s["cot"]}], bos_token=tokenizer.bos_token,
433 | eos_token=tokenizer.eos_token,
434 | add_generation_prompt=False) for s in
435 | sources]
436 | else:
437 | assert split == 'eval'
438 | if preprocess_type_eval == "Q":
439 | sources_q = [generate_chat_template([
440 | {"role": "system", "content": SYS_INST},
441 | {"role": "user", "content": s["Q"]}], bos_token=tokenizer.bos_token, eos_token=tokenizer.eos_token, add_generation_prompt=True) for s in
442 | sources]
443 | else:
444 | assert preprocess_type_eval == "Q+cot"
445 | sources_q = [generate_chat_template2([
446 | {"role": "system", "content": SYS_INST},
447 | {"role": "user", "content": s["Q"]},
448 | {"role": "assistant", "content": s["cot"]}], bos_token=tokenizer.bos_token,
449 | eos_token=tokenizer.eos_token,
450 | add_generation_prompt=False) for s in
451 | sources]
452 |
453 | examples_tokenized = _tokenize_fn(examples, tokenizer)
454 | sources_tokenized = _tokenize_fn(sources_q, tokenizer)
455 |
456 | input_ids = examples_tokenized["input_ids"]
457 | labels = copy.deepcopy(input_ids)
458 | for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
459 | label[:source_len] = IGNORE_INDEX
460 | return dict(input_ids=input_ids, labels=labels)
461 |
462 |
463 | def preprocess_cls(
464 | sources: Sequence[Dict[str, str]],
465 | tokenizer: transformers.PreTrainedTokenizer,
466 | added_str: str,
467 | preprocess_type: str
468 | ) -> Dict:
469 | """Preprocess the data by tokenizing."""
470 |
471 | if preprocess_type == "smry":
472 | inputs = [added_str + '\n' + s["smry"] for s in sources]
473 | elif preprocess_type == "trend":
474 | inputs = [added_str + '\n' + s["trend_text"] for s in sources]
475 | elif preprocess_type == "corr":
476 | inputs = [added_str + '\n' + s["corr_text"] for s in sources]
477 | elif preprocess_type == "none":
478 | inputs = [added_str for _ in sources]
479 | elif preprocess_type == "smry+Q":
480 | inputs = [added_str + '\n' + s["smry"] + '\n' + s["Q"] for s in sources]
481 | elif preprocess_type == "smry+meta":
482 | inputs = [added_str + '\n' + s["smry"] + '\n' + s["info_text"] for s in sources]
483 | elif preprocess_type == "smry+meta+Q":
484 | inputs = [added_str + '\n' + s["smry"] + '\n' + s["info_text"] + '\n' + s["Q"] for s in sources]
485 | elif preprocess_type == "smry+corr":
486 | inputs = [added_str + '\n' + s["smry"] + '\n' + s["corr_text"] for s in sources]
487 | elif preprocess_type == "smry+corr+Q":
488 | inputs = [added_str + '\n' + s["smry"] + '\n' + s["corr_text"] + '\n' + s["Q"] for s in sources]
489 | elif preprocess_type == "smry+trend+corr":
490 | inputs = [added_str + '\n' + s["smry"] + '\n' + s["trend_text"] + '\n' + s["corr_text"] for s in sources]
491 | elif preprocess_type == "smry+trend+corr+Q":
492 | inputs = [added_str + '\n' + s["smry"] + '\n' + s["trend_text"] + '\n' + s["corr_text"] + '\n' + s["Q"] for s in sources]
493 | elif preprocess_type == "smry+trend+Q":
494 | inputs = [added_str + '\n' + s["smry"] + '\n' + s["trend_text"] + '\n' + s["Q"] for s in sources]
495 | else:
496 | assert preprocess_type == "smry+trend", f"Undefined preprocess_type {preprocess_type}"
497 | inputs = [added_str + '\n' + s["smry"] + '\n' + s["trend_text"] for s in sources]
498 |
499 | inputs_tokenized = _tokenize_fn(inputs, tokenizer)
500 |
501 | input_ids = inputs_tokenized["input_ids"]
502 | return dict(input_ids=input_ids, input_texts=inputs)
--------------------------------------------------------------------------------
/sensorllm/eval/eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import re
4 | from transformers import AutoTokenizer, AutoConfig
5 | from sensorllm.model.chronos_model import *
6 | import torch
7 | from torch.utils.data import DataLoader
8 | from tqdm import tqdm
9 | from sensorllm.model import *
10 | from sensorllm.data import UniChannelTimeSeriesDataset
11 | from sensorllm.data.utils import generate_chat_template
12 | from sensorllm.utils import disable_torch_init
13 | import warnings
14 | import argparse
15 |
16 |
17 | warnings.filterwarnings("ignore")
18 |
19 | SYS_INST = "A chat between a curious human and an AI assistant. The assistant is given a sequence of N features that represent information extracted from sensor (time-series) readings. The original readings consisted of N data points collected at a sample rate of 100Hz. The assistant's task is to analyze the trends and patterns in the sensor readings by leveraging the encoded information within the features to answer the following specific questions provided by the human."
20 |
21 |
22 | def parse_config():
23 | parser = argparse.ArgumentParser(description='arg parser')
24 | parser.add_argument('--model_name_or_path', type=str,
25 | default="")
26 | parser.add_argument('--pt_encoder_backbone_ckpt', type=str,
27 | default="")
28 | parser.add_argument('--tokenize_method', type=str, default="MeanScaleUniformBins")
29 | parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float32", "float16", "bfloat16"])
30 |
31 | parser.add_argument('--dataset', type=str, default="usc-had")
32 | parser.add_argument('--output_file_name', type=str, default="eval.json")
33 | parser.add_argument('--model_max_length', type=int, default=8192, help='context length during evaluation')
34 | parser.add_argument('--data_path', type=str, default="",
35 | help="Path to the testing data.")
36 | parser.add_argument('--qa_path', type=str, default="",
37 | help="Path to the testing QA data.")
38 | parser.add_argument('--ignore_qa_types', type=str, nargs='*', default=["sub_trend_no_val"])
39 |
40 | # * data loader, batch_size, shuffle, num_workers
41 | parser.add_argument("--batch_size", type=int, default=6)
42 | parser.add_argument("--shuffle", type=bool, default=False)
43 | parser.add_argument("--num_workers", type=int, default=2)
44 |
45 | args = parser.parse_args()
46 | return args
47 |
48 |
49 | def load_dataset(data_path, qa_path, chronos_tokenizer):
50 | print("Loading validation datasets.")
51 | dataset = UniChannelTimeSeriesDataset(
52 | data_path=data_path,
53 | qa_path=qa_path,
54 | tokenizer=None, # * load ts and QA
55 | chronos_tokenizer=chronos_tokenizer,
56 | data_args=args
57 | )
58 | print(f"Example data: {dataset[5]}")
59 | print("Done!")
60 | print(dataset)
61 | return dataset
62 |
63 |
64 | def custom_collate_fn(batch):
65 | batch_dict = {
66 | 'question': [],
67 | 'ground_truth': [],
68 | 'type': [],
69 | 'ts_token_ids': [],
70 | 'ts_attention_mask': []
71 | }
72 |
73 | for item in batch:
74 | for key in batch_dict:
75 | batch_dict[key].append(item[key])
76 |
77 | return batch_dict
78 |
79 |
80 | def get_dataloader(dataset, batch_size, num_workers=2):
81 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers,
82 | collate_fn=custom_collate_fn)
83 | return dataloader
84 |
85 |
86 | def init_model(args):
87 | # Model
88 | disable_torch_init()
89 | model_name = os.path.expanduser(args.model_name_or_path)
90 |
91 | # * print the model_name (get the basename)
92 | print(f'[INFO] Model name: {os.path.basename(model_name)}')
93 |
94 | tokenizer = AutoTokenizer.from_pretrained(
95 | model_name,
96 | padding_side="left"
97 | )
98 | model = SensorLLMStage1LlamaForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=False, use_cache=False,
99 | torch_dtype=args.torch_dtype).cuda()
100 | model.get_model().load_pt_encoder_backbone_checkpoint(args.pt_encoder_backbone_ckpt,
101 | tc=args.tokenize_method,
102 | torch_dtype=args.torch_dtype)
103 | pt_backbone_config = AutoConfig.from_pretrained(args.pt_encoder_backbone_ckpt)
104 |
105 | assert hasattr(pt_backbone_config, "chronos_config"), "Not a Chronos config file"
106 |
107 | chronos_config = ChronosConfig(**pt_backbone_config.chronos_config)
108 | chronos_config.tokenizer_class = args.tokenize_method
109 | chronos_tokenizer = chronos_config.create_tokenizer()
110 |
111 | model.initialize_tokenizer_ts_backbone_config_wo_embedding(tokenizer, dataset=args.dataset)
112 | model.get_model().load_start_end_tokens(dataset=args.dataset)
113 |
114 | return model, tokenizer, chronos_tokenizer
115 |
116 |
117 | def generate_outputs(model, tokenizer, inputs, ts_token_ids, ts_attention_mask, do_sample=True, temperature=0.6,
118 | top_k=50, max_length=8192, top_p=0.9):
119 | model.eval()
120 | model.get_model().pt_encoder_backbone.eval()
121 | terminators = [
122 | tokenizer.eos_token_id,
123 | tokenizer.convert_tokens_to_ids("<|eot_id|>")
124 | ]
125 | with torch.inference_mode():
126 | outputs = model.generate(
127 | **inputs,
128 | ts_token_ids=ts_token_ids,
129 | ts_attention_mask=ts_attention_mask,
130 | do_sample=do_sample,
131 | use_cache=False,
132 | temperature=temperature,
133 | top_k=top_k,
134 | max_new_tokens=max_length,
135 | top_p=top_p,
136 | eos_token_id=terminators,
137 | pad_token_id=tokenizer.pad_token_id
138 | ) # * B, L'
139 | input_token_len = inputs.input_ids.shape[1]
140 | n_diff_input_output = (inputs.input_ids != outputs[:, :input_token_len]).sum().item()
141 |
142 | if n_diff_input_output > 0:
143 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
144 | outputs = tokenizer.batch_decode(outputs[:, input_token_len:], skip_special_tokens=True)
145 | outputs = [output.strip() for output in outputs]
146 |
147 | return outputs
148 |
149 |
150 | def start_generation(model, tokenizer, dataloader, output_dir, output_file_name):
151 | os.makedirs(output_dir, exist_ok=True)
152 | output_file = os.path.join(output_dir, output_file_name)
153 |
154 | results = {"prompt": SYS_INST, "results": []}
155 | if os.path.exists(output_file):
156 | # 如果文件已存在,加载现有结果
157 | with open(output_file, 'r') as f:
158 | results = json.load(f)
159 |
160 | processed_count = len(results["results"])
161 |
162 | o_i = 0
163 | for batch_idx, batch in enumerate(tqdm(dataloader)):
164 | if batch_idx * dataloader.batch_size < processed_count:
165 | continue
166 | print(f"start from id {batch_idx}...")
167 |
168 | ts_token_ids = [ts_tensor.cuda() for ts_tensor in batch["ts_token_ids"]] # * tensor of B, N.
169 | ts_attention_mask = [ts_tensor.cuda() for ts_tensor in batch["ts_attention_mask"]]
170 |
171 | ground_truths = batch["ground_truth"] # * list of string
172 | types = batch["type"]
173 | questions = batch["question"] # * list of string
174 |
175 | templated_questions = [generate_chat_template([
176 | {"role": "system", "content": SYS_INST},
177 | {"role": "user", "content": q}], bos_token=tokenizer.bos_token, eos_token=tokenizer.eos_token,
178 | add_generation_prompt=True) for q in
179 | questions]
180 |
181 | inputs = tokenizer(templated_questions, padding=True, return_tensors="pt").to(model.device)
182 | outputs = generate_outputs(model, tokenizer, inputs, ts_token_ids,
183 | ts_attention_mask) # List of str, length is B
184 |
185 | # saving results
186 | batch_results = []
187 | for q, gt, output, tp, ts in zip(questions, ground_truths, outputs, types, ts_token_ids):
188 | result = {
189 | "questions": q,
190 | "ground_truth": gt,
191 | "model_output": output,
192 | "model_len": len(ts[0]),
193 | "type": tp
194 | }
195 | batch_results.append(result)
196 | if o_i < 10:
197 | tqdm.write(f"Type: {tp}\nOutput: {output}\nGround-truth: {gt}\n\n")
198 | tqdm.write("---------" * 30)
199 | o_i += 1
200 | results["results"].extend(batch_results)
201 |
202 | if batch_idx % 10 == 0: # 每10个批次保存一次
203 | save_results(results, output_file)
204 |
205 | save_results(results, output_file)
206 | return results
207 |
208 | def save_results(results, output_file):
209 | temp_file = output_file + '.tmp'
210 | with open(temp_file, 'w') as f:
211 | json.dump(results, f, indent=2)
212 |
213 | if os.path.exists(output_file):
214 | os.replace(temp_file, output_file)
215 | else:
216 | os.rename(temp_file, output_file)
217 |
218 |
219 | def eval(args):
220 | # * ouptut
221 | args.output_dir = os.path.join(args.model_name_or_path, "evaluation")
222 | output_file_path = os.path.join(args.output_dir, args.output_file_name)
223 |
224 | if not os.path.exists(output_file_path):
225 | # * need inferencing
226 | model, tokenizer, chronos_tokenizer = init_model(args)
227 | ts_backbone_config = model.get_model().ts_backbone_config
228 | args.ts_backbone_config = ts_backbone_config
229 |
230 | dataset = load_dataset(args.data_path, args.qa_path, chronos_tokenizer)
231 | dataloader = get_dataloader(dataset, args.batch_size, args.num_workers)
232 |
233 | print(f'[INFO] Start generating results for {args.output_file_name}.')
234 | results = start_generation(model, tokenizer, dataloader, args.output_dir, args.output_file_name)
235 |
236 | del model
237 | del tokenizer
238 | torch.cuda.empty_cache()
239 | else:
240 | # * directly load the results
241 | print(f'[INFO] {output_file_path} already exists, directly loading...')
242 | with open(output_file_path, 'r') as fp:
243 | results = json.load(fp)
244 | print(results["results"][:10])
245 |
246 |
247 | if __name__ == "__main__":
248 | args = parse_config()
249 | dtype_mapping = {
250 | "float32": torch.float32,
251 | "float16": torch.float16,
252 | "bfloat16": torch.bfloat16,
253 | }
254 |
255 | args.torch_dtype = dtype_mapping[args.torch_dtype]
256 |
257 | eval(args)
258 |
--------------------------------------------------------------------------------
/sensorllm/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .stage1_sensorllm import SensorLLMStage1Config, SensorLLMStage1LlamaForCausalLM
2 | from .stage2_sensorllm import SensorLLMStage2Config, SensorLLMStage2LlamaForCausalLM, SensorLLMStage2LlamaForSequenceClassification
--------------------------------------------------------------------------------
/sensorllm/model/chronos_model/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 |
5 | from .chronos_model import (
6 | ChronosConfig,
7 | ChronosModel,
8 | ChronosPipeline,
9 | ChronosTokenizer,
10 | MeanScaleUniformBins,
11 | StanNormalizeUniformBins,
12 | left_pad_and_stack_1D
13 | )
14 |
15 | __all__ = [
16 | "ChronosConfig",
17 | "ChronosModel",
18 | "ChronosPipeline",
19 | "ChronosTokenizer",
20 | "MeanScaleUniformBins",
21 | "StanNormalizeUniformBins",
22 | "left_pad_and_stack_1D"
23 | ]
--------------------------------------------------------------------------------
/sensorllm/model/chronos_model/chronos_model.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from dataclasses import dataclass
3 | from typing import Any, Dict, List, Literal, Optional, Tuple, Union
4 |
5 | from . import chronos_model
6 | import torch
7 | import torch.nn as nn
8 | from transformers import (
9 | AutoConfig,
10 | AutoModelForCausalLM,
11 | AutoModelForSeq2SeqLM,
12 | GenerationConfig,
13 | PreTrainedModel,
14 | )
15 |
16 |
17 | @dataclass
18 | class ChronosConfig:
19 | """
20 | This class holds all the configuration parameters to be used
21 | by ``ChronosTokenizer`` and ``ChronosModel``.
22 | """
23 |
24 | tokenizer_class: str
25 | tokenizer_kwargs: Dict[str, Any]
26 | context_length: int
27 | prediction_length: int
28 | n_tokens: int
29 | n_special_tokens: int
30 | pad_token_id: int
31 | eos_token_id: int
32 | use_eos_token: bool
33 | model_type: Literal["causal", "seq2seq"]
34 | num_samples: int
35 | temperature: float
36 | top_k: int
37 | top_p: float
38 |
39 | def __post_init__(self):
40 | assert (
41 | self.pad_token_id < self.n_special_tokens
42 | and self.eos_token_id < self.n_special_tokens
43 | ), f"Special token id's must be smaller than {self.n_special_tokens=}"
44 |
45 | def create_tokenizer(self) -> "ChronosTokenizer":
46 | class_ = getattr(chronos_model, self.tokenizer_class)
47 | return class_(**self.tokenizer_kwargs, config=self)
48 |
49 |
50 | class ChronosTokenizer:
51 | """
52 | A ``ChronosTokenizer`` definines how time series are mapped into token IDs
53 | and back.
54 |
55 | For details, see the ``input_transform`` and ``output_transform`` methods,
56 | which concrete classes must implement.
57 | """
58 |
59 | def context_input_transform(
60 | self,
61 | context: torch.Tensor,
62 | ) -> Tuple:
63 | """
64 | Turn a batch of time series into token IDs, attention map, and tokenizer_state.
65 |
66 | Parameters
67 | ----------
68 | context
69 | A tensor shaped (batch_size, time_length), containing the
70 | timeseries to forecast. Use left-padding with ``torch.nan``
71 | to align time series of different lengths.
72 |
73 | Returns
74 | -------
75 | token_ids
76 | A tensor of integers, shaped (batch_size, time_length + 1)
77 | if ``config.use_eos_token`` and (batch_size, time_length)
78 | otherwise, containing token IDs for the input series.
79 | attention_mask
80 | A boolean tensor, same shape as ``token_ids``, indicating
81 | which input observations are not ``torch.nan`` (i.e. not
82 | missing nor padding).
83 | tokenizer_state
84 | An object that can be passed to ``label_input_transform``
85 | and ``output_transform``. Contains the relevant information
86 | to decode output samples into real values,
87 | such as location and scale parameters.
88 | """
89 | raise NotImplementedError()
90 |
91 | def label_input_transform(self, label: torch.Tensor, tokenizer_state: Any) -> Tuple:
92 | """
93 | Turn a batch of label slices of time series into token IDs and attention map
94 | using the ``tokenizer_state`` provided by ``context_input_transform``.
95 |
96 | Parameters
97 | ----------
98 | context
99 | A tensor shaped (batch_size, time_length), containing the
100 | timeseries to forecast. Use left-padding with ``torch.nan``
101 | to align time series of different lengths.
102 | tokenizer_state
103 | An object returned by ``context_input_transform`` containing
104 | relevant information to preprocess data, such as location and
105 | scale. The nature of this depends on the specific tokenizer.
106 | This is used for tokenizing the label, in order to use the same
107 | scaling used to tokenize the context.
108 |
109 | Returns
110 | -------
111 | token_ids
112 | A tensor of integers, shaped (batch_size, time_length + 1)
113 | if ``config.use_eos_token`` and (batch_size, time_length)
114 | otherwise, containing token IDs for the input series.
115 | attention_mask
116 | A boolean tensor, same shape as ``token_ids``, indicating
117 | which input observations are not ``torch.nan`` (i.e. not
118 | missing nor padding).
119 | """
120 | raise NotImplementedError()
121 |
122 | def output_transform(
123 | self, samples: torch.Tensor, tokenizer_state: Any
124 | ) -> torch.Tensor:
125 | """
126 | Turn a batch of sample token IDs into real values.
127 |
128 | Parameters
129 | ----------
130 | samples
131 | A tensor of integers, shaped (batch_size, num_samples, time_length),
132 | containing token IDs of sample trajectories.
133 | tokenizer_state
134 | An object returned by ``input_transform`` containing
135 | relevant context to decode samples, such as location and scale.
136 | The nature of this depends on the specific tokenizer.
137 |
138 | Returns
139 | -------
140 | forecasts
141 | A real tensor, shaped (batch_size, num_samples, time_length),
142 | containing forecasted sample paths.
143 | """
144 | raise NotImplementedError()
145 |
146 |
147 | class MeanScaleUniformBins(ChronosTokenizer):
148 | def __init__(
149 | self, low_limit: float, high_limit: float, config: ChronosConfig
150 | ) -> None:
151 | self.config = config
152 | self.centers = torch.linspace(
153 | low_limit,
154 | high_limit,
155 | config.n_tokens - config.n_special_tokens - 1,
156 | )
157 | self.boundaries = torch.concat(
158 | (
159 | torch.tensor([-1e20], device=self.centers.device),
160 | (self.centers[1:] + self.centers[:-1]) / 2,
161 | torch.tensor([1e20], device=self.centers.device),
162 | )
163 | )
164 |
165 | def _input_transform(
166 | self, context: torch.Tensor, scale: Optional[torch.Tensor] = None
167 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
168 | attention_mask = ~torch.isnan(context)
169 |
170 | if scale is None:
171 | scale = torch.nansum(
172 | torch.abs(context) * attention_mask, dim=-1
173 | ) / torch.nansum(attention_mask, dim=-1)
174 | scale[~(scale > 0)] = 1.0
175 |
176 | scaled_context = context / scale.unsqueeze(dim=-1)
177 | token_ids = (
178 | torch.bucketize(
179 | input=scaled_context,
180 | boundaries=self.boundaries,
181 | # buckets are open to the right, see:
182 | # https://pytorch.org/docs/2.1/generated/torch.bucketize.html#torch-bucketize
183 | right=True,
184 | )
185 | + self.config.n_special_tokens
186 | )
187 | token_ids[~attention_mask] = self.config.pad_token_id
188 |
189 | return token_ids, attention_mask, scale
190 |
191 | def _append_eos_token(
192 | self, token_ids: torch.Tensor, attention_mask: torch.Tensor
193 | ) -> Tuple[torch.Tensor, torch.Tensor]:
194 | batch_size = token_ids.shape[0]
195 | eos_tokens = torch.full((batch_size, 1), fill_value=self.config.eos_token_id)
196 | token_ids = torch.concat((token_ids, eos_tokens), dim=1)
197 | eos_mask = torch.full((batch_size, 1), fill_value=True)
198 | attention_mask = torch.concat((attention_mask, eos_mask), dim=1)
199 |
200 | return token_ids, attention_mask
201 |
202 | def context_input_transform(
203 | self, context: torch.Tensor
204 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
205 | length = context.shape[-1]
206 |
207 | if length > self.config.context_length:
208 | context = context[..., -self.config.context_length :]
209 |
210 | token_ids, attention_mask, scale = self._input_transform(context=context)
211 |
212 | if self.config.use_eos_token and self.config.model_type == "seq2seq":
213 | token_ids, attention_mask = self._append_eos_token(
214 | token_ids=token_ids, attention_mask=attention_mask
215 | )
216 |
217 | return token_ids, attention_mask, scale
218 |
219 | def label_input_transform(
220 | self, label: torch.Tensor, scale: torch.Tensor
221 | ) -> Tuple[torch.Tensor, torch.Tensor]:
222 | length = label.shape[-1]
223 |
224 | assert length == self.config.prediction_length
225 | token_ids, attention_mask, _ = self._input_transform(context=label, scale=scale)
226 |
227 | if self.config.use_eos_token:
228 | token_ids, attention_mask = self._append_eos_token(
229 | token_ids=token_ids, attention_mask=attention_mask
230 | )
231 |
232 | return token_ids, attention_mask
233 |
234 | def output_transform(
235 | self, samples: torch.Tensor, scale: torch.Tensor
236 | ) -> torch.Tensor:
237 | scale_unsqueezed = scale.unsqueeze(-1).unsqueeze(-1)
238 | indices = torch.clamp(
239 | samples - self.config.n_special_tokens - 1,
240 | min=0,
241 | max=len(self.centers) - 1,
242 | )
243 | return self.centers[indices] * scale_unsqueezed
244 |
245 |
246 | class StanNormalizeUniformBins(ChronosTokenizer):
247 | def __init__(
248 | self, low_limit: float, high_limit: float, config: ChronosConfig
249 | ) -> None:
250 | self.config = config
251 | self.centers = torch.linspace(
252 | low_limit,
253 | high_limit,
254 | config.n_tokens - config.n_special_tokens - 1
255 | )
256 | # print("self.centers.device:", self.centers.device)
257 | self.boundaries = torch.concat(
258 | (
259 | torch.tensor([-1e20], device=self.centers.device),
260 | (self.centers[1:] + self.centers[:-1]) / 2,
261 | torch.tensor([1e20], device=self.centers.device),
262 | )
263 | )
264 |
265 | def _input_transform(
266 | self, context: torch.Tensor, scale: Tuple[torch.Tensor, torch.Tensor] = None, eps: float = 1e-8
267 | ) -> Tuple[torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
268 | attention_mask = ~torch.isnan(context)
269 |
270 | if scale is None:
271 | mean = torch.nansum(context * attention_mask, dim=-1) / torch.nansum(attention_mask, dim=-1)
272 | variance = torch.nansum((context - mean.unsqueeze(dim=-1)) ** 2 * attention_mask, dim=-1) / torch.nansum(
273 | attention_mask, dim=-1)
274 | std = torch.sqrt(variance + eps)
275 | scale = (mean, std) # 添加 eps 以避免除以零
276 |
277 | scaled_context = (context - scale[0].unsqueeze(dim=-1)) / scale[1].unsqueeze(dim=-1)
278 | # print("scaled_context.device:", scaled_context.device)
279 | token_ids = (
280 | torch.bucketize(
281 | input=scaled_context,
282 | boundaries=self.boundaries,
283 | # buckets are open to the right, see:
284 | # https://pytorch.org/docs/2.1/generated/torch.bucketize.html#torch-bucketize
285 | right=True,
286 | )
287 | + self.config.n_special_tokens
288 | )
289 | token_ids[~attention_mask] = self.config.pad_token_id
290 |
291 | return token_ids, attention_mask, scale
292 |
293 | def _append_eos_token(
294 | self, token_ids: torch.Tensor, attention_mask: torch.Tensor
295 | ) -> Tuple[torch.Tensor, torch.Tensor]:
296 | batch_size = token_ids.shape[0]
297 | eos_tokens = torch.full((batch_size, 1), fill_value=self.config.eos_token_id)
298 | token_ids = torch.concat((token_ids, eos_tokens), dim=1)
299 | eos_mask = torch.full((batch_size, 1), fill_value=True)
300 | attention_mask = torch.concat((attention_mask, eos_mask), dim=1)
301 |
302 | return token_ids, attention_mask
303 |
304 | def context_input_transform(
305 | self, context: torch.Tensor
306 | ) -> Tuple[torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
307 | length = context.shape[-1]
308 |
309 | if length > self.config.context_length:
310 | context = context[..., -self.config.context_length:]
311 |
312 | token_ids, attention_mask, scale = self._input_transform(context=context)
313 |
314 | if self.config.use_eos_token and self.config.model_type == "seq2seq":
315 | token_ids, attention_mask = self._append_eos_token(
316 | token_ids=token_ids, attention_mask=attention_mask
317 | )
318 |
319 | return token_ids, attention_mask, scale
320 |
321 | def label_input_transform(
322 | self, label: torch.Tensor, scale: Tuple[torch.Tensor, torch.Tensor]
323 | ) -> Tuple[torch.Tensor, torch.Tensor]:
324 | length = label.shape[-1]
325 |
326 | assert length == self.config.prediction_length
327 | token_ids, attention_mask, _ = self._input_transform(context=label, scale=scale)
328 |
329 | if self.config.use_eos_token:
330 | token_ids, attention_mask = self._append_eos_token(
331 | token_ids=token_ids, attention_mask=attention_mask
332 | )
333 |
334 | return token_ids, attention_mask
335 |
336 | def output_transform(
337 | self, samples: torch.Tensor, scale: Tuple[torch.Tensor, torch.Tensor]
338 | ) -> torch.Tensor:
339 | mean_unsqueezed = scale[0].unsqueeze(-1).unsqueeze(-1)
340 | std_unsqueezed = scale[1].unsqueeze(-1).unsqueeze(-1)
341 | indices = torch.clamp(
342 | samples - self.config.n_special_tokens - 1,
343 | min=0,
344 | max=len(self.centers) - 1,
345 | )
346 | return self.centers[indices] * std_unsqueezed + mean_unsqueezed
347 |
348 |
349 | class ChronosModel(nn.Module):
350 | """
351 | A ``ChronosModel`` wraps a ``PreTrainedModel`` object from ``transformers``
352 | and uses it to predict sample paths for time series tokens.
353 |
354 | Parameters
355 | ----------
356 | config
357 | The configuration to use.
358 | model
359 | The pretrained model to use.
360 | """
361 |
362 | def __init__(self, config: ChronosConfig, model: PreTrainedModel) -> None:
363 | super().__init__()
364 | self.config = config
365 | self.model = model
366 |
367 | @property
368 | def device(self):
369 | return self.model.device
370 |
371 | def encode(
372 | self,
373 | input_ids: torch.Tensor,
374 | attention_mask: torch.Tensor,
375 | ):
376 | """
377 | Extract the encoder embedding for the given token sequences.
378 |
379 | Parameters
380 | ----------
381 | input_ids
382 | Tensor of indices of input sequence tokens in the vocabulary
383 | with shape (batch_size, sequence_length).
384 | attention_mask
385 | A mask tensor of the same shape as input_ids to avoid attending
386 | on padding or missing tokens.
387 |
388 | Returns
389 | -------
390 | embedding
391 | A tensor of encoder embeddings with shape
392 | (batch_size, sequence_length, d_model).
393 | """
394 | assert (
395 | self.config.model_type == "seq2seq"
396 | ), "Encoder embeddings are only supported for encoder-decoder models"
397 | return self.model.encoder(
398 | input_ids=input_ids, attention_mask=attention_mask
399 | ).last_hidden_state
400 |
401 | def forward(
402 | self,
403 | input_ids: torch.Tensor,
404 | attention_mask: torch.Tensor,
405 | prediction_length: Optional[int] = None,
406 | num_samples: Optional[int] = None,
407 | temperature: Optional[float] = None,
408 | top_k: Optional[int] = None,
409 | top_p: Optional[float] = None,
410 | ) -> torch.Tensor:
411 | """
412 | Predict future sample tokens for the given token sequences.
413 |
414 | Arguments ``prediction_length``, ``num_samples``, ``temperature``,
415 | ``top_k``, ``top_p`` can be used to customize the model inference,
416 | and default to the corresponding attributes in ``self.config`` if
417 | not provided.
418 |
419 | Returns
420 | -------
421 | samples
422 | A tensor of integers, shaped (batch_size, num_samples, time_length),
423 | containing forecasted sample paths.
424 | """
425 | if prediction_length is None:
426 | prediction_length = self.config.prediction_length
427 | if num_samples is None:
428 | num_samples = self.config.num_samples
429 | if temperature is None:
430 | temperature = self.config.temperature
431 | if top_k is None:
432 | top_k = self.config.top_k
433 | if top_p is None:
434 | top_p = self.config.top_p
435 |
436 | preds = self.model.generate(
437 | input_ids=input_ids,
438 | attention_mask=attention_mask,
439 | generation_config=GenerationConfig(
440 | min_new_tokens=prediction_length,
441 | max_new_tokens=prediction_length,
442 | do_sample=True,
443 | num_return_sequences=num_samples,
444 | eos_token_id=self.config.eos_token_id,
445 | pad_token_id=self.config.pad_token_id,
446 | temperature=temperature,
447 | top_k=top_k,
448 | top_p=top_p,
449 | ),
450 | )
451 |
452 | if self.config.model_type == "seq2seq":
453 | preds = preds[..., 1:] # remove the decoder start token
454 | else:
455 | assert self.config.model_type == "causal"
456 | assert preds.size(-1) == input_ids.size(-1) + prediction_length
457 | preds = preds[..., -prediction_length:]
458 |
459 | return preds.reshape(input_ids.size(0), num_samples, -1)
460 |
461 |
462 | def left_pad_and_stack_1D(tensors: List[torch.Tensor]) -> torch.Tensor:
463 | max_len = max(len(c) for c in tensors)
464 | padded = []
465 | for c in tensors:
466 | assert isinstance(c, torch.Tensor)
467 | assert c.ndim == 1
468 | padding = torch.full(
469 | size=(max_len - len(c),), fill_value=torch.nan, device=c.device
470 | )
471 | padded.append(torch.concat((padding, c), dim=-1))
472 | return torch.stack(padded)
473 |
474 |
475 | @dataclass
476 | class ChronosPipeline:
477 | """
478 | A ``ChronosPipeline`` uses the given tokenizer and model to forecast
479 | input time series.
480 |
481 | Use the ``from_pretrained`` class method to load serialized models.
482 | Use the ``predict`` method to get forecasts.
483 |
484 | Parameters
485 | ----------
486 | tokenizer
487 | The tokenizer object to use.
488 | model
489 | The model to use.
490 | """
491 |
492 | tokenizer: ChronosTokenizer
493 | model: ChronosModel
494 |
495 | def _prepare_and_validate_context(
496 | self, context: Union[torch.Tensor, List[torch.Tensor]]
497 | ):
498 | if isinstance(context, list):
499 | context = left_pad_and_stack_1D(context)
500 | assert isinstance(context, torch.Tensor)
501 | if context.ndim == 1:
502 | context = context.unsqueeze(0)
503 | assert context.ndim == 2
504 |
505 | return context
506 |
507 | @torch.no_grad()
508 | def embed(
509 | self, context: Union[torch.Tensor, List[torch.Tensor]]
510 | ) -> Tuple[torch.Tensor, Any]:
511 | """
512 | Get encoder embeddings for the given time series.
513 |
514 | Parameters
515 | ----------
516 | context
517 | Input series. This is either a 1D tensor, or a list
518 | of 1D tensors, or a 2D tensor whose first dimension
519 | is batch. In the latter case, use left-padding with
520 | ``torch.nan`` to align series of different lengths.
521 |
522 | Returns
523 | -------
524 | embeddings, tokenizer_state
525 | A tuple of two tensors: the encoder embeddings and the tokenizer_state,
526 | e.g., the scale of the time series in the case of mean scaling.
527 | The encoder embeddings are shaped (batch_size, context_length, d_model)
528 | or (batch_size, context_length + 1, d_model), where context_length
529 | is the size of the context along the time axis if a 2D tensor was provided
530 | or the length of the longest time series, if a list of 1D tensors was
531 | provided, and the extra 1 is for EOS.
532 | """
533 | context_tensor = self._prepare_and_validate_context(context=context)
534 | token_ids, attention_mask, tokenizer_state = (
535 | self.tokenizer.context_input_transform(context_tensor)
536 | )
537 | embeddings = self.model.encode(
538 | input_ids=token_ids.to(self.model.device),
539 | attention_mask=attention_mask.to(self.model.device),
540 | ).cpu()
541 | return embeddings, tokenizer_state
542 |
543 | def predict(
544 | self,
545 | context: Union[torch.Tensor, List[torch.Tensor]],
546 | prediction_length: Optional[int] = None,
547 | num_samples: Optional[int] = None,
548 | temperature: Optional[float] = None,
549 | top_k: Optional[int] = None,
550 | top_p: Optional[float] = None,
551 | limit_prediction_length: bool = True,
552 | ) -> torch.Tensor:
553 | """
554 | Get forecasts for the given time series.
555 |
556 | Parameters
557 | ----------
558 | context
559 | Input series. This is either a 1D tensor, or a list
560 | of 1D tensors, or a 2D tensor whose first dimension
561 | is batch. In the latter case, use left-padding with
562 | ``torch.nan`` to align series of different lengths.
563 | prediction_length
564 | Time steps to predict. Defaults to what specified
565 | in ``self.model.config``.
566 | num_samples
567 | Number of sample paths to predict. Defaults to what
568 | specified in ``self.model.config``.
569 | temperature
570 | Temperature to use for generating sample tokens.
571 | Defaults to what specified in ``self.model.config``.
572 | top_k
573 | Top-k parameter to use for generating sample tokens.
574 | Defaults to what specified in ``self.model.config``.
575 | top_p
576 | Top-p parameter to use for generating sample tokens.
577 | Defaults to what specified in ``self.model.config``.
578 | limit_prediction_length
579 | Force prediction length smaller or equal than the
580 | built-in prediction length from the model. True by
581 | default. When true, fail loudly if longer predictions
582 | are requested, otherwise longer predictions are allowed.
583 |
584 | Returns
585 | -------
586 | samples
587 | Tensor of sample forecasts, of shape
588 | (batch_size, num_samples, prediction_length).
589 | """
590 | context_tensor = self._prepare_and_validate_context(context=context)
591 |
592 | if prediction_length is None:
593 | prediction_length = self.model.config.prediction_length
594 |
595 | if prediction_length > self.model.config.prediction_length:
596 | msg = (
597 | f"We recommend keeping prediction length <= {self.model.config.prediction_length}. "
598 | "The quality of longer predictions may degrade since the model is not optimized for it. "
599 | )
600 | if limit_prediction_length:
601 | msg += "You can turn off this check by setting `limit_prediction_length=False`."
602 | raise ValueError(msg)
603 | warnings.warn(msg)
604 |
605 | predictions = []
606 | remaining = prediction_length
607 |
608 | while remaining > 0:
609 | token_ids, attention_mask, scale = self.tokenizer.context_input_transform(
610 | context_tensor
611 | )
612 | samples = self.model(
613 | token_ids.to(self.model.device),
614 | attention_mask.to(self.model.device),
615 | min(remaining, self.model.config.prediction_length),
616 | num_samples,
617 | temperature,
618 | top_k,
619 | top_p,
620 | )
621 | prediction = self.tokenizer.output_transform(
622 | samples.to(scale.device), scale
623 | )
624 |
625 | predictions.append(prediction)
626 | remaining -= prediction.shape[-1]
627 |
628 | if remaining <= 0:
629 | break
630 |
631 | context_tensor = torch.cat(
632 | [context_tensor, prediction.median(dim=1).values], dim=-1
633 | )
634 |
635 | return torch.cat(predictions, dim=-1)
636 |
637 | @classmethod
638 | def from_pretrained(cls, *args, tc='MeanScaleUniformBins', **kwargs):
639 | """
640 | Load the model, either from a local path or from the HuggingFace Hub.
641 | Supports the same arguments as ``AutoConfig`` and ``AutoModel``
642 | from ``transformers``.
643 | """
644 |
645 | config = AutoConfig.from_pretrained(*args, **kwargs)
646 |
647 | assert hasattr(config, "chronos_config"), "Not a Chronos config file"
648 |
649 | chronos_config = ChronosConfig(**config.chronos_config)
650 | chronos_config.tokenizer_class = tc
651 |
652 | if chronos_config.model_type == "seq2seq":
653 | inner_model = AutoModelForSeq2SeqLM.from_pretrained(*args, **kwargs)
654 | else:
655 | assert chronos_config.model_type == "causal"
656 | inner_model = AutoModelForCausalLM.from_pretrained(*args, **kwargs)
657 |
658 | return cls(
659 | tokenizer=chronos_config.create_tokenizer(),
660 | model=ChronosModel(config=chronos_config, model=inner_model),
661 | )
--------------------------------------------------------------------------------
/sensorllm/model/stage1_sensorllm.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Tuple, Union
2 | from torch.nn import CrossEntropyLoss
3 | from .utils import *
4 | from contextlib import nullcontext
5 |
6 | from transformers import (
7 | AutoConfig,
8 | AutoModelForCausalLM,
9 | LlamaConfig,
10 | LlamaForCausalLM,
11 | )
12 |
13 | from transformers.modeling_outputs import (
14 | BaseModelOutputWithPast,
15 | CausalLMOutputWithPast,
16 | )
17 |
18 | import logging
19 |
20 | logger = logging.getLogger(__name__)
21 |
22 |
23 | class SensorLLMStage1Config(LlamaConfig):
24 | model_type = "sensorllmstage1"
25 |
26 |
27 | class SensorLLMStage1LlamaModel(BaseSensorLLMModel):
28 | config_class = SensorLLMStage1Config
29 |
30 | def __init__(self, config: LlamaConfig):
31 | super(SensorLLMStage1LlamaModel, self).__init__(config)
32 |
33 | def forward(
34 | self,
35 | input_ids: torch.LongTensor = None, # B, L
36 | attention_mask: Optional[torch.Tensor] = None,
37 | past_key_values: Optional[List[torch.FloatTensor]] = None,
38 | inputs_embeds: Optional[torch.FloatTensor] = None,
39 | use_cache: Optional[bool] = None,
40 | output_attentions: Optional[bool] = None,
41 | output_hidden_states: Optional[bool] = None,
42 | ts_token_ids: Optional[List[torch.Tensor]] = None, # B, L_ts
43 | ts_attention_mask: Optional[List[torch.Tensor]] = None,
44 | ts_tokenizer_state: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
45 | return_dict: Optional[bool] = None,
46 | cache_position: Optional[torch.LongTensor] = None,
47 | ) -> Union[Tuple, BaseModelOutputWithPast]:
48 |
49 | # Check the dimensions of input_ids
50 | if input_ids.dim() != 2:
51 | raise ValueError(f"Expected input_ids to be a 2D tensor, but got {input_ids.dim()}D tensor")
52 |
53 | orig_embeds_params = getattr(self, "orig_embeds_params", None)
54 |
55 | if inputs_embeds is None:
56 | inputs_embeds = self.embed_tokens(input_ids)
57 |
58 | pt_encoder_backbone = getattr(self, 'pt_encoder_backbone', None)
59 |
60 | if pt_encoder_backbone is not None and (input_ids.shape[1] != 1 or self.training) and ts_token_ids is not None:
61 | assert type(ts_token_ids) is list
62 |
63 | with torch.no_grad() if self.fix_ts_encoder else nullcontext():
64 | if self.fix_ts_encoder:
65 | self.pt_encoder_backbone.eval()
66 | ts_features = []
67 | for ti, am in zip(ts_token_ids, ts_attention_mask): # * iterate over batch
68 | ts_feature = self.ts_embed(ti, am)
69 | if torch.any(torch.isnan(ts_feature)) or torch.any(torch.isinf(ts_feature)):
70 | raise ValueError("ts_feature has NaN values")
71 | ts_features.append(ts_feature[0])
72 | summed_ts_embeddings = [self.ts_proj(ts_feature) for ts_feature in ts_features]
73 |
74 | new_input_embeds = []
75 | for cur_input_ids, cur_input_embeds, cur_ts_embeds in zip(
76 | input_ids, inputs_embeds, summed_ts_embeddings
77 | ): # * input_ids: B, L; input_embeds: B, L, C; summed_ts_embeddings: B, L_ts, C
78 | cur_ts_embeds = cur_ts_embeds.to(
79 | device=cur_input_embeds.device
80 | )
81 |
82 | num_ts_tokens = cur_ts_embeds.shape[0] # * number of ts tokens
83 | total_ts_token_count = (cur_input_ids == self.ts_backbone_config["ts_token_id"]).sum()
84 |
85 | if num_ts_tokens != total_ts_token_count:
86 | raise ValueError(
87 | f"The window size of time-series tokens ({num_ts_tokens}) and input template ts tokens ({total_ts_token_count}) should be the same.")
88 |
89 | for start_token_id, end_token_id in self.start_end_tokens.items():
90 | start_token_count = (cur_input_ids == start_token_id).sum()
91 | end_token_count = (cur_input_ids == end_token_id).sum()
92 |
93 | if start_token_count != end_token_count:
94 | raise ValueError(
95 | f"The number of {start_token_id} tokens ({start_token_count}) and {end_token_id} tokens ({end_token_count}) should be the same.")
96 |
97 | start_token_positions = torch.where(cur_input_ids == start_token_id)[0]
98 |
99 | for start_token_pos in start_token_positions:
100 | end_token_pos = start_token_pos + num_ts_tokens + 1
101 | total_ts_token_count -= num_ts_tokens
102 |
103 | if end_token_pos >= len(cur_input_ids) or cur_input_ids[end_token_pos] != end_token_id:
104 | raise ValueError(
105 | f"The end token '{end_token_id}' should follow the start token '{start_token_id}' after {num_ts_tokens} positions."
106 | )
107 |
108 | if orig_embeds_params is not None: # * will not update the original embeddings except for TS_START_TOKEN and TS_END_TOKEN
109 | # print("Will not update the original embeddings except for TS_START_TOKEN and TS_END_TOKEN")
110 | cur_input_embeds = torch.cat(
111 | (
112 | cur_input_embeds[:start_token_pos].detach(),
113 | cur_input_embeds[start_token_pos: start_token_pos + 1],
114 | cur_ts_embeds,
115 | cur_input_embeds[end_token_pos: end_token_pos + 1],
116 | cur_input_embeds[end_token_pos + 1:].detach(),
117 | ),
118 | dim=0,
119 | )
120 | else:
121 | # print("Will update the original embeddings")
122 | cur_input_embeds = torch.cat(
123 | (
124 | cur_input_embeds[:start_token_pos + 1],
125 | cur_ts_embeds,
126 | cur_input_embeds[end_token_pos:],
127 | ),
128 | dim=0,
129 | )
130 |
131 | if total_ts_token_count != 0:
132 | raise ValueError(
133 | f"The value of total_ts_token_count ({total_ts_token_count}) should be the 0.")
134 | new_input_embeds.append(cur_input_embeds)
135 | inputs_embeds = torch.stack(new_input_embeds, dim=0)
136 |
137 | return super(SensorLLMStage1LlamaModel, self).forward(
138 | input_ids=None,
139 | attention_mask=attention_mask,
140 | past_key_values=past_key_values,
141 | inputs_embeds=inputs_embeds,
142 | use_cache=use_cache,
143 | output_attentions=output_attentions,
144 | output_hidden_states=output_hidden_states,
145 | return_dict=return_dict,
146 | cache_position=cache_position,
147 | )
148 |
149 |
150 | class SensorLLMStage1LlamaForCausalLM(BaseSensorLLM, LlamaForCausalLM):
151 | config_class = SensorLLMStage1Config
152 |
153 | def __init__(self, config):
154 | super(LlamaForCausalLM, self).__init__(config)
155 | self.model = SensorLLMStage1LlamaModel(config)
156 |
157 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
158 |
159 | # Initialize weights and apply final processing
160 | self.post_init()
161 |
162 | def get_model(self):
163 | return self.model
164 |
165 | def forward(
166 | self,
167 | input_ids: torch.LongTensor = None,
168 | attention_mask: Optional[torch.Tensor] = None,
169 | past_key_values: Optional[List[torch.FloatTensor]] = None,
170 | inputs_embeds: Optional[torch.FloatTensor] = None,
171 | labels: Optional[torch.LongTensor] = None,
172 | use_cache: Optional[bool] = None, # * control whether to return past_key_values
173 | output_attentions: Optional[bool] = None,
174 | output_hidden_states: Optional[bool] = None,
175 | ts_token_ids: Optional[List[torch.Tensor]] = None, # B, L_ts
176 | ts_attention_mask: Optional[List[torch.Tensor]] = None,
177 | ts_tokenizer_state: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
178 | return_dict: Optional[bool] = None,
179 | cache_position: Optional[torch.LongTensor] = None,
180 | ) -> Union[Tuple, CausalLMOutputWithPast]:
181 | output_attentions = (
182 | output_attentions
183 | if output_attentions is not None
184 | else self.config.output_attentions
185 | )
186 | output_hidden_states = (
187 | output_hidden_states
188 | if output_hidden_states is not None
189 | else self.config.output_hidden_states
190 | )
191 | return_dict = (
192 | return_dict if return_dict is not None else self.config.use_return_dict
193 | )
194 |
195 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
196 | outputs = self.model(
197 | input_ids=input_ids,
198 | attention_mask=attention_mask,
199 | past_key_values=past_key_values,
200 | inputs_embeds=inputs_embeds,
201 | use_cache=use_cache,
202 | output_attentions=output_attentions,
203 | output_hidden_states=output_hidden_states,
204 | return_dict=return_dict,
205 | ts_token_ids=ts_token_ids,
206 | ts_attention_mask=ts_attention_mask,
207 | ts_tokenizer_state=ts_tokenizer_state,
208 | cache_position=cache_position,
209 | )
210 |
211 | hidden_states = outputs[0]
212 | logits = self.lm_head(hidden_states)
213 |
214 | loss = None
215 | if labels is not None:
216 | # Shift so that tokens < n predict n
217 | shift_logits = logits[..., :-1, :].contiguous() # * B, L, V
218 | shift_labels = labels[..., 1:].contiguous() # * B, L
219 | # Flatten the tokens
220 | loss_fct = CrossEntropyLoss()
221 | shift_logits = shift_logits.view(-1, self.config.vocab_size)
222 | shift_labels = shift_labels.view(-1)
223 | # Enable model/pipeline parallelism
224 | shift_labels = shift_labels.to(shift_logits.device)
225 | loss = loss_fct(shift_logits, shift_labels)
226 |
227 | if not return_dict:
228 | output = (logits,) + outputs[1:]
229 | return (loss,) + output if loss is not None else output
230 |
231 | return CausalLMOutputWithPast(
232 | loss=loss,
233 | logits=logits,
234 | past_key_values=outputs.past_key_values,
235 | hidden_states=outputs.hidden_states,
236 | attentions=outputs.attentions,
237 | )
238 |
239 | def prepare_inputs_for_generation(
240 | self,
241 | input_ids,
242 | past_key_values=None,
243 | attention_mask=None,
244 | inputs_embeds=None,
245 | **kwargs,
246 | ):
247 | if past_key_values:
248 | input_ids = input_ids[:, -1:]
249 |
250 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
251 | if inputs_embeds is not None and past_key_values is None:
252 | model_inputs = {"inputs_embeds": inputs_embeds}
253 | else:
254 | model_inputs = {"input_ids": input_ids}
255 | model_inputs.update(
256 | {
257 | "past_key_values": past_key_values,
258 | "use_cache": kwargs.get("use_cache"),
259 | "attention_mask": attention_mask,
260 | "ts_token_ids": kwargs.get("ts_token_ids", None),
261 | "ts_attention_mask": kwargs.get("ts_attention_mask", None),
262 | "ts_tokenizer_state": kwargs.get("ts_tokenizer_state", None),
263 | "cache_position": kwargs.get("cache_position", None),
264 | }
265 | )
266 | return model_inputs
267 |
268 |
269 | AutoConfig.register("sensorllmstage1", SensorLLMStage1Config)
270 | AutoModelForCausalLM.register(SensorLLMStage1Config, SensorLLMStage1LlamaForCausalLM)
271 |
--------------------------------------------------------------------------------
/sensorllm/model/stage2_sensorllm.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Tuple, Union
2 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
3 | from .utils import *
4 | from contextlib import nullcontext
5 | from transformers import (
6 | AutoConfig,
7 | AutoModelForCausalLM,
8 | AutoModelForSequenceClassification,
9 | LlamaConfig,
10 | LlamaForCausalLM,
11 | LlamaForSequenceClassification
12 | )
13 |
14 | from transformers.modeling_outputs import (
15 | BaseModelOutputWithPast,
16 | CausalLMOutputWithPast,
17 | SequenceClassifierOutputWithPast
18 | )
19 |
20 | import logging
21 |
22 | logger = logging.getLogger(__name__)
23 |
24 |
25 | class SensorLLMStage2Config(LlamaConfig):
26 | model_type = "sensorllmstage2"
27 |
28 |
29 | class SensorLLMStage2LlamaModel(BaseSensorLLMModel):
30 | config_class = SensorLLMStage2Config
31 |
32 | def __init__(self, config: LlamaConfig):
33 | super(SensorLLMStage2LlamaModel, self).__init__(config)
34 |
35 | def forward(
36 | self,
37 | input_ids: torch.LongTensor = None, # B, L
38 | attention_mask: Optional[torch.Tensor] = None,
39 | position_ids: Optional[torch.LongTensor] = None,
40 | past_key_values: Optional[List[torch.FloatTensor]] = None,
41 | inputs_embeds: Optional[torch.FloatTensor] = None,
42 | use_cache: Optional[bool] = None,
43 | output_attentions: Optional[bool] = None,
44 | output_hidden_states: Optional[bool] = None,
45 | mts_token_ids: Optional[torch.Tensor] = None,
46 | mts_attention_mask: Optional[torch.Tensor] = None,
47 | mts_tokenizer_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
48 | return_dict: Optional[bool] = None,
49 | ) -> Union[Tuple, BaseModelOutputWithPast]:
50 | # Check the dimensions of input_ids
51 | if input_ids.dim() != 2:
52 | raise ValueError(f"Expected input_ids to be a 2D tensor, but got {input_ids.dim()}D tensor")
53 |
54 | if mts_token_ids.dim() != 3:
55 | raise ValueError(f"Expected multichannel_ts to be a 3D tensor, but got {mts_token_ids.dim()}D tensor")
56 |
57 | orig_embeds_params = getattr(self, "orig_embeds_params", None)
58 |
59 | if inputs_embeds is None:
60 | inputs_embeds = self.embed_tokens(input_ids)
61 |
62 | pt_encoder_backbone = getattr(self, 'pt_encoder_backbone', None)
63 |
64 | if pt_encoder_backbone is not None and (input_ids.shape[1] != 1 or self.training) and mts_token_ids is not None:
65 |
66 | channel_num = mts_token_ids.size(1)
67 | with torch.no_grad() if self.fix_ts_encoder else nullcontext():
68 | if self.fix_ts_encoder:
69 | self.pt_encoder_backbone.eval()
70 | ts_features = []
71 | for ts_token_ids, ts_attention_mask in zip(mts_token_ids, mts_attention_mask): # * iterate over batch
72 | ts_feature = self.ts_embed(ts_token_ids, ts_attention_mask)
73 | if torch.any(torch.isnan(ts_feature)) or torch.any(torch.isinf(ts_feature)):
74 | raise ValueError("ts_feature has NaN values")
75 | ts_features.append(ts_feature)
76 |
77 | summed_ts_embeddings = [self.ts_proj(ts_feature) for ts_feature in ts_features]
78 | summed_ts_embeddings = torch.stack(summed_ts_embeddings)
79 |
80 | new_input_embeds = []
81 | for cur_input_ids, cur_input_embeds, cur_ts_embeds in zip(
82 | input_ids, inputs_embeds, summed_ts_embeddings
83 | ): # * input_ids: B, L; input_embeds: B, L, D; summed_ts_embeddings: B, C, L_ts, D
84 | cur_ts_embeds = cur_ts_embeds.to(
85 | device=cur_input_embeds.device
86 | )
87 | num_ts_tokens = cur_ts_embeds.shape[1] # * number of ts tokens
88 |
89 | if len(self.start_end_tokens) != channel_num:
90 | raise ValueError(
91 | f"The length of start_end_tokens ({len(self.start_end_tokens)}) and channel_num ({channel_num}) should be the same.")
92 | if len(self.start_end_tokens) != cur_ts_embeds.size(0):
93 | raise ValueError(
94 | f"The length of start_end_tokens ({len(self.start_end_tokens)}) and cur_ts_embeds ({cur_ts_embeds.size(0)}) should be the same.")
95 |
96 | total_ts_token_count = (cur_input_ids == self.ts_backbone_config["ts_token_id"]).sum()
97 | for (start_token_id, end_token_id), channel_ebd in zip(self.start_end_tokens.items(), cur_ts_embeds):
98 | start_token_count = (cur_input_ids == start_token_id).sum()
99 | end_token_count = (cur_input_ids == end_token_id).sum()
100 |
101 | if start_token_count != end_token_count:
102 | raise ValueError(
103 | f"The number of {start_token_id} tokens ({start_token_count}) and {end_token_id} tokens ({end_token_count}) should be the same.")
104 |
105 | start_token_positions = torch.where(cur_input_ids == start_token_id)[0]
106 |
107 | for start_token_pos in start_token_positions:
108 | end_token_pos = start_token_pos + num_ts_tokens + 1
109 | total_ts_token_count -= num_ts_tokens
110 |
111 | if end_token_pos >= len(cur_input_ids) or cur_input_ids[end_token_pos] != end_token_id:
112 | raise ValueError(
113 | f"The end token '{end_token_id}' should follow the start token '{start_token_id}' after {num_ts_tokens} positions."
114 | )
115 |
116 | if orig_embeds_params is not None: # * will not update the original embeddings except for TS_START_TOKEN and TS_END_TOKEN
117 | cur_input_embeds = torch.cat(
118 | (
119 | cur_input_embeds[:start_token_pos].detach(),
120 | cur_input_embeds[start_token_pos: start_token_pos + 1],
121 | channel_ebd,
122 | cur_input_embeds[end_token_pos: end_token_pos + 1],
123 | cur_input_embeds[end_token_pos + 1:].detach(),
124 | ),
125 | dim=0,
126 | )
127 | else:
128 | cur_input_embeds = torch.cat(
129 | (
130 | cur_input_embeds[:start_token_pos + 1],
131 | channel_ebd,
132 | cur_input_embeds[end_token_pos:],
133 | ),
134 | dim=0,
135 | )
136 | if total_ts_token_count != 0:
137 | raise ValueError(
138 | f"The value of total_ts_token_count ({total_ts_token_count}) should be the 0.")
139 | new_input_embeds.append(cur_input_embeds)
140 | inputs_embeds = torch.stack(new_input_embeds, dim=0)
141 |
142 | return super(SensorLLMStage2LlamaModel, self).forward(
143 | input_ids=None,
144 | attention_mask=attention_mask,
145 | position_ids=position_ids,
146 | past_key_values=past_key_values,
147 | inputs_embeds=inputs_embeds,
148 | use_cache=use_cache,
149 | output_attentions=output_attentions,
150 | output_hidden_states=output_hidden_states,
151 | return_dict=return_dict,
152 | )
153 |
154 |
155 | class SensorLLMStage2LlamaForCausalLM(BaseSensorLLM, LlamaForCausalLM):
156 | config_class = SensorLLMStage2Config
157 |
158 | def __init__(self, config):
159 | super(LlamaForCausalLM, self).__init__(config)
160 | self.model = SensorLLMStage2LlamaModel(config)
161 |
162 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
163 |
164 | # Initialize weights and apply final processing
165 | self.post_init()
166 |
167 | def get_model(self):
168 | return self.model
169 |
170 | def forward(
171 | self,
172 | input_ids: torch.LongTensor = None,
173 | attention_mask: Optional[torch.Tensor] = None,
174 | past_key_values: Optional[List[torch.FloatTensor]] = None,
175 | inputs_embeds: Optional[torch.FloatTensor] = None,
176 | labels: Optional[torch.LongTensor] = None,
177 | use_cache: Optional[bool] = None, # * control whether to return past_key_values
178 | output_attentions: Optional[bool] = None,
179 | output_hidden_states: Optional[bool] = None,
180 | mts_token_ids: Optional[torch.Tensor] = None,
181 | mts_attention_mask: Optional[torch.Tensor] = None,
182 | mts_tokenizer_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
183 | return_dict: Optional[bool] = None,
184 | ) -> Union[Tuple, CausalLMOutputWithPast]:
185 | output_attentions = (
186 | output_attentions
187 | if output_attentions is not None
188 | else self.config.output_attentions
189 | )
190 | output_hidden_states = (
191 | output_hidden_states
192 | if output_hidden_states is not None
193 | else self.config.output_hidden_states
194 | )
195 | return_dict = (
196 | return_dict if return_dict is not None else self.config.use_return_dict
197 | )
198 |
199 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
200 | outputs = self.model(
201 | input_ids=input_ids,
202 | attention_mask=attention_mask,
203 | past_key_values=past_key_values,
204 | inputs_embeds=inputs_embeds,
205 | use_cache=use_cache,
206 | output_attentions=output_attentions,
207 | output_hidden_states=output_hidden_states,
208 | return_dict=return_dict,
209 | mts_token_ids=mts_token_ids,
210 | mts_attention_mask=mts_attention_mask,
211 | mts_tokenizer_state=mts_tokenizer_state
212 | )
213 |
214 | hidden_states = outputs[0]
215 | logits = self.lm_head(hidden_states)
216 |
217 | loss = None
218 | if labels is not None:
219 | # Shift so that tokens < n predict n
220 | shift_logits = logits[..., :-1, :].contiguous() # * B, L, V
221 | shift_labels = labels[..., 1:].contiguous() # * B, L
222 | # Flatten the tokens
223 | loss_fct = CrossEntropyLoss()
224 | shift_logits = shift_logits.view(-1, self.config.vocab_size)
225 | shift_labels = shift_labels.view(-1)
226 | # Enable model/pipeline parallelism
227 | shift_labels = shift_labels.to(shift_logits.device)
228 | loss = loss_fct(shift_logits, shift_labels)
229 |
230 | if not return_dict:
231 | output = (logits,) + outputs[1:]
232 | return (loss,) + output if loss is not None else output
233 |
234 | return CausalLMOutputWithPast(
235 | loss=loss,
236 | logits=logits,
237 | past_key_values=outputs.past_key_values,
238 | hidden_states=outputs.hidden_states,
239 | attentions=outputs.attentions,
240 | )
241 |
242 | def prepare_inputs_for_generation(
243 | self,
244 | input_ids,
245 | past_key_values=None,
246 | attention_mask=None,
247 | inputs_embeds=None,
248 | **kwargs,
249 | ):
250 | if past_key_values:
251 | input_ids = input_ids[:, -1:]
252 |
253 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
254 | if inputs_embeds is not None and past_key_values is None:
255 | model_inputs = {"inputs_embeds": inputs_embeds}
256 | else:
257 | model_inputs = {"input_ids": input_ids}
258 |
259 | model_inputs.update(
260 | {
261 | "past_key_values": past_key_values,
262 | "use_cache": kwargs.get("use_cache"),
263 | "attention_mask": attention_mask,
264 | "mts_token_ids": kwargs.get("mts_token_ids", None),
265 | "mts_attention_mask": kwargs.get("mts_attention_mask", None),
266 | "mts_tokenizer_state": kwargs.get("mts_tokenizer_state", None),
267 | }
268 | )
269 | model_inputs.pop("cache_position")
270 | return model_inputs
271 |
272 |
273 | class SensorLLMStage2LlamaForSequenceClassification(BaseSensorLLM, LlamaForSequenceClassification):
274 | config_class = SensorLLMStage2Config
275 |
276 | def __init__(self, config):
277 | super(LlamaForSequenceClassification, self).__init__(config)
278 | self.num_labels = config.num_labels
279 | self.model = SensorLLMStage2LlamaModel(config)
280 | self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
281 |
282 | # Initialize weights and apply final processing
283 | self.post_init()
284 |
285 | def get_model(self):
286 | return self.model
287 |
288 | def forward(
289 | self,
290 | input_ids: torch.LongTensor = None,
291 | attention_mask: Optional[torch.Tensor] = None,
292 | position_ids: Optional[torch.LongTensor] = None,
293 | past_key_values: Optional[List[torch.FloatTensor]] = None,
294 | inputs_embeds: Optional[torch.FloatTensor] = None,
295 | labels: Optional[torch.LongTensor] = None,
296 | use_cache: Optional[bool] = None, # * control whether to return past_key_values
297 | output_attentions: Optional[bool] = None,
298 | output_hidden_states: Optional[bool] = None,
299 | mts_token_ids: Optional[torch.Tensor] = None,
300 | mts_attention_mask: Optional[torch.Tensor] = None,
301 | mts_tokenizer_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
302 | return_dict: Optional[bool] = None,
303 | ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
304 | r"""
305 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
306 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
307 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
308 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
309 | """
310 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
311 |
312 | transformer_outputs = self.model(
313 | input_ids=input_ids,
314 | attention_mask=attention_mask,
315 | position_ids=position_ids,
316 | past_key_values=past_key_values,
317 | inputs_embeds=inputs_embeds,
318 | use_cache=use_cache,
319 | output_attentions=output_attentions,
320 | output_hidden_states=output_hidden_states,
321 | return_dict=return_dict,
322 | mts_token_ids=mts_token_ids,
323 | mts_attention_mask=mts_attention_mask,
324 | mts_tokenizer_state=mts_tokenizer_state
325 | )
326 | hidden_states = transformer_outputs[0]
327 | logits = self.score(hidden_states)
328 |
329 | if input_ids is not None:
330 | batch_size = input_ids.shape[0]
331 | else:
332 | batch_size = inputs_embeds.shape[0]
333 |
334 | if self.config.pad_token_id is None and batch_size != 1:
335 | raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
336 | if self.config.pad_token_id is None:
337 | sequence_lengths = -1
338 | else:
339 | if input_ids is not None:
340 | # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
341 | sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
342 | sequence_lengths = sequence_lengths % input_ids.shape[-1]
343 | sequence_lengths = sequence_lengths.to(logits.device)
344 | else:
345 | sequence_lengths = -1
346 |
347 | pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
348 |
349 | loss = None
350 | if labels is not None:
351 | labels = labels.to(logits.device)
352 | if self.config.problem_type is None:
353 | if self.num_labels == 1:
354 | self.config.problem_type = "regression"
355 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
356 | self.config.problem_type = "single_label_classification"
357 | else:
358 | self.config.problem_type = "multi_label_classification"
359 |
360 | if self.config.problem_type == "regression":
361 | loss_fct = MSELoss()
362 | if self.num_labels == 1:
363 | loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
364 | else:
365 | loss = loss_fct(pooled_logits, labels)
366 | elif self.config.problem_type == "single_label_classification":
367 | loss_fct = CrossEntropyLoss()
368 | loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
369 | elif self.config.problem_type == "multi_label_classification":
370 | loss_fct = BCEWithLogitsLoss()
371 | loss = loss_fct(pooled_logits, labels)
372 | if not return_dict:
373 | output = (pooled_logits,) + transformer_outputs[1:]
374 | return ((loss,) + output) if loss is not None else output
375 |
376 | return SequenceClassifierOutputWithPast(
377 | loss=loss,
378 | logits=pooled_logits,
379 | past_key_values=transformer_outputs.past_key_values,
380 | hidden_states=transformer_outputs.hidden_states,
381 | attentions=transformer_outputs.attentions,
382 | )
383 |
384 |
385 | AutoConfig.register("sensorllmstage2", SensorLLMStage2Config)
386 | AutoModelForCausalLM.register(SensorLLMStage2Config, SensorLLMStage2LlamaForCausalLM)
387 | AutoModelForSequenceClassification.register(SensorLLMStage2Config, SensorLLMStage2LlamaForSequenceClassification)
--------------------------------------------------------------------------------
/sensorllm/model/ts_backbone.yaml:
--------------------------------------------------------------------------------
1 | ts_backbone_type: "chronos"
2 |
3 | chronos_model:
4 | name: "chronos-t5-large"
5 | encoder_output_dim: 1024
6 | projection_hidden_layer: 2
7 | projection_hidden_dim: [2048, 3072]
8 | last_token: true
9 |
10 | dropout_rate: 0.3
11 | default_ts_token: ""
12 |
13 | usc-had:
14 | id2label:
15 | 0: "1. Walking Forward"
16 | 1: "2. Walking Left"
17 | 2: "3. Walking Right"
18 | 3: "4. Walking Upstairs"
19 | 4: "5. Walking Downstairs"
20 | 5: "6. Running Forward"
21 | 6: "7. Jumping"
22 | 7: "8. Sitting"
23 | 8: "9. Standing"
24 | 9: "10. Sleeping"
25 | 10: "11. Elevator Up"
26 | 11: "12. Elevator Down"
27 | num_labels: 12
28 | channel_num: 6
29 | sample_rate: 100
30 | default_x_acc_start_token: ""
31 | default_x_acc_end_token: ""
32 | default_y_acc_start_token: ""
33 | default_y_acc_end_token: ""
34 | default_z_acc_start_token: ""
35 | default_z_acc_end_token: ""
36 | default_x_gyro_start_token: ""
37 | default_x_gyro_end_token: ""
38 | default_y_gyro_start_token: ""
39 | default_y_gyro_end_token: ""
40 | default_z_gyro_start_token: ""
41 | default_z_gyro_end_token: ""
42 |
43 | uci:
44 | id2label:
45 | 0: "Walking"
46 | 1: "Walking upstairs"
47 | 2: "Walking downstairs"
48 | 3: "Sitting"
49 | 4: "Standing"
50 | 5: "Laying"
51 | num_labels: 6
52 | channel_num: 6
53 | sample_rate: 50
54 | default_x_acc_start_token: ""
55 | default_x_acc_end_token: ""
56 | default_y_acc_start_token: ""
57 | default_y_acc_end_token: ""
58 | default_z_acc_start_token: ""
59 | default_z_acc_end_token: ""
60 | default_x_gyro_start_token: ""
61 | default_x_gyro_end_token: ""
62 | default_y_gyro_start_token: ""
63 | default_y_gyro_end_token: ""
64 | default_z_gyro_start_token: ""
65 | default_z_gyro_end_token: ""
66 |
67 | mhealth:
68 | id2label:
69 | 0: 'Standing still (1 min)'
70 | 1: 'Sitting and relaxing (1 min)'
71 | 2: 'Lying down (1 min)'
72 | 3: 'Walking (1 min)'
73 | 4: 'Climbing stairs (1 min)'
74 | 5: 'Waist bends forward (20x)'
75 | 6: 'Frontal elevation of arms (20x)'
76 | 7: 'Knees bending (crouching) (20x)'
77 | 8: 'Cycling (1 min)'
78 | 9: 'Jogging (1 min)'
79 | 10: 'Running (1 min)'
80 | 11: 'Jump front & back (20x)'
81 | num_labels: 12
82 | channel_num: 15
83 | sample_rate: 50
84 | default_chest_x_acc_start_token: ""
85 | default_chest_x_acc_end_token: ""
86 | default_chest_y_acc_start_token: ""
87 | default_chest_y_acc_end_token: ""
88 | default_chest_z_acc_start_token: ""
89 | default_chest_z_acc_end_token: ""
90 | default_left_ankle_x_acc_start_token: ""
91 | default_left_ankle_x_acc_end_token: ""
92 | default_left_ankle_y_acc_start_token: ""
93 | default_left_ankle_y_acc_end_token: ""
94 | default_left_ankle_z_acc_start_token: ""
95 | default_left_ankle_z_acc_end_token: ""
96 | default_left_ankle_x_gyro_start_token: ""
97 | default_left_ankle_x_gyro_end_token: ""
98 | default_left_ankle_y_gyro_start_token: ""
99 | default_left_ankle_y_gyro_end_token: ""
100 | default_left_ankle_z_gyro_start_token: ""
101 | default_left_ankle_z_gyro_end_token: ""
102 | default_right_lower_arm_x_acc_start_token: ""
103 | default_right_lower_arm_x_acc_end_token: ""
104 | default_right_lower_arm_y_acc_start_token: ""
105 | default_right_lower_arm_y_acc_end_token: ""
106 | default_right_lower_arm_z_acc_start_token: ""
107 | default_right_lower_arm_z_acc_end_token: ""
108 | default_right_lower_arm_x_gyro_start_token: ""
109 | default_right_lower_arm_x_gyro_end_token: ""
110 | default_right_lower_arm_y_gyro_start_token: ""
111 | default_right_lower_arm_y_gyro_end_token: ""
112 | default_right_lower_arm_z_gyro_start_token: ""
113 | default_right_lower_arm_z_gyro_end_token: ""
114 |
115 | pamap50:
116 | id2label:
117 | 0: 'lying'
118 | 1: 'sitting'
119 | 2: 'standing'
120 | 3: 'walking'
121 | 4: 'running'
122 | 5: 'cycling'
123 | 6: 'Nordic walking'
124 | 7: 'ascending stairs'
125 | 8: 'descending stairs'
126 | 9: 'vacuum cleaning'
127 | 10: 'ironing'
128 | 11: 'rope jumping'
129 | num_labels: 12
130 | channel_num: 27
131 | sample_rate: 50
132 | default_chest_x_acc_start_token: ""
133 | default_chest_x_acc_end_token: ""
134 | default_chest_y_acc_start_token: ""
135 | default_chest_y_acc_end_token: ""
136 | default_chest_z_acc_start_token: ""
137 | default_chest_z_acc_end_token: ""
138 | default_chest_x_gyro_start_token: ""
139 | default_chest_x_gyro_end_token: ""
140 | default_chest_y_gyro_start_token: ""
141 | default_chest_y_gyro_end_token: ""
142 | default_chest_z_gyro_start_token: ""
143 | default_chest_z_gyro_end_token: ""
144 | default_chest_x_mag_start_token: ""
145 | default_chest_x_mag_end_token: ""
146 | default_chest_y_mag_start_token: ""
147 | default_chest_y_mag_end_token: ""
148 | default_chest_z_mag_start_token: ""
149 | default_chest_z_mag_end_token: ""
150 | default_ankle_x_acc_start_token: ""
151 | default_ankle_x_acc_end_token: ""
152 | default_ankle_y_acc_start_token: ""
153 | default_ankle_y_acc_end_token: ""
154 | default_ankle_z_acc_start_token: ""
155 | default_ankle_z_acc_end_token: ""
156 | default_ankle_x_gyro_start_token: ""
157 | default_ankle_x_gyro_end_token: ""
158 | default_ankle_y_gyro_start_token: ""
159 | default_ankle_y_gyro_end_token: ""
160 | default_ankle_z_gyro_start_token: ""
161 | default_ankle_z_gyro_end_token: ""
162 | default_ankle_x_mag_start_token: ""
163 | default_ankle_x_mag_end_token: ""
164 | default_ankle_y_mag_start_token: ""
165 | default_ankle_y_mag_end_token: ""
166 | default_ankle_z_mag_start_token: ""
167 | default_ankle_z_mag_end_token: ""
168 | default_hand_x_acc_start_token: ""
169 | default_hand_x_acc_end_token: ""
170 | default_hand_y_acc_start_token: ""
171 | default_hand_y_acc_end_token: ""
172 | default_hand_z_acc_start_token: ""
173 | default_hand_z_acc_end_token: ""
174 | default_hand_x_gyro_start_token: ""
175 | default_hand_x_gyro_end_token: ""
176 | default_hand_y_gyro_start_token: ""
177 | default_hand_y_gyro_end_token: ""
178 | default_hand_z_gyro_start_token: ""
179 | default_hand_z_gyro_end_token: ""
180 | default_hand_x_mag_start_token: ""
181 | default_hand_x_mag_end_token: ""
182 | default_hand_y_mag_start_token: ""
183 | default_hand_y_mag_end_token: ""
184 | default_hand_z_mag_start_token: ""
185 | default_hand_z_mag_end_token: ""
186 |
187 | pamap:
188 | id2label:
189 | 0: 'lying'
190 | 1: 'sitting'
191 | 2: 'standing'
192 | 3: 'walking'
193 | 4: 'running'
194 | 5: 'cycling'
195 | 6: 'Nordic walking'
196 | 7: 'ascending stairs'
197 | 8: 'descending stairs'
198 | 9: 'vacuum cleaning'
199 | 10: 'ironing'
200 | 11: 'rope jumping'
201 | num_labels: 12
202 | channel_num: 27
203 | sample_rate: 100
204 | default_chest_x_acc_start_token: ""
205 | default_chest_x_acc_end_token: ""
206 | default_chest_y_acc_start_token: ""
207 | default_chest_y_acc_end_token: ""
208 | default_chest_z_acc_start_token: ""
209 | default_chest_z_acc_end_token: ""
210 | default_chest_x_gyro_start_token: ""
211 | default_chest_x_gyro_end_token: ""
212 | default_chest_y_gyro_start_token: ""
213 | default_chest_y_gyro_end_token: ""
214 | default_chest_z_gyro_start_token: ""
215 | default_chest_z_gyro_end_token: ""
216 | default_chest_x_mag_start_token: ""
217 | default_chest_x_mag_end_token: ""
218 | default_chest_y_mag_start_token: ""
219 | default_chest_y_mag_end_token: ""
220 | default_chest_z_mag_start_token: ""
221 | default_chest_z_mag_end_token: ""
222 | default_ankle_x_acc_start_token: ""
223 | default_ankle_x_acc_end_token: ""
224 | default_ankle_y_acc_start_token: ""
225 | default_ankle_y_acc_end_token: ""
226 | default_ankle_z_acc_start_token: ""
227 | default_ankle_z_acc_end_token: ""
228 | default_ankle_x_gyro_start_token: ""
229 | default_ankle_x_gyro_end_token: ""
230 | default_ankle_y_gyro_start_token: ""
231 | default_ankle_y_gyro_end_token: ""
232 | default_ankle_z_gyro_start_token: ""
233 | default_ankle_z_gyro_end_token: ""
234 | default_ankle_x_mag_start_token: ""
235 | default_ankle_x_mag_end_token: ""
236 | default_ankle_y_mag_start_token: ""
237 | default_ankle_y_mag_end_token: ""
238 | default_ankle_z_mag_start_token: ""
239 | default_ankle_z_mag_end_token: ""
240 | default_hand_x_acc_start_token: ""
241 | default_hand_x_acc_end_token: ""
242 | default_hand_y_acc_start_token: ""
243 | default_hand_y_acc_end_token: ""
244 | default_hand_z_acc_start_token: ""
245 | default_hand_z_acc_end_token: ""
246 | default_hand_x_gyro_start_token: ""
247 | default_hand_x_gyro_end_token: ""
248 | default_hand_y_gyro_start_token: ""
249 | default_hand_y_gyro_end_token: ""
250 | default_hand_z_gyro_start_token: ""
251 | default_hand_z_gyro_end_token: ""
252 | default_hand_x_mag_start_token: ""
253 | default_hand_x_mag_end_token: ""
254 | default_hand_y_mag_start_token: ""
255 | default_hand_y_mag_end_token: ""
256 | default_hand_z_mag_start_token: ""
257 | default_hand_z_mag_end_token: ""
258 |
259 | capture24:
260 | id2label:
261 | 0: 'sleep'
262 | 1: 'sitting'
263 | 2: 'household-chores'
264 | 3: 'walking'
265 | 4: 'vehicle'
266 | 5: 'bicycling'
267 | 6: 'mixed-activity'
268 | 7: 'standing'
269 | 8: 'manual-work'
270 | 9: 'sports'
271 | num_labels: 10
272 | channel_num: 3
273 | sample_rate: 50
274 | default_x_acc_start_token: ""
275 | default_x_acc_end_token: ""
276 | default_y_acc_start_token: ""
277 | default_y_acc_end_token: ""
278 | default_z_acc_start_token: ""
279 | default_z_acc_end_token: ""
--------------------------------------------------------------------------------
/sensorllm/model/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import StoppingCriteria
3 | import torch.nn as nn
4 | import math
5 | from sensorllm.utils import *
6 | from .chronos_model import *
7 | from transformers import (
8 | LlamaConfig,
9 | LlamaModel,
10 | PreTrainedModel
11 | )
12 | import logging
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 | DEFAULT_PAD_TOKEN = "[PAD]"
17 | DEFAULT_EOS_TOKEN = ""
18 | DEFAULT_BOS_TOKEN = ""
19 | DEFAULT_UNK_TOKEN = ""
20 |
21 |
22 | class BaseSensorLLMModel(LlamaModel):
23 | def __init__(self, config: LlamaConfig):
24 | super(BaseSensorLLMModel, self).__init__(config)
25 |
26 | current_dir = os.path.dirname(os.path.abspath(__file__))
27 | print(f"current dir: {current_dir}")
28 | ts_config_addr = os.path.join(current_dir, "ts_backbone.yaml")
29 | self.ts_backbone_config = cfg_from_yaml_file(ts_config_addr)
30 |
31 | # self.ts_backbone_config["ts_backbone_output_dimension"] = self.config.hidden_size
32 |
33 | logger.warning(
34 | f"The hidden size of LLM is {self.config.hidden_size}."
35 | )
36 |
37 | backbone_output_dim = self.ts_backbone_config['chronos_model']['encoder_output_dim']
38 | logger.warning(
39 | f"{self.ts_backbone_config['chronos_model']['name']} output dim: {self.ts_backbone_config['chronos_model']['encoder_output_dim']}.")
40 | logger.warning(
41 | f"Use {self.ts_backbone_config['chronos_model']['projection_hidden_layer']} projection hidden layers.")
42 | if self.ts_backbone_config['chronos_model']['projection_hidden_layer'] > 0:
43 | # Add projection layer with linear layers and GELU activation
44 | projection_layers = []
45 | last_dim = backbone_output_dim
46 | for i in range(self.ts_backbone_config['chronos_model']['projection_hidden_layer']):
47 | projection_layers.append(
48 | nn.Linear(last_dim, self.ts_backbone_config['chronos_model']["projection_hidden_dim"][i]))
49 | projection_layers.append(nn.GELU())
50 | last_dim = self.ts_backbone_config['chronos_model']["projection_hidden_dim"][i]
51 |
52 | projection_layers.append(nn.Linear(last_dim, self.config.hidden_size))
53 | self.ts_proj = nn.Sequential(*projection_layers)
54 | logger.warning(
55 | f"Each layer with {self.ts_backbone_config['chronos_model']['projection_hidden_dim']} hidden units.")
56 | else:
57 | # Single layer
58 | self.ts_proj = nn.Linear(backbone_output_dim, self.config.hidden_size)
59 | logger.warning(f"TS projector output dim: {self.config.hidden_size}.")
60 |
61 | self.fix_llm = False
62 | self.fix_ts_encoder = False
63 |
64 | def load_pt_encoder_backbone_checkpoint(self, checkpoint_path=None, tc=None, torch_dtype=None):
65 | logger.warning(f"Loading default pt_encoder_backbone_ckpt ...")
66 | pipeline = ChronosPipeline.from_pretrained(
67 | self.config.pt_encoder_backbone_ckpt if checkpoint_path is None else checkpoint_path,
68 | device_map=self.device,
69 | tc="MeanScaleUniformBins" if tc is None else tc,
70 | torch_dtype=torch.float32 if torch_dtype is None else torch_dtype,
71 | )
72 | self.pt_encoder_backbone = pipeline.model
73 | self.pt_encoder_backbone.to(self.device)
74 |
75 | def load_start_end_tokens(self, dataset=None):
76 | logger.warning(f"Loading start_end_tokens dict for {dataset} dataset ...")
77 | dataset_config = self.ts_backbone_config[dataset]
78 | self.start_end_tokens = {}
79 | for key in dataset_config:
80 | if key.endswith('_start_token_id'):
81 | end_token_key = key.replace('_start_token_id', '_end_token_id')
82 |
83 | assert end_token_key in dataset_config
84 | self.start_end_tokens[dataset_config[key]] = dataset_config[end_token_key]
85 |
86 | @torch.no_grad()
87 | def ts_embed(
88 | self,
89 | token_ids: torch.Tensor,
90 | attention_mask: torch.Tensor
91 | ) -> torch.Tensor:
92 | """
93 | Get encoder embeddings for the given time series.
94 |
95 | Parameters
96 | ----------
97 | context
98 | Input series. This is either a 1D tensor, or a list
99 | of 1D tensors, or a 2D tensor whose first dimension
100 | is batch. In the latter case, use left-padding with
101 | ``torch.nan`` to align series of different lengths.
102 |
103 | Returns
104 | -------
105 | embeddings, tokenizer_state
106 | A tuple of two tensors: the encoder embeddings and the tokenizer_state,
107 | e.g., the scale of the time series in the case of mean scaling.
108 | The encoder embeddings are shaped (batch_size, context_length, d_model)
109 | or (batch_size, context_length + 1, d_model), where context_length
110 | is the size of the context along the time axis if a 2D tensor was provided
111 | or the length of the longest time series, if a list of 1D tensors was
112 | provided, and the extra 1 is for EOS.
113 | """
114 |
115 | embeddings = self.pt_encoder_backbone.encode(
116 | input_ids=token_ids,
117 | attention_mask=attention_mask,
118 | )
119 | # if str(self.model.device) == 'cuda:1':
120 | # print("3", embeddings)
121 |
122 | return embeddings
123 |
124 |
125 |
126 | class BaseSensorLLM(PreTrainedModel):
127 | def initialize_tokenizer_ts_backbone_config_wo_embedding(self, tokenizer, dataset):
128 | # * called when stage2 or inference or inference without pre-training, assume tokenizer has time-series tokens
129 | ts_backbone_config = self.get_model().ts_backbone_config
130 |
131 | default_ts_token = ts_backbone_config["default_ts_token"] #
132 | # print(tokenizer.convert_tokens_to_ids([default_ts_token, "", ""]))
133 |
134 | tokenizer.add_tokens([default_ts_token], special_tokens=True)
135 |
136 | # * assert tokenizer has the default_ts_token
137 | ts_backbone_config["ts_token_id"] = tokenizer.convert_tokens_to_ids([default_ts_token])[0]
138 |
139 | if dataset not in ts_backbone_config:
140 | raise ValueError(f"Cannot find {dataset} in ts_backbone.yaml file.")
141 |
142 | dataset_config = ts_backbone_config[dataset]
143 |
144 | token_keys = [key for key in dataset_config.keys() if key.startswith('default_') and key.endswith('_token')]
145 | assert len(token_keys) == dataset_config["channel_num"]*2, f"len(token_keys) ! channel_num*2"
146 | tokenizer.add_tokens([dataset_config[token_key] for token_key in token_keys], special_tokens=True)
147 |
148 | for token_key in token_keys:
149 | token_id_key = token_key.replace('default_', '').replace('_token', '_token_id')
150 | dataset_config[token_id_key] = tokenizer.convert_tokens_to_ids([dataset_config[token_key]])[0]
151 |
152 | special_tokens_dict = dict()
153 | if tokenizer.pad_token is None:
154 | special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
155 | if tokenizer.eos_token is None:
156 | special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
157 | if tokenizer.bos_token is None:
158 | special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
159 | if tokenizer.unk_token is None:
160 | special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
161 |
162 | tokenizer.add_special_tokens(special_tokens_dict)
163 |
164 | def initialize_tokenizer_ts_backbone_config(
165 | self, tokenizer, device, fix_llm=True, dataset='usc-had'
166 | ):
167 |
168 | ts_backbone_config = self.get_model().ts_backbone_config
169 |
170 | default_ts_token = ts_backbone_config["default_ts_token"] #
171 |
172 | tokenizer.add_tokens(
173 | [default_ts_token], special_tokens=True
174 | ) # * no need to update embed since it will be replaced
175 |
176 | ts_backbone_config["ts_token_id"] = tokenizer.convert_tokens_to_ids([default_ts_token])[0]
177 |
178 | if dataset not in ts_backbone_config:
179 | raise ValueError(f"Cannot find {dataset} in ts_backbone.yaml file.")
180 |
181 | dataset_config = ts_backbone_config[dataset]
182 |
183 | token_keys = [key for key in dataset_config.keys() if key.startswith('default_') and key.endswith('_token')]
184 | assert len(token_keys) == dataset_config["channel_num"]*2, f"len(token_keys) ! channel_num*2"
185 |
186 | num_new_tokens = tokenizer.add_tokens([dataset_config[token_key] for token_key in token_keys], special_tokens=True)
187 |
188 | special_tokens_dict = dict()
189 | if tokenizer.pad_token is None:
190 | special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
191 | if tokenizer.eos_token is None:
192 | special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
193 | if tokenizer.bos_token is None:
194 | special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
195 | if tokenizer.unk_token is None:
196 | special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
197 |
198 | num_new_tokens += tokenizer.add_special_tokens(special_tokens_dict)
199 |
200 | self.resize_token_embeddings(
201 | len(tokenizer)
202 | ) # ! resize_token_embeddings will make the tokens trainable again
203 |
204 | for token_key in token_keys:
205 | token_id_key = token_key.replace('default_', '').replace('_token', '_token_id')
206 | dataset_config[token_id_key] = tokenizer.convert_tokens_to_ids([dataset_config[token_key]])[0]
207 |
208 | if num_new_tokens > 0:
209 | # Get the input embedding and output embedding of the model
210 | print("Calculate the average of the input embedding as the initialization value of the new token...")
211 | input_embeddings = self.get_input_embeddings().weight.data
212 |
213 | # Calculate the average of the input embedding and output embedding as the initialization value of the new token
214 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
215 | dim=0, keepdim=True
216 | )
217 |
218 | # Initialize the embedding of the new token to the average of the previous tokens
219 | input_embeddings[-num_new_tokens:] = input_embeddings_avg
220 |
221 | if hasattr(self, 'get_output_embeddings') and callable(getattr(self, 'get_output_embeddings')):
222 | output_embeddings = self.get_output_embeddings()
223 | if output_embeddings is not None:
224 | print("Calculate the average of the output embedding as the initialization value of the new token...")
225 | output_embeddings = self.get_output_embeddings().weight.data
226 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
227 | dim=0, keepdim=True
228 | )
229 | output_embeddings[-num_new_tokens:] = output_embeddings_avg
230 | else:
231 | print(
232 | f"Output embeddings not available.")
233 |
234 | # Set resize_token_embeddings to a multiple of 8 to improve performance
235 | T, E = input_embeddings = self.get_input_embeddings().weight.shape
236 | self.resize_token_embeddings(int(8 * math.ceil(T / 8.0)))
237 |
238 | # need to update the input embedding, but no need to update the output embedding
239 | for p in self.get_input_embeddings().parameters():
240 | p.requires_grad = True
241 |
242 | if fix_llm:
243 | # Save original input embeddings
244 | self.get_model().orig_embeds_params = [
245 | self.get_input_embeddings().weight.data.clone().to(device=device)
246 | ] # * only tuning the new embeddings
247 |
248 | # Try to fix output embeddings if the method exists
249 | if hasattr(self, 'get_output_embeddings') and callable(getattr(self, 'get_output_embeddings')):
250 | output_embeddings = self.get_output_embeddings()
251 | if output_embeddings is not None:
252 | for p in output_embeddings.parameters():
253 | p.requires_grad = False
254 | print("Setting output embeddings fixed.")
255 | else:
256 | print("Output embeddings not available.")
257 | print(f"Setting {num_new_tokens} new tokens' input embeddings trainable.")
258 | else:
259 | self.get_model().orig_embeds_params = None
260 |
261 | # Try to make output embeddings trainable if the method exists
262 | if hasattr(self, 'get_output_embeddings') and callable(getattr(self, 'get_output_embeddings')):
263 | output_embeddings = self.get_output_embeddings()
264 | if output_embeddings is not None:
265 | for p in output_embeddings.parameters():
266 | p.requires_grad = True
267 | print("Setting output embeddings and all input embeddings trainable.")
268 | else:
269 | print("Setting all input embeddings trainable.")
270 | else:
271 | print("Output embeddings not available. Setting all input embeddings trainable.")
272 |
273 |
274 | class KeywordsStoppingCriteria(StoppingCriteria):
275 | def __init__(self, keywords, tokenizer, input_ids):
276 | self.keywords = keywords
277 | self.keyword_ids = [tokenizer(keyword).input_ids for keyword in keywords]
278 | self.keyword_ids = [keyword_id[0] for keyword_id in self.keyword_ids if
279 | type(keyword_id) is list and len(keyword_id) == 1]
280 | self.tokenizer = tokenizer
281 | self.start_len = None
282 | self.input_ids = input_ids
283 |
284 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
285 | if self.start_len is None:
286 | self.start_len = self.input_ids.shape[1]
287 | else:
288 | for keyword_id in self.keyword_ids:
289 | if output_ids[0, -1] == keyword_id:
290 | return True
291 | outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0]
292 | for keyword in self.keywords:
293 | if keyword in outputs:
294 | return True
295 | return False
296 |
--------------------------------------------------------------------------------
/sensorllm/train/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zechenli03/SensorLLM/1f4142a30f452721e943771190fe1dade3337249/sensorllm/train/__init__.py
--------------------------------------------------------------------------------
/sensorllm/train/llama_flash_attn_monkey_patch.py:
--------------------------------------------------------------------------------
1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
2 | import warnings
3 | from typing import Optional, Tuple
4 |
5 | import torch
6 | from flash_attn import __version__ as flash_attn_version
7 | from flash_attn.bert_padding import pad_input, unpad_input
8 | from flash_attn.flash_attn_interface import (
9 | flash_attn_func,
10 | flash_attn_varlen_kvpacked_func,
11 | )
12 | from transformers.models.llama.modeling_llama import (
13 | LlamaAttention,
14 | LlamaModel,
15 | rotate_half,
16 | )
17 |
18 |
19 | def apply_rotary_pos_emb(q, k, cos_sin, position_ids):
20 | gather_indices = position_ids[:, :, None, None] # [bsz, seq_len, 1, 1]
21 | gather_indices = gather_indices.repeat(
22 | 1, 1, cos_sin[0].shape[1], cos_sin[0].shape[3]
23 | )
24 | bsz = gather_indices.shape[0]
25 | cos, sin = (
26 | torch.gather(x.transpose(1, 2).repeat(bsz, 1, 1, 1), 1, gather_indices)
27 | for x in cos_sin
28 | )
29 | q, k = ((x * cos) + (rotate_half(x) * sin) for x in (q, k))
30 | return q, k
31 |
32 |
33 | def forward(
34 | self,
35 | hidden_states: torch.Tensor,
36 | attention_mask: Optional[torch.Tensor] = None,
37 | position_ids: Optional[torch.Tensor] = None,
38 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
39 | output_attentions: bool = False,
40 | use_cache: bool = False,
41 | padding_mask: Optional[torch.Tensor] = None,
42 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
43 | if output_attentions:
44 | warnings.warn(
45 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
46 | )
47 |
48 | bsz, q_len, _ = hidden_states.size()
49 | kv_heads = getattr(self, "num_key_value_heads", self.num_heads)
50 |
51 | q, k, v = (
52 | op(hidden_states).view(bsz, q_len, nh, self.head_dim)
53 | for op, nh in (
54 | (self.q_proj, self.num_heads),
55 | (self.k_proj, kv_heads),
56 | (self.v_proj, kv_heads),
57 | )
58 | )
59 | # shape: (b, s, num_heads, head_dim)
60 |
61 | kv_seq_len = k.shape[1]
62 | past_kv_len = 0
63 | if past_key_value is not None:
64 | past_kv_len = past_key_value[0].shape[2]
65 | kv_seq_len += past_kv_len
66 |
67 | cos_sin = self.rotary_emb(v, seq_len=kv_seq_len)
68 | q, k = apply_rotary_pos_emb(q, k, cos_sin, position_ids)
69 |
70 | if past_key_value is not None:
71 | assert (
72 | flash_attn_version >= "2.1.0"
73 | ), "past_key_value support requires flash-attn >= 2.1.0"
74 | # reuse k, v
75 | k = torch.cat([past_key_value[0].transpose(1, 2), k], dim=1)
76 | v = torch.cat([past_key_value[1].transpose(1, 2), v], dim=1)
77 |
78 | past_key_value = (k.transpose(1, 2), v.transpose(1, 2)) if use_cache else None
79 |
80 | if attention_mask is None:
81 | output = flash_attn_func(q, k, v, 0.0, softmax_scale=None, causal=True).view(
82 | bsz, q_len, -1
83 | )
84 | else:
85 | q, indices, cu_q_lens, max_s = unpad_input(q, attention_mask[:, -q_len:])
86 | # We can skip concat and call unpad twice but seems better to call unpad only once.
87 | kv, _, cu_k_lens, max_k = unpad_input(
88 | torch.stack((k, v), dim=2), attention_mask
89 | )
90 | output_unpad = flash_attn_varlen_kvpacked_func(
91 | q,
92 | kv,
93 | cu_q_lens,
94 | cu_k_lens,
95 | max_s,
96 | max_k,
97 | 0.0,
98 | softmax_scale=None,
99 | causal=True,
100 | )
101 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
102 | output = pad_input(output_unpad, indices, bsz, q_len)
103 |
104 | return self.o_proj(output), None, past_key_value
105 |
106 |
107 | # Disable the transformation of the attention mask in LlamaModel as flash attention
108 | # takes a boolean key_padding_mask. Fills in the past kv length for use in forward.
109 | def _prepare_decoder_attention_mask(
110 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length
111 | ):
112 | # [bsz, seq_len]
113 | if past_key_values_length > 0 and attention_mask is not None:
114 | attention_mask = torch.cat(
115 | (
116 | torch.full(
117 | (input_shape[0], past_key_values_length),
118 | True,
119 | dtype=attention_mask.dtype,
120 | device=attention_mask.device,
121 | ),
122 | attention_mask,
123 | ),
124 | dim=-1,
125 | )
126 |
127 | if attention_mask is not None and torch.all(attention_mask):
128 | return None # This uses the faster call when training with full samples
129 |
130 | return attention_mask
131 |
132 |
133 | def replace_llama_attn_with_flash_attn():
134 | cuda_major, cuda_minor = torch.cuda.get_device_capability()
135 | if cuda_major < 8:
136 | warnings.warn(
137 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
138 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
139 | )
140 |
141 | LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
142 | LlamaAttention.forward = forward
143 |
144 |
145 | def test():
146 | from fastchat.train.llama_flash_attn_monkey_patch import forward as fastchat_forward
147 | from transformers.models.llama.configuration_llama import LlamaConfig
148 |
149 | config = LlamaConfig(
150 | hidden_size=1024,
151 | intermediate_size=128,
152 | num_hidden_layers=1,
153 | num_attention_heads=8,
154 | max_position_embeddings=16,
155 | )
156 | device = torch.device("cuda")
157 | model = LlamaModel(config)
158 | attn = LlamaAttention(config).to(device).half()
159 | bsz, hs, seqlen = 2, config.hidden_size, config.max_position_embeddings
160 | position_ids = torch.arange(seqlen, dtype=torch.long, device=device).view(
161 | -1, seqlen
162 | )
163 |
164 | mask = torch.full((bsz, seqlen), True, dtype=torch.bool, device=device)
165 | for i in range(4):
166 | hidden = torch.rand((bsz, seqlen, hs), dtype=torch.float16, device=device)
167 | if i:
168 | mask[0, -i:] = False
169 | mask[1, :i] = False
170 |
171 | lmask = model._prepare_decoder_attention_mask(mask, hidden.shape[:2], hidden, 0)
172 | ref, _, _ = attn.forward(
173 | hidden, attention_mask=lmask, position_ids=position_ids
174 | )
175 |
176 | fast, _, _ = fastchat_forward(
177 | attn, hidden, attention_mask=mask, position_ids=position_ids
178 | )
179 |
180 | lmask = _prepare_decoder_attention_mask(
181 | model, mask, hidden.shape[:2], hidden, 0
182 | )
183 | test, _, _ = forward(
184 | attn, hidden, attention_mask=lmask, position_ids=position_ids
185 | )
186 |
187 | print(f"Mean(abs(ref)) = {torch.mean(torch.abs(ref))}")
188 | print(f"Mean(abs(ref - fast)) = {torch.mean(torch.abs(ref - fast))}")
189 | print(f"Mean(abs(ref - test)) = {torch.mean(torch.abs(ref - test))}")
190 | print(f"Mean(abs(fast - test)) = {torch.mean(torch.abs(fast - test))}")
191 | print(f"allclose(fast, test) = {torch.allclose(fast, test)}")
192 |
193 | with torch.no_grad():
194 | # Also check that past_kv is handled properly
195 | hidden = torch.rand((bsz, seqlen, hs), dtype=torch.float16, device=device)
196 | part_len = seqlen // 4
197 | assert part_len * 4 == seqlen
198 | mask = torch.full((bsz, seqlen), True, dtype=torch.bool, device=device)
199 | mask[0, -2:] = False
200 | lmask = _prepare_decoder_attention_mask(
201 | model, mask, hidden.shape[:2], hidden, 0
202 | )
203 | oneshot, _, _ = forward(
204 | attn, hidden, attention_mask=lmask, position_ids=position_ids
205 | )
206 | parts = []
207 | past_kv, past_kv_len = None, 0
208 | for i in range(4):
209 | start = part_len * i
210 | end = start + part_len
211 | hidden_part = hidden[:, start:end, ...]
212 | lmask = _prepare_decoder_attention_mask(
213 | model,
214 | mask[:, start:end],
215 | hidden_part.shape[:2],
216 | hidden_part,
217 | past_kv_len,
218 | )
219 | part, _, past_kv = forward(
220 | attn,
221 | hidden_part.clone(),
222 | attention_mask=lmask,
223 | position_ids=position_ids[:, start:end],
224 | past_key_value=past_kv,
225 | use_cache=True,
226 | )
227 | parts.append(part)
228 | past_kv_len = past_kv[0].shape[2]
229 |
230 | print(
231 | f"allclose(oneshot[:, 0], parts[0]) = {torch.allclose(oneshot[:, :part_len], parts[0])}"
232 | )
233 | print(
234 | f"allclose(oneshot, parts) = {torch.allclose(oneshot, torch.cat(parts, dim=1))}"
235 | )
236 |
237 |
238 | if __name__ == "__main__":
239 | test()
--------------------------------------------------------------------------------
/sensorllm/train/sensorllm_trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 |
5 | from transformers import Trainer
6 | from typing import Optional
7 |
8 |
9 | def unwrap_model(model: nn.Module) -> nn.Module:
10 | """
11 | Recursively unwraps a model from potential containers (as used in distributed training).
12 |
13 | Args:
14 | model (`torch.nn.Module`): The model to unwrap.
15 | """
16 | # since there could be multiple levels of wrapping, unwrap recursively
17 | if hasattr(model, "module"):
18 | return unwrap_model(model.module)
19 | else:
20 | return model
21 |
22 |
23 | class SensorLLMTrainer(Trainer):
24 |
25 | def _save(self, output_dir: Optional[str] = None, state_dict=None):
26 | # Save the model
27 | _state_dict = state_dict
28 | if _state_dict is None:
29 | # Only save the model itself if we are using distributed training
30 | model_to_save = unwrap_model(self.model)
31 | _state_dict = model_to_save.state_dict()
32 |
33 | keys_to_match = ['pt_encoder_backbone']
34 | filtered_state_dict = {k: v for k, v in _state_dict.items() if
35 | not any(key_match in k for key_match in keys_to_match)}
36 |
37 | super(SensorLLMTrainer, self)._save(output_dir, filtered_state_dict)
38 |
39 |
40 | class SensorLLMWeightedCELossTrainer(SensorLLMTrainer):
41 | def __init__(self, *args, class_weights=None, **kwargs):
42 | super().__init__(*args, **kwargs)
43 | # Ensure label_weights is a tensor
44 | assert class_weights is not None, "class_weights for SensorLLMWeightedCELossTrainer is None"
45 | print(f"class_weights: {class_weights}")
46 | self.class_weights = class_weights
47 |
48 | def compute_loss(self, model, inputs, return_outputs=False):
49 | # Extract labels and convert them to long type for cross_entropy
50 | labels = inputs.pop("labels")
51 |
52 | # Forward pass
53 | outputs = model(**inputs)
54 |
55 | # Extract logits assuming they are directly outputted by the model
56 | logits = outputs.get('logits')
57 |
58 | # Compute custom loss with class weights for imbalanced data handling
59 | assert self.class_weights is not None, "self.class_weights is None"
60 | loss_fct = torch.nn.CrossEntropyLoss(
61 | weight=torch.tensor(self.class_weights, device=model.device, dtype=logits.dtype))
62 |
63 | loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
64 |
65 | return (loss, outputs) if return_outputs else loss
66 |
67 |
68 |
69 |
--------------------------------------------------------------------------------
/sensorllm/train/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import pathlib
4 | import yaml
5 | from dataclasses import dataclass, field
6 | from typing import Optional, List
7 | import transformers
8 | from transformers import AutoConfig
9 |
10 | import nltk
11 | from nltk.translate.bleu_score import sentence_bleu
12 | from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
13 | import numpy as np
14 |
15 | from sensorllm.model import *
16 | from sensorllm.model.chronos_model import *
17 | from sensorllm.train.sensorllm_trainer import SensorLLMTrainer, SensorLLMWeightedCELossTrainer
18 | import logging as logger
19 | from sensorllm.data import make_ts_text_data_module, make_ts_text_data_module_stage2, make_ts_classification_data_module_stage2
20 | import warnings
21 |
22 | import evaluate
23 |
24 | warnings.filterwarnings("ignore")
25 |
26 |
27 | @dataclass
28 | class ModelArguments:
29 | model_name_or_path: Optional[str] = field(default="")
30 | pt_encoder_backbone_ckpt: str = field(default=None)
31 | tokenize_method: str = field(default='MeanScaleUniformBins',
32 | metadata={"help": "MeanScaleUniformBins or StanNormalizeUniformBins."})
33 | model_type: Optional[str] = field(default="CasualLM",
34 | metadata={"help": "CasualLM or SequenceClassification."})
35 |
36 |
37 | @dataclass
38 | class DataArguments:
39 | dataset: str = field(
40 | default="usc-had",
41 | metadata={"help": "usc-had, mhealth, pamap, pamap50, uci, capture24"},
42 | )
43 | data_path: str = field(
44 | default="",
45 | metadata={"help": "Path to the training data."},
46 | )
47 | qa_path: str = field(
48 | default="",
49 | metadata={"help": "Path to the training QA data."},
50 | )
51 | eval_data_path: str = field(
52 | default="",
53 | metadata={"help": "Path to the eval data."},
54 | )
55 | eval_qa_path: str = field(
56 | default="",
57 | metadata={"help": "Path to the eval QA data."},
58 | )
59 | shuffle: bool = field(default=True, metadata={"help": "Whether to shuffle data."})
60 | ignore_qa_types: List[str] = field(default_factory=lambda: ["trend"])
61 | preprocess_type: str = field(default='Q',
62 | metadata={"help": "Q or Q+cot."})
63 | preprocess_type_eval: str = field(default='Q+cot',
64 | metadata={"help": "Q or Q+cot."})
65 | add_ts_special_token_text: bool = field(default=False)
66 |
67 |
68 | @dataclass
69 | class TrainingArguments(transformers.TrainingArguments):
70 | cache_dir: Optional[str] = field(default=None)
71 | optim: str = field(default="adamw_torch")
72 | model_max_length: int = field(
73 | default=8192,
74 | metadata={
75 | "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
76 | },
77 | )
78 |
79 | num_labels: int = field(
80 | default=12,
81 | metadata={
82 | "help": "Number of output labels."
83 | },
84 | )
85 | use_weighted_loss: bool = field(default=True, metadata={"help": "Use weighted loss for classification model."})
86 |
87 | fix_llm: bool = field(default=True, metadata={"help": "Whether to fix the LLM."})
88 | fix_ts_encoder: bool = field(default=True, metadata={"help": "Whether to fix the pretrained ts encoder."})
89 | fix_cls_head: bool = field(default=False, metadata={"help": "Whether to fix the cls head of LLM."})
90 | stage_2: bool = field(default=False) # * set True when fine-tuning
91 | only_stage2: bool = field(default=False)
92 | # * ts backbone ckpt path
93 | tune_mm_mlp_adapter: bool = field(default=True)
94 | metric_for_best_model: str = field(default='eval_loss')
95 |
96 |
97 | def print_trainable_parameters(model):
98 | all_param = 0
99 | trainable_params = 0
100 | for name, param in model.named_parameters():
101 | all_param += param.numel()
102 | if param.requires_grad:
103 | print(f"Layer: {name}, Trainable: {param.requires_grad}")
104 | trainable_params += param.numel()
105 |
106 | for name, param in model.get_model().pt_encoder_backbone.model.named_parameters():
107 | all_param += param.numel()
108 | if param.requires_grad:
109 | print(f"Layer: {name}, Trainable: {param.requires_grad}")
110 | trainable_params += param.numel()
111 | print(
112 | f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}"
113 | )
114 |
115 |
116 | def check_model_parameters(model_or_params, print_details=False):
117 | device_map = {}
118 |
119 | if hasattr(model_or_params, 'named_parameters'):
120 | param_iterator = model_or_params.named_parameters()
121 | elif isinstance(model_or_params, dict):
122 | param_iterator = model_or_params.items()
123 | else:
124 | raise ValueError("Input must be a model or a dictionary of parameters")
125 |
126 | for name, param in param_iterator:
127 | device = param.device
128 | shape = tuple(param.shape)
129 | device_str = str(device)
130 |
131 | if device_str not in device_map:
132 | device_map[device_str] = []
133 |
134 | device_map[device_str].append((name, shape))
135 |
136 | if print_details:
137 | print(f"Parameter: {name}")
138 | print(f" Shape: {shape}")
139 | print(f" Device: {device}")
140 | print("-" * 50)
141 |
142 | print("\nSummary:")
143 | for device, params in device_map.items():
144 | print(f"\nDevice: {device}")
145 | print(f"Number of parameters: {len(params)}")
146 | if print_details:
147 | for name, shape in params:
148 | print(f" {name}: {shape}")
149 |
150 | return device_map
151 |
152 |
153 | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
154 | state_dict = trainer.model.state_dict()
155 | cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
156 | if trainer.args.should_save:
157 | trainer._save(output_dir, state_dict=cpu_state_dict)
158 |
159 |
160 | def train():
161 | parser = transformers.HfArgumentParser(
162 | (ModelArguments, DataArguments, TrainingArguments))
163 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
164 |
165 | training_args.log_level = "info" # * default is passive(warning)
166 | # * build logger
167 | # logger = build_logger(__name__, training_args.output_dir + '/train.log')
168 |
169 | logger.warning(f"Using device: {training_args.device}")
170 |
171 | if not training_args.stage_2:
172 | # stage 1
173 | logger.warning("Using model of Stage 1")
174 | model = SensorLLMStage1LlamaForCausalLM.from_pretrained(
175 | model_args.model_name_or_path,
176 | cache_dir=training_args.cache_dir,
177 | )
178 | else:
179 | if model_args.model_type == "CasualLM":
180 | model = SensorLLMStage2LlamaForCausalLM.from_pretrained(
181 | model_args.model_name_or_path,
182 | cache_dir=training_args.cache_dir,
183 | )
184 | logger.warning("Loaded CausalLM model.")
185 | else:
186 | logger.warning(f"Loading {data_args.dataset} dataset configs ...")
187 | with open('./sensorllm/model/ts_backbone.yaml', 'r') as file:
188 | dataset_configs = yaml.safe_load(file)
189 |
190 | dataset_config = dataset_configs[data_args.dataset]
191 |
192 | id2label = dataset_config["id2label"]
193 | print(f"Dataset id2label:\n{id2label}")
194 | label2id = {v: k for k, v in id2label.items()}
195 | assert training_args.num_labels == len(id2label)
196 | assert model_args.model_type == "SequenceClassification", f"Undefined model_type {model_args.model_type}"
197 | model = SensorLLMStage2LlamaForSequenceClassification.from_pretrained(
198 | model_args.model_name_or_path,
199 | num_labels=training_args.num_labels,
200 | id2label=id2label,
201 | label2id=label2id,
202 | cache_dir=training_args.cache_dir,
203 | )
204 | logger.warning("Loaded SequenceClassification model.")
205 |
206 | model.config.use_cache = False
207 |
208 | print(f"Default pt_encoder_backbone_ckpt is {model_args.pt_encoder_backbone_ckpt}.")
209 | model.get_model().load_pt_encoder_backbone_checkpoint(model_args.pt_encoder_backbone_ckpt,
210 | tc=model_args.tokenize_method)
211 | pt_backbone_config = AutoConfig.from_pretrained(model_args.pt_encoder_backbone_ckpt)
212 |
213 | assert hasattr(pt_backbone_config, "chronos_config"), "Not a Chronos config file"
214 |
215 | chronos_config = ChronosConfig(**pt_backbone_config.chronos_config)
216 | chronos_config.tokenizer_class = model_args.tokenize_method
217 | chronos_tokenizer = chronos_config.create_tokenizer()
218 |
219 | if training_args.fix_llm:
220 | # * This will fix all the parameters
221 | model.requires_grad_(False)
222 | # * fix llama, lm_head
223 | model.get_model().fix_llm = True
224 | logger.warning("LLM is fixed. Fix_llm flag is set to True")
225 | model.get_model().ts_proj.requires_grad_(True)
226 | model.get_model().pt_encoder_backbone.requires_grad_(True)
227 | else:
228 | model.get_model().fix_llm = False
229 | logger.warning("LLM is trainable. Fix_llm flag is set to False")
230 |
231 | tokenizer = transformers.AutoTokenizer.from_pretrained(
232 | model_args.model_name_or_path,
233 | cache_dir=training_args.cache_dir,
234 | model_max_length=training_args.model_max_length,
235 | padding_side="right",
236 | use_fast=False,
237 | )
238 |
239 | if not training_args.fix_ts_encoder:
240 | # * not fix pretrained ts encoder
241 | model.get_model().fix_ts_encoder = False
242 | logger.warning(
243 | "Pretrained TS backbone is trainable. fix_ts_encoder flag is set to False, ts net grad will be recorded.")
244 | else:
245 | model.get_model().fix_ts_encoder = True # * use with torch.inference_mode to control
246 | logger.warning(
247 | "Pretrained TS backbone is fixed. fix_ts_encoder flag is set to True, ts net grad will not be recorded.")
248 |
249 | logger.warning("Set requires_grad of Pretrained TS backbone to False")
250 | model.get_model().pt_encoder_backbone.requires_grad_(
251 | False)
252 |
253 | if training_args.tune_mm_mlp_adapter:
254 | # * not fix the projection layer
255 | # * may need to set the embed_tokens to require_grad = True if added new tokens
256 | # * this is done in initialize_tokenizer_ts_backbone_config
257 | logger.warning("Time-series Projector is trainable. ")
258 | else:
259 | model.get_model().ts_proj.requires_grad_(False)
260 | logger.warning("Time-series Projector is fixed.")
261 |
262 | if model_args.model_type == "SequenceClassification":
263 | if not training_args.fix_cls_head:
264 | model.score.requires_grad_(True)
265 | logger.warning("LLM classification head is trainable. ")
266 | else:
267 | model.score.requires_grad_(False)
268 | logger.warning("LLM classification head is fixed.")
269 |
270 | if not training_args.stage_2:
271 | # * we assume in stage2, llm, and time-series embedder (and projection layer) can be loaded from the model checkpoint
272 | model.initialize_tokenizer_ts_backbone_config(tokenizer=tokenizer, device=training_args.device,
273 | fix_llm=training_args.fix_llm, dataset=data_args.dataset)
274 | else:
275 | # * stage2
276 | if training_args.only_stage2:
277 | logger.warning("The loaded model haven't been trained in Stage 1. Initializing the LLM tokenizer with new tokens now...")
278 | model.initialize_tokenizer_ts_backbone_config(tokenizer=tokenizer, device=training_args.device,
279 | fix_llm=training_args.fix_llm, dataset=data_args.dataset)
280 | else:
281 | model.initialize_tokenizer_ts_backbone_config_wo_embedding(tokenizer=tokenizer, dataset=data_args.dataset)
282 |
283 | model.get_model().load_start_end_tokens(dataset=data_args.dataset)
284 |
285 | ts_backbone_config = model.get_model().ts_backbone_config
286 | data_args.ts_backbone_config = ts_backbone_config
287 |
288 | if not training_args.stage_2:
289 | logger.warning("Stage 1")
290 | data_module = make_ts_text_data_module(tokenizer=tokenizer, chronos_tokenizer=chronos_tokenizer,
291 | data_args=data_args)
292 | else:
293 | # * stage2
294 | if model_args.model_type == "CasualLM":
295 | data_module = make_ts_text_data_module_stage2(tokenizer=tokenizer, chronos_tokenizer=chronos_tokenizer,
296 | data_args=data_args)
297 | else:
298 | assert model_args.model_type == "SequenceClassification", f"Undefined model_type {model_args.model_type} for data_module"
299 | data_module = make_ts_classification_data_module_stage2(tokenizer=tokenizer,
300 | chronos_tokenizer=chronos_tokenizer,
301 | label2id=label2id,
302 | data_args=data_args)
303 |
304 | if model_args.model_type == "CasualLM":
305 | trainer = SensorLLMTrainer(model=model,
306 | args=training_args,
307 | tokenizer=tokenizer,
308 | **data_module)
309 | else:
310 | assert model_args.model_type == "SequenceClassification", f"Undefined model_type {model_args.model_type} for Trainer"
311 |
312 | metric_f1 = evaluate.load("../metrics/f1")
313 | metric_acc = evaluate.load("../metrics/accuracy")
314 | metric_precision = evaluate.load("../metrics/precision")
315 | metric_recall = evaluate.load("../metrics/recall")
316 |
317 | def compute_metrics(eval_pred):
318 | predictions, labels = eval_pred
319 | predictions = np.argmax(predictions, axis=1)
320 |
321 | accuracy = metric_acc.compute(predictions=predictions, references=labels)["accuracy"]
322 |
323 | precision_macro = metric_precision.compute(predictions=predictions, references=labels, average='macro')[
324 | "precision"]
325 | recall_macro = metric_recall.compute(predictions=predictions, references=labels, average='macro')["recall"]
326 | f1_macro = metric_f1.compute(predictions=predictions, references=labels, average='macro')["f1"]
327 |
328 | f1_micro = metric_f1.compute(predictions=predictions, references=labels, average='micro')["f1"]
329 |
330 | precision_per_class = metric_precision.compute(predictions=predictions, references=labels, average=None)[
331 | "precision"]
332 | recall_per_class = metric_recall.compute(predictions=predictions, references=labels, average=None)["recall"]
333 | f1_per_class = metric_f1.compute(predictions=predictions, references=labels, average=None)["f1"]
334 |
335 | results = {
336 | "f1_macro": f1_macro,
337 | "f1_micro": f1_micro,
338 | "accuracy": accuracy,
339 | "precision_macro": precision_macro,
340 | "recall_macro": recall_macro
341 | }
342 |
343 | for i, (p, r, f) in enumerate(zip(precision_per_class, recall_per_class, f1_per_class)):
344 | results[f"precision_class_{i}"] = p
345 | results[f"recall_class_{i}"] = r
346 | results[f"f1_class_{i}"] = f
347 |
348 | return results
349 |
350 | model.config.pad_token_id = tokenizer.pad_token_id
351 | if training_args.use_weighted_loss:
352 | logger.warning("Using weighted_loss trainer")
353 | trainer = SensorLLMWeightedCELossTrainer(model=model,
354 | args=training_args,
355 | tokenizer=tokenizer,
356 | compute_metrics=compute_metrics,
357 | **data_module)
358 | else:
359 | del data_module['class_weights']
360 | print(data_module.keys)
361 | trainer = SensorLLMTrainer(model=model,
362 | args=training_args,
363 | tokenizer=tokenizer,
364 | compute_metrics=compute_metrics,
365 | **data_module)
366 |
367 | print_trainable_parameters(model)
368 |
369 | if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
370 | trainer.train(resume_from_checkpoint=True)
371 | else:
372 | trainer.train()
373 | trainer.save_state()
374 |
375 | safe_save_model_for_hf_trainer(trainer=trainer,
376 | output_dir=training_args.output_dir)
377 |
378 | # eval_results = trainer.evaluate()
379 | # print("Evaluation results:", eval_results)
380 |
381 |
382 | if __name__ == "__main__":
383 | train()
384 |
--------------------------------------------------------------------------------
/sensorllm/train/train_mem.py:
--------------------------------------------------------------------------------
1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
3 | # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
4 |
5 | # Need to call this before importing transformers.
6 | from sensorllm.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
7 |
8 | replace_llama_attn_with_flash_attn()
9 |
10 | from sensorllm.train.train import train
11 |
12 | if __name__ == "__main__":
13 | train()
--------------------------------------------------------------------------------
/sensorllm/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import logging.handlers
3 | import os
4 | import sys
5 |
6 | import requests
7 |
8 | import yaml
9 |
10 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
11 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
12 |
13 | handler = None
14 |
15 |
16 | class EasyDict(dict):
17 | def __init__(self, d=None, **kwargs):
18 | if d is None:
19 | d = {}
20 | if kwargs:
21 | d.update(**kwargs)
22 | for k, v in d.items():
23 | setattr(self, k, v)
24 | # Class attributes
25 | for k in self.__class__.__dict__.keys():
26 | if not (k.startswith('__') and k.endswith('__')) and not k in ('update', 'pop'):
27 | setattr(self, k, getattr(self, k))
28 |
29 | def __setattr__(self, name, value):
30 | if isinstance(name, str):
31 | if isinstance(value, (list, tuple)):
32 | value = [self.__class__(x)
33 | if isinstance(x, dict) else x for x in value]
34 | elif isinstance(value, dict) and not isinstance(value, self.__class__):
35 | value = self.__class__(value)
36 | super(EasyDict, self).__setattr__(name, value)
37 | super(EasyDict, self).__setitem__(name, value)
38 | else:
39 | self.__setitem__(name, value)
40 |
41 | def __setitem__(self, name, value):
42 |
43 | if isinstance(value, (list, tuple)):
44 | value = [self.__class__(x)
45 | if isinstance(x, dict) else x for x in value]
46 | elif isinstance(value, dict) and not isinstance(value, self.__class__):
47 | value = self.__class__(value)
48 | super(EasyDict, self).__setitem__(name, value)
49 |
50 | def update(self, e=None, **f):
51 | d = e or dict()
52 | d.update(f)
53 | for k in d:
54 | setattr(self, k, d[k])
55 |
56 | def pop(self, k, d=None):
57 | delattr(self, k)
58 | return super(EasyDict, self).pop(k, d)
59 |
60 | def merge_new_config(config, new_config):
61 | for key, val in new_config.items():
62 | if not isinstance(val, dict):
63 | if key == '_base_':
64 | with open(new_config['_base_'], 'r') as f:
65 | try:
66 | val = yaml.load(f, Loader=yaml.FullLoader)
67 | except:
68 | val = yaml.load(f)
69 | config[key] = EasyDict()
70 | merge_new_config(config[key], val)
71 | else:
72 | config[key] = val
73 | continue
74 | if key not in config:
75 | config[key] = EasyDict()
76 | merge_new_config(config[key], val)
77 | return config
78 |
79 | def cfg_from_yaml_file(cfg_file):
80 | config = EasyDict()
81 | with open(cfg_file, 'r') as f:
82 | new_config = yaml.load(f, Loader=yaml.FullLoader)
83 | merge_new_config(config=config, new_config=new_config)
84 | return config
85 |
86 | def build_logger(logger_name, logger_filepath):
87 | global handler
88 |
89 | formatter = logging.Formatter(
90 | fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
91 | datefmt="%Y-%m-%d %H:%M:%S",
92 | )
93 |
94 | # Set the format of root handlers
95 | if not logging.getLogger().handlers:
96 | logging.basicConfig(level=logging.INFO)
97 | else:
98 | logging.getLogger().handlers[0].setFormatter(formatter)
99 |
100 | # Redirect stdout and stderr to loggers
101 | stdout_logger = logging.getLogger("stdout")
102 | stdout_logger.setLevel(logging.INFO)
103 | sl_out = StreamToLogger(stdout_logger, logging.INFO)
104 | sys.stdout = sl_out
105 |
106 | stderr_logger = logging.getLogger("stderr")
107 | stderr_logger.setLevel(logging.ERROR)
108 | sl_err = StreamToLogger(stderr_logger, logging.ERROR)
109 | sys.stderr = sl_err
110 |
111 | # Get logger
112 | logger = logging.getLogger(logger_name)
113 | logger.setLevel(logging.INFO)
114 |
115 | # Add a file handler for all loggers
116 | if handler is None:
117 | # Get the logger_file's directory, and create it if it does not exist
118 | logger_filedir = os.path.dirname(logger_filepath)
119 | os.makedirs(logger_filedir, exist_ok=True)
120 | handler = logging.handlers.TimedRotatingFileHandler(
121 | logger_filepath, when='D', utc=True)
122 | handler.setFormatter(formatter)
123 |
124 | # Attach the handler to all existing loggers
125 | for name, item in logging.root.manager.loggerDict.items():
126 | if isinstance(item, logging.Logger):
127 | item.addHandler(handler)
128 |
129 | return logger
130 |
131 |
132 | class StreamToLogger(object):
133 | """
134 | Fake file-like stream object that redirects writes to a logger instance.
135 | """
136 | def __init__(self, logger, log_level=logging.INFO):
137 | self.terminal = sys.stdout if log_level == logging.INFO else sys.stderr
138 | self.logger = logger
139 | self.log_level = log_level
140 | self.linebuf = ''
141 |
142 | def __getattr__(self, attr):
143 | return getattr(self.terminal, attr)
144 |
145 | def write(self, buf):
146 | temp_linebuf = self.linebuf + buf
147 | self.linebuf = ''
148 | for line in temp_linebuf.splitlines(True):
149 | if line.endswith('\n'):
150 | self.logger.log(self.log_level, line.rstrip())
151 | else:
152 | self.linebuf += line
153 |
154 | def flush(self):
155 | if self.linebuf:
156 | self.logger.log(self.log_level, self.linebuf.rstrip())
157 | self.linebuf = ''
158 |
159 |
160 | def disable_torch_init():
161 | """
162 | Disable the redundant torch default initialization to accelerate model creation.
163 | """
164 | import torch
165 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
166 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
167 |
168 |
169 | def violates_moderation(text):
170 | """
171 | Check whether the text violates OpenAI moderation API.
172 | """
173 | url = "https://api.openai.com/v1/moderations"
174 | headers = {"Content-Type": "application/json",
175 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
176 | text = text.replace("\n", "")
177 | data = "{" + '"input": ' + f'"{text}"' + "}"
178 | data = data.encode("utf-8")
179 | try:
180 | ret = requests.post(url, headers=headers, data=data, timeout=5)
181 | flagged = ret.json()["results"][0]["flagged"]
182 | except requests.exceptions.RequestException as e:
183 | flagged = False
184 | except KeyError as e:
185 | flagged = False
186 |
187 | return flagged
188 |
189 |
190 | def pretty_print_semaphore(semaphore):
191 | if semaphore is None:
192 | return "None"
193 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
--------------------------------------------------------------------------------