|>"
115 | )
116 |
117 | prompt.append(prompt_)
118 |
119 | x_enc = x_enc.reshape(B, N, T).permute(0, 2, 1).contiguous()
120 |
121 | prompt = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=2048).input_ids
122 | prompt_embeddings = self.llm_model.get_input_embeddings()(prompt.to(x_enc.device)) # (batch, prompt_token, dim)
123 |
124 | source_embeddings = self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0)
125 |
126 | x_enc = x_enc.permute(0, 2, 1).contiguous()
127 | enc_out, n_vars = self.patch_embedding(x_enc.to(torch.float32))
128 | enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings)
129 | llama_enc_out = torch.cat([prompt_embeddings, enc_out], dim=1)
130 | dec_out = self.llm_model(inputs_embeds=llama_enc_out).last_hidden_state
131 | dec_out = dec_out[:, :, :self.d_ff] # (batch, patch_num, d_ff)
132 |
133 | dec_out = torch.reshape(
134 | dec_out, (-1, n_vars, dec_out.shape[-2], dec_out.shape[-1]))
135 | dec_out = dec_out.permute(0, 1, 3, 2).contiguous()
136 |
137 | dec_out = self.output_projection(dec_out[:, :, :, -self.patch_nums:])
138 | dec_out = dec_out.permute(0, 2, 1).contiguous()
139 |
140 | dec_out = self.normalize_layers(dec_out, 'denorm')
141 |
142 | return dec_out[:, -self.pred_len:, :]
143 |
--------------------------------------------------------------------------------
/ltsm/models/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from math import sqrt
5 | from transformers.modeling_utils import PreTrainedModel, PretrainedConfig
6 |
7 | class Normalize(nn.Module):
8 | def __init__(self, num_features: int, eps=1e-5, affine=False, subtract_last=False, non_norm=False):
9 | """
10 | :param num_features: the number of features or channels
11 | :param eps: a value added for numerical stability
12 | :param affine: if True, RevIN has learnable affine parameters
13 | """
14 | super(Normalize, self).__init__()
15 | self.num_features = num_features
16 | self.eps = eps
17 | self.affine = affine
18 | self.subtract_last = subtract_last
19 | self.non_norm = non_norm
20 | if self.affine:
21 | self._init_params()
22 |
23 | def forward(self, x, mode: str):
24 | if mode == 'norm':
25 | self._get_statistics(x)
26 | x = self._normalize(x)
27 | elif mode == 'denorm':
28 | x = self._denormalize(x)
29 | else:
30 | raise NotImplementedError
31 | return x
32 |
33 | def _init_params(self):
34 | # initialize RevIN params: (C,)
35 | self.affine_weight = nn.Parameter(torch.ones(self.num_features))
36 | self.affine_bias = nn.Parameter(torch.zeros(self.num_features))
37 |
38 | def _get_statistics(self, x):
39 | dim2reduce = tuple(range(1, x.ndim - 1))
40 | if self.subtract_last:
41 | self.last = x[:, -1, :].unsqueeze(1)
42 | else:
43 | self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
44 | self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()
45 |
46 | def _normalize(self, x):
47 | if self.non_norm:
48 | return x
49 | if self.subtract_last:
50 | x = x - self.last
51 | else:
52 | x = x - self.mean
53 | x = x / self.stdev
54 | if self.affine:
55 | x = x * self.affine_weight
56 | x = x + self.affine_bias
57 | return x
58 |
59 | def _denormalize(self, x):
60 | if self.non_norm:
61 | return x
62 | if self.affine:
63 | x = x - self.affine_bias
64 | x = x / (self.affine_weight + self.eps * self.eps)
65 | x = x * self.stdev
66 | if self.subtract_last:
67 | x = x + self.last
68 | else:
69 | x = x + self.mean
70 | return x
71 |
72 |
73 | class FlattenHead(nn.Module):
74 | def __init__(self, n_vars, nf, target_window, head_dropout=0):
75 | super().__init__()
76 | self.n_vars = n_vars
77 | self.flatten = nn.Flatten(start_dim=-2)
78 | self.linear = nn.Linear(nf, target_window)
79 | self.dropout = nn.Dropout(head_dropout)
80 |
81 | def forward(self, x):
82 | x = self.flatten(x)
83 | x = self.linear(x)
84 | x = self.dropout(x)
85 | return x
86 |
87 | class ReprogrammingLayer(nn.Module):
88 | def __init__(self, d_model, n_heads, d_keys=None, d_llm=None, attention_dropout=0.1):
89 | super(ReprogrammingLayer, self).__init__()
90 |
91 | d_keys = d_keys or (d_model // n_heads)
92 |
93 | self.query_projection = nn.Linear(d_model, d_keys * n_heads)
94 | self.key_projection = nn.Linear(d_llm, d_keys * n_heads)
95 | self.value_projection = nn.Linear(d_llm, d_keys * n_heads)
96 | self.out_projection = nn.Linear(d_keys * n_heads, d_llm)
97 | self.n_heads = n_heads
98 | self.dropout = nn.Dropout(attention_dropout)
99 |
100 | def forward(self, target_embedding, source_embedding, value_embedding):
101 | B, L, _ = target_embedding.shape
102 | S, _ = source_embedding.shape
103 | H = self.n_heads
104 |
105 | target_embedding = self.query_projection(target_embedding).view(B, L, H, -1)
106 | source_embedding = self.key_projection(source_embedding).view(S, H, -1)
107 | value_embedding = self.value_projection(value_embedding).view(S, H, -1)
108 |
109 | out = self.reprogramming(target_embedding, source_embedding, value_embedding)
110 |
111 | out = out.reshape(B, L, -1)
112 |
113 | return self.out_projection(out)
114 |
115 | def reprogramming(self, target_embedding, source_embedding, value_embedding):
116 | B, L, H, E = target_embedding.shape
117 |
118 | scale = 1. / sqrt(E)
119 |
120 | scores = torch.einsum("blhe,she->bhls", target_embedding, source_embedding)
121 |
122 | A = self.dropout(torch.softmax(scale * scores, dim=-1))
123 | reprogramming_embedding = torch.einsum("bhls,she->blhe", A, value_embedding)
124 |
125 | return reprogramming_embedding
126 |
127 |
128 |
129 | def get_model(config):
130 | if config.model == 'LTSM_WordPrompt':
131 | from .ltsm_wordprompt import LTSM_WordPrompt
132 | model = LTSM_WordPrompt(config)
133 | elif config.model == 'LTSM_Tokenizer':
134 | from .ltsm_tokenizer import LTSM_Tokenizer
135 | model = LTSM_Tokenizer(config)
136 | else:
137 | from .ltsm_model import LTSM
138 | if config.local_pretrain == "None":
139 | model = LTSM(config)
140 | else:
141 | model_config = PretrainedConfig.from_pretrained(config.local_pretrain)
142 | model = LTSM.from_pretrained(config.local_pretrain, model_config)
143 |
144 |
145 | return model
146 |
147 |
--------------------------------------------------------------------------------
/ltsm/utils/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/ltsm/91ee7775ee5dabfd4baa4bf8713ecd111560655d/ltsm/utils/.DS_Store
--------------------------------------------------------------------------------
/ltsm/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/ltsm/91ee7775ee5dabfd4baa4bf8713ecd111560655d/ltsm/utils/__init__.py
--------------------------------------------------------------------------------
/ltsm/utils/dist.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.spatial.distance import euclidean
3 | from fastdtw import fastdtw
4 | import torch
5 |
6 | def pairwise_dtw(x_batch, y_batch):
7 | """
8 |
9 | Args:
10 | :param x_batch: Tensor, [ Batchsize, Time, Dimension_x ]
11 | :param y_batch: Tensor, [ Batchsize, Time, Dimension_y ]
12 |
13 | The input tensor should have Dimension_x == Dimension_y
14 |
15 | :return: Pair-wise Distance, Tensor, [ Batchsize, Batchsize ]
16 | """
17 |
18 | batchsize_x = x_batch.shape[0]
19 | batchsize_y = y_batch.shape[0]
20 | dist_matrix = torch.zeros((batchsize_x, batchsize_y), device=torch.device("cpu"))
21 | for idx1, x in enumerate(x_batch):
22 | for idx2, y in enumerate(y_batch):
23 | if x_batch is y_batch and dist_matrix[idx2, idx1] > 0:
24 | dist_matrix[idx1, idx2] = dist_matrix[idx2, idx1]
25 |
26 | else:
27 | distance_xy, _ = fastdtw(x, y, dist=euclidean)
28 | dist_matrix[idx1, idx2] = distance_xy
29 |
30 |
31 |
32 |
33 |
34 |
--------------------------------------------------------------------------------
/ltsm/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def RSE(pred, true):
5 | return np.sqrt(np.sum((true - pred) ** 2)) / np.sqrt(np.sum((true - true.mean()) ** 2))
6 |
7 |
8 | def CORR(pred, true):
9 | u = ((true - true.mean(0)) * (pred - pred.mean(0))).sum(0)
10 | d = np.sqrt(((true - true.mean(0)) ** 2 * (pred - pred.mean(0)) ** 2).sum(0))
11 | return (u / d).mean(-1)
12 |
13 |
14 | def MAE(pred, true):
15 | return np.mean(np.abs(pred - true))
16 |
17 |
18 | def MSE(pred, true):
19 | return np.mean((pred - true) ** 2)
20 |
21 |
22 | def RMSE(pred, true):
23 | return np.sqrt(MSE(pred, true))
24 |
25 |
26 | def MAPE(pred, true):
27 | return np.mean(np.abs(100 * (pred - true) / (true +1e-8)))
28 |
29 |
30 | def MSPE(pred, true):
31 | return np.mean(np.square((pred - true) / (true + 1e-8)))
32 |
33 | def SMAPE(pred, true):
34 | return np.mean(200 * np.abs(pred - true) / (np.abs(pred) + np.abs(true) + 1e-8))
35 | # return np.mean(200 * np.abs(pred - true) / (pred + true + 1e-8))
36 |
37 | def ND(pred, true):
38 | return np.mean(np.abs(true - pred)) / np.mean(np.abs(true))
39 |
40 | def metric(pred, true):
41 | mae = MAE(pred, true)
42 | mse = MSE(pred, true)
43 | rmse = RMSE(pred, true)
44 | mape = MAPE(pred, true)
45 | mspe = MSPE(pred, true)
46 | smape = SMAPE(pred, true)
47 | nd = ND(pred, true)
48 |
49 | return mae, mse, rmse, mape, mspe, smape, nd
50 |
--------------------------------------------------------------------------------
/ltsm/utils/timefeatures.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import numpy as np
4 | import pandas as pd
5 | from pandas.tseries import offsets
6 | from pandas.tseries.frequencies import to_offset
7 |
8 |
9 | class TimeFeature:
10 | def __init__(self):
11 | pass
12 |
13 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
14 | pass
15 |
16 | def __repr__(self):
17 | return self.__class__.__name__ + "()"
18 |
19 |
20 | class SecondOfMinute(TimeFeature):
21 | """Minute of hour encoded as value between [-0.5, 0.5]"""
22 |
23 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
24 | return index.second / 59.0 - 0.5
25 |
26 |
27 | class MinuteOfHour(TimeFeature):
28 | """Minute of hour encoded as value between [-0.5, 0.5]"""
29 |
30 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
31 | return index.minute / 59.0 - 0.5
32 |
33 |
34 | class HourOfDay(TimeFeature):
35 | """Hour of day encoded as value between [-0.5, 0.5]"""
36 |
37 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
38 | return index.hour / 23.0 - 0.5
39 |
40 |
41 | class DayOfWeek(TimeFeature):
42 | """Hour of day encoded as value between [-0.5, 0.5]"""
43 |
44 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
45 | return index.dayofweek / 6.0 - 0.5
46 |
47 |
48 | class DayOfMonth(TimeFeature):
49 | """Day of month encoded as value between [-0.5, 0.5]"""
50 |
51 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
52 | return (index.day - 1) / 30.0 - 0.5
53 |
54 |
55 | class DayOfYear(TimeFeature):
56 | """Day of year encoded as value between [-0.5, 0.5]"""
57 |
58 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
59 | return (index.dayofyear - 1) / 365.0 - 0.5
60 |
61 |
62 | class MonthOfYear(TimeFeature):
63 | """Month of year encoded as value between [-0.5, 0.5]"""
64 |
65 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
66 | return (index.month - 1) / 11.0 - 0.5
67 |
68 |
69 | class WeekOfYear(TimeFeature):
70 | """Week of year encoded as value between [-0.5, 0.5]"""
71 |
72 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
73 | return (index.isocalendar().week - 1) / 52.0 - 0.5
74 |
75 |
76 | def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]:
77 | """
78 | Returns a list of time features that will be appropriate for the given frequency string.
79 | Parameters
80 | ----------
81 | freq_str
82 | Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc.
83 | """
84 |
85 | features_by_offsets = {
86 | offsets.YearEnd: [],
87 | offsets.QuarterEnd: [MonthOfYear],
88 | offsets.MonthEnd: [MonthOfYear],
89 | offsets.Week: [DayOfMonth, WeekOfYear],
90 | offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear],
91 | offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear],
92 | offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear],
93 | offsets.Minute: [
94 | MinuteOfHour,
95 | HourOfDay,
96 | DayOfWeek,
97 | DayOfMonth,
98 | DayOfYear,
99 | ],
100 | offsets.Second: [
101 | SecondOfMinute,
102 | MinuteOfHour,
103 | HourOfDay,
104 | DayOfWeek,
105 | DayOfMonth,
106 | DayOfYear,
107 | ],
108 | }
109 |
110 | offset = to_offset(freq_str)
111 |
112 | for offset_type, feature_classes in features_by_offsets.items():
113 | if isinstance(offset, offset_type):
114 | return [cls() for cls in feature_classes]
115 |
116 | supported_freq_msg = f"""
117 | Unsupported frequency {freq_str}
118 | The following frequencies are supported:
119 | Y - yearly
120 | alias: A
121 | M - monthly
122 | W - weekly
123 | D - daily
124 | B - business days
125 | H - hourly
126 | T - minutely
127 | alias: min
128 | S - secondly
129 | """
130 | raise RuntimeError(supported_freq_msg)
131 |
132 |
133 | def time_features(dates, freq='h'):
134 | return np.vstack([feat(dates) for feat in time_features_from_frequency_str(freq)])
135 |
--------------------------------------------------------------------------------
/ltsm/utils/tools.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import matplotlib.pyplot as plt
5 | from tqdm import tqdm
6 | from datetime import datetime
7 | from distutils.util import strtobool
8 | import pandas as pd
9 |
10 | from ltsm.utils.metrics import metric
11 |
12 | plt.switch_backend('agg')
13 |
14 |
15 | class dotdict(dict):
16 | """dot.notation access to dictionary attributes"""
17 | __getattr__ = dict.get
18 | __setattr__ = dict.__setitem__
19 | __delattr__ = dict.__delitem__
20 |
21 |
22 | class StandardScaler():
23 | def __init__(self, mean, std):
24 | self.mean = mean
25 | self.std = std
26 |
27 | def transform(self, data):
28 | return (data - self.mean) / self.std
29 |
30 | def inverse_transform(self, data):
31 | return (data * self.std) + self.mean
32 |
33 |
34 | def visual(true, preds=None, name='./pic/test.pdf'):
35 | """
36 | Results visualization
37 | """
38 | plt.figure()
39 | plt.plot(true, label='GroundTruth', linewidth=2)
40 | if preds is not None:
41 | plt.plot(preds, label='Prediction', linewidth=2)
42 | plt.legend()
43 | plt.savefig(name, bbox_inches='tight')
44 |
45 |
46 | def convert_tsf_to_dataframe(
47 | full_file_path_and_name,
48 | replace_missing_vals_with="NaN",
49 | value_column_name="series_value",
50 | ):
51 | col_names = []
52 | col_types = []
53 | all_data = {}
54 | line_count = 0
55 | frequency = None
56 | forecast_horizon = None
57 | contain_missing_values = None
58 | contain_equal_length = None
59 | found_data_tag = False
60 | found_data_section = False
61 | started_reading_data_section = False
62 |
63 | print(full_file_path_and_name)
64 | with open(full_file_path_and_name, "r", encoding="cp1252") as file:
65 | for line in file:
66 | # Strip white space from start/end of line
67 | line = line.strip()
68 |
69 | if line:
70 | if line.startswith("@"): # Read meta-data
71 | if not line.startswith("@data"):
72 | line_content = line.split(" ")
73 | if line.startswith("@attribute"):
74 | if (
75 | len(line_content) != 3
76 | ): # Attributes have both name and type
77 | raise Exception("Invalid meta-data specification.")
78 |
79 | col_names.append(line_content[1])
80 | col_types.append(line_content[2])
81 | else:
82 | if (
83 | len(line_content) != 2
84 | ): # Other meta-data have only values
85 | raise Exception("Invalid meta-data specification.")
86 |
87 | if line.startswith("@frequency"):
88 | frequency = line_content[1]
89 | elif line.startswith("@horizon"):
90 | forecast_horizon = int(line_content[1])
91 | elif line.startswith("@missing"):
92 | contain_missing_values = bool(
93 | strtobool(line_content[1])
94 | )
95 | elif line.startswith("@equallength"):
96 | contain_equal_length = bool(strtobool(line_content[1]))
97 |
98 | else:
99 | if len(col_names) == 0:
100 | raise Exception(
101 | "Missing attribute section. Attribute section must come before data."
102 | )
103 |
104 | found_data_tag = True
105 | elif not line.startswith("#"):
106 | if len(col_names) == 0:
107 | raise Exception(
108 | "Missing attribute section. Attribute section must come before data."
109 | )
110 | elif not found_data_tag:
111 | raise Exception("Missing @data tag.")
112 | else:
113 | if not started_reading_data_section:
114 | started_reading_data_section = True
115 | found_data_section = True
116 | all_series = []
117 |
118 | for col in col_names:
119 | all_data[col] = []
120 |
121 | full_info = line.split(":")
122 |
123 | if len(full_info) != (len(col_names) + 1):
124 | continue
125 | #raise Exception("Missing attributes/values in series.")
126 |
127 | series = full_info[len(full_info) - 1]
128 | series = series.split(",")
129 |
130 | if len(series) == 0:
131 | raise Exception(
132 | "A given series should contains a set of comma separated numeric values. At least one numeric value should be there in a series. Missing values should be indicated with ? symbol"
133 | )
134 |
135 | numeric_series = []
136 |
137 | for val in series:
138 | if val == "?":
139 | numeric_series.append(replace_missing_vals_with)
140 | else:
141 | numeric_series.append(float(val))
142 |
143 | if numeric_series.count(replace_missing_vals_with) == len(
144 | numeric_series
145 | ):
146 | raise Exception(
147 | "All series values are missing. A given series should contains a set of comma separated numeric values. At least one numeric value should be there in a series."
148 | )
149 |
150 | all_series.append(pd.Series(numeric_series).array)
151 |
152 | for i in range(len(col_names)):
153 | att_val = None
154 | if col_types[i] == "numeric":
155 | att_val = int(full_info[i])
156 | elif col_types[i] == "string":
157 | att_val = str(full_info[i])
158 | elif col_types[i] == "date":
159 | att_val = datetime.strptime(
160 | full_info[i], "%Y-%m-%d %H-%M-%S"
161 | )
162 | else:
163 | raise Exception(
164 | "Invalid attribute type."
165 | ) # Currently, the code supports only numeric, string and date types. Extend this as required.
166 |
167 | if att_val is None:
168 | raise Exception("Invalid attribute value.")
169 | else:
170 | all_data[col_names[i]].append(att_val)
171 |
172 | line_count = line_count + 1
173 |
174 | if line_count == 0:
175 | raise Exception("Empty file.")
176 | if len(col_names) == 0:
177 | raise Exception("Missing attribute section.")
178 | if not found_data_section:
179 | raise Exception("Missing series information under data section.")
180 |
181 | all_data[value_column_name] = all_series
182 | loaded_data = pd.DataFrame(all_data)
183 |
184 | # ipdb.set_trace()
185 |
186 | return (
187 | loaded_data,
188 | frequency,
189 | forecast_horizon,
190 | contain_missing_values,
191 | contain_equal_length,
192 | )
193 |
194 |
195 | def MASE(x, freq, pred, true):
196 | masep = np.mean(np.abs(x[:, freq:] - x[:, :-freq]))
197 | return np.mean(np.abs(pred - true) / (masep + 1e-8))
198 |
199 |
--------------------------------------------------------------------------------
/main_ltsm.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn
4 | import os
5 | import argparse
6 | import random
7 | import ipdb
8 |
9 | from ltsm.data_provider.data_factory import get_data_loaders, get_datasets,get_test_datasets
10 | from ltsm.data_provider.data_loader import HF_Dataset
11 | from ltsm.models import get_model, LTSMConfig
12 | from peft import get_peft_config, get_peft_model, LoraConfig
13 |
14 | from transformers import (
15 | Trainer,
16 | TrainingArguments,
17 | EvalPrediction,
18 | set_seed,
19 | )
20 |
21 | def get_args():
22 | parser = argparse.ArgumentParser(description='LTSM')
23 |
24 | # Basic Config
25 | parser.add_argument('--model_id', type=str, default='test_run', help='model id')
26 | parser.add_argument('--model_name_or_path', type=str, default="gpt2-medium", help='model name')
27 | parser.add_argument('--seed', type=int, default=2024, help='random seed')
28 | parser.add_argument('--device', type=str, default="cuda:0")
29 | parser.add_argument('--checkpoints', type=str, default='./checkpoints/')
30 |
31 | # Data Settings
32 | parser.add_argument('--data_path', nargs='+', default='dataset/weather.csv', help='data files')
33 | parser.add_argument('--test_data_path_list', nargs='+', required=True, help='test data file')
34 | parser.add_argument('--prompt_data_path', type=str, default='./weather.csv', help='prompt data file')
35 | parser.add_argument('--data_processing', type=str, default="standard_scaler", help='data processing method')
36 | parser.add_argument('--train_ratio', type=float, default=0.7, help='train data ratio')
37 | parser.add_argument('--val_ratio', type=float, default=0.1, help='validation data ratio')
38 |
39 | # Forecasting Settings
40 | parser.add_argument('--seq_len', type=int, default=336, help='input sequence length')
41 | parser.add_argument('--pred_len', type=int, default=96, help='prediction sequence length')
42 | parser.add_argument('--prompt_len', type=int, default=133, help='prompt sequence length')
43 |
44 | # Model Settings
45 | parser.add_argument('--lora', action="store_true", help='use lora')
46 | parser.add_argument('--lora_dim', type=int, default=128, help='dimension of lora')
47 | parser.add_argument('--gpt_layers', type=int, default=3, help='number of gpt layers')
48 | parser.add_argument('--d_model', type=int, default=1024, help='dimension of model')
49 | parser.add_argument('--n_heads', type=int, default=16, help='number of heads')
50 | parser.add_argument('--d_ff', type=int, default=512, help='dimension of fcn')
51 | parser.add_argument('--dropout', type=float, default=0.2, help='dropout')
52 | parser.add_argument('--enc_in', type=int, default=1, help='encoder input size')
53 | parser.add_argument('--c_out', type=int, default=862, help='output size')
54 | parser.add_argument('--patch_size', type=int, default=16, help='patch size')
55 | parser.add_argument('--pretrain', type=int, default=1, help='is pretrain')
56 | parser.add_argument('--local_pretrain', type=str, default="None", help='local pretrain weight')
57 | parser.add_argument('--freeze', type=int, default=1, help='is model weight frozen')
58 | parser.add_argument('--model', type=str, default='model', help='model name, , options:[LTSM, LTSM_WordPrompt, LTSM_Tokenizer]')
59 | parser.add_argument('--stride', type=int, default=8, help='stride')
60 | parser.add_argument('--tmax', type=int, default=10, help='tmax')
61 |
62 | # Training Settings
63 | parser.add_argument('--eval', type=int, default=0, help='evaluation')
64 | parser.add_argument('--itr', type=int, default=1, help='experiments times')
65 | parser.add_argument('--output_dir', type=str, default='output/ltsm_train_lr0005/', help='output directory')
66 | parser.add_argument('--downsample_rate', type=int, default=100, help='downsample rate')
67 | parser.add_argument('--llm_layers', type=int, default=32)
68 | parser.add_argument('--decay_fac', type=float, default=0.75, help='decay factor')
69 | parser.add_argument('--learning_rate', type=float, default=0.0001, help='learning rate')
70 | parser.add_argument('--batch_size', type=int, default=512, help='batch size')
71 | parser.add_argument('--num_workers', type=int, default=10, help='number of workers')
72 | parser.add_argument('--train_epochs', type=int, default=1, help='number of epochs')
73 | parser.add_argument('--lradj', type=str, default='type1', help='learning rate adjustment type')
74 | parser.add_argument('--patience', type=int, default=3, help='early stopping patience')
75 | parser.add_argument('--gradient_accumulation_steps', type=int, default=64, help='gradient accumulation steps')
76 | args, unknown = parser.parse_known_args()
77 |
78 | return args
79 |
80 |
81 | def seed_all(fixed_seed):
82 | random.seed(fixed_seed)
83 | torch.manual_seed(fixed_seed)
84 | np.random.seed(fixed_seed)
85 |
86 | def freeze_parameters(model):
87 |
88 | freeze_param_buf = ["gpt2"]
89 | for n, p in model.named_parameters():
90 | if any(fp in n for fp in freeze_param_buf):
91 | p.requires_grad = False
92 | print(f"{n} has been freeezed")
93 |
94 | trainable_param_buf = ["ln", "wpe", "in_layer", "out_layer", "lora"]
95 | for n, p in model.named_parameters():
96 | if any(fp in n for fp in trainable_param_buf):
97 | p.requires_grad = True
98 |
99 | def print_trainable_parameters(model):
100 | for n, p in model.named_parameters():
101 | if p.requires_grad:
102 | print(f"{n} is trainable...")
103 |
104 |
105 | def run(args):
106 | print(args)
107 |
108 | model_config = LTSMConfig(**vars(args))
109 | model = get_model(model_config)
110 |
111 | if args.lora:
112 | peft_config = LoraConfig(
113 | target_modules=["c_attn"],
114 | inference_mode=False,
115 | r=args.lora_dim,
116 | lora_alpha=32,
117 | lora_dropout=0.1
118 | )
119 | model = get_peft_model(model, peft_config)
120 | model.print_trainable_parameters()
121 |
122 | elif args.freeze:
123 | freeze_parameters(model)
124 |
125 | print_trainable_parameters(model)
126 |
127 | # Optimizer settings
128 | model_optim = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
129 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(model_optim, T_max=args.tmax, eta_min=1e-8)
130 |
131 | # Evaluation metrics
132 | def compute_metrics(p: EvalPrediction):
133 | preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
134 | preds = np.squeeze(preds)
135 | if preds.shape != p.label_ids.shape:
136 | label_ids = np.squeeze(p.label_ids)
137 | else:
138 | label_ids = p.label_ids
139 | return {
140 | "mse": ((preds - label_ids) ** 2).mean().item(),
141 | "mae": (np.abs(preds - label_ids)).mean().item()
142 | }
143 |
144 | # Loss function
145 | def compute_loss(model, inputs, return_outputs=False):
146 | outputs = model(inputs["input_data"])
147 | loss = nn.functional.mse_loss(outputs, inputs["labels"])
148 | return (loss, outputs) if return_outputs else loss
149 |
150 | # Data collator
151 | def collate_fn(batch):
152 | return {
153 | 'input_data': torch.from_numpy(np.stack([x['input_data'] for x in batch])).type(torch.float32),
154 | 'labels': torch.from_numpy(np.stack([x['labels'] for x in batch])).type(torch.float32),
155 | }
156 |
157 | # Prediction step
158 | @torch.no_grad()
159 | def prediction_step(model, inputs, prediction_loss_only=False, ignore_keys=None):
160 | # CSV
161 | input_data = inputs["input_data"].to(model.module.device)
162 | labels = inputs["labels"].to(model.module.device)
163 | outputs = model(input_data)
164 | loss = nn.functional.mse_loss(outputs, labels)
165 | return (loss, outputs, labels)
166 |
167 | # Training settings
168 | training_args = TrainingArguments(
169 | output_dir=args.output_dir,
170 | per_device_train_batch_size=args.batch_size,
171 | per_device_eval_batch_size=args.batch_size,
172 | evaluation_strategy="steps",
173 | num_train_epochs=args.train_epochs,
174 | fp16=False,
175 | save_steps=100,
176 | eval_steps=25,
177 | logging_steps=5,
178 | learning_rate=args.learning_rate,
179 | gradient_accumulation_steps=args.gradient_accumulation_steps,
180 | save_total_limit=10,
181 | remove_unused_columns=False,
182 | push_to_hub=False,
183 | load_best_model_at_end=True,
184 | )
185 |
186 | train_dataset, eval_dataset, _ = get_datasets(args)
187 | train_dataset, eval_dataset= HF_Dataset(train_dataset), HF_Dataset(eval_dataset)
188 |
189 | trainer = Trainer(
190 | model=model,
191 | args=training_args,
192 | data_collator=collate_fn,
193 | compute_metrics=compute_metrics,
194 | train_dataset=train_dataset,
195 | eval_dataset=eval_dataset,
196 | tokenizer=None,
197 | optimizers=(model_optim, lr_scheduler),
198 | )
199 |
200 | # Overload the trainer API
201 | if not args.eval:
202 | trainer.compute_loss = compute_loss
203 | trainer.prediction_step = prediction_step
204 | train_results = trainer.train()
205 | trainer.save_model()
206 | trainer.log_metrics("train", train_results.metrics)
207 | trainer.save_metrics("train", train_results.metrics)
208 | trainer.save_state()
209 |
210 | # Testing settings
211 | for data_path in args.test_data_path_list:
212 | trainer.compute_loss = compute_loss
213 | trainer.prediction_step = prediction_step
214 | args.test_data_path = data_path
215 | test_dataset, _ = get_test_datasets(args)
216 | test_dataset = HF_Dataset(test_dataset)
217 |
218 | metrics = trainer.evaluate(test_dataset)
219 | trainer.log_metrics("Test", metrics)
220 | trainer.save_metrics("Test", metrics)
221 |
222 |
223 | if __name__ == "__main__":
224 | args = get_args()
225 | seed_all(args.seed)
226 | run(args)
--------------------------------------------------------------------------------
/main_tokenizer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn
4 | import os
5 | import argparse
6 | import random
7 | import sys
8 |
9 | sys.path.append("/home/yc146/github_open_ltsm/ltsm")
10 |
11 | from ltsm.data_provider.data_factory import get_datasets,get_test_datasets
12 | from ltsm.data_provider.data_loader import HF_Dataset
13 | from ltsm.data_provider.data_processing.tokenizer_processor import TokenizerConfig
14 | from ltsm.models import get_model, LTSMConfig
15 | from peft import get_peft_model, LoraConfig
16 |
17 | from transformers import (
18 | Trainer,
19 | TrainingArguments,
20 | EvalPrediction,
21 | set_seed,
22 | )
23 | def get_args():
24 | parser = argparse.ArgumentParser(description='LTSM')
25 |
26 | # Basic Config
27 | parser.add_argument('--model_id', type=str, default='test_run', help='model id')
28 | parser.add_argument('--model_name_or_path', type=str, default="gpt2-medium", help='model name')
29 | parser.add_argument('--seed', type=int, default=2024, help='random seed')
30 | parser.add_argument('--device', type=str, default="cuda:0")
31 | parser.add_argument('--checkpoints', type=str, default='./checkpoints/')
32 |
33 | # Data Settings
34 | parser.add_argument('--data_path', nargs='+', default='dataset/weather.csv', help='data files')
35 | parser.add_argument('--test_data_path', type=str, default='dataset/weather.csv', help='test data file')
36 | parser.add_argument('--test_data_path_list', nargs='+', required=True, help='test data file')
37 | parser.add_argument('--prompt_data_path', type=str, default='./weather.csv', help='prompt data file')
38 | parser.add_argument('--data_processing', type=str, default="standard_scaler", help='data processing method')
39 | parser.add_argument('--train_ratio', type=float, default=0.7, help='train data ratio')
40 | parser.add_argument('--val_ratio', type=float, default=0.1, help='validation data ratio')
41 |
42 | # Forecasting Settings
43 | parser.add_argument('--seq_len', type=int, default=336, help='input sequence length')
44 | parser.add_argument('--pred_len', type=int, default=96, help='prediction sequence length')
45 | parser.add_argument('--prompt_len', type=int, default=133, help='prompt sequence length')
46 |
47 |
48 | # Model Settings
49 | parser.add_argument('--lora', action="store_true", help='use lora')
50 | parser.add_argument('--lora_dim', type=int, default=128, help='dimension of lora')
51 | parser.add_argument('--gpt_layers', type=int, default=3, help='number of gpt layers')
52 | parser.add_argument('--d_model', type=int, default=1024, help='dimension of model')
53 | parser.add_argument('--n_heads', type=int, default=16, help='number of heads')
54 | parser.add_argument('--d_ff', type=int, default=512, help='dimension of fcn')
55 | parser.add_argument('--dropout', type=float, default=0.2, help='dropout')
56 | parser.add_argument('--enc_in', type=int, default=1, help='encoder input size')
57 | parser.add_argument('--c_out', type=int, default=862, help='output size')
58 | parser.add_argument('--patch_size', type=int, default=16, help='patch size')
59 | parser.add_argument('--pretrain', type=int, default=1, help='is pretrain')
60 | parser.add_argument('--local_pretrain', type=str, default="None", help='local pretrain weight')
61 | parser.add_argument('--freeze', type=int, default=1, help='is model weight frozen')
62 | parser.add_argument('--model', type=str, default='model', help='model name, , options:[LTSM, LTSM_WordPrompt, LTSM_Tokenizer]')
63 | parser.add_argument('--stride', type=int, default=8, help='stride')
64 | parser.add_argument('--tmax', type=int, default=10, help='tmax')
65 |
66 | # Training Settings
67 | parser.add_argument('--eval', type=int, default=0, help='evaluation')
68 | parser.add_argument('--itr', type=int, default=1, help='experiments times')
69 | parser.add_argument('--output_dir', type=str, default='output/ltsm_train_lr0005/', help='output directory')
70 | parser.add_argument('--downsample_rate', type=int, default=100, help='downsample rate')
71 | parser.add_argument('--llm_layers', type=int, default=32)
72 | parser.add_argument('--decay_fac', type=float, default=0.75, help='decay factor')
73 | parser.add_argument('--learning_rate', type=float, default=0.0001, help='learning rate')
74 | parser.add_argument('--batch_size', type=int, default=512, help='batch size')
75 | parser.add_argument('--num_workers', type=int, default=10, help='number of workers')
76 | parser.add_argument('--train_epochs', type=int, default=1, help='number of epochs')
77 | parser.add_argument('--lradj', type=str, default='type1', help='learning rate adjustment type')
78 | parser.add_argument('--patience', type=int, default=3, help='early stopping patience')
79 | parser.add_argument('--gradient_accumulation_steps', type=int, default=64, help='gradient accumulation steps')
80 | args, unknown = parser.parse_known_args()
81 |
82 | return args
83 |
84 |
85 | def seed_all(fixed_seed):
86 | random.seed(fixed_seed)
87 | torch.manual_seed(fixed_seed)
88 | np.random.seed(fixed_seed)
89 |
90 | def freeze_parameters(model):
91 |
92 | freeze_param_buf = ["gpt2"]
93 | for n, p in model.named_parameters():
94 | if any(fp in n for fp in freeze_param_buf):
95 | p.requires_grad = False
96 | print(f"{n} has been freeezed")
97 |
98 | trainable_param_buf = ["ln", "wpe", "in_layer", "out_layer", "lora"]
99 | for n, p in model.named_parameters():
100 | if any(fp in n for fp in trainable_param_buf):
101 | p.requires_grad = True
102 |
103 | def print_trainable_parameters(model):
104 | for n, p in model.named_parameters():
105 | if p.requires_grad:
106 | print(f"{n} is trainable...")
107 |
108 | def run(args):
109 | print(args)
110 | model_config = LTSMConfig(**vars(args))
111 | model = get_model(model_config)
112 |
113 | if args.lora:
114 | peft_config = LoraConfig(
115 | target_modules=["c_attn"], # ["q", "v"],
116 | inference_mode=False,
117 | r=args.lora_dim,
118 | lora_alpha=32,
119 | lora_dropout=0.1
120 | )
121 | model = get_peft_model(model, peft_config)
122 | model.print_trainable_parameters()
123 |
124 | elif args.freeze:
125 | freeze_parameters(model)
126 |
127 | print_trainable_parameters(model)
128 |
129 |
130 | model_optim = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
131 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(model_optim, T_max=args.tmax, eta_min=1e-8)
132 |
133 | # Load Tokenizer Config, Reference: https://github.com/amazon-science/chronos-forecasting
134 | context_length = args.seq_len+args.pred_len
135 | prediction_length = args.pred_len
136 | n_tokens = 1024
137 | n_special_tokens = 2
138 | config = TokenizerConfig(
139 | tokenizer_class="MeanScaleUniformBins",
140 | tokenizer_kwargs=dict(low_limit=-3.0, high_limit=3.0),
141 | n_tokens=n_tokens,
142 | n_special_tokens=n_special_tokens,
143 | pad_token_id=0,
144 | eos_token_id=1,
145 | use_eos_token=0,
146 | model_type="causal",
147 | context_length=context_length,
148 | prediction_length=prediction_length,
149 | num_samples=20,
150 | temperature=1.0,
151 | top_k=50,
152 | top_p=1.0,
153 | )
154 |
155 | tokenizer = config.create_tokenizer()
156 |
157 | def compute_metrics(p: EvalPrediction):
158 | preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
159 | preds = np.squeeze(preds)
160 | if preds.shape != p.label_ids.shape:
161 | label_ids = np.squeeze(p.label_ids)
162 | else:
163 | label_ids = p.label_ids
164 | return {
165 | "mse": ((preds - label_ids) ** 2).mean().item(),
166 | "mae": (np.abs(preds - label_ids)).mean().item()}
167 |
168 | def compute_loss(model, inputs, return_outputs=False):
169 | outputs = model(inputs["input_data"])
170 | B, L, M, _ = outputs.shape
171 | loss = nn.functional.cross_entropy(outputs.reshape(B*L,-1), inputs["labels"][:,1:].long().reshape(B*L))
172 | return (loss, outputs) if return_outputs else loss
173 |
174 | def collate_fn(batch):
175 | return {
176 | 'input_data': torch.from_numpy(np.stack([x['input_data'] for x in batch])).type(torch.float32),
177 | 'labels': torch.from_numpy(np.stack([x['labels'] for x in batch])).type(torch.float32),
178 | }
179 |
180 | @torch.no_grad()
181 | def prediction_step(model, inputs, prediction_loss_only=False, ignore_keys=None):
182 | input_data = inputs["input_data"].to(model.module.device)
183 | labels = inputs["labels"].to(model.module.device)
184 | scale = labels[:,0]
185 | labels = labels[:,1:]
186 | outputs = model(input_data)
187 | indices = torch.max(outputs, dim=-1).indices
188 |
189 | output_value = tokenizer.output_transform(indices, scale)
190 | label_value = tokenizer.output_transform(labels.unsqueeze(-1).long(), scale)
191 | loss = nn.functional.mse_loss(output_value, label_value)
192 | return (loss, output_value, label_value)
193 |
194 |
195 | training_args = TrainingArguments(
196 | output_dir=args.output_dir,
197 | per_device_train_batch_size=args.batch_size,
198 | per_device_eval_batch_size=args.batch_size,
199 | evaluation_strategy="steps",
200 | num_train_epochs=args.train_epochs,
201 | fp16=False,
202 | save_steps=100,
203 | eval_steps=25,
204 | logging_steps=5,
205 | learning_rate=args.learning_rate,
206 | gradient_accumulation_steps=args.gradient_accumulation_steps,
207 | save_total_limit=10,
208 | remove_unused_columns=False,
209 | push_to_hub=False,
210 | load_best_model_at_end=True,
211 | )
212 |
213 | # Training settings
214 | train_dataset, eval_dataset, _ = get_datasets(args)
215 | train_dataset, eval_dataset= HF_Dataset(train_dataset), HF_Dataset(eval_dataset)
216 |
217 | trainer = Trainer(
218 | model=model,
219 | args=training_args,
220 | data_collator=collate_fn,
221 | compute_metrics=compute_metrics,
222 | train_dataset=train_dataset,
223 | eval_dataset=eval_dataset,
224 | tokenizer=None,
225 | optimizers=(model_optim, lr_scheduler),
226 | )
227 |
228 | # Overload the trainer API
229 | if not args.eval:
230 | trainer.compute_loss = compute_loss
231 | trainer.prediction_step = prediction_step
232 | train_results = trainer.train()
233 | trainer.save_model()
234 | trainer.log_metrics("train", train_results.metrics)
235 | trainer.save_metrics("train", train_results.metrics)
236 | trainer.save_state()
237 |
238 | # Testing settings
239 | for data_path in args.test_data_path_list:
240 | trainer.compute_loss = compute_loss
241 | trainer.prediction_step = prediction_step
242 | args.test_data_path = data_path
243 | test_dataset, _ = get_test_datasets(args)
244 | test_dataset = HF_Dataset(test_dataset)
245 |
246 | metrics = trainer.evaluate(test_dataset)
247 | trainer.log_metrics("Test", metrics)
248 | trainer.save_metrics("Test", metrics)
249 |
250 |
251 | if __name__ == "__main__":
252 | args = get_args()
253 | seed_all(args.seed)
254 | run(args)
255 |
--------------------------------------------------------------------------------
/prompt_bank/prompt_data_normalize_split/README.md:
--------------------------------------------------------------------------------
1 | # Time Series Prompt Dataset
--------------------------------------------------------------------------------
/prompt_bank/stat-prompt/README.md:
--------------------------------------------------------------------------------
1 | # Time Series Prompt Generator
2 |
3 |
4 | Time series prompts are designed to capture the extensive characteristics of time series data comprehensively. These prompts, distinct from text-based ones, are created by extracting a wide range of global features from the entire training dataset. This method ensures a robust representation of the underlying dynamics, essential for boosting model performance.
5 |
6 | ## Quick Start
7 | **Step 1.** Download the dataset from our [Google Drive](). Make sure your local data folder like this:
8 | ````angular2html
9 | - ltsm/
10 | - datasets/
11 | electricity/
12 | ETT-small/
13 | exchange_rate/
14 | illness/
15 | traffic/
16 | weather/
17 | ...
18 | ````
19 |
20 | **Step 2.** Generating the time series prompts from training, validating, and testing datasets
21 | ````angular2html
22 | python3 prompt_generate_split.py
23 | ````
24 |
25 | **Step 3.** Find the generated time series prompts in the './prompt_data_split' folder. Then run the following command for normalizing the prompts:
26 | ````angular2html
27 | python3 prompt_normalization_split.py --mode fit
28 | ````
29 |
30 | **Step 4.** Run this command to export the prompts to the "./prompt_data_normalize_split" folder:
31 | ````angular2html
32 | python3 prompt_normalization_split.py --mode transform
33 | ````
--------------------------------------------------------------------------------
/prompt_bank/stat-prompt/prompt_generate_split.py:
--------------------------------------------------------------------------------
1 | # from ltsm.data_provider.data_factory import get_data_loader, get_data_loaders, get_dataset
2 | import argparse
3 | import ipdb
4 | import pandas as pd
5 | import numpy as np
6 | import tsfel
7 | from pandas import read_csv, read_feather
8 | import matplotlib.pyplot as plt
9 | import sys, os
10 | import torch
11 |
12 |
13 | def get_args():
14 | parser = argparse.ArgumentParser(description='LTSM')
15 |
16 | parser.add_argument('--data_path', type=str, default='dataset/weather.csv')
17 | parser.add_argument('--data', type=str, default='custom')
18 | parser.add_argument('--freq', type=str, default="h")
19 | parser.add_argument('--target', type=str, default='OT')
20 | parser.add_argument('--embed', type=str, default='timeF')
21 | parser.add_argument('--percent', type=int, default=10)
22 | parser.add_argument('--batch_size', type=int, default=512)
23 | parser.add_argument('--max_len', type=int, default=-1)
24 | parser.add_argument('--seq_len', type=int, default=512)
25 | parser.add_argument('--pred_len', type=int, default=96)
26 | parser.add_argument('--label_len', type=int, default=48)
27 | parser.add_argument('--features', type=str, default='M')
28 |
29 | args = parser.parse_args()
30 |
31 | return args
32 |
33 | def prompt_prune(pt):
34 | pt_dict = pt.to_dict()
35 | pt_keys = list(pt_dict.keys())
36 | for key in pt_keys:
37 | if key.startswith("0_FFT mean coefficient"):
38 | del pt[key]
39 |
40 | return pt
41 |
42 |
43 | def prompt_generation_single(ts):
44 | cfg = tsfel.get_features_by_domain()
45 | prompt = tsfel.time_series_features_extractor(cfg, ts)
46 | prompt = prompt_prune(prompt)
47 | return prompt
48 |
49 | def prompt_generation(ts, ts_name):
50 |
51 | print(ts.shape)
52 |
53 | if ts.shape[1] == 1:
54 |
55 | return None
56 |
57 | else:
58 |
59 | column_name = [name.replace("/", "-") for name in list(ts.columns)]
60 | prompt_buf_train = pd.DataFrame(np.zeros((133, ts.shape[1])), columns=column_name)
61 | prompt_buf_val = pd.DataFrame(np.zeros((133, ts.shape[1])), columns=column_name)
62 | prompt_buf_test = pd.DataFrame(np.zeros((133, ts.shape[1])), columns=column_name)
63 | for index, col in ts.T.iterrows():
64 | if "ETT" in ts_name:
65 | ts_len = len(ts)
66 | t1, t2 = int(0.6*ts_len), int(0.6*ts_len) + int(0.2*ts_len)
67 | ts_train, ts_val, ts_test = col[:t1], col[t1:t2].reset_index(drop=True), col[t2:].reset_index(drop=True)
68 | else:
69 | ts_len = len(ts)
70 | t1, t2 = int(0.7 * ts_len), int(0.7 * ts_len) + int(0.1 * ts_len)
71 | ts_train, ts_val, ts_test = col[:t1], col[t1:t2].reset_index(drop=True), col[t2:].reset_index(drop=True)
72 |
73 | prompt_train = prompt_generation_single(ts_train)
74 | prompt_val = prompt_generation_single(ts_val)
75 | prompt_test = prompt_generation_single(ts_test)
76 |
77 | prompt_buf_train[index.replace("/", "-")] = prompt_train.T.values
78 | prompt_buf_val[index.replace("/", "-")] = prompt_val.T.values
79 | prompt_buf_test[index.replace("/", "-")] = prompt_test.T.values
80 |
81 | prompt_buf_total = {"train": prompt_buf_train, "val": prompt_buf_val, "test": prompt_buf_test}
82 | print(prompt_buf_total)
83 | return prompt_buf_total
84 |
85 |
86 | def prompt_save(prompt_buf, output_path):
87 |
88 | print(prompt_buf["train"])
89 | if prompt_buf["train"].shape[1] == 1:
90 | # ipdb.set_trace()
91 | return None
92 |
93 | # prompt_train_fname = os.path.join(prompt_train_data_dir, data_name + "_prompt.pth.tar")
94 | # prompt_train = prompt_buf["train"]
95 | # print("Export", prompt_train_fname, prompt_train.shape)
96 | #
97 | # prompt_val_fname = os.path.join(prompt_val_data_dir, data_name + "_prompt.pth.tar")
98 | # prompt_val = prompt_buf["val"]
99 | # torch.save(prompt_val, prompt_val_fname)
100 | # print("Export", prompt_val_fname, prompt_val.shape)
101 | #
102 | # prompt_test_fname = os.path.join(prompt_test_data_dir, data_name + "_prompt.pth.tar")
103 | # prompt_test = prompt_buf["test"]
104 | # torch.save(prompt_test, prompt_test_fname)
105 | # print("Export", prompt_test_fname, prompt_test.shape)
106 |
107 | else:
108 |
109 | for index, col in prompt_buf["train"].T.iterrows():
110 |
111 | prompt_train_fname = os.path.join(output_path, "train", data_name + "_" + index + "_prompt.pth.tar")
112 | prompt_train = col
113 | prompt_train.columns = [index]
114 | prompt_train = prompt_train.T
115 | torch.save(prompt_train, prompt_train_fname)
116 | print("Export", prompt_train_fname, prompt_train.shape)
117 |
118 | for index, col in prompt_buf["val"].T.iterrows():
119 | prompt_val_fname = os.path.join(output_path, "val", data_name + "_" + index + "_prompt.pth.tar")
120 | prompt_val = col
121 | prompt_val.columns = [index]
122 | prompt_val = prompt_val.T
123 | torch.save(prompt_val, prompt_val_fname)
124 | print("Export", prompt_val_fname, prompt_val.shape)
125 |
126 | for index, col in prompt_buf["test"].T.iterrows():
127 | prompt_test_fname = os.path.join(output_path, "test", data_name + "_" + index + "_prompt.pth.tar")
128 | prompt_test = col
129 | prompt_test.columns = [index]
130 | prompt_test = prompt_test.T
131 | torch.save(prompt_test, prompt_test_fname)
132 | print("Export", prompt_test_fname, prompt_test.shape)
133 |
134 |
135 | def data_import(path, format="feather"):
136 |
137 | if format == "feather":
138 | data = read_feather(path)
139 | data_name = path.replace(root_path, "").replace(".feather", "")
140 | data_dir = data_name[0:data_name.rfind("/")]
141 | # ipdb.set_trace()
142 | data = data.value
143 |
144 | else:
145 | data = read_csv(path)
146 | data_name = path.replace(root_path, "").replace(".csv", "")
147 | data_dir = data_name[0:data_name.rfind("/")]
148 | if "date" in data.columns:
149 | data = data.drop("date", axis=1)
150 | # print(data)
151 | # data = data.value
152 |
153 |
154 | return data, data_name, data_dir
155 |
156 |
157 | def create_data_dir(dir_name):
158 | # prompt_dir =
159 | if not os.path.exists(dir_name):
160 | os.mkdir(dir_name)
161 |
162 |
163 | if __name__ == "__main__":
164 |
165 | root_path = "./datasets/"
166 | output_path = "./prompt_bank/stat-prompt/prompt_data_split/"
167 |
168 |
169 | dataset_name = [
170 | "electricity",
171 | "ETT-small",
172 | "exchange_rate",
173 | "illness",
174 | "traffic",
175 | "weather",
176 | ]
177 |
178 | dataset_fullname = [os.path.join(root_path, name) for name in dataset_name]
179 | data_path_buf = []
180 | for dataset_dir in dataset_fullname:
181 | for root, dirs, files in os.walk(dataset_dir):
182 | for file_name in files:
183 | if file_name.endswith(".csv"):
184 | file_path = os.path.join(root, file_name)
185 | data_path_buf.append(file_path)
186 |
187 | print(data_path_buf)
188 | create_data_dir(output_path)
189 | # ipdb.set_trace()
190 |
191 | for path_idx, path in enumerate(data_path_buf):
192 |
193 | # print(path)
194 |
195 | data, data_name, data_dir = data_import(path, "csv")
196 | # print("Data Shape:", data.shape)
197 | if data.shape[0] < 20:
198 | print(path, "Skip too short time-series data.", data.shape)
199 | continue
200 | else:
201 | print("Import", path, "data shape", data.shape)
202 |
203 | create_data_dir(os.path.join(output_path, "train"))
204 | create_data_dir(os.path.join(output_path, "val"))
205 | create_data_dir(os.path.join(output_path, "test"))
206 | create_data_dir(os.path.join(output_path, "train", data_dir))
207 | create_data_dir(os.path.join(output_path, "val", data_dir))
208 | create_data_dir(os.path.join(output_path, "test", data_dir))
209 |
210 | prompt_data_buf = prompt_generation(data, data_name)
211 | if prompt_data_buf is not None:
212 | prompt_save(prompt_data_buf, output_path)
213 |
214 |
--------------------------------------------------------------------------------
/prompt_bank/stat-prompt/prompt_normalization_split.py:
--------------------------------------------------------------------------------
1 | # from ltsm.data_provider.data_factory import get_data_loader, get_data_loaders, get_dataset
2 | import argparse
3 | import ipdb
4 | import pandas as pd
5 | import numpy as np
6 | # import tsfel
7 | from pandas import read_csv, read_feather
8 | import matplotlib.pyplot as plt
9 | import sys, os
10 | import torch
11 | from sklearn.preprocessing import StandardScaler
12 |
13 |
14 | def get_args():
15 | parser = argparse.ArgumentParser(description='LTSM')
16 | parser.add_argument('--mode', choices=["fit", "transform"], required=True)
17 | args = parser.parse_args()
18 |
19 | return args
20 |
21 |
22 | def prompt_generation(ts):
23 | cfg = tsfel.get_features_by_domain()
24 | prompt = tsfel.time_series_features_extractor(cfg, ts)
25 | return prompt
26 |
27 |
28 | def prompt_prune(pt):
29 | pt_dict = pt.to_dict()
30 | pt_keys = list(pt_dict.keys())
31 | for key in pt_keys:
32 | if type(key) == type("abc") and key.startswith("0_FFT mean coefficient"):
33 | del pt[key]
34 |
35 | return pt
36 |
37 |
38 | def mean_std_export_ds(data_path_buf, normalize_param_fname):
39 | prompt_data_buf = []
40 | output_dir_buf = []
41 | output_path_buf = []
42 | for index, dataset_path in enumerate(data_path_buf):
43 | prompt_data = torch.load(dataset_path)
44 | prompt_data = prompt_prune(prompt_data)
45 | # print(prompt_data)
46 | prompt_data_buf.append(prompt_data)
47 |
48 | data_name = dataset_path.replace(root_path, "").replace(".csv", "")
49 | data_dir = data_name[0:data_name.rfind("/")]
50 | prompt_dir = os.path.join(output_path, data_dir)
51 | prompt_fname = os.path.join(output_path, data_name)
52 | # print(prompt_fname)
53 | output_dir_buf.append(prompt_dir)
54 | output_path_buf.append(prompt_fname)
55 | print("Import from {}".format(dataset_path), prompt_data.shape)
56 | # ipdb.set_trace()
57 |
58 | prompt_data_all = pd.concat(prompt_data_buf, axis=1).T
59 | print(prompt_data_all)
60 |
61 | scaler = StandardScaler()
62 | scaler.fit(prompt_data_all)
63 |
64 | sc_mean = pd.DataFrame(scaler.mean_.reshape(1,-1), columns=prompt_data_all.keys())
65 | sc_scale = pd.DataFrame(scaler.scale_.reshape(1,-1), columns=prompt_data_all.keys())
66 |
67 | print({"mean": sc_mean, "scale": sc_scale})
68 | print("Save the mean and std to {}".format(normalize_param_fname))
69 | torch.save({"mean": sc_mean, "scale": sc_scale}, normalize_param_fname)
70 |
71 |
72 | def standardscale_export(data_path_buf, params_fname, output_path, root_path):
73 |
74 | params = torch.load(params_fname)
75 | mean, std = params["mean"], params["scale"]
76 | scaler = StandardScaler()
77 | scaler.mean_ = mean
78 | scaler.scale_ = std
79 | # ipdb.set_trace()
80 |
81 | for index, dataset_path in enumerate(data_path_buf):
82 | prompt_data_raw = torch.load(dataset_path)
83 | prompt_data_raw = prompt_prune(prompt_data_raw)
84 |
85 | prompt_data = scaler.transform(prompt_data_raw.values.reshape(1, -1))
86 | prompt_data_array = prompt_data
87 | # print(prompt_data)
88 | prompt_data_array[np.isnan(prompt_data_array)] = 0
89 | prompt_data_transform = pd.DataFrame(prompt_data_array, columns=prompt_data.keys())
90 | # ipdb.set_trace()
91 |
92 | prompt_fname = dataset_path.replace(root_path, output_path)
93 | prompt_dir = prompt_fname[0:prompt_fname.rfind("/")]
94 | if not os.path.exists(prompt_dir):
95 | os.mkdir(prompt_dir)
96 |
97 | torch.save(prompt_data_transform, prompt_fname)
98 | print("Save to {}".format(prompt_fname))
99 | del prompt_data
100 |
101 |
102 | def create_data_dir(dir_name):
103 | # prompt_dir =
104 | if not os.path.exists(dir_name):
105 | os.mkdir(dir_name)
106 |
107 | if __name__ == "__main__":
108 |
109 | root_path_train = "./prompt_bank/stat-prompt/prompt_data_split/train"
110 | output_path_train = "./prompt_bank/stat-prompt/prompt_data_normalize_split/train"
111 | root_path_val = "./prompt_bank/stat-prompt/prompt_data_split/val"
112 | output_path_val = "./prompt_bank/stat-prompt/prompt_data_normalize_split/val"
113 | root_path_test = "./prompt_bank/stat-prompt/prompt_data_split/test"
114 | output_path_test = "./prompt_bank/stat-prompt/prompt_data_normalize_split/test"
115 | # normalize_param_fname = os.path.join(output_path, "normalization_params.pth.tar")
116 | ds_size = 50
117 | mode = get_args().mode # "transform" # "fit" #
118 |
119 | data_path_buf = {
120 | "train": {"root_path": root_path_train, "output_path": output_path_train, "normalize_param_fname": os.path.join(output_path_train, "normalization_params.pth.tar")},
121 | "val": {"root_path": root_path_val, "output_path": output_path_val, "normalize_param_fname": os.path.join(output_path_val, "normalization_params.pth.tar")},
122 | "test": {"root_path": root_path_test, "output_path": output_path_test, "normalize_param_fname": os.path.join(output_path_test, "normalization_params.pth.tar")},
123 | }
124 |
125 |
126 | dataset_name = [
127 | "electricity",
128 | "ETT-small",
129 | "exchange_rate",
130 | "illness",
131 | "traffic",
132 | "weather",
133 | ]
134 |
135 | for split_name, data_path in data_path_buf.items():
136 | root_path = data_path_buf[split_name]["root_path"]
137 | output_path = data_path_buf[split_name]["output_path"]
138 | normalize_param_fname = data_path_buf[split_name]["normalize_param_fname"]
139 |
140 | create_data_dir(output_path)
141 |
142 | dataset_fullname = [os.path.join(root_path, name) for name in dataset_name]
143 | data_path_buf_tmp = []
144 | if mode == "fit":
145 |
146 | for dataset_dir in dataset_fullname:
147 | paths = os.listdir(dataset_dir)
148 | new_dataset = [os.path.join(dataset_dir, path) for path in paths]
149 | sample_idx = np.random.permutation(len(new_dataset))[:ds_size].astype(np.int64)
150 | # ipdb.set_trace()
151 | new_dataset = np.array(new_dataset)[sample_idx].tolist()
152 | data_path_buf_tmp.extend(new_dataset)
153 |
154 | else:
155 | for dataset_dir in dataset_fullname:
156 | paths = os.listdir(dataset_dir)
157 | new_dataset = [os.path.join(dataset_dir, path) for path in paths]
158 | data_path_buf_tmp.extend(new_dataset)
159 |
160 |
161 | if mode == "fit":
162 |
163 | mean_std_export_ds(data_path_buf_tmp, normalize_param_fname)
164 | else:
165 | # ipdb.set_trace()
166 | standardscale_export(data_path_buf_tmp, normalize_param_fname, output_path, root_path)
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
--------------------------------------------------------------------------------
/prompt_bank/stat-prompt/prompt_tsne.py:
--------------------------------------------------------------------------------
1 | # from ltsm.data_provider.data_factory import get_data_loader, get_data_loaders, get_dataset
2 | import argparse
3 | import ipdb
4 | import pandas as pd
5 | import numpy as np
6 | # import tsfel
7 | from pandas import read_csv, read_feather
8 | import matplotlib.pyplot as plt
9 | import sys, os
10 | import torch
11 | from sklearn.preprocessing import StandardScaler
12 | from sklearn import manifold
13 |
14 |
15 | def get_args():
16 | parser = argparse.ArgumentParser(description='LTSM')
17 |
18 | parser.add_argument('--data_path', type=str, default='dataset/weather.csv')
19 | parser.add_argument('--data', type=str, default='custom')
20 | parser.add_argument('--freq', type=str, default="h")
21 | parser.add_argument('--target', type=str, default='OT')
22 | parser.add_argument('--embed', type=str, default='timeF')
23 | parser.add_argument('--percent', type=int, default=10)
24 | parser.add_argument('--batch_size', type=int, default=512)
25 | parser.add_argument('--max_len', type=int, default=-1)
26 | parser.add_argument('--seq_len', type=int, default=512)
27 | parser.add_argument('--pred_len', type=int, default=96)
28 | parser.add_argument('--label_len', type=int, default=48)
29 | parser.add_argument('--features', type=str, default='M')
30 |
31 | args = parser.parse_args()
32 |
33 | return args
34 |
35 |
36 | def prompt_generation(ts):
37 | cfg = tsfel.get_features_by_domain()
38 | prompt = tsfel.time_series_features_extractor(cfg, ts)
39 | return prompt
40 |
41 |
42 | def prompt_prune(pt):
43 | pt_dict = pt.to_dict()
44 | pt_keys = list(pt_dict.keys())
45 | for key in pt_keys:
46 | if key.startswith("0_FFT mean coefficient"):
47 | del pt[key]
48 |
49 | return pt
50 |
51 |
52 | if __name__ == "__main__":
53 |
54 | root_path = "./prompt_bank/stat-prompt/prompt_data_split/"
55 | # print(data_path_buf)
56 |
57 | dataset_name = [
58 | "electricity",
59 | "ETT-small",
60 | "exchange_rate",
61 | "illness",
62 | "traffic",
63 | "weather",
64 | ]
65 | split_buf = ["train", "val", "test"]
66 |
67 | dataset_fullname_train = [os.path.join(root_path, "train", name) for name in dataset_name]
68 | dataset_fullname_val = [os.path.join(root_path, "val", name) for name in dataset_name]
69 | dataset_fullname_test = [os.path.join(root_path, "test", name) for name in dataset_name]
70 | dataset_fullname = dataset_fullname_train + dataset_fullname_val + dataset_fullname_test
71 | data_path_buf = []
72 | dataset_dir_buf = []
73 | dataset_split_buf = []
74 | K = 100
75 | for index, dataset_dir in enumerate(dataset_fullname):
76 | paths = os.listdir(dataset_dir)
77 | new_dataset = [os.path.join(dataset_dir, path) for path in paths]
78 | sample_idx = np.random.permutation(len(new_dataset))[:K].astype(np.int64)
79 | # ipdb.set_trace()
80 | new_dataset = np.array(new_dataset)[sample_idx].tolist()
81 | data_path_buf.extend(new_dataset)
82 |
83 | for dataset_index, dname in enumerate(dataset_name):
84 | if dname in dataset_dir:
85 | dataset_dir_buf.extend(len(new_dataset) * [dataset_index])
86 |
87 | for split_index, split in enumerate(split_buf):
88 | if split in dataset_dir:
89 | dataset_split_buf.extend(len(new_dataset) * [split_index])
90 | break
91 |
92 | prompt_data_buf = []
93 | for index, dataset_path in enumerate(data_path_buf):
94 | prompt_data = torch.load(dataset_path)
95 | prompt_data_buf.append(prompt_data)
96 | print("Import from {}".format(dataset_path))
97 | # print(prompt_data)
98 |
99 | # if index == 100:
100 | # break
101 |
102 | # print(prompt_data_buf)
103 | # print(output_path_buf)
104 |
105 | prompt_data_all = pd.concat(prompt_data_buf, axis=0).values
106 | print(prompt_data_all.shape)
107 | # (3166, 133)
108 |
109 | # nan_index = np.where(np.isnan(prompt_data_all))[0]
110 | # prompt_data_all[nan_index] = 0
111 |
112 | # ipdb.set_trace()
113 | tsne = manifold.TSNE(n_components=2, init='pca', random_state=0)
114 | prompt_data_tsne = tsne.fit_transform(prompt_data_all)
115 | dataset_plot_buf = ["electricity"]
116 | color_buf = ["red", "blue", "black", "green", "pink", "brown"]
117 | marker_buf = [".", "^", "x"]
118 | for index, _ in enumerate(dataset_name):
119 | for sindex, split_fold in enumerate(split_buf):
120 | data_index = (np.array(dataset_dir_buf) == index)
121 | split_index = (np.array(dataset_split_buf) == sindex)
122 | plot_index = data_index & split_index
123 | plt.plot(prompt_data_tsne[plot_index, 0], prompt_data_tsne[plot_index, 1], linewidth=0, marker=marker_buf[sindex], label=str(dataset_name[index][0:8] + "-" + split_fold), color=color_buf[index])
124 | # plt.text(prompt_data_tsne[data_index, 0].mean()-20, prompt_data_tsne[data_index, 1].mean(), str(dataset_name[index][0:8]), fontdict={'weight': 'bold', 'size': 9})
125 |
126 | plt.legend(loc="right")
127 | plt.savefig("./figures/stat_prompt_tsne.png")
128 | plt.close()
129 |
130 | # ipdb.set_trace()
131 | # plt.xticks([])
132 | # plt.yticks([])
133 |
134 | # print(prompt_data_all)
135 | # , color = plt.cm.Set1(dataset_dir_buf[index])
136 | # print(prompt_data_transform)
137 | # print(prompt_data_transform_array.mean(axis=0))
138 | # print(prompt_data_transform_array.std(axis=0))
139 | # print(prompt_data_transform.loc[5])
140 |
141 |
142 |
143 |
144 |
145 |
146 |
--------------------------------------------------------------------------------
/prompt_bank/stat-prompt/tsfel/__init__.py:
--------------------------------------------------------------------------------
1 | from tsfel.utils import *
2 | from tsfel.feature_extraction import *
3 |
--------------------------------------------------------------------------------
/prompt_bank/stat-prompt/tsfel/feature_extraction/__init__.py:
--------------------------------------------------------------------------------
1 | from tsfel.feature_extraction.calc_features import *
2 | from tsfel.feature_extraction.features import *
3 | from tsfel.feature_extraction.features_settings import *
4 | from tsfel.feature_extraction.features_utils import *
5 |
--------------------------------------------------------------------------------
/prompt_bank/stat-prompt/tsfel/feature_extraction/calc_features.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import importlib
3 | import multiprocessing as mp
4 | import numbers
5 | import os
6 | import pathlib
7 | import sys
8 | import warnings
9 | from functools import partial
10 | from pathlib import Path
11 |
12 | import numpy as np
13 | import pandas as pd
14 |
15 | from IPython import get_ipython
16 | from IPython.display import display
17 |
18 | from tsfel.utils.progress_bar import display_progress_bar, progress_bar_notebook
19 | from tsfel.utils.signal_processing import merge_time_series, signal_window_splitter
20 |
21 | import ipdb
22 |
23 | def dataset_features_extractor(main_directory, feat_dict, verbose=1, **kwargs):
24 | """Extracts features from a dataset.
25 |
26 | Parameters
27 | ----------
28 | main_directory : String
29 | Input directory
30 | feat_dict : dict
31 | Dictionary with features
32 | verbose : int
33 | Verbosity mode. 0 = silent, 1 = progress bar.
34 | (0 or 1 (Default))
35 | \**kwargs:
36 | See below:
37 | * *search_criteria* (``list``) --
38 | List of file names to compute features. (Example: 'Accelerometer.txt')
39 | (default: ``None``)
40 |
41 | * *time_unit* (``float``) --
42 | Time unit
43 | (default: ``1e9``)
44 |
45 | * *resampling_rate* (``int``) --
46 | Resampling rate
47 | (default: ``100``)
48 |
49 | * *window_size* (``int``) --
50 | Window size in number of samples
51 | (default: ``100``)
52 |
53 | * *overlap* (``float``) --
54 | Overlap between 0 and 1
55 | (default: ``0``)
56 |
57 | * *pre_process* (``function``) --
58 | Function with pre processing code
59 |
60 | (default: ``None``)
61 |
62 | * *output_directory* (``String``) --
63 | Output directory
64 | (default: ``'output_directory', str(Path.home()) + '/tsfel_output'``)
65 |
66 | * *features_path* (``string``) --
67 | Directory of script with personal features
68 |
69 | * *header_names* (``list or array``) --
70 | Names of each column window
71 |
72 | * *n_jobs* (``int``) --
73 | The number of jobs to run in parallel. ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
74 | ``-1`` means using all processors.
75 | (default: ``None`` in Windows and ``-1`` for other systems)
76 |
77 | Returns
78 | -------
79 | file
80 | csv file with the extracted features
81 |
82 | """
83 | search_criteria = kwargs.get('search_criteria', None)
84 | time_unit = kwargs.get('time_unit', 1e9)
85 | resample_rate = kwargs.get('resample_rate', 30)
86 | window_size = kwargs.get('window_size', 100)
87 | overlap = kwargs.get('overlap', 0)
88 | pre_process = kwargs.get('pre_process', None)
89 | output_directory = kwargs.get('output_directory', str(Path.home()) + '/tsfel_output')
90 | features_path = kwargs.get('features_path', None)
91 | names = kwargs.get('header_names', None)
92 |
93 | # Choosing default of n_jobs by operating system
94 | if sys.platform[:-2] == 'win':
95 | n_jobs_default = None
96 | else:
97 | n_jobs_default = -1
98 |
99 | # Choosing default of n_jobs by python interface
100 | if get_ipython().__class__.__name__ == 'ZMQInteractiveShell' or \
101 | get_ipython().__class__.__name__ == 'Shell':
102 | n_jobs_default = -1
103 |
104 | n_jobs = kwargs.get('n_jobs', n_jobs_default)
105 |
106 | if main_directory[-1] != os.sep:
107 | main_directory = main_directory + os.sep
108 |
109 | folders = [f for f in glob.glob(main_directory + "**/", recursive=True)]
110 |
111 | if folders:
112 | for fl in folders:
113 | sensor_data = {}
114 | if search_criteria:
115 | for c in search_criteria:
116 | if os.path.isfile(fl + c):
117 | key = c.split('.')[0]
118 | sensor_data[key] = pd.read_csv(fl + c, header=None)
119 | else:
120 | all_files = np.concatenate((glob.glob(fl + '/*.txt'), glob.glob(fl + '/*.csv')))
121 | for c in all_files:
122 | key = c.split(os.sep)[-1].split('.')[0]
123 | try:
124 | data_file = pd.read_csv(c, header=None)
125 | except pd.io.common.CParserError:
126 | continue
127 |
128 | if np.dtype('O') in np.array(data_file.dtypes):
129 | continue
130 |
131 | sensor_data[key] = pd.read_csv(c, header=None)
132 |
133 | if not sensor_data:
134 | continue
135 |
136 | pp_sensor_data = sensor_data if pre_process is None else pre_process(sensor_data)
137 |
138 | data_new = merge_time_series(pp_sensor_data, resample_rate, time_unit)
139 |
140 | windows = signal_window_splitter(data_new, window_size, overlap)
141 |
142 | if features_path:
143 | features = time_series_features_extractor(feat_dict, windows, fs=resample_rate, verbose=0,
144 | features_path=features_path, header_names=names, n_jobs=n_jobs)
145 | else:
146 | features = time_series_features_extractor(feat_dict, windows, fs=resample_rate, verbose=0,
147 | header_names=names, n_jobs=n_jobs)
148 |
149 | fl = '/'.join(fl.split(os.sep))
150 | invalid_char = '<>:"\|?* '
151 | for char in invalid_char:
152 | fl = fl.replace(char, '')
153 |
154 | pathlib.Path(output_directory + fl).mkdir(parents=True, exist_ok=True)
155 | features.to_csv(output_directory + fl + '/Features.csv', sep=',', encoding='utf-8')
156 |
157 | if verbose == 1:
158 | print('Features files saved in: ', output_directory)
159 | else:
160 | raise FileNotFoundError("There is no folder(s) in directory: " + main_directory)
161 |
162 |
163 | def calc_features(wind_sig, dict_features, fs, **kwargs):
164 | """Extraction of time series features.
165 |
166 | Parameters
167 | ----------
168 | wind_sig: list
169 | Input from which features are computed, window
170 | dict_features : dict
171 | Dictionary with features
172 | fs : float or None
173 | Sampling frequency
174 | \**kwargs:
175 | * *features_path* (``string``) --
176 | Directory of script with personal features
177 | * *header_names* (``list or array``) --
178 | Names of each column window
179 |
180 | Returns
181 | -------
182 | DataFrame
183 | Extracted features
184 |
185 | """
186 |
187 | features_path = kwargs.get('features_path', None)
188 | names = kwargs.get('header_names', None)
189 | feat_val = calc_window_features(dict_features, wind_sig, fs, features_path=features_path, header_names=names)
190 | feat_val.reset_index(drop=True)
191 |
192 | return feat_val
193 |
194 |
195 | def time_series_features_extractor(dict_features, signal_windows, fs=None, verbose=1, **kwargs):
196 | """Extraction of time series features.
197 |
198 | Parameters
199 | ----------
200 | dict_features : dict
201 | Dictionary with features
202 | signal_windows: list
203 | Input from which features are computed, window
204 | fs : int or None
205 | Sampling frequency
206 | verbose : int
207 | Verbosity mode. 0 = silent, 1 = progress bar.
208 | (0 or 1 (Default))
209 | \**kwargs:
210 | See below:
211 | * *window_size* (``int``) --
212 | Window size in number of samples
213 | (default: ``100``)
214 |
215 | * *overlap* (``float``) --
216 | Overlap between 0 and 1
217 | (default: ``0``)
218 |
219 | * *features_path* (``string``) --
220 | Directory of script with personal features
221 |
222 | * *header_names* (``list or array``) --
223 | Names of each column window
224 |
225 | * *n_jobs* (``int``) --
226 | The number of jobs to run in parallel. ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
227 | ``-1`` means using all processors.
228 | (default: ``None`` in Windows and ``-1`` for other systems)
229 |
230 | Returns
231 | -------
232 | DataFrame
233 | Extracted features
234 |
235 | """
236 | if verbose == 1:
237 | print("*** Feature extraction started ***")
238 |
239 | window_size = kwargs.get('window_size', None)
240 | overlap = kwargs.get('overlap', 0)
241 | features_path = kwargs.get('features_path', None)
242 | names = kwargs.get('header_names', None)
243 |
244 | # Choosing default of n_jobs by operating system
245 | if sys.platform[:-2] == 'win':
246 | n_jobs_default = None
247 | else:
248 | n_jobs_default = -1
249 |
250 | # Choosing default of n_jobs by python interface
251 | if get_ipython().__class__.__name__ == 'ZMQInteractiveShell' or \
252 | get_ipython().__class__.__name__ == 'Shell':
253 | n_jobs_default = -1
254 |
255 | n_jobs = kwargs.get('n_jobs', n_jobs_default)
256 |
257 | if fs is None:
258 | warnings.warn('Using default sampling frequency set in configuration file.', stacklevel=2)
259 |
260 | if names is not None:
261 | names = list(names)
262 | else:
263 | # Name of each column to be concatenated with feature name
264 | if isinstance(signal_windows, pd.DataFrame):
265 | names = signal_windows.columns.values
266 | elif isinstance(signal_windows[0], pd.DataFrame):
267 | names = signal_windows[0].columns.values
268 |
269 | if window_size is not None:
270 | signal_windows = signal_window_splitter(signal_windows, window_size, overlap)
271 |
272 | if len(signal_windows) == 0:
273 | raise SystemExit('Empty signal windows. Please check window size input parameter.')
274 |
275 | features_final = pd.DataFrame()
276 |
277 | if isinstance(signal_windows, list) and isinstance(signal_windows[0], numbers.Real):
278 | signal_windows = np.array(signal_windows)
279 |
280 | # more than one window
281 | if isinstance(signal_windows, list):
282 | # Starting the display of progress bar for notebooks interfaces
283 | if (get_ipython().__class__.__name__ == "ZMQInteractiveShell") or (
284 | get_ipython().__class__.__name__ == "Shell"
285 | ):
286 |
287 | out = display(progress_bar_notebook(0, len(signal_windows)), display_id=True)
288 | else:
289 | out = None
290 |
291 | if isinstance(n_jobs, int):
292 | # Multiprocessing use
293 | if n_jobs == -1:
294 | cpu_count = mp.cpu_count()
295 | else:
296 | cpu_count = n_jobs
297 |
298 | pool = mp.Pool(cpu_count)
299 | features = pool.imap(
300 | partial(
301 | calc_features,
302 | dict_features=dict_features,
303 | fs=fs,
304 | features_path=features_path,
305 | header_names=names,
306 | ),
307 | signal_windows,
308 | )
309 |
310 | for i, feat in enumerate(features):
311 | if verbose == 1:
312 | display_progress_bar(i, len(signal_windows), out)
313 | features_final = pd.concat([features_final, feat], axis=0)
314 |
315 | pool.close()
316 | pool.join()
317 |
318 | elif n_jobs is None:
319 | for i, feat in enumerate(signal_windows):
320 | features_final = pd.concat(
321 | [
322 | features_final,
323 | calc_window_features(
324 | dict_features, feat, fs, features_path=features_path, header_names=names)
325 | ], axis=0)
326 | if verbose == 1:
327 | display_progress_bar(i, len(signal_windows), out)
328 | else:
329 | raise SystemExit(
330 | "n_jobs value is not valid. " "Choose an integer value or None for no multiprocessing."
331 | )
332 | # single window
333 | else:
334 | # import ipdb
335 | # ipdb.set_trace()
336 | features_final = calc_window_features(
337 | dict_features,
338 | signal_windows,
339 | fs,
340 | verbose=verbose,
341 | features_path=features_path,
342 | header_names=names,
343 | single_window=True,
344 | )
345 |
346 | if verbose == 1:
347 | print("\n"+"*** Feature extraction finished ***")
348 |
349 | # Assuring the same feature extraction order
350 | features_final = features_final.reindex(sorted(features_final.columns), axis=1)
351 | return features_final.reset_index(drop=True)
352 |
353 |
354 | def calc_window_features(dict_features, signal_window, fs, verbose=1, single_window=False, **kwargs):
355 | """This function computes features matrix for one window.
356 |
357 | Parameters
358 | ----------
359 | dict_features : dict
360 | Dictionary with features
361 | signal_window: pandas DataFrame
362 | Input from which features are computed, window
363 | fs : float
364 | Sampling frequency
365 | verbose : int
366 | Level of function communication
367 | (0 or 1 (Default))
368 | single_window: Bool
369 | Boolean value for printing the progress bar for only one window feature extraction
370 | \**kwargs:
371 | See below:
372 | * *features_path* (``string``) --
373 | Directory of script with personal features
374 | * *header_names* (``list or array``) --
375 | Names of each column window
376 |
377 | Returns
378 | -------
379 | pandas DataFrame
380 | (columns) names of the features
381 | (data) values of each features for signal
382 |
383 | """
384 |
385 | features_path = kwargs.get('features_path', None)
386 | header_names = kwargs.get('header_names', None)
387 |
388 | # To handle object type signals
389 | signal_window = np.array(signal_window).astype(float)
390 |
391 | single_axis = True if len(signal_window.shape) == 1 else False
392 |
393 | if header_names is None:
394 | header_names = np.array([0]) if single_axis else np.arange(signal_window.shape[-1])
395 | else:
396 | if (len(header_names) != signal_window.shape[-1] and not single_axis) or \
397 | (len(header_names) != 1 and single_axis):
398 | raise Exception("header_names dimension does not match input columns.")
399 |
400 | # Execute imports
401 | exec("from tsfel import *")
402 | domain = dict_features.keys()
403 |
404 | if features_path:
405 | sys.path.append(features_path[:-len(features_path.split(os.sep)[-1])-1])
406 | exec("import "+features_path.split(os.sep)[-1][:-3])
407 | importlib.reload(sys.modules[features_path.split(os.sep)[-1][:-3]])
408 | exec("from " + features_path.split(os.sep)[-1][:-3]+" import *")
409 |
410 | # Create global arrays
411 | feature_results = []
412 | feature_names = []
413 |
414 | # Starting the display of progress bar for notebooks interfaces
415 | # Iterating over features of a single window
416 | if verbose == 1 and single_window:
417 |
418 | feat_nb = np.hstack([list(dict_features[_type].keys()) for _type in domain])
419 |
420 | if (get_ipython().__class__.__name__ == 'ZMQInteractiveShell') or (
421 | get_ipython().__class__.__name__ == 'Shell'):
422 | out = display(progress_bar_notebook(0, len(feat_nb)), display_id=True)
423 | else:
424 | out = None
425 |
426 | i_feat = -1
427 |
428 | for _type in domain:
429 | domain_feats = dict_features[_type].keys()
430 | # print(domain_feats)
431 | # ipdb.set_trace()
432 |
433 | for feat in domain_feats:
434 |
435 | if verbose == 1 and single_window:
436 | i_feat = i_feat + 1
437 | display_progress_bar(i_feat, len(feat_nb), out)
438 |
439 | # Only returns used functions
440 | if dict_features[_type][feat]["use"] == "yes":
441 |
442 | # Read Function (real name of function)
443 | func_total = dict_features[_type][feat]["function"]
444 |
445 | if func_total.find("tsfel.") == 0:
446 | func_total = func_total.replace("tsfel.", "")
447 |
448 | # Check for parameters
449 | parameters_total = {}
450 |
451 | if dict_features[_type][feat]["parameters"] != "":
452 | parameters_total = dict_features[_type][feat]["parameters"]
453 |
454 | # Check assert fs parameter:
455 | if "fs" in parameters_total:
456 |
457 | # Select which fs to use
458 | if fs is None:
459 | # Check if features dict has default sampling frequency value
460 | if not (type(parameters_total["fs"]) is int or type(parameters_total["fs"]) is float):
461 | raise Exception("No sampling frequency assigned.")
462 | else:
463 | parameters_total["fs"] = fs
464 |
465 | # Eval feature results
466 | if single_axis:
467 | eval_result = locals()[func_total](signal_window, **parameters_total)
468 | eval_result = np.array([eval_result])
469 |
470 | for ax in range(len(header_names)):
471 | sig_ax = signal_window if single_axis else signal_window[:, ax]
472 | eval_result_ax = locals()[func_total](sig_ax, **parameters_total)
473 | # Function returns more than one element
474 | if type(eval_result_ax) == tuple:
475 | if np.isnan(eval_result_ax[0]):
476 | eval_result_ax = np.zeros(len(eval_result_ax))
477 | for rr in range(len(eval_result_ax)):
478 | feature_results += [eval_result_ax[rr]]
479 | feature_names += [str(header_names[ax]) + "_" + feat + "_" + str(rr)]
480 | else:
481 | feature_results += [eval_result_ax]
482 | feature_names += [str(header_names[ax]) + "_" + feat]
483 |
484 | features = pd.DataFrame(
485 | data=np.array(feature_results).reshape(1, len(feature_results)), columns=np.array(feature_names)
486 | )
487 |
488 | return features
489 |
--------------------------------------------------------------------------------
/prompt_bank/stat-prompt/tsfel/feature_extraction/features.json:
--------------------------------------------------------------------------------
1 | {
2 | "spectral": {
3 | "FFT mean coefficient": {
4 | "complexity": "constant",
5 | "description": "Computes the mean value of each spectrogram frequency.",
6 | "function": "tsfel.fft_mean_coeff",
7 | "parameters": {
8 | "fs": 100,
9 | "nfreq": 256
10 | },
11 | "n_features": "nfreq",
12 | "use": "yes"
13 | },
14 | "Fundamental frequency": {
15 | "complexity": "log",
16 | "description": "Computes the fundamental frequency.",
17 | "function": "tsfel.fundamental_frequency",
18 | "parameters": {
19 | "fs": 100
20 | },
21 | "n_features": 1,
22 | "use": "yes"
23 | },
24 | "Human range energy": {
25 | "complexity": "log",
26 | "description": "Computes the human range energy ratio given by the ratio between the energy in frequency 0.6-2.5Hz and the whole energy band.",
27 | "function": "tsfel.human_range_energy",
28 | "parameters": {
29 | "fs": 100
30 | },
31 | "n_features": 1,
32 | "use": "yes",
33 | "tag": "inertial"
34 | },
35 | "LPCC": {
36 | "complexity": "log",
37 | "description": "Computes the linear prediction cepstral coefficients.",
38 | "function": "tsfel.lpcc",
39 | "parameters": {
40 | "n_coeff": 12
41 | },
42 | "n_features": "n_coeff",
43 | "use": "yes",
44 | "tag": "audio"
45 | },
46 | "MFCC": {
47 | "complexity": "constant",
48 | "description": "Computes the MEL cepstral coefficients.",
49 | "function": "tsfel.mfcc",
50 | "parameters": {
51 | "cep_lifter": 22,
52 | "fs": 100,
53 | "nfft": 512,
54 | "nfilt": 40,
55 | "num_ceps": 12,
56 | "pre_emphasis": 0.97
57 | },
58 | "n_features": "num_ceps",
59 | "use": "yes",
60 | "tag": [
61 | "audio",
62 | "emg"
63 | ]
64 | },
65 | "Max power spectrum": {
66 | "complexity": "log",
67 | "description": "Computes the maximum power spectrum density.",
68 | "function": "tsfel.max_power_spectrum",
69 | "parameters": {
70 | "fs": 100
71 | },
72 | "n_features": 1,
73 | "use": "yes"
74 | },
75 | "Maximum frequency": {
76 | "complexity": "log",
77 | "description": "Computes the maximum frequency.",
78 | "function": "tsfel.max_frequency",
79 | "parameters": {
80 | "fs": 100
81 | },
82 | "n_features": 1,
83 | "use": "yes"
84 | },
85 | "Median frequency": {
86 | "complexity": "log",
87 | "description": "Computes the median frequency.",
88 | "function": "tsfel.median_frequency",
89 | "parameters": {
90 | "fs": 100
91 | },
92 | "n_features": 1,
93 | "use": "yes"
94 | },
95 | "Power bandwidth": {
96 | "complexity": "log",
97 | "description": "Computes power spectrum density bandwidth of the signal.",
98 | "function": "tsfel.power_bandwidth",
99 | "parameters": {
100 | "fs": 100
101 | },
102 | "n_features": 1,
103 | "use": "yes"
104 | },
105 | "Spectral centroid": {
106 | "complexity": "linear",
107 | "description": "Computes the barycenter of the spectrum.",
108 | "function": "tsfel.spectral_centroid",
109 | "parameters": {
110 | "fs": 100
111 | },
112 | "n_features": 1,
113 | "use": "yes",
114 | "tag": "audio"
115 | },
116 | "Spectral decrease": {
117 | "complexity": "log",
118 | "description": "Computes the amount of decreasing of the spectra amplitude.",
119 | "function": "tsfel.spectral_decrease",
120 | "parameters": {
121 | "fs": 100
122 | },
123 | "n_features": 1,
124 | "use": "yes"
125 | },
126 | "Spectral distance": {
127 | "complexity": "log",
128 | "description": "Computes the signal spectral distance.",
129 | "function": "tsfel.spectral_distance",
130 | "parameters": {
131 | "fs": 100
132 | },
133 | "n_features": 1,
134 | "use": "yes"
135 | },
136 | "Spectral entropy": {
137 | "complexity": "log",
138 | "description": "Computes the spectral entropy of the signal based on Fourier transform.",
139 | "function": "tsfel.spectral_entropy",
140 | "parameters": {
141 | "fs": 100
142 | },
143 | "n_features": 1,
144 | "use": "yes",
145 | "tag": "eeg"
146 | },
147 | "Spectral kurtosis": {
148 | "complexity": "linear",
149 | "description": "Computes the flatness of a distribution around its mean value.",
150 | "function": "tsfel.spectral_kurtosis",
151 | "parameters": {
152 | "fs": 100
153 | },
154 | "n_features": 1,
155 | "use": "yes"
156 | },
157 | "Spectral positive turning points": {
158 | "complexity": "log",
159 | "description": "Computes number of positive turning points of the fft magnitude signal",
160 | "function": "tsfel.spectral_positive_turning",
161 | "parameters": {
162 | "fs": 100
163 | },
164 | "n_features": 1,
165 | "use": "yes"
166 | },
167 | "Spectral roll-off": {
168 | "complexity": "log",
169 | "description": "Computes the frequency where 95% of the signal magnitude is contained below of this value.",
170 | "function": "tsfel.spectral_roll_off",
171 | "parameters": {
172 | "fs": 100
173 | },
174 | "n_features": 1,
175 | "use": "yes",
176 | "tag": "audio"
177 | },
178 | "Spectral roll-on": {
179 | "complexity": "log",
180 | "description": "Computes the frequency where 5% of the signal magnitude is contained below of this value.",
181 | "function": "tsfel.spectral_roll_on",
182 | "parameters": {
183 | "fs": 100
184 | },
185 | "n_features": 1,
186 | "use": "yes"
187 | },
188 | "Spectral skewness": {
189 | "complexity": "linear",
190 | "description": "Computes the asymmetry of a distribution around its mean value.",
191 | "function": "tsfel.spectral_skewness",
192 | "parameters": {
193 | "fs": 100
194 | },
195 | "n_features": 1,
196 | "use": "yes"
197 | },
198 | "Spectral slope": {
199 | "complexity": "log",
200 | "description": "Computes the spectral slope, obtained by linear regression of the spectral amplitude.",
201 | "function": "tsfel.spectral_slope",
202 | "parameters": {
203 | "fs": 100
204 | },
205 | "n_features": 1,
206 | "use": "yes"
207 | },
208 | "Spectral spread": {
209 | "complexity": "linear",
210 | "description": "Computes the spread of the spectrum around its mean value.",
211 | "function": "tsfel.spectral_spread",
212 | "parameters": {
213 | "fs": 100
214 | },
215 | "n_features": 1,
216 | "use": "yes"
217 | },
218 | "Spectral variation": {
219 | "complexity": "log",
220 | "description": "Computes the amount of variation of the spectrum along time.",
221 | "function": "tsfel.spectral_variation",
222 | "parameters": {
223 | "fs": 100
224 | },
225 | "n_features": 1,
226 | "use": "yes"
227 | },
228 | "Wavelet absolute mean": {
229 | "complexity": "linear",
230 | "description": "Computes CWT absolute mean value of each wavelet scale.",
231 | "function": "tsfel.wavelet_abs_mean",
232 | "parameters": {
233 | "function": "scipy.signal.ricker",
234 | "widths": "np.arange(1,10)"
235 | },
236 | "n_features": "widths",
237 | "use": "yes",
238 | "tag": [
239 | "eeg",
240 | "ecg"
241 | ]
242 | },
243 | "Wavelet energy": {
244 | "complexity": "linear",
245 | "description": "Computes CWT energy of each wavelet scale.",
246 | "function": "tsfel.wavelet_energy",
247 | "parameters": {
248 | "function": "scipy.signal.ricker",
249 | "widths": "np.arange(1,10)"
250 | },
251 | "n_features": "widths",
252 | "use": "yes",
253 | "tag": "eeg"
254 | },
255 | "Wavelet entropy": {
256 | "complexity": "linear",
257 | "description": "Computes CWT entropy of the signal.",
258 | "function": "tsfel.wavelet_entropy",
259 | "parameters": {
260 | "function": "scipy.signal.ricker",
261 | "widths": "np.arange(1,10)"
262 | },
263 | "n_features": 1,
264 | "use": "yes",
265 | "tag": "eeg"
266 | },
267 | "Wavelet standard deviation": {
268 | "complexity": "linear",
269 | "description": "Computes CWT std value of each wavelet scale.",
270 | "function": "tsfel.wavelet_std",
271 | "parameters": {
272 | "function": "scipy.signal.ricker",
273 | "widths": "np.arange(1,10)"
274 | },
275 | "n_features": "widths",
276 | "use": "yes",
277 | "tag": "eeg"
278 | },
279 | "Wavelet variance": {
280 | "complexity": "linear",
281 | "description": "Computes CWT variance value of each wavelet scale.",
282 | "function": "tsfel.wavelet_var",
283 | "parameters": {
284 | "function": "scipy.signal.ricker",
285 | "widths": "np.arange(1,10)"
286 | },
287 | "n_features": "widths",
288 | "use": "yes",
289 | "tag": "eeg"
290 | }
291 | },
292 | "statistical": {
293 | "Absolute energy": {
294 | "complexity": "log",
295 | "description": "Computes the absolute energy of the signal.",
296 | "function": "tsfel.abs_energy",
297 | "parameters": "",
298 | "n_features": 1,
299 | "use": "yes",
300 | "tag": "audio"
301 | },
302 | "Average power": {
303 | "complexity": "constant",
304 | "description": "Computes the average power of the signal.",
305 | "function": "tsfel.average_power",
306 | "parameters": {
307 | "fs": 100
308 | },
309 | "n_features": 1,
310 | "use": "yes",
311 | "tag": "audio"
312 | },
313 | "ECDF": {
314 | "complexity": "log",
315 | "description": "Computes the values of ECDF (empirical cumulative distribution function) along the time axis.",
316 | "function": "tsfel.ecdf",
317 | "parameters": {
318 | "d": 10
319 | },
320 | "n_features": "d",
321 | "use": "yes"
322 | },
323 | "ECDF Percentile": {
324 | "complexity": "log",
325 | "description": "Determines the percentile value of the ECDF.",
326 | "function": "tsfel.ecdf_percentile",
327 | "parameters": {
328 | "percentile": "[0.2, 0.8]"
329 | },
330 | "n_features": "percentile",
331 | "use": "yes"
332 | },
333 | "ECDF Percentile Count": {
334 | "complexity": "log",
335 | "description": "Determines the cumulative sum of samples that are less than the percentile.",
336 | "function": "tsfel.ecdf_percentile_count",
337 | "parameters": {
338 | "percentile": "[0.2, 0.8]"
339 | },
340 | "n_features": "percentile",
341 | "use": "yes"
342 | },
343 | "Entropy": {
344 | "complexity": "log",
345 | "description": "Computes the entropy of the signal using the Shannon Entropy.",
346 | "function": "tsfel.entropy",
347 | "parameters": {
348 | "prob": "standard"
349 | },
350 | "n_features": 1,
351 | "use": "yes",
352 | "tag": "eeg"
353 | },
354 | "Histogram": {
355 | "complexity": "log",
356 | "description": "Computes histogram of the signal.",
357 | "function": "tsfel.hist",
358 | "parameters": {
359 | "nbins": 10,
360 | "r": 1
361 | },
362 | "n_features": "nbins",
363 | "use": "yes"
364 | },
365 | "Interquartile range": {
366 | "complexity": "constant",
367 | "description": "Computes interquartile range of the signal.",
368 | "function": "tsfel.interq_range",
369 | "parameters": "",
370 | "n_features": 1,
371 | "use": "yes"
372 | },
373 | "Kurtosis": {
374 | "complexity": "constant",
375 | "description": "Computes kurtosis of the signal.",
376 | "function": "tsfel.kurtosis",
377 | "parameters": "",
378 | "n_features": 1,
379 | "use": "yes"
380 | },
381 | "Max": {
382 | "complexity": "constant",
383 | "description": "Computes the maximum value of the signal.",
384 | "function": "tsfel.calc_max",
385 | "parameters": "",
386 | "n_features": 1,
387 | "use": "yes"
388 | },
389 | "Mean": {
390 | "complexity": "constant",
391 | "description": "Computes the mean value of the signal.",
392 | "function": "tsfel.calc_mean",
393 | "parameters": "",
394 | "n_features": 1,
395 | "use": "yes",
396 | "tag": "inertial"
397 | },
398 | "Mean absolute deviation": {
399 | "complexity": "log",
400 | "description": "Computes mean absolute deviation of the signal.",
401 | "function": "tsfel.mean_abs_deviation",
402 | "parameters": "",
403 | "n_features": 1,
404 | "use": "yes"
405 | },
406 | "Median": {
407 | "complexity": "constant",
408 | "description": "Computes median of the signal.",
409 | "function": "tsfel.calc_median",
410 | "parameters": "",
411 | "n_features": 1,
412 | "use": "yes"
413 | },
414 | "Median absolute deviation": {
415 | "complexity": "constant",
416 | "description": "Computes median absolute deviation of the signal.",
417 | "function": "tsfel.median_abs_deviation",
418 | "parameters": "",
419 | "n_features": 1,
420 | "use": "yes"
421 | },
422 | "Min": {
423 | "complexity": "constant",
424 | "description": "Computes the minimum value of the signal.",
425 | "function": "tsfel.calc_min",
426 | "parameters": "",
427 | "n_features": 1,
428 | "use": "yes"
429 | },
430 | "Peak to peak distance": {
431 | "complexity": "constant",
432 | "description": "Computes the peak to peak distance.",
433 | "function": "tsfel.pk_pk_distance",
434 | "parameters": "",
435 | "n_features": 1,
436 | "use": "yes"
437 | },
438 | "Root mean square": {
439 | "complexity": "constant",
440 | "description": "Computes root mean square of the signal.",
441 | "function": "tsfel.rms",
442 | "parameters": "",
443 | "n_features": 1,
444 | "use": "yes",
445 | "tag": [
446 | "emg",
447 | "inertial"
448 | ]
449 | },
450 | "Skewness": {
451 | "complexity": "constant",
452 | "description": "Computes skewness of the signal.",
453 | "function": "tsfel.skewness",
454 | "parameters": "",
455 | "n_features": 1,
456 | "use": "yes"
457 | },
458 | "Standard deviation": {
459 | "complexity": "constant",
460 | "description": "Computes standard deviation of the signal.",
461 | "function": "tsfel.calc_std",
462 | "parameters": "",
463 | "n_features": 1,
464 | "use": "yes"
465 | },
466 | "Variance": {
467 | "complexity": "constant",
468 | "description": "Computes variance of the signal.",
469 | "function": "tsfel.calc_var",
470 | "parameters": "",
471 | "n_features": 1,
472 | "use": "yes"
473 | }
474 | },
475 | "temporal": {
476 | "Area under the curve": {
477 | "complexity": "log",
478 | "description": "Computes the area under the curve of the signal computed with trapezoid rule.",
479 | "function": "tsfel.auc",
480 | "parameters": {
481 | "fs": 100
482 | },
483 | "n_features": 1,
484 | "use": "yes"
485 | },
486 | "Autocorrelation": {
487 | "complexity": "constant",
488 | "description": "Computes autocorrelation of the signal.",
489 | "function": "tsfel.autocorr",
490 | "parameters": "",
491 | "n_features": 1,
492 | "use": "yes",
493 | "tag": "inertial"
494 | },
495 | "Centroid": {
496 | "complexity": "constant",
497 | "description": "Computes the centroid along the time axis.",
498 | "function": "tsfel.calc_centroid",
499 | "parameters": {
500 | "fs": 100
501 | },
502 | "n_features": 1,
503 | "use": "yes"
504 | },
505 | "Mean absolute diff": {
506 | "complexity": "constant",
507 | "description": "Computes mean absolute differences of the signal.",
508 | "function": "tsfel.mean_abs_diff",
509 | "parameters": "",
510 | "n_features": 1,
511 | "use": "yes"
512 | },
513 | "Mean diff": {
514 | "complexity": "constant",
515 | "description": "Computes mean of differences of the signal.",
516 | "function": "tsfel.mean_diff",
517 | "parameters": "",
518 | "n_features": 1,
519 | "use": "yes"
520 | },
521 | "Median absolute diff": {
522 | "complexity": "constant",
523 | "description": "Computes median absolute differences of the signal.",
524 | "function": "tsfel.median_abs_diff",
525 | "parameters": "",
526 | "n_features": 1,
527 | "use": "yes"
528 | },
529 | "Median diff": {
530 | "complexity": "constant",
531 | "description": "Computes median of differences of the signal.",
532 | "function": "tsfel.median_diff",
533 | "parameters": "",
534 | "n_features": 1,
535 | "use": "yes"
536 | },
537 | "Negative turning points": {
538 | "complexity": "constant",
539 | "description": "Computes number of negative turning points of the signal.",
540 | "function": "tsfel.negative_turning",
541 | "parameters": "",
542 | "n_features": 1,
543 | "use": "yes",
544 | "tag": "emg"
545 | },
546 | "Neighbourhood peaks": {
547 | "complexity": "constant",
548 | "description": "Computes the number of peaks from a defined neighbourhood of the signal.",
549 | "function": "tsfel.neighbourhood_peaks",
550 | "parameters": {
551 | "n": 10
552 | },
553 | "n_features": 1,
554 | "use": "yes"
555 | },
556 | "Positive turning points": {
557 | "complexity": "constant",
558 | "description": "Computes number of positive turning points of the signal.",
559 | "function": "tsfel.positive_turning",
560 | "parameters": "",
561 | "n_features": 1,
562 | "use": "yes",
563 | "tag": "emg"
564 | },
565 | "Signal distance": {
566 | "complexity": "constant",
567 | "description": "Computes signal traveled distance.",
568 | "function": "tsfel.distance",
569 | "parameters": "",
570 | "n_features": 1,
571 | "use": "yes"
572 | },
573 | "Slope": {
574 | "complexity": "log",
575 | "description": "Computes the slope of the signal by fitting a linear equation to the observed data.",
576 | "function": "tsfel.slope",
577 | "parameters": "",
578 | "n_features": 1,
579 | "use": "yes"
580 | },
581 | "Sum absolute diff": {
582 | "complexity": "constant",
583 | "description": "Computes sum of absolute differences of the signal.",
584 | "function": "tsfel.sum_abs_diff",
585 | "parameters": "",
586 | "n_features": 1,
587 | "use": "yes"
588 | },
589 | "Zero crossing rate": {
590 | "complexity": "constant",
591 | "description": "Computes Zero-crossing rate of the signal.",
592 | "function": "tsfel.zero_cross",
593 | "parameters": "",
594 | "n_features": 1,
595 | "use": "yes",
596 | "tag": [
597 | "audio",
598 | "emg"
599 | ]
600 | }
601 | }
602 | }
--------------------------------------------------------------------------------
/prompt_bank/stat-prompt/tsfel/feature_extraction/features_settings.py:
--------------------------------------------------------------------------------
1 | import json
2 | import tsfel
3 | import numpy as np
4 |
5 |
6 | def load_json(json_path):
7 | """Loads the json file given by filename.
8 |
9 | Parameters
10 | ----------
11 | json_path : string
12 | Json path
13 |
14 | Returns
15 | -------
16 | Dict
17 | Dictionary
18 |
19 | """
20 |
21 | return json.load(open(json_path))
22 |
23 |
24 | def get_features_by_domain(domain=None, json_path=None):
25 | """Creates a dictionary with the features settings by domain.
26 |
27 | Parameters
28 | ----------
29 | domain : string
30 | Available domains: "statistical"; "spectral"; "temporal"
31 | If domain equals None, then the features settings from all domains are returned.
32 | json_path : string
33 | Directory of json file. Default: package features.json directory
34 |
35 | Returns
36 | -------
37 | Dict
38 | Dictionary with the features settings
39 |
40 | """
41 |
42 | if json_path is None:
43 | json_path = tsfel.__path__[0] + "/feature_extraction/features.json"
44 |
45 | if domain not in ['statistical', 'temporal', 'spectral', None]:
46 | raise SystemExit(
47 | 'No valid domain. Choose: statistical, temporal, spectral or None (for all feature settings).')
48 |
49 | dict_features = load_json(json_path)
50 | if domain is None:
51 | return dict_features
52 | else:
53 | return {domain: dict_features[domain]}
54 |
55 |
56 | def get_features_by_tag(tag=None, json_path=None):
57 | """Creates a dictionary with the features settings by tag.
58 |
59 | Parameters
60 | ----------
61 | tag : string
62 | Available tags: "audio"; "inertial", "ecg"; "eeg"; "emg".
63 | If tag equals None then, all available features are returned.
64 | json_path : string
65 | Directory of json file. Default: package features.json directory
66 |
67 | Returns
68 | -------
69 | Dict
70 | Dictionary with the features settings
71 |
72 | """
73 | if json_path is None:
74 | json_path = tsfel.__path__[0] + "/feature_extraction/features.json"
75 |
76 | if tag not in ["audio", "inertial", "ecg", "eeg", "emg", None]:
77 | raise SystemExit(
78 | "No valid tag. Choose: audio, inertial, ecg, eeg, emg or None.")
79 | features_tag = {}
80 | dict_features = load_json(json_path)
81 | if tag is None:
82 | return dict_features
83 | else:
84 | for domain in dict_features:
85 | features_tag[domain] = {}
86 | for feat in dict_features[domain]:
87 | if dict_features[domain][feat]["use"] == "no":
88 | continue
89 | # Check if tag is defined
90 | try:
91 | js_tag = dict_features[domain][feat]["tag"]
92 | if isinstance(js_tag, list):
93 | if any([tag in js_t for js_t in js_tag]):
94 | features_tag[domain].update({feat: dict_features[domain][feat]})
95 | elif js_tag == tag:
96 | features_tag[domain].update({feat: dict_features[domain][feat]})
97 | except KeyError:
98 | continue
99 | # To remove empty dicts
100 | return dict([[d, features_tag[d]] for d in list(features_tag.keys()) if bool(features_tag[d])])
101 |
102 |
103 | def get_number_features(dict_features):
104 | """Count the total number of features based on input parameters of each feature
105 |
106 | Parameters
107 | ----------
108 | dict_features : dict
109 | Dictionary with features settings
110 |
111 | Returns
112 | -------
113 | int
114 | Feature vector size
115 | """
116 | number_features = 0
117 | for domain in dict_features:
118 | for feat in dict_features[domain]:
119 | if dict_features[domain][feat]["use"] == "no":
120 | continue
121 | n_feat = dict_features[domain][feat]["n_features"]
122 |
123 | if isinstance(n_feat, int):
124 | number_features += n_feat
125 | else:
126 | n_feat_param = dict_features[domain][feat]["parameters"][n_feat]
127 | if isinstance(n_feat_param, int):
128 | number_features += n_feat_param
129 | else:
130 | number_features += eval("len(" + n_feat_param + ")")
131 |
132 | return number_features
133 |
--------------------------------------------------------------------------------
/prompt_bank/stat-prompt/tsfel/feature_extraction/features_utils.py:
--------------------------------------------------------------------------------
1 | import scipy
2 | import numpy as np
3 |
4 |
5 | def set_domain(key, value):
6 | def decorate_func(func):
7 | setattr(func, key, value)
8 | return func
9 |
10 | return decorate_func
11 |
12 |
13 | def compute_time(signal, fs):
14 | """Creates the signal correspondent time array.
15 |
16 | Parameters
17 | ----------
18 | signal: nd-array
19 | Input from which the time is computed.
20 | fs: int
21 | Sampling Frequency
22 |
23 | Returns
24 | -------
25 | time : float list
26 | Signal time
27 |
28 | """
29 |
30 | return np.arange(0, len(signal))/fs
31 |
32 |
33 | def calc_fft(signal, fs):
34 | """ This functions computes the fft of a signal.
35 |
36 | Parameters
37 | ----------
38 | signal : nd-array
39 | The input signal from which fft is computed
40 | fs : float
41 | Sampling frequency
42 |
43 | Returns
44 | -------
45 | f: nd-array
46 | Frequency values (xx axis)
47 | fmag: nd-array
48 | Amplitude of the frequency values (yy axis)
49 |
50 | """
51 |
52 | fmag = np.abs(np.fft.rfft(signal))
53 | f = np.fft.rfftfreq(len(signal), d=1/fs)
54 |
55 | return f.copy(), fmag.copy()
56 |
57 |
58 | def filterbank(signal, fs, pre_emphasis=0.97, nfft=512, nfilt=40):
59 | """Computes the MEL-spaced filterbank.
60 |
61 | It provides the information about the power in each frequency band.
62 |
63 | Implementation details and description on:
64 | https://www.kaggle.com/ilyamich/mfcc-implementation-and-tutorial
65 | https://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html#fnref:1
66 |
67 | Parameters
68 | ----------
69 | signal : nd-array
70 | Input from which filterbank is computed
71 | fs : float
72 | Sampling frequency
73 | pre_emphasis : float
74 | Pre-emphasis coefficient for pre-emphasis filter application
75 | nfft : int
76 | Number of points of fft
77 | nfilt : int
78 | Number of filters
79 |
80 | Returns
81 | -------
82 | nd-array
83 | MEL-spaced filterbank
84 |
85 | """
86 |
87 | # Signal is already a window from the original signal, so no frame is needed.
88 | # According to the references it is needed the application of a window function such as
89 | # hann window. However if the signal windows don't have overlap, we will lose information,
90 | # as the application of a hann window will overshadow the windows signal edges.
91 |
92 | # pre-emphasis filter to amplify the high frequencies
93 |
94 | emphasized_signal = np.append(np.array(signal)[0], np.array(signal[1:]) - pre_emphasis * np.array(signal[:-1]))
95 |
96 | # Fourier transform and Power spectrum
97 | mag_frames = np.absolute(np.fft.rfft(emphasized_signal, nfft)) # Magnitude of the FFT
98 |
99 | pow_frames = ((1.0 / nfft) * (mag_frames ** 2)) # Power Spectrum
100 |
101 | low_freq_mel = 0
102 | high_freq_mel = (2595 * np.log10(1 + (fs / 2) / 700)) # Convert Hz to Mel
103 | mel_points = np.linspace(low_freq_mel, high_freq_mel, nfilt + 2) # Equally spaced in Mel scale
104 | hz_points = (700 * (10 ** (mel_points / 2595) - 1)) # Convert Mel to Hz
105 | filter_bin = np.floor((nfft + 1) * hz_points / fs)
106 |
107 | fbank = np.zeros((nfilt, int(np.floor(nfft / 2 + 1))))
108 | for m in range(1, nfilt + 1):
109 |
110 | f_m_minus = int(filter_bin[m - 1]) # left
111 | f_m = int(filter_bin[m]) # center
112 | f_m_plus = int(filter_bin[m + 1]) # right
113 |
114 | for k in range(f_m_minus, f_m):
115 | fbank[m - 1, k] = (k - filter_bin[m - 1]) / (filter_bin[m] - filter_bin[m - 1])
116 | for k in range(f_m, f_m_plus):
117 | fbank[m - 1, k] = (filter_bin[m + 1] - k) / (filter_bin[m + 1] - filter_bin[m])
118 |
119 | # Area Normalization
120 | # If we don't normalize the noise will increase with frequency because of the filter width.
121 | enorm = 2.0 / (hz_points[2:nfilt + 2] - hz_points[:nfilt])
122 | fbank *= enorm[:, np.newaxis]
123 |
124 | filter_banks = np.dot(pow_frames, fbank.T)
125 | filter_banks = np.where(filter_banks == 0, np.finfo(float).eps, filter_banks) # Numerical Stability
126 | filter_banks = 20 * np.log10(filter_banks) # dB
127 |
128 | return filter_banks
129 |
130 |
131 | def autocorr_norm(signal):
132 | """Computes the autocorrelation.
133 |
134 | Implementation details and description in:
135 | https://ccrma.stanford.edu/~orchi/Documents/speaker_recognition_report.pdf
136 |
137 | Parameters
138 | ----------
139 | signal : nd-array
140 | Input from linear prediction coefficients are computed
141 |
142 | Returns
143 | -------
144 | nd-array
145 | Autocorrelation result
146 |
147 | """
148 |
149 | variance = np.var(signal)
150 | signal = np.copy(signal - signal.mean())
151 | r = scipy.signal.correlate(signal, signal)[-len(signal):]
152 |
153 | if (signal == 0).all():
154 | return np.zeros(len(signal))
155 |
156 | acf = r / variance / len(signal)
157 |
158 | return acf
159 |
160 |
161 | def create_symmetric_matrix(acf, order=11):
162 | """Computes a symmetric matrix.
163 |
164 | Implementation details and description in:
165 | https://ccrma.stanford.edu/~orchi/Documents/speaker_recognition_report.pdf
166 |
167 | Parameters
168 | ----------
169 | acf : nd-array
170 | Input from which a symmetric matrix is computed
171 | order : int
172 | Order
173 |
174 | Returns
175 | -------
176 | nd-array
177 | Symmetric Matrix
178 |
179 | """
180 |
181 | smatrix = np.empty((order, order))
182 | xx = np.arange(order)
183 | j = np.tile(xx, order)
184 | i = np.repeat(xx, order)
185 | smatrix[i, j] = acf[np.abs(i - j)]
186 |
187 | return smatrix
188 |
189 |
190 | def lpc(signal, n_coeff=12):
191 | """Computes the linear prediction coefficients.
192 |
193 | Implementation details and description in:
194 | https://ccrma.stanford.edu/~orchi/Documents/speaker_recognition_report.pdf
195 |
196 | Parameters
197 | ----------
198 | signal : nd-array
199 | Input from linear prediction coefficients are computed
200 | n_coeff : int
201 | Number of coefficients
202 |
203 | Returns
204 | -------
205 | nd-array
206 | Linear prediction coefficients
207 |
208 | """
209 |
210 | if signal.ndim > 1:
211 | raise ValueError("Only 1 dimensional arrays are valid")
212 | if n_coeff > signal.size:
213 | raise ValueError("Input signal must have a length >= n_coeff")
214 |
215 | # Calculate the order based on the number of coefficients
216 | order = n_coeff - 1
217 |
218 | # Calculate LPC with Yule-Walker
219 | acf = np.correlate(signal, signal, 'full')
220 |
221 | r = np.zeros(order+1, 'float32')
222 | # Assuring that works for all type of input lengths
223 | nx = np.min([order+1, len(signal)])
224 | r[:nx] = acf[len(signal)-1:len(signal)+order]
225 |
226 | smatrix = create_symmetric_matrix(r[:-1], order)
227 |
228 | if np.sum(smatrix) == 0:
229 | return tuple(np.zeros(order+1))
230 |
231 | lpc_coeffs = np.dot(np.linalg.inv(smatrix), -r[1:])
232 |
233 | return tuple(np.concatenate(([1.], lpc_coeffs)))
234 |
235 |
236 | def create_xx(features):
237 | """Computes the range of features amplitude for the probability density function calculus.
238 |
239 | Parameters
240 | ----------
241 | features : nd-array
242 | Input features
243 |
244 | Returns
245 | -------
246 | nd-array
247 | range of features amplitude
248 |
249 | """
250 |
251 | features_ = np.copy(features)
252 |
253 | if max(features_) < 0:
254 | max_f = - max(features_)
255 | min_f = min(features_)
256 | else:
257 | min_f = min(features_)
258 | max_f = max(features_)
259 |
260 | if min(features_) == max(features_):
261 | xx = np.linspace(min_f, min_f + 10, len(features_))
262 | else:
263 | xx = np.linspace(min_f, max_f, len(features_))
264 |
265 | return xx
266 |
267 |
268 | def kde(features):
269 | """Computes the probability density function of the input signal using a Gaussian KDE (Kernel Density Estimate)
270 |
271 | Parameters
272 | ----------
273 | features : nd-array
274 | Input from which probability density function is computed
275 |
276 | Returns
277 | -------
278 | nd-array
279 | probability density values
280 |
281 | """
282 | features_ = np.copy(features)
283 | xx = create_xx(features_)
284 |
285 | if min(features_) == max(features_):
286 | noise = np.random.randn(len(features_)) * 0.0001
287 | features_ = np.copy(features_ + noise)
288 |
289 | kernel = scipy.stats.gaussian_kde(features_, bw_method='silverman')
290 |
291 | return np.array(kernel(xx) / np.sum(kernel(xx)))
292 |
293 |
294 | def gaussian(features):
295 | """Computes the probability density function of the input signal using a Gaussian function
296 |
297 | Parameters
298 | ----------
299 | features : nd-array
300 | Input from which probability density function is computed
301 | Returns
302 | -------
303 | nd-array
304 | probability density values
305 |
306 | """
307 |
308 | features_ = np.copy(features)
309 |
310 | xx = create_xx(features_)
311 | std_value = np.std(features_)
312 | mean_value = np.mean(features_)
313 |
314 | if std_value == 0:
315 | return 0.0
316 | pdf_gauss = scipy.stats.norm.pdf(xx, mean_value, std_value)
317 |
318 | return np.array(pdf_gauss / np.sum(pdf_gauss))
319 |
320 |
321 | def wavelet(signal, function=scipy.signal.ricker, widths=np.arange(1, 10)):
322 | """Computes CWT (continuous wavelet transform) of the signal.
323 |
324 | Parameters
325 | ----------
326 | signal : nd-array
327 | Input from which CWT is computed
328 | function : wavelet function
329 | Default: scipy.signal.ricker
330 | widths : nd-array
331 | Widths to use for transformation
332 | Default: np.arange(1,10)
333 |
334 | Returns
335 | -------
336 | nd-array
337 | The result of the CWT along the time axis
338 | matrix with size (len(widths),len(signal))
339 |
340 | """
341 |
342 | if isinstance(function, str):
343 | function = eval(function)
344 |
345 | if isinstance(widths, str):
346 | widths = eval(widths)
347 |
348 | cwt = scipy.signal.cwt(signal, function, widths)
349 |
350 | return cwt
351 |
352 |
353 | def calc_ecdf(signal):
354 | """Computes the ECDF of the signal.
355 |
356 | Parameters
357 | ----------
358 | signal : nd-array
359 | Input from which ECDF is computed
360 | Returns
361 | -------
362 | nd-array
363 | Sorted signal and computed ECDF.
364 |
365 | """
366 | return np.sort(signal), np.arange(1, len(signal)+1)/len(signal)
367 |
368 |
--------------------------------------------------------------------------------
/prompt_bank/stat-prompt/tsfel/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from tsfel.utils.calculate_complexity import *
2 | from tsfel.utils.signal_processing import *
3 | from tsfel.utils.add_personal_features import *
4 | from tsfel.utils.progress_bar import *
5 |
--------------------------------------------------------------------------------
/prompt_bank/stat-prompt/tsfel/utils/add_personal_features.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import inspect
3 | import json
4 | import os
5 | import sys
6 | import warnings
7 | from inspect import getmembers, isfunction
8 |
9 | from tsfel.feature_extraction.features_settings import load_json
10 | from tsfel.utils.calculate_complexity import compute_complexity
11 |
12 |
13 | def add_feature_json(features_path, json_path):
14 | """Adds new feature to features.json.
15 |
16 | Parameters
17 | ----------
18 | features_path: string
19 | Personal Python module directory containing new features implementation.
20 |
21 | json_path: string
22 | Personal .json file directory containing existing features from TSFEL.
23 | New customised features will be added to file in this directory.
24 |
25 | """
26 |
27 | sys.path.append(features_path[:-len(features_path.split(os.sep)[-1]) - 1])
28 | exec("import " + features_path.split(os.sep)[-1][:-3])
29 |
30 | # Reload module containing the new features
31 | importlib.reload(sys.modules[features_path.split(os.sep)[-1][:-3]])
32 | exec("import " + features_path.split(os.sep)[-1][:-3] + " as pymodule")
33 |
34 | # Functions from module containing the new features
35 | functions_list = [o for o in getmembers(locals()['pymodule']) if isfunction(o[1])]
36 | function_names = [fname[0] for fname in functions_list]
37 |
38 | # Check if @set_domain was declared on features module
39 | vset_domain = False
40 |
41 | for fname, f in list(locals()['pymodule'].__dict__.items()):
42 |
43 | if getattr(f, "domain", None) is not None:
44 |
45 | vset_domain = True
46 |
47 | # Access to personal features.json
48 | feat_json = load_json(json_path)
49 |
50 | # Assign domain and tag
51 | domain = getattr(f, "domain", None)
52 | tag = getattr(f, "tag", None)
53 |
54 | # Feature specifications
55 | # Description
56 | if f.__doc__ is not None:
57 | descrip = f.__doc__.split("\n")[0]
58 | else:
59 | descrip = ""
60 | # Feature usage
61 | use = "yes"
62 | # Feature function arguments
63 | args_name = inspect.getfullargspec(f)[0]
64 |
65 | # Access feature parameters
66 | if args_name != "":
67 | # Retrieve default values of arguments
68 | spec = inspect.getfullargspec(f)
69 | defaults = dict(zip(spec.args[::-1], (spec.defaults or ())[::-1]))
70 | defaults.update(spec.kwonlydefaults or {})
71 |
72 | for p in args_name[1:]:
73 | if p not in list(defaults.keys()):
74 | if p == 'fs':
75 | # Assigning a default value for fs if not given
76 | defaults[p] = 100
77 | else:
78 | defaults[p] = None
79 | if len(defaults) == 0:
80 | defaults = ""
81 | else:
82 | defaults = ""
83 |
84 | # Settings of new feature
85 | new_feature = {"description": descrip,
86 | "parameters": defaults,
87 | "function": fname,
88 | "use": use
89 | }
90 |
91 | # Check if domain exists
92 | try:
93 | feat_json[domain][fname] = new_feature
94 | except KeyError:
95 | feat_json[domain] = {fname: new_feature}
96 |
97 | # Insert tag if it is declared
98 | if tag is not None:
99 | feat_json[domain][fname]['tag'] = tag
100 |
101 | # Write new feature on json file
102 | with open(json_path, "w") as fout:
103 | json.dump(feat_json, fout, indent=" ")
104 |
105 | # Calculate feature complexity
106 | compute_complexity(fname, domain, json_path, features_path=features_path)
107 | print('Feature '+str(fname)+' was added.')
108 |
109 | if vset_domain is False:
110 | warnings.warn('No features were added. Please declare @set_domain.', stacklevel=2)
111 |
112 |
113 |
--------------------------------------------------------------------------------
/prompt_bank/stat-prompt/tsfel/utils/calculate_complexity.py:
--------------------------------------------------------------------------------
1 | import time
2 | import json
3 | import numpy as np
4 | from scipy.optimize import curve_fit
5 | from tsfel.feature_extraction.features_settings import load_json
6 | from tsfel.feature_extraction.calc_features import calc_window_features
7 |
8 |
9 | # curves
10 | def n_squared(x, no):
11 | """The model function"""
12 | return no * x ** 2
13 |
14 |
15 | def n_nlog(x, no):
16 | """The model function"""
17 | return no * x * np.log(x)
18 |
19 |
20 | def n_linear(x, no):
21 | """The model function"""
22 | return no * x
23 |
24 |
25 | def n_log(x, no):
26 | """The model function"""
27 | return no * np.log(x)
28 |
29 |
30 | def n_constant(x, no):
31 | """The model function"""
32 | return np.zeros(len(x)) + no
33 |
34 |
35 | def find_best_curve(t, signal):
36 | """Finds the best curve.
37 |
38 | Parameters
39 | ----------
40 | t : nd-array
41 | Log space
42 | signal : nd-array
43 | Mean execution time array
44 |
45 | Returns
46 | -------
47 | str
48 | Best fit curve name
49 |
50 | """
51 |
52 | all_chisq = []
53 | list_curves = [n_squared, n_nlog, n_linear, n_log, n_constant]
54 | all_curves = []
55 | # Model parameters
56 | stdev = 2
57 | sig = np.zeros(len(signal)) + stdev
58 |
59 | # Fit the curve
60 | for curve in list_curves:
61 | start = 1
62 | popt, pcov = curve_fit(curve, t, signal, sigma=sig, p0=start, absolute_sigma=True)
63 |
64 | # Compute chi square
65 | nexp = curve(t, *popt)
66 | r = signal - nexp
67 | chisq = np.sum((r / stdev) ** 2)
68 | all_chisq.append(chisq)
69 | all_curves.append(nexp)
70 |
71 | idx_best = np.argmin(all_chisq)
72 |
73 | curve_name = str(list_curves[idx_best])
74 | idx1 = curve_name.find("n_")
75 | idx2 = curve_name.find("at")
76 | curve_name = curve_name[idx1 + 2:idx2 - 1]
77 |
78 | return curve_name
79 |
80 |
81 | def compute_complexity(feature, domain, json_path, **kwargs):
82 | """Computes the feature complexity.
83 |
84 | Parameters
85 | ----------
86 | feature : string
87 | Feature name
88 | domain : string
89 | Feature domain
90 | json_path: json
91 | Features json file
92 | \**kwargs:
93 | See below:
94 | * *features_path* (``string``) --
95 | Directory of script with personal features
96 |
97 | Returns
98 | -------
99 | int
100 | Feature complexity
101 |
102 | Writes complexity in json file
103 |
104 | """
105 |
106 | dictionary = load_json(json_path)
107 |
108 | features_path = kwargs.get('features_path', None)
109 |
110 | # The inputs from this function should be replaced by a dictionary
111 | one_feat_dict = {domain: {feature: dictionary[domain][feature]}}
112 |
113 | t = np.logspace(3.0, 5.0, 6)
114 | signal, s = [], []
115 | f = 0.05
116 | x = np.arange(0, t[-1] + 1, 1)
117 | fs = 100
118 | wave = np.sin(2 * np.pi * f * x / fs)
119 |
120 | for ti in t:
121 | for _ in range(20):
122 |
123 | start = time.time()
124 | calc_window_features(one_feat_dict, wave[:int(ti)], fs, features_path=features_path)
125 | end = time.time()
126 |
127 | s += [end - start]
128 |
129 | signal += [np.mean(s)]
130 |
131 | curve_name = find_best_curve(t, signal)
132 | dictionary[domain][feature]['complexity'] = curve_name
133 |
134 | with open(json_path, "w") as write_file:
135 | json.dump(dictionary, write_file, indent=4, sort_keys=True)
136 |
137 | if curve_name == 'constant' or curve_name == 'log':
138 | return 1
139 | elif curve_name == 'linear':
140 | return 2
141 | elif curve_name == 'nlog' or curve_name == 'squared':
142 | return 3
143 | else:
144 | return 0
145 |
--------------------------------------------------------------------------------
/prompt_bank/stat-prompt/tsfel/utils/progress_bar.py:
--------------------------------------------------------------------------------
1 | from IPython.display import HTML
2 | from IPython import get_ipython
3 |
4 |
5 | def progress_bar_terminal(iteration, total, prefix="", suffix="", decimals=0, length=100, fill="█", printend="\r"):
6 | """Call in a loop to create terminal progress bar.
7 |
8 | Parameters
9 | ----------
10 | iteration: int
11 | current iteration
12 | total: int
13 | total iterations
14 | prefix: str
15 | prefix string
16 | suffix: str
17 | suffix string
18 | decimals: int
19 | positive number of decimals in percent complete
20 | length: int
21 | character length of bar
22 | fill: str
23 | bar fill character
24 | printend: str
25 | end character (e.g. "\r", "\r\n")
26 | """
27 |
28 | percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total)))
29 | filledlength = int(length * iteration // total)
30 | bar = fill * filledlength + "-" * (length - filledlength)
31 | print("\r%s |%s| %s%% %s" % (prefix, bar, percent, suffix), end=printend)
32 | # Print New Line on Complete
33 | if iteration == total:
34 | print()
35 |
36 |
37 | def progress_bar_notebook(iteration, total=100):
38 | """Progress bar for notebooks.
39 |
40 | Parameters
41 | ----------
42 | iteration: int
43 | current iteration
44 | total: int
45 | total iterations
46 |
47 | Returns
48 | -------
49 | Progress bar for notebooks
50 |
51 | """
52 | result = int((iteration / total) * 100)
53 | return HTML(
54 | """
55 |
56 | Progress: {result}% Complete
57 |
58 |
65 |
66 | """.format(
67 | value=iteration, max_value=total, result=result
68 | )
69 | )
70 |
71 |
72 | def display_progress_bar(iteration, total, out):
73 | """Displays progress bar according to python interface.
74 |
75 | Parameters
76 | ----------
77 | iteration: int
78 | current iteration
79 | total: int
80 | total iterations
81 | out: progress bar notebook output
82 |
83 | """
84 |
85 | if (
86 | (get_ipython().__class__.__name__ == "ZMQInteractiveShell")
87 | or (get_ipython().__class__.__name__ == "Shell")
88 | and out is not None
89 | ):
90 | out.update(progress_bar_notebook(iteration + 1, total))
91 | else:
92 | progress_bar_terminal(iteration + 1, total, prefix="Progress:", suffix="Complete", length=50)
93 | return
--------------------------------------------------------------------------------
/prompt_bank/stat-prompt/tsfel/utils/signal_processing.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | from scipy.interpolate import interp1d
4 |
5 |
6 | def signal_window_splitter(signal, window_size, overlap=0):
7 | """Splits the signal into windows
8 | Parameters
9 | ----------
10 | signal : nd-array or pandas DataFrame
11 | input signal
12 | window_size : int
13 | number of points of window size
14 | overlap : float
15 | percentage of overlap, value between 0 and 1 (exclusive)
16 | Default: 0
17 | Returns
18 | -------
19 | list
20 | list of signal windows
21 | """
22 | if not isinstance(window_size, int):
23 | raise SystemExit('window_size must be an integer.')
24 | step = int(round(window_size)) if overlap == 0 else int(round(window_size * (1 - overlap)))
25 | if step == 0:
26 | raise SystemExit('Invalid overlap. '
27 | 'Choose a lower overlap value.')
28 | if len(signal) % window_size == 0 and overlap == 0:
29 | return [signal[i:i + window_size] for i in range(0, len(signal), step)]
30 | else:
31 | return [signal[i:i + window_size] for i in range(0, len(signal) - window_size + 1, step)]
32 |
33 |
34 | def merge_time_series(data, fs_resample, time_unit):
35 | """Time series data interpolation
36 |
37 | Parameters
38 | ----------
39 | data : dict
40 | data to interpolate
41 | fs_resample :
42 | resample sampling frequency
43 | time_unit :
44 | time unit in seconds
45 |
46 | Returns
47 | -------
48 | DataFrame
49 | Interpolated data
50 |
51 | """
52 |
53 | # time interval for interpolation
54 | sensors_time = np.array([[dn.iloc[0, 0], dn.iloc[-1, 0]] for k, dn in data.items()])
55 | t0 = np.max(sensors_time[:, 0])
56 | tn = np.min(sensors_time[:, 1])
57 | x_new = np.linspace(t0, tn, int((tn - t0) / ((1 / fs_resample) * time_unit)))
58 |
59 | # interpolation
60 | data_new = np.copy(x_new.reshape(len(x_new), 1))
61 | header_values = ['time']
62 | for k, dn in data.items():
63 | header_values += [k + str(i) for i in range(1, np.shape(dn)[1])]
64 | data_new = np.hstack((data_new, np.array([interp1d(dn.iloc[:, 0], dn.iloc[:, ax])(x_new) for ax in range(1, np.shape(dn)[1])]).T))
65 |
66 | return pd.DataFrame(data=data_new[:, 1:], columns=header_values[1:])
67 |
68 |
69 | def correlated_features(features, threshold=0.95):
70 | """Compute pairwise correlation of features using pearson method
71 |
72 | Parameters
73 | ----------
74 | features : DataFrame
75 | features
76 | threshold :
77 | correlation value for removing highly correlated features
78 | Returns
79 | -------
80 | DataFrame
81 | correlated features names
82 |
83 | """
84 | corr_matrix = features.corr().abs()
85 | # Select upper triangle of correlation matrix
86 | upper = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))
87 | # Find index and column name of features with correlation greater than 0.95
88 | to_drop = [column for column in upper.columns if any(upper[column] > threshold)]
89 |
90 | return to_drop
91 |
--------------------------------------------------------------------------------
/prompt_bank/text_prompt_data_csv/csv_prompt.json:
--------------------------------------------------------------------------------
1 | {
2 | "0": "The Electricity Transformer Temperature (ETT) is a crucial indicator in the electric power long-term deployment. This dataset consists of 2 years data from two separated counties in China. To explore the granularity on the Long sequence time-series forecasting (LSTF) problem, different subsets are created, {ETTh1, ETTh2} for 1-hour-level and ETTm1 for 15-minutes-level. Each data point consists of the target value ”oil temperature” and 6 power load features. The train/val/test is 12/4/4 months.",
3 | "1": "The Electricity Transformer Temperature (ETT) is a crucial indicator in the electric power long-term deployment. This dataset consists of 2 years data from two separated counties in China. To explore the granularity on the Long sequence time-series forecasting (LSTF) problem, different subsets are created, {ETTh1, ETTh2} for 1-hour-level and ETTm1 for 15-minutes-level. Each data point consists of the target value ”oil temperature” and 6 power load features. The train/val/test is 12/4/4 months.",
4 | "2": "The Electricity Transformer Temperature (ETT) is a crucial indicator in the electric power long-term deployment. This dataset consists of 2 years data from two separated counties in China. To explore the granularity on the Long sequence time-series forecasting (LSTF) problem, different subsets are created, {ETTh1, ETTh2} for 1-hour-level and ETTm1 for 15-minutes-level. Each data point consists of the target value ”oil temperature” and 6 power load features. The train/val/test is 12/4/4 months.",
5 | "3": "The Electricity Transformer Temperature (ETT) is a crucial indicator in the electric power long-term deployment. This dataset consists of 2 years data from two separated counties in China. To explore the granularity on the Long sequence time-series forecasting (LSTF) problem, different subsets are created, {ETTh1, ETTh2} for 1-hour-level and ETTm1 for 15-minutes-level. Each data point consists of the target value ”oil temperature” and 6 power load features. The train/val/test is 12/4/4 months.",
6 | "4": "Electricity contains electircity consumption of 321 clients from 2012 to 2014. And the data was converted to reflect hourly consumption.",
7 | "5": "Exchange rate is a collection of the daily exchange rates of eight foreign countries ranging from 1990 to 2016.",
8 | "6": "Traffic is a collection of hourly data from California Department of Transportation, which describes the road occupancy rates measured by different sensors on San Francisco Bay area freeways.",
9 | "7": "Weather is recorded every 10 minutes for the 2020 whole year, which contains 21 meteorological indicators, such as air temperature, humidity, etc."
10 | }
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.18.0
2 | charset-normalizer==3.1.0
3 | cmake==3.26.3
4 | contourpy==1.0.7
5 | cycler==0.11.0
6 | einops==0.6.0
7 | filelock==3.12.0
8 | fonttools==4.39.3
9 | huggingface-hub==0.13.4
10 | idna==3.4
11 | importlib-resources==5.12.0
12 | Jinja2==3.1.2
13 | joblib==1.2.0
14 | kiwisolver==1.4.4
15 | lit==16.0.1
16 | MarkupSafe==2.1.2
17 | matplotlib==3.7.1
18 | mpmath==1.3.0
19 | networkx==3.1
20 | numpy==1.24.2
21 | nvidia-cublas-cu11==11.10.3.66
22 | nvidia-cuda-cupti-cu11==11.7.101
23 | nvidia-cuda-nvrtc-cu11==11.7.99
24 | nvidia-cuda-runtime-cu11==11.7.99
25 | nvidia-cudnn-cu11==8.5.0.96
26 | nvidia-cufft-cu11==10.9.0.58
27 | nvidia-curand-cu11==10.2.10.91
28 | nvidia-cusolver-cu11==11.4.0.1
29 | nvidia-cusparse-cu11==11.7.4.91
30 | nvidia-nccl-cu11==2.14.3
31 | nvidia-nvtx-cu11==11.7.91
32 | packaging==23.1
33 | pandas==2.0.0
34 | Pillow==9.5.0
35 | pyparsing==3.0.9
36 | python-dateutil==2.8.2
37 | pytz==2023.3
38 | PyYAML==6.0
39 | regex==2023.3.23
40 | requests==2.28.2
41 | scikit-learn==1.2.2
42 | scipy==1.10.1
43 | six==1.16.0
44 | sympy==1.11.1
45 | threadpoolctl==3.1.0
46 | tokenizers==0.13.3
47 | torch==2.0.0
48 | tqdm==4.65.0
49 | transformers==4.28.1
50 | triton==2.0.0
51 | typing_extensions==4.5.0
52 | tzdata==2023.3
53 | urllib3==1.26.15
54 | zipp==3.15.0
55 |
--------------------------------------------------------------------------------
/scripts/test_csv_lora.sh:
--------------------------------------------------------------------------------
1 | TRAIN="datasets/ETT-small/ETTh1.csv
2 | datasets/ETT-small/ETTh2.csv
3 | datasets/ETT-small/ETTm1.csv
4 | datasets/ETT-small/ETTm2.csv
5 | datasets/electricity/electricity.csv
6 | datasets/exchange_rate/exchange_rate.csv
7 | datasets/traffic/traffic.csv
8 | datasets/weather/weather.csv"
9 |
10 | TEST="datasets/ETT-small/ETTh1.csv
11 | datasets/ETT-small/ETTh2.csv
12 | datasets/ETT-small/ETTm1.csv
13 | datasets/ETT-small/ETTm2.csv
14 | datasets/electricity/electricity.csv
15 | datasets/exchange_rate/exchange_rate.csv
16 | datasets/traffic/traffic.csv
17 | datasets/weather/weather.csv"
18 |
19 | PROMPT="prompt_bank/prompt_data_normalize_csv_split"
20 |
21 | epoch=500
22 | downsample_rate=20
23 | freeze=0
24 | OUTPUT_PATH="output/test_ltsm_lr${lr}_loraFalse_down${downsample_rate}_freeze${freeze}_e${epoch}_pred${pred_len}/"
25 |
26 | for pred_len in 96 192 336 720
27 | do
28 | for lr in 1e-3
29 | do
30 | for lora_dim in 32 64
31 | do
32 | CUDA_VISIBLE_DEVICES=1,2,3,4,5,6,7 python3 main_ltsm.py \
33 | --lora \
34 | --lora_dim ${lora_dim} \
35 | --model_id test_run \
36 | --train_epochs ${epoch} \
37 | --batch_size 800 \
38 | --pred_len ${pred_len} \
39 | --gradient_accumulation_steps 64 \
40 | --data_path ${TRAIN} \
41 | --test_data_path ${INIT_TEST} \
42 | --test_data_path_list ${TEST} \
43 | --prompt_data_path ${PROMPT} \
44 | --freeze ${freeze} \
45 | --learning_rate ${lr} \
46 | --downsample_rate ${downsample_rate} \
47 | --output_dir ${OUTPUT_PATH}
48 | done
49 | done
50 | done
51 |
--------------------------------------------------------------------------------
/scripts/test_ltsm.sh:
--------------------------------------------------------------------------------
1 | TRAIN="
2 | all_six_datasets/ETT-small/ETTh1.csv
3 | all_six_datasets/ETT-small/ETTh2.csv
4 | all_six_datasets/ETT-small/ETTm1.csv
5 | all_six_datasets/ETT-small/ETTm2.csv
6 | all_six_datasets/electricity/electricity.csv
7 | all_six_datasets/exchange_rate/exchange_rate.csv
8 | all_six_datasets/traffic/traffic.csv
9 | all_six_datasets/weather/weather.csv"
10 |
11 |
12 | TEST="
13 | all_six_datasets/ETT-small/ETTh1.csv
14 | all_six_datasets/ETT-small/ETTh2.csv
15 | all_six_datasets/ETT-small/ETTm1.csv
16 | all_six_datasets/ETT-small/ETTm2.csv
17 | all_six_datasets/electricity/electricity.csv
18 | all_six_datasets/exchange_rate/exchange_rate.csv
19 | all_six_datasets/traffic/traffic.csv
20 | all_six_datasets/weather/weather.csv"
21 |
22 | PROMPT="prompt_bank/prompt_data_normalize_csv_split"
23 | epoch=500
24 | downsample_rate=20
25 | freeze=0
26 | lr=1e-3
27 |
28 |
29 | for pred_len in 96
30 | do
31 |
32 | CUDA_VISIBLE_DEVICES=0,1 python3 main_ltsm.py \
33 | --model LTSM \
34 | --model_name_or_path gpt2-medium \
35 | --local_pretrain LSC2204/LTSM-bundle \
36 | --train_epochs ${epoch} \
37 | --batch_size 800 \
38 | --pred_len ${pred_len} \
39 | --gradient_accumulation_steps 64 \
40 | --data_path ${TRAIN} \
41 | --test_data_path_list ${TEST} \
42 | --prompt_data_path ${PROMPT} \
43 | --freeze ${freeze} \
44 | --learning_rate ${lr} \
45 | --downsample_rate ${downsample_rate} \
46 | --output_dir "output/ltsm_csv_medium_lr${lr}_loraFalse_down${downsample_rate}_freeze${freeze}_e${epoch}_pred${pred_len}/"\
47 | --eval 1
48 | done
49 |
--------------------------------------------------------------------------------
/scripts/train_ltsm_csv.sh:
--------------------------------------------------------------------------------
1 | TRAIN="datasets/ETT-small/ETTh1.csv
2 | datasets/ETT-small/ETTh2.csv
3 | datasets/ETT-small/ETTm1.csv
4 | datasets/ETT-small/ETTm2.csv
5 | datasets/electricity/electricity.csv
6 | datasets/exchange_rate/exchange_rate.csv
7 | datasets/traffic/traffic.csv
8 | datasets/weather/weather.csv"
9 |
10 |
11 | TEST="datasets/ETT-small/ETTh1.csv
12 | datasets/ETT-small/ETTh2.csv
13 | datasets/ETT-small/ETTm1.csv
14 | datasets/ETT-small/ETTm2.csv
15 | datasets/electricity/electricity.csv
16 | datasets/exchange_rate/exchange_rate.csv
17 | datasets/traffic/traffic.csv
18 | datasets/weather/weather.csv"
19 |
20 | PROMPT="prompt_bank/prompt_data_normalize_split"
21 |
22 | epoch=1000
23 | downsample_rate=20
24 | freeze=0
25 | lr=1e-3
26 |
27 |
28 | for pred_len in 96 192 336 720
29 | do
30 | OUTPUT_PATH="output/ltsm_lr${lr}_loraFalse_down${downsample_rate}_freeze${freeze}_e${epoch}_pred${pred_len}/"
31 | echo "Current OUTPUT_PATH: ${OUTPUT_PATH}"
32 | CUDA_VISIBLE_DEVICES=0,1,2,3 python3 main_ltsm.py \
33 | --model LTSM \
34 | --model_name_or_path gpt2-medium \
35 | --train_epochs ${epoch} \
36 | --batch_size 100 \
37 | --pred_len ${pred_len} \
38 | --gradient_accumulation_steps 64 \
39 | --data_path ${TRAIN} \
40 | --test_data_path_list ${TEST} \
41 | --prompt_data_path ${PROMPT} \
42 | --freeze ${freeze} \
43 | --learning_rate ${lr} \
44 | --downsample_rate ${downsample_rate} \
45 | --output_dir ${OUTPUT_PATH}\
46 | --eval 0
47 | done
48 |
--------------------------------------------------------------------------------
/scripts/train_ltsm_textprompt_csv.sh:
--------------------------------------------------------------------------------
1 | TRAIN="datasets/ETT-small/ETTh1.csv
2 | datasets/ETT-small/ETTh2.csv
3 | datasets/ETT-small/ETTm1.csv
4 | datasets/ETT-small/ETTm2.csv
5 | datasets/electricity/electricity.csv
6 | datasets/exchange_rate/exchange_rate.csv
7 | datasets/traffic/traffic.csv
8 | datasets/weather/weather.csv"
9 |
10 | TEST="datasets/ETT-small/ETTh1.csv
11 | datasets/ETT-small/ETTh2.csv
12 | datasets/ETT-small/ETTm1.csv
13 | datasets/ETT-small/ETTm2.csv
14 | datasets/electricity/electricity.csv
15 | datasets/exchange_rate/exchange_rate.csv
16 | datasets/traffic/traffic.csv
17 | datasets/weather/weather.csv"
18 |
19 | PROMPT="prompt_bank/text_prompt_data_csv/csv_prompt.json"
20 | epoch=1000
21 | downsample_rate=20
22 | freeze=0
23 | lr=1e-3
24 |
25 |
26 | for pred_len in 96 192 336 720
27 | do
28 | OUTPUT_PATH="output/ltsm_textprompt_lr${lr}_loraFalse_down${downsample_rate}_freeze${freeze}_e${epoch}_pred${pred_len}/"
29 | CUDA_VISIBLE_DEVICES=0,1,2,3 python3 main_ltsm.py \
30 | --model LTSM_WordPrompt \
31 | --model_name_or_path gpt2-medium \
32 | --train_epochs ${epoch} \
33 | --batch_size 10 \
34 | --pred_len ${pred_len} \
35 | --gradient_accumulation_steps 64 \
36 | --data_path ${TRAIN} \
37 | --test_data_path_list ${TEST} \
38 | --prompt_data_path ${PROMPT} \
39 | --freeze ${freeze} \
40 | --learning_rate ${lr} \
41 | --downsample_rate ${downsample_rate} \
42 | --output_dir ${OUTPUT_PATH} \
43 | --eval 0
44 | done
45 |
--------------------------------------------------------------------------------
/scripts/train_ltsm_tokenizer_csv.sh:
--------------------------------------------------------------------------------
1 | TRAIN="datasets/ETT-small/ETTh1.csv
2 | datasets/ETT-small/ETTh2.csv
3 | datasets/ETT-small/ETTm1.csv
4 | datasets/ETT-small/ETTm2.csv
5 | datasets/electricity/electricity.csv
6 | datasets/exchange_rate/exchange_rate.csv
7 | datasets/traffic/traffic.csv
8 | datasets/weather/weather.csv"
9 |
10 | TEST="datasets/ETT-small/ETTh1.csv
11 | datasets/ETT-small/ETTh2.csv
12 | datasets/ETT-small/ETTm1.csv
13 | datasets/ETT-small/ETTm2.csv
14 | datasets/electricity/electricity.csv
15 | datasets/exchange_rate/exchange_rate.csv
16 | datasets/traffic/traffic.csv
17 | datasets/weather/weather.csv"
18 | PROMPT="prompt_bank/prompt_data_normalize_csv_split"
19 | lr=1e-3
20 | epoch=50
21 | downsample_rate=20
22 | freeze=0
23 | d_ff=128
24 | OUTPUT_PATH="output/ltsm_tokenizer_lr${lr}_loraFalse_down${downsample_rate}_freeze${freeze}_e${epoch}_pred${pred_len}/"
25 |
26 | for pred_len in 96
27 | do
28 | CUDA_VISIBLE_DEVICES=0,1 python3 main_tokenizer.py \
29 | --model LTSM_Tokenizer \
30 | --model_name_or_path gpt2-medium \
31 | --d_ff $d_ff \
32 | --train_epochs ${epoch} \
33 | --batch_size 20 \
34 | --pred_len ${pred_len} \
35 | --gradient_accumulation_steps 64 \
36 | --data_path ${TRAIN} \
37 | --test_data_path_list ${TEST} \
38 | --prompt_data_path ${PROMPT} \
39 | --freeze ${freeze} \
40 | --learning_rate ${lr} \
41 | --downsample_rate ${downsample_rate} \
42 | --output_dir ${OUTPUT_PATH}\
43 | --eval 1
44 | done
45 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 |
3 | setuptools.setup(
4 | name="ltsm",
5 | version='1.0.0',
6 | author="Data Lab",
7 | author_email="daochen.zha@rice.edu",
8 | description="Large Time Sereis Model",
9 | url="XXXX",
10 | keywords=["Time Series"],
11 | packages=setuptools.find_packages(exclude=('tests',)),
12 | requires_python='>=3.8',
13 | classifiers=[
14 | "Programming Language :: Python :: 3.8",
15 | "License :: OSI Approved :: MIT License",
16 | "Operating System :: OS Independent",
17 | ],
18 | )
19 |
--------------------------------------------------------------------------------
/tutorial/README.md:
--------------------------------------------------------------------------------
1 | # Tutorial of LTSM-bundle
2 |
3 |
4 | ## Installation
5 | ```
6 | conda create -n ltsm python=3.8.0
7 | conda activate ltsm
8 | git clone git@github.com:daochenzha/ltsm.git
9 | cd ltsm
10 | pip3 install -e .
11 | pip3 install -r requirements.txt
12 | ```
13 |
14 |
15 | ## :bookmark: Step 0: Collect Datasets and Time Series Prompts
16 |
17 | ### :cyclone: You can use our prepared dataset to on-board youselves on LTSM-bundle
18 |
19 | ### Download training datasets
20 | ```bash
21 | cd datasets
22 | download: https://drive.google.com/drive/folders/1hLFbz0FRxdiDCzgFYtKCOPJYSBVvwW9P
23 | ```
24 |
25 | ### Download time sereis prompts
26 | ```bash
27 | cd prompt_bank/propmt_data_csv
28 | download: https://drive.google.com/drive/folders/1hLFbz0FRxdiDCzgFYtKCOPJYSBVvwW9P
29 | ```
30 |
31 | ### Check word prompts
32 | ```bash
33 | cd prompt_bank/text_prompt_data_csv/
34 | check: csv_prompt.json
35 | ```
36 |
37 | ## :bookmark: Step 1: Customize Datasets and Time Series Prompts
38 |
39 | ### :cyclone: If you prefer to build LTSM-bundle on your own dataset, please follow the 5-step instructions below:
40 |
41 | **Step 1-a.** Prepare your dataset. Make sure your local data folder like this:
42 | ````angular2html
43 | - ltsm/
44 | - datasets/
45 | DATA_1.csv/
46 | DATA_2.csv/
47 | ...
48 | ````
49 |
50 | **Step 1-b.** Generating the time series prompts from training, validating, and testing datasets
51 | ````angular2html
52 | python3 prompt_generate_split.py
53 | ````
54 |
55 | **Step 1-c.** Find the generated time series prompts in the './prompt_data_split' folder. Then run the following command for normalizing the prompts:
56 | ````angular2html
57 | python3 prompt_normalization_split.py --mode fit
58 | ````
59 |
60 | **Step 1-d.** Run this command to export the prompts to the "./prompt_data_normalize_split" folder:
61 | ````angular2html
62 | python3 prompt_normalization_split.py --mode transform
63 | ````
64 |
65 | **Step 1-e.** Modify the word prompt based on your dataset description in "prompt_bank/text_prompt_data_csv/csv_prompt.json":
66 | ````angular2html
67 | vim prompt_bank/text_prompt_data_csv/csv_prompt.json
68 | ````
69 |
70 | ## :bookmark: Step 2: Customize your own LTSM-bundle
71 |
72 | ### :cyclone: Now, it's time to build you own LTSM-bundle!!
73 |
74 | #### Option-(1) Explore [Word Prompt] and [Linear Tokenization] on gpt2-medium
75 | ```bash
76 | python3 main_ltsm.py \
77 | --model LTSM_WordPrompt \
78 | --model_name_or_path gpt2-medium \
79 | --train_epochs 500 \
80 | --batch_size 10 \
81 | --pred_len 96 \
82 | --data_path "datasets/ETT-small/ETTh1.csv" \
83 | --test_data_path_list "datasets/ETT-small/ETTh1.csv" \
84 | --prompt_data_path "prompt_bank/text_prompt_data_csv/csv_prompt.json" \
85 | --freeze 0 \
86 | --learning_rate 1e-3 \
87 | --downsample_rate 20 \
88 | --output_dir [Your_Output_Path] \
89 | ```
90 |
91 | #### Option-(2) Explore [Time Series Prompt] and [Linear Tokenization] on gpt2-medium
92 | ```bash
93 | python3 main_ltsm.py \
94 | --model LTSM \
95 | --model_name_or_path gpt2-medium \
96 | --train_epochs 500 \
97 | --batch_size 10 \
98 | --pred_len 96 \
99 | --data_path "datasets/ETT-small/ETTh1.csv" \
100 | --test_data_path_list "datasets/ETT-small/ETTh1.csv" \
101 | --prompt_data_path "prompt_bank/prompt_data_normalize_split" \
102 | --freeze 0 \
103 | --learning_rate 1e-3 \
104 | --downsample_rate 20 \
105 | --output_dir [Your_Output_Path] \
106 | ```
107 |
108 | #### Option-(3) Finetune your dataset based on pre-trained LTSM-bundle model: [Time Series Prompt] and [Linear Tokenization] on gpt2-medium
109 | ```bash
110 | python3 main_ltsm.py \
111 | --model LTSM \
112 | --model_name_or_path gpt2-medium \
113 | --local_pretrain LSC2204/LTSM-bundle \ # This model weight is for pred_len == 96
114 | --train_epochs 500 \
115 | --batch_size 10 \
116 | --pred_len 96 \
117 | --data_path "datasets/ETT-small/ETTh1.csv" \
118 | --test_data_path_list "datasets/ETT-small/ETTh1.csv" \
119 | --prompt_data_path "prompt_bank/prompt_data_normalize_split" \
120 | --freeze 0 \
121 | --learning_rate 1e-3 \
122 | --downsample_rate 20 \
123 | --output_dir [Your_Output_Path] \
124 | ```
125 |
126 |
--------------------------------------------------------------------------------