├── Figure_1.pdf ├── Related_Work_A.1.pdf ├── data ├── ist_cnt2category2cnt.npz ├── jkt_cnt2category2cnt.npz ├── nyc_cnt2category2cnt.npz └── tky_cnt2category2cnt.npz ├── model ├── base.py ├── configuration_phi.py ├── layers.py ├── positional_encoding.py ├── bert.py ├── utils.py ├── llm.py ├── MobilityLLM.py └── phi_model.py ├── config ├── MobilityLLM_wee_POI.conf ├── MobilityLLM_wee_TPP.conf ├── MobilityLLM_wee_TUL.conf ├── MobilityLLM_bkc_POI.conf ├── MobilityLLM_bkc_TPP.conf ├── MobilityLLM_bkc_TUL.conf ├── MobilityLLM_gow_POI.conf ├── MobilityLLM_gow_TPP.conf ├── MobilityLLM_gow_TUL.conf ├── MobilityLLM_tky_POI.conf ├── MobilityLLM_tky_TPP.conf └── MobilityLLM_tky_TUL.conf ├── README.md ├── utils.py ├── preprocess └── load_data.py └── train_MobilityLLM.py /Figure_1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LetianGong/Mobility-LLM/HEAD/Figure_1.pdf -------------------------------------------------------------------------------- /Related_Work_A.1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LetianGong/Mobility-LLM/HEAD/Related_Work_A.1.pdf -------------------------------------------------------------------------------- /data/ist_cnt2category2cnt.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LetianGong/Mobility-LLM/HEAD/data/ist_cnt2category2cnt.npz -------------------------------------------------------------------------------- /data/jkt_cnt2category2cnt.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LetianGong/Mobility-LLM/HEAD/data/jkt_cnt2category2cnt.npz -------------------------------------------------------------------------------- /data/nyc_cnt2category2cnt.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LetianGong/Mobility-LLM/HEAD/data/nyc_cnt2category2cnt.npz -------------------------------------------------------------------------------- /data/tky_cnt2category2cnt.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LetianGong/Mobility-LLM/HEAD/data/tky_cnt2category2cnt.npz -------------------------------------------------------------------------------- /model/base.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | import numpy as np 7 | from einops import repeat, rearrange 8 | 9 | 10 | class Encoder(nn.Module): 11 | def __init__(self, name): 12 | super().__init__() 13 | 14 | self.name = name 15 | 16 | 17 | class Decoder(nn.Module): 18 | def __init__(self, name): 19 | super().__init__() 20 | 21 | # self.denormalizer = denormalizer 22 | self.name = name 23 | 24 | 25 | class Denoiser(nn.Module): 26 | def __init__(self, name): 27 | super().__init__() 28 | self.name = name 29 | 30 | 31 | class GNN(nn.Module): 32 | def __init__(self, graph, name): 33 | super().__init__() 34 | 35 | self.graph = graph 36 | 37 | self.name = name 38 | -------------------------------------------------------------------------------- /config/MobilityLLM_wee_POI.conf: -------------------------------------------------------------------------------- 1 | [Data] 2 | dataset_name = www_WEE 3 | country_name = US 4 | max_his_period_days = 30 5 | max_merge_seconds_limit = 3600 6 | max_delta_mins = 1440 7 | min_session_mins = 120 8 | least_disuser_count = 10 9 | least_checkins_count = 10 10 | 11 | latN = 50 12 | lngN = 40 13 | split_save = 1 14 | 15 | [Training] 16 | use_nni = 0 17 | mode = train 18 | ctx = 0 19 | regularization = 1e-5 20 | learning_rate = 1e-3 21 | max_epochs = 100 22 | display_step = 1 23 | patience = 6 24 | train_batch = 8 25 | val_batch = 8 26 | test_batch = 8 27 | batch_size = 8 28 | save_results = 0 29 | 30 | 31 | [Model] 32 | loc_emb_size = 256 33 | geohash_size = 128 34 | category_size = 128 35 | tim_emb_size = 256 36 | user_emb_size = 256 37 | hidden_size = 256 38 | loc_noise_mean = 0 39 | loc_noise_sigma = 0.01 40 | tim_noise_mean = 0 41 | tim_noise_sigma = 0.01 42 | user_noise_mean = 0 43 | user_noise_sigma = 0.01 44 | tau = 4 45 | pos_eps = 0.5 46 | neg_eps = 0.5 47 | dropout_rate_1 = 0.5 48 | dropout_rate_2 = 0.5 49 | adv = 1 50 | self_weight = 0.05 51 | self_weight_s = 0.05 52 | self_weight_t = 0.05 53 | self_weight_st = 0.05 54 | k = 8 55 | momentum = 0.95 56 | theta = 0.18 57 | temperature = 0.1 58 | rnn_type = GRU 59 | num_layers = 3 60 | downstream = POI 61 | dump_path = checkpoints 62 | rank = 0 63 | queue_length = 1024 64 | world_size = -1 65 | epoch_queue_starts = 0 66 | crops_for_assign = 01 67 | feat_dim = 256 68 | loss = mae 69 | tpp = linear 70 | epsilon = 0.05 71 | dropout_spatial = 0.3 72 | learnable_param_size = 5 -------------------------------------------------------------------------------- /config/MobilityLLM_wee_TPP.conf: -------------------------------------------------------------------------------- 1 | [Data] 2 | dataset_name = www_WEE 3 | country_name = US 4 | max_his_period_days = 30 5 | max_merge_seconds_limit = 3600 6 | max_delta_mins = 1440 7 | min_session_mins = 120 8 | least_disuser_count = 10 9 | least_checkins_count = 10 10 | 11 | latN = 50 12 | lngN = 40 13 | split_save = 1 14 | 15 | [Training] 16 | use_nni = 0 17 | mode = train 18 | ctx = 0 19 | regularization = 1e-5 20 | learning_rate = 1e-3 21 | max_epochs = 100 22 | display_step = 1 23 | patience = 6 24 | train_batch = 8 25 | val_batch = 8 26 | test_batch = 8 27 | batch_size = 8 28 | save_results = 0 29 | 30 | 31 | [Model] 32 | loc_emb_size = 256 33 | geohash_size = 128 34 | category_size = 128 35 | tim_emb_size = 256 36 | user_emb_size = 256 37 | hidden_size = 256 38 | loc_noise_mean = 0 39 | loc_noise_sigma = 0.01 40 | tim_noise_mean = 0 41 | tim_noise_sigma = 0.01 42 | user_noise_mean = 0 43 | user_noise_sigma = 0.01 44 | tau = 4 45 | pos_eps = 0.5 46 | neg_eps = 0.5 47 | dropout_rate_1 = 0.5 48 | dropout_rate_2 = 0.5 49 | adv = 1 50 | self_weight = 0.05 51 | self_weight_s = 0.05 52 | self_weight_t = 0.05 53 | self_weight_st = 0.05 54 | k = 8 55 | momentum = 0.95 56 | theta = 0.18 57 | temperature = 0.1 58 | rnn_type = GRU 59 | num_layers = 3 60 | downstream = TPP 61 | dump_path = checkpoints 62 | rank = 0 63 | queue_length = 1024 64 | world_size = -1 65 | epoch_queue_starts = 0 66 | crops_for_assign = 01 67 | feat_dim = 256 68 | loss = mae 69 | tpp = linear 70 | epsilon = 0.05 71 | dropout_spatial = 0.3 72 | learnable_param_size = 3 -------------------------------------------------------------------------------- /config/MobilityLLM_wee_TUL.conf: -------------------------------------------------------------------------------- 1 | [Data] 2 | dataset_name = www_WEE 3 | country_name = US 4 | max_his_period_days = 30 5 | max_merge_seconds_limit = 3600 6 | max_delta_mins = 1440 7 | min_session_mins = 120 8 | least_disuser_count = 10 9 | least_checkins_count = 10 10 | 11 | latN = 50 12 | lngN = 40 13 | split_save = 1 14 | 15 | [Training] 16 | use_nni = 0 17 | mode = train 18 | ctx = 0 19 | regularization = 1e-5 20 | learning_rate = 1e-3 21 | max_epochs = 100 22 | display_step = 1 23 | patience = 6 24 | train_batch = 8 25 | val_batch = 8 26 | test_batch = 8 27 | batch_size = 8 28 | save_results = 0 29 | 30 | 31 | [Model] 32 | loc_emb_size = 256 33 | geohash_size = 128 34 | category_size = 128 35 | tim_emb_size = 256 36 | user_emb_size = 256 37 | hidden_size = 256 38 | loc_noise_mean = 0 39 | loc_noise_sigma = 0.01 40 | tim_noise_mean = 0 41 | tim_noise_sigma = 0.01 42 | user_noise_mean = 0 43 | user_noise_sigma = 0.01 44 | tau = 4 45 | pos_eps = 0.5 46 | neg_eps = 0.5 47 | dropout_rate_1 = 0.5 48 | dropout_rate_2 = 0.5 49 | adv = 1 50 | self_weight = 0.05 51 | self_weight_s = 0.05 52 | self_weight_t = 0.05 53 | self_weight_st = 0.05 54 | k = 8 55 | momentum = 0.95 56 | theta = 0.18 57 | temperature = 0.1 58 | rnn_type = GRU 59 | num_layers = 3 60 | downstream = TUL 61 | dump_path = checkpoints 62 | rank = 0 63 | queue_length = 1024 64 | world_size = -1 65 | epoch_queue_starts = 0 66 | crops_for_assign = 01 67 | feat_dim = 256 68 | loss = mae 69 | tpp = linear 70 | epsilon = 0.05 71 | dropout_spatial = 0.3 72 | learnable_param_size = 5 -------------------------------------------------------------------------------- /config/MobilityLLM_bkc_POI.conf: -------------------------------------------------------------------------------- 1 | [Data] 2 | dataset_name = www_BKC 3 | country_name = US 4 | max_his_period_days = 120 5 | max_merge_seconds_limit = 10800 6 | max_delta_mins = 1440 7 | min_session_mins = 1440 8 | least_disuser_count = 10 9 | least_checkins_count = 10 10 | 11 | latN = 50 12 | lngN = 40 13 | split_save = 1 14 | 15 | [Training] 16 | use_nni = 0 17 | mode = train 18 | ctx = 0 19 | regularization = 1e-5 20 | learning_rate = 1e-3 21 | max_epochs = 100 22 | # max_epochs = 10 23 | display_step = 1 24 | patience = 6 25 | train_batch = 8 26 | val_batch = 8 27 | test_batch = 8 28 | batch_size = 8 29 | save_results = 0 30 | 31 | 32 | [Model] 33 | loc_emb_size = 256 34 | geohash_size = 128 35 | category_size = 128 36 | tim_emb_size = 256 37 | user_emb_size = 256 38 | hidden_size = 256 39 | loc_noise_mean = 0 40 | loc_noise_sigma = 0.01 41 | tim_noise_mean = 0 42 | tim_noise_sigma = 0.01 43 | user_noise_mean = 0 44 | user_noise_sigma = 0.01 45 | tau = 4 46 | pos_eps = 0.5 47 | neg_eps = 0.5 48 | dropout_rate_1 = 0.5 49 | dropout_rate_2 = 0.5 50 | adv = 1 51 | self_weight = 0.05 52 | self_weight_s = 0.05 53 | self_weight_t = 0.05 54 | self_weight_st = 0.05 55 | k = 8 56 | momentum = 0.95 57 | theta = 0.18 58 | temperature = 0.1 59 | rnn_type = GRU 60 | num_layers = 3 61 | downstream = POI 62 | # downstream = TPP 63 | dump_path = checkpoints 64 | rank = 0 65 | queue_length = 1024 66 | world_size = -1 67 | epoch_queue_starts = 0 68 | crops_for_assign = 01 69 | feat_dim = 256 70 | loss = mae 71 | tpp = linear 72 | epsilon = 0.05 73 | dropout_spatial = 0.3 74 | learnable_param_size = 1 -------------------------------------------------------------------------------- /config/MobilityLLM_bkc_TPP.conf: -------------------------------------------------------------------------------- 1 | [Data] 2 | dataset_name = www_BKC 3 | country_name = US 4 | max_his_period_days = 120 5 | max_merge_seconds_limit = 10800 6 | max_delta_mins = 1440 7 | min_session_mins = 1440 8 | least_disuser_count = 10 9 | least_checkins_count = 10 10 | 11 | latN = 50 12 | lngN = 40 13 | split_save = 1 14 | 15 | [Training] 16 | use_nni = 0 17 | mode = train 18 | ctx = 0 19 | regularization = 1e-5 20 | learning_rate = 1e-3 21 | max_epochs = 100 22 | # max_epochs = 10 23 | display_step = 1 24 | patience = 6 25 | train_batch = 8 26 | val_batch = 8 27 | test_batch = 8 28 | batch_size = 8 29 | save_results = 0 30 | 31 | 32 | [Model] 33 | loc_emb_size = 256 34 | geohash_size = 128 35 | category_size = 128 36 | tim_emb_size = 256 37 | user_emb_size = 256 38 | hidden_size = 256 39 | loc_noise_mean = 0 40 | loc_noise_sigma = 0.01 41 | tim_noise_mean = 0 42 | tim_noise_sigma = 0.01 43 | user_noise_mean = 0 44 | user_noise_sigma = 0.01 45 | tau = 4 46 | pos_eps = 0.5 47 | neg_eps = 0.5 48 | dropout_rate_1 = 0.5 49 | dropout_rate_2 = 0.5 50 | adv = 1 51 | self_weight = 0.05 52 | self_weight_s = 0.05 53 | self_weight_t = 0.05 54 | self_weight_st = 0.05 55 | k = 8 56 | momentum = 0.95 57 | theta = 0.18 58 | temperature = 0.1 59 | rnn_type = GRU 60 | num_layers = 3 61 | # downstream = POI 62 | downstream = TPP 63 | dump_path = checkpoints 64 | rank = 0 65 | queue_length = 1024 66 | world_size = -1 67 | epoch_queue_starts = 0 68 | crops_for_assign = 01 69 | feat_dim = 256 70 | loss = mae 71 | tpp = linear 72 | epsilon = 0.05 73 | dropout_spatial = 0.3 74 | learnable_param_size = 3 -------------------------------------------------------------------------------- /config/MobilityLLM_bkc_TUL.conf: -------------------------------------------------------------------------------- 1 | [Data] 2 | dataset_name = www_BKC 3 | country_name = US 4 | max_his_period_days = 120 5 | max_merge_seconds_limit = 10800 6 | max_delta_mins = 1440 7 | min_session_mins = 1440 8 | least_disuser_count = 10 9 | least_checkins_count = 10 10 | 11 | latN = 50 12 | lngN = 40 13 | split_save = 1 14 | 15 | [Training] 16 | use_nni = 0 17 | mode = train 18 | ctx = 0 19 | regularization = 1e-5 20 | learning_rate = 1e-3 21 | max_epochs = 100 22 | # max_epochs = 10 23 | display_step = 1 24 | patience = 6 25 | train_batch = 8 26 | val_batch = 8 27 | test_batch = 8 28 | batch_size = 8 29 | save_results = 0 30 | 31 | 32 | [Model] 33 | loc_emb_size = 256 34 | geohash_size = 128 35 | category_size = 128 36 | tim_emb_size = 256 37 | user_emb_size = 256 38 | hidden_size = 256 39 | loc_noise_mean = 0 40 | loc_noise_sigma = 0.01 41 | tim_noise_mean = 0 42 | tim_noise_sigma = 0.01 43 | user_noise_mean = 0 44 | user_noise_sigma = 0.01 45 | tau = 4 46 | pos_eps = 0.5 47 | neg_eps = 0.5 48 | dropout_rate_1 = 0.5 49 | dropout_rate_2 = 0.5 50 | adv = 1 51 | self_weight = 0.05 52 | self_weight_s = 0.05 53 | self_weight_t = 0.05 54 | self_weight_st = 0.05 55 | k = 8 56 | momentum = 0.95 57 | theta = 0.18 58 | temperature = 0.1 59 | rnn_type = GRU 60 | num_layers = 3 61 | downstream = TUL 62 | # downstream = TPP 63 | dump_path = checkpoints 64 | rank = 0 65 | queue_length = 1024 66 | world_size = -1 67 | epoch_queue_starts = 0 68 | crops_for_assign = 01 69 | feat_dim = 256 70 | loss = mae 71 | tpp = linear 72 | epsilon = 0.05 73 | dropout_spatial = 0.3 74 | learnable_param_size = 1 -------------------------------------------------------------------------------- /config/MobilityLLM_gow_POI.conf: -------------------------------------------------------------------------------- 1 | [Data] 2 | dataset_name = www_GOW 3 | country_name = US 4 | max_his_period_days = 120 5 | max_merge_seconds_limit = 10800 6 | max_delta_mins = 1440 7 | min_session_mins = 1440 8 | least_disuser_count = 10 9 | least_checkins_count = 10 10 | 11 | latN = 50 12 | lngN = 40 13 | split_save = 1 14 | 15 | [Training] 16 | use_nni = 0 17 | mode = train 18 | ctx = 0 19 | regularization = 1e-5 20 | learning_rate = 1e-3 21 | max_epochs = 100 22 | # max_epochs = 10 23 | display_step = 1 24 | patience = 6 25 | train_batch = 8 26 | val_batch = 8 27 | test_batch = 8 28 | batch_size = 8 29 | save_results = 0 30 | 31 | 32 | [Model] 33 | loc_emb_size = 256 34 | geohash_size = 128 35 | category_size = 128 36 | tim_emb_size = 256 37 | user_emb_size = 256 38 | hidden_size = 256 39 | loc_noise_mean = 0 40 | loc_noise_sigma = 0.01 41 | tim_noise_mean = 0 42 | tim_noise_sigma = 0.01 43 | user_noise_mean = 0 44 | user_noise_sigma = 0.01 45 | tau = 4 46 | pos_eps = 0.5 47 | neg_eps = 0.5 48 | dropout_rate_1 = 0.5 49 | dropout_rate_2 = 0.5 50 | adv = 1 51 | self_weight = 0.05 52 | self_weight_s = 0.05 53 | self_weight_t = 0.05 54 | self_weight_st = 0.05 55 | k = 8 56 | momentum = 0.95 57 | theta = 0.18 58 | temperature = 0.1 59 | rnn_type = GRU 60 | num_layers = 3 61 | downstream = POI 62 | # downstream = TPP 63 | dump_path = checkpoints 64 | rank = 0 65 | queue_length = 1024 66 | world_size = -1 67 | epoch_queue_starts = 0 68 | crops_for_assign = 01 69 | feat_dim = 256 70 | loss = mae 71 | tpp = linear 72 | epsilon = 0.05 73 | dropout_spatial = 0.3 74 | learnable_param_size = 1 -------------------------------------------------------------------------------- /config/MobilityLLM_gow_TPP.conf: -------------------------------------------------------------------------------- 1 | [Data] 2 | dataset_name = www_GOW 3 | country_name = US 4 | max_his_period_days = 120 5 | max_merge_seconds_limit = 10800 6 | max_delta_mins = 1440 7 | min_session_mins = 1440 8 | least_disuser_count = 10 9 | least_checkins_count = 10 10 | 11 | latN = 50 12 | lngN = 40 13 | split_save = 1 14 | 15 | [Training] 16 | use_nni = 0 17 | mode = train 18 | ctx = 0 19 | regularization = 1e-5 20 | learning_rate = 1e-3 21 | max_epochs = 100 22 | # max_epochs = 10 23 | display_step = 1 24 | patience = 6 25 | train_batch = 8 26 | val_batch = 8 27 | test_batch = 8 28 | batch_size = 8 29 | save_results = 0 30 | 31 | 32 | [Model] 33 | loc_emb_size = 256 34 | geohash_size = 128 35 | category_size = 128 36 | tim_emb_size = 256 37 | user_emb_size = 256 38 | hidden_size = 256 39 | loc_noise_mean = 0 40 | loc_noise_sigma = 0.01 41 | tim_noise_mean = 0 42 | tim_noise_sigma = 0.01 43 | user_noise_mean = 0 44 | user_noise_sigma = 0.01 45 | tau = 4 46 | pos_eps = 0.5 47 | neg_eps = 0.5 48 | dropout_rate_1 = 0.5 49 | dropout_rate_2 = 0.5 50 | adv = 1 51 | self_weight = 0.05 52 | self_weight_s = 0.05 53 | self_weight_t = 0.05 54 | self_weight_st = 0.05 55 | k = 8 56 | momentum = 0.95 57 | theta = 0.18 58 | temperature = 0.1 59 | rnn_type = GRU 60 | num_layers = 3 61 | # downstream = POI 62 | downstream = TPP 63 | dump_path = checkpoints 64 | rank = 0 65 | queue_length = 1024 66 | world_size = -1 67 | epoch_queue_starts = 0 68 | crops_for_assign = 01 69 | feat_dim = 256 70 | loss = mae 71 | tpp = linear 72 | epsilon = 0.05 73 | dropout_spatial = 0.3 74 | learnable_param_size = 3 -------------------------------------------------------------------------------- /config/MobilityLLM_gow_TUL.conf: -------------------------------------------------------------------------------- 1 | [Data] 2 | dataset_name = www_GOW 3 | country_name = US 4 | max_his_period_days = 120 5 | max_merge_seconds_limit = 10800 6 | max_delta_mins = 1440 7 | min_session_mins = 1440 8 | least_disuser_count = 10 9 | least_checkins_count = 10 10 | 11 | latN = 50 12 | lngN = 40 13 | split_save = 1 14 | 15 | [Training] 16 | use_nni = 0 17 | mode = train 18 | ctx = 0 19 | regularization = 1e-5 20 | learning_rate = 1e-3 21 | max_epochs = 100 22 | # max_epochs = 10 23 | display_step = 1 24 | patience = 6 25 | train_batch = 8 26 | val_batch = 8 27 | test_batch = 8 28 | batch_size = 8 29 | save_results = 0 30 | 31 | 32 | [Model] 33 | loc_emb_size = 256 34 | geohash_size = 128 35 | category_size = 128 36 | tim_emb_size = 256 37 | user_emb_size = 256 38 | hidden_size = 256 39 | loc_noise_mean = 0 40 | loc_noise_sigma = 0.01 41 | tim_noise_mean = 0 42 | tim_noise_sigma = 0.01 43 | user_noise_mean = 0 44 | user_noise_sigma = 0.01 45 | tau = 4 46 | pos_eps = 0.5 47 | neg_eps = 0.5 48 | dropout_rate_1 = 0.5 49 | dropout_rate_2 = 0.5 50 | adv = 1 51 | self_weight = 0.05 52 | self_weight_s = 0.05 53 | self_weight_t = 0.05 54 | self_weight_st = 0.05 55 | k = 8 56 | momentum = 0.95 57 | theta = 0.18 58 | temperature = 0.1 59 | rnn_type = GRU 60 | num_layers = 3 61 | downstream = TUL 62 | # downstream = TPP 63 | dump_path = checkpoints 64 | rank = 0 65 | queue_length = 1024 66 | world_size = -1 67 | epoch_queue_starts = 0 68 | crops_for_assign = 01 69 | feat_dim = 256 70 | loss = mae 71 | tpp = linear 72 | epsilon = 0.05 73 | dropout_spatial = 0.3 74 | learnable_param_size = 1 -------------------------------------------------------------------------------- /config/MobilityLLM_tky_POI.conf: -------------------------------------------------------------------------------- 1 | [Data] 2 | dataset_name = www_TKY 3 | country_name = JP 4 | max_his_period_days = 120 5 | max_merge_seconds_limit = 10800 6 | max_delta_mins = 1440 7 | min_session_mins = 1440 8 | least_disuser_count = 10 9 | least_checkins_count = 10 10 | 11 | latN = 40 12 | lngN = 40 13 | split_save = 1 14 | 15 | [Training] 16 | use_nni = 0 17 | mode = train 18 | ctx = 0 19 | regularization = 1e-5 20 | learning_rate = 1e-3 21 | max_epochs = 100 22 | # max_epochs = 10 23 | display_step = 1 24 | patience = 6 25 | train_batch = 8 26 | val_batch = 8 27 | test_batch = 8 28 | batch_size = 8 29 | save_results = 0 30 | 31 | 32 | [Model] 33 | loc_emb_size = 256 34 | geohash_size = 128 35 | category_size = 128 36 | tim_emb_size = 256 37 | user_emb_size = 256 38 | hidden_size = 256 39 | loc_noise_mean = 0 40 | loc_noise_sigma = 0.01 41 | tim_noise_mean = 0 42 | tim_noise_sigma = 0.01 43 | user_noise_mean = 0 44 | user_noise_sigma = 0.01 45 | tau = 4 46 | pos_eps = 0.5 47 | neg_eps = 0.5 48 | dropout_rate_1 = 0.5 49 | dropout_rate_2 = 0.5 50 | adv = 1 51 | self_weight = 0.05 52 | self_weight_s = 0.05 53 | self_weight_t = 0.05 54 | self_weight_st = 0.05 55 | k = 8 56 | momentum = 0.95 57 | theta = 0.18 58 | temperature = 0.1 59 | rnn_type = GRU 60 | num_layers = 3 61 | downstream = POI 62 | # downstream = TPP 63 | dump_path = checkpoints 64 | rank = 0 65 | queue_length = 1024 66 | world_size = -1 67 | epoch_queue_starts = 0 68 | crops_for_assign = 01 69 | feat_dim = 256 70 | loss = mae 71 | tpp = linear 72 | epsilon = 0.05 73 | dropout_spatial = 0.3 74 | learnable_param_size = 3 -------------------------------------------------------------------------------- /config/MobilityLLM_tky_TPP.conf: -------------------------------------------------------------------------------- 1 | [Data] 2 | dataset_name = www_TKY 3 | country_name = JP 4 | max_his_period_days = 120 5 | max_merge_seconds_limit = 10800 6 | max_delta_mins = 1440 7 | min_session_mins = 1440 8 | least_disuser_count = 10 9 | least_checkins_count = 10 10 | 11 | latN = 40 12 | lngN = 40 13 | split_save = 1 14 | 15 | [Training] 16 | use_nni = 0 17 | mode = train 18 | ctx = 0 19 | regularization = 1e-5 20 | learning_rate = 1e-3 21 | max_epochs = 100 22 | # max_epochs = 10 23 | display_step = 1 24 | patience = 6 25 | train_batch = 8 26 | val_batch = 8 27 | test_batch = 8 28 | batch_size = 8 29 | save_results = 0 30 | 31 | 32 | [Model] 33 | loc_emb_size = 256 34 | geohash_size = 128 35 | category_size = 128 36 | tim_emb_size = 256 37 | user_emb_size = 256 38 | hidden_size = 256 39 | loc_noise_mean = 0 40 | loc_noise_sigma = 0.01 41 | tim_noise_mean = 0 42 | tim_noise_sigma = 0.01 43 | user_noise_mean = 0 44 | user_noise_sigma = 0.01 45 | tau = 4 46 | pos_eps = 0.5 47 | neg_eps = 0.5 48 | dropout_rate_1 = 0.5 49 | dropout_rate_2 = 0.5 50 | adv = 1 51 | self_weight = 0.05 52 | self_weight_s = 0.05 53 | self_weight_t = 0.05 54 | self_weight_st = 0.05 55 | k = 8 56 | momentum = 0.95 57 | theta = 0.18 58 | temperature = 0.1 59 | rnn_type = GRU 60 | num_layers = 3 61 | # downstream = POI 62 | downstream = TPP 63 | dump_path = checkpoints 64 | rank = 0 65 | queue_length = 1024 66 | world_size = -1 67 | epoch_queue_starts = 0 68 | crops_for_assign = 01 69 | feat_dim = 256 70 | loss = mae 71 | tpp = linear 72 | epsilon = 0.05 73 | dropout_spatial = 0.3 74 | learnable_param_size = 3 -------------------------------------------------------------------------------- /config/MobilityLLM_tky_TUL.conf: -------------------------------------------------------------------------------- 1 | [Data] 2 | dataset_name = www_TKY 3 | country_name = JP 4 | max_his_period_days = 120 5 | max_merge_seconds_limit = 10800 6 | max_delta_mins = 1440 7 | min_session_mins = 1440 8 | least_disuser_count = 10 9 | least_checkins_count = 10 10 | 11 | latN = 40 12 | lngN = 40 13 | split_save = 1 14 | 15 | [Training] 16 | use_nni = 0 17 | mode = train 18 | ctx = 0 19 | regularization = 1e-5 20 | learning_rate = 1e-3 21 | max_epochs = 100 22 | # max_epochs = 10 23 | display_step = 1 24 | patience = 6 25 | train_batch = 8 26 | val_batch = 8 27 | test_batch = 8 28 | batch_size = 8 29 | save_results = 0 30 | 31 | 32 | [Model] 33 | loc_emb_size = 256 34 | geohash_size = 128 35 | category_size = 128 36 | tim_emb_size = 256 37 | user_emb_size = 256 38 | hidden_size = 256 39 | loc_noise_mean = 0 40 | loc_noise_sigma = 0.01 41 | tim_noise_mean = 0 42 | tim_noise_sigma = 0.01 43 | user_noise_mean = 0 44 | user_noise_sigma = 0.01 45 | tau = 4 46 | pos_eps = 0.5 47 | neg_eps = 0.5 48 | dropout_rate_1 = 0.5 49 | dropout_rate_2 = 0.5 50 | adv = 1 51 | self_weight = 0.05 52 | self_weight_s = 0.05 53 | self_weight_t = 0.05 54 | self_weight_st = 0.05 55 | k = 8 56 | momentum = 0.95 57 | theta = 0.18 58 | temperature = 0.1 59 | rnn_type = GRU 60 | num_layers = 3 61 | downstream = TUL 62 | # downstream = TPP 63 | dump_path = checkpoints 64 | rank = 0 65 | queue_length = 1024 66 | world_size = -1 67 | epoch_queue_starts = 0 68 | crops_for_assign = 01 69 | feat_dim = 256 70 | loss = mae 71 | tpp = linear 72 | epsilon = 0.05 73 | dropout_spatial = 0.3 74 | learnable_param_size = 1 -------------------------------------------------------------------------------- /model/configuration_phi.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import math 5 | from typing import Optional 6 | 7 | from transformers import PretrainedConfig 8 | 9 | 10 | class PhiConfig(PretrainedConfig): 11 | """Phi configuration.""" 12 | 13 | model_type = "phi-msft" 14 | attribute_map = { 15 | "max_position_embeddings": "n_positions", 16 | "hidden_size": "n_embd", 17 | "num_attention_heads": "n_head", 18 | "num_hidden_layers": "n_layer", 19 | } 20 | 21 | def __init__( 22 | self, 23 | vocab_size: int = 50304, 24 | n_positions: int = 2048, 25 | n_embd: int = 1024, 26 | n_layer: int = 20, 27 | n_inner: Optional[int] = None, 28 | n_head: int = 16, 29 | n_head_kv: Optional[int] = None, 30 | rotary_dim: Optional[int] = 32, 31 | activation_function: Optional[str] = "gelu_new", 32 | flash_attn: bool = False, 33 | flash_rotary: bool = False, 34 | fused_dense: bool = False, 35 | attn_pdrop: float = 0.0, 36 | embd_pdrop: float = 0.0, 37 | resid_pdrop: float = 0.0, 38 | layer_norm_epsilon: float = 1e-5, 39 | initializer_range: float = 0.02, 40 | tie_word_embeddings: bool = False, 41 | pad_vocab_size_multiple: int = 64, 42 | **kwargs 43 | ) -> None: 44 | self.vocab_size = int(math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple) 45 | self.n_positions = n_positions 46 | self.n_embd = n_embd 47 | self.n_layer = n_layer 48 | self.n_inner = n_inner 49 | self.n_head = n_head 50 | self.n_head_kv = n_head_kv 51 | self.rotary_dim = min(rotary_dim, n_embd // n_head) 52 | self.activation_function = activation_function 53 | self.flash_attn = flash_attn 54 | self.flash_rotary = flash_rotary 55 | self.fused_dense = fused_dense 56 | self.attn_pdrop = attn_pdrop 57 | self.embd_pdrop = embd_pdrop 58 | self.resid_pdrop = resid_pdrop 59 | self.layer_norm_epsilon = layer_norm_epsilon 60 | self.initializer_range = initializer_range 61 | 62 | super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) 63 | -------------------------------------------------------------------------------- /model/layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | 7 | 8 | class ContinuousEncoding(nn.Module): 9 | """ 10 | A type of trigonometric encoding for encode continuous values into distance-sensitive vectors. 11 | """ 12 | 13 | def __init__(self, embed_size): 14 | super().__init__() 15 | self.omega = nn.Parameter((torch.from_numpy(1 / 10 ** np.linspace(0, 9, embed_size))).float(), 16 | requires_grad=True) 17 | self.bias = nn.Parameter(torch.zeros(embed_size).float(), requires_grad=True) 18 | self.div_term = math.sqrt(1. / embed_size) 19 | 20 | def forward(self, x): 21 | """ 22 | :param x: input sequence for encoding, (batch_size, seq_len) 23 | :return: encoded sequence, shape (batch_size, seq_len, embed_size) 24 | """ 25 | encode = x.unsqueeze(-1) * self.omega.reshape(1, 1, -1) + self.bias.reshape(1, 1, -1) 26 | encode = torch.cos(encode) 27 | return self.div_term * encode 28 | 29 | 30 | class PositionalEncoding(nn.Module): 31 | """ 32 | A type of trigonometric encoding for indicating items' positions in sequences. 33 | """ 34 | 35 | def __init__(self, embed_size, max_len): 36 | super().__init__() 37 | 38 | pe = torch.zeros(max_len, embed_size).float() 39 | pe.requires_grad = False 40 | 41 | position = torch.arange(0, max_len).float().unsqueeze(1) 42 | div_term = (torch.arange(0, embed_size, 2).float() * -(math.log(10000.0) / embed_size)).exp() 43 | 44 | pe[:, 0::2] = torch.sin(position * div_term) 45 | pe[:, 1::2] = torch.cos(position * div_term) 46 | 47 | pe = pe.unsqueeze(0) 48 | self.register_buffer('pe', pe) 49 | 50 | def forward(self, x, position_ids=None): 51 | """ 52 | Args: 53 | x: (B, T, d_model) 54 | position_ids: (B, T) or None 55 | 56 | Returns: 57 | (1, T, d_model) / (B, T, d_model) 58 | """ 59 | if position_ids is None: 60 | return self.pe[:, :x.size(1)] 61 | else: 62 | batch_size, seq_len = position_ids.shape 63 | pe = self.pe[:, :seq_len, :] # (1, T, d_model) 64 | pe = pe.expand((position_ids.shape[0], -1, -1)) # (B, T, d_model) 65 | pe = pe.reshape(-1, self.d_model) # (B * T, d_model) 66 | position_ids = position_ids.reshape(-1, 1).squeeze(1) # (B * T,) 67 | output_pe = pe[position_ids].reshape(batch_size, seq_len, self.d_model).detach() 68 | return output_pe 69 | 70 | 71 | class SinusoidalPositionEmbeddings(nn.Module): 72 | """ 73 | Sinusoidal-based function used for encoding timestamps. 74 | """ 75 | 76 | def __init__(self, dim): 77 | super().__init__() 78 | self.dim = dim 79 | 80 | def forward(self, time): 81 | device = time.device 82 | half_dim = self.dim // 2 83 | embeddings = math.log(10000) / (half_dim - 1) 84 | embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) 85 | embeddings = time[:, None] * embeddings[None, :] 86 | embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) 87 | return embeddings 88 | 89 | 90 | class TimeEmbed(nn.Module): 91 | def __init__(self, input_dim, output_dim): 92 | super().__init__() 93 | 94 | self.time_mlp = nn.Sequential( 95 | SinusoidalPositionEmbeddings(input_dim), 96 | nn.Linear(input_dim, output_dim), 97 | nn.SiLU(), 98 | nn.Linear(output_dim, output_dim) 99 | ) 100 | 101 | def forward(self, time): 102 | return self.time_mlp(time) 103 | -------------------------------------------------------------------------------- /model/positional_encoding.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Unofficial pytorch implementation of the paper "Learnable Fourier Features for Multi-Dimensional Spatial Positional Encoding", NeurIPS 2021. 3 | ''' 4 | import torch 5 | import torch.nn as nn 6 | import numpy as np 7 | from einops import rearrange 8 | 9 | 10 | # Learnable Fourier Features for Multi-Dimensional Spatial Positional Encoding 11 | class LearnableFourierFeatures(nn.Module): 12 | def __init__(self, pos_dim, f_dim, h_dim, d_dim, g_dim=1, gamma=1.0): 13 | super(LearnableFourierFeatures, self).__init__() 14 | assert f_dim % 2 == 0, 'number of fourier feature dimensions must be divisible by 2.' 15 | assert d_dim % g_dim == 0, 'number of D dimension must be divisible by the number of G dimension.' 16 | enc_f_dim = int(f_dim / 2) 17 | dg_dim = int(d_dim / g_dim) 18 | self.Wr = nn.Parameter(torch.randn([enc_f_dim, pos_dim]) * (gamma ** 2)) 19 | self.mlp = nn.Sequential( 20 | nn.Linear(f_dim, h_dim), 21 | nn.GELU(), 22 | nn.Linear(h_dim, dg_dim) 23 | ) 24 | self.div_term = np.sqrt(f_dim) 25 | 26 | def forward(self, pos): 27 | # input pos dim: (B L G M) 28 | # output dim: (B L D) 29 | # L stands for sequence length. all dimensions must be flattened to a single dimension. 30 | XWr = torch.matmul(pos, self.Wr.T) 31 | F = torch.cat([torch.cos(XWr), torch.sin(XWr)], dim=-1) / self.div_term 32 | Y = self.mlp(F) 33 | pos_enc = rearrange(Y, 'b l g d -> b l (g d)') 34 | 35 | #return pos_enc 36 | return Y.squeeze() 37 | 38 | 39 | 40 | # Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains 41 | class FourierFeatures(nn.Module): 42 | def __init__(self, pos_dim, f_dim, sigma=10, train=False): 43 | super(FourierFeatures, self).__init__() 44 | assert f_dim % 2 == 0, 'number of channels must be divisible by 2.' 45 | enc_dim = int(f_dim / 2) 46 | self.B = torch.randn([pos_dim, enc_dim]) * sigma 47 | if train: 48 | self.B = nn.Parameter(self.B) 49 | 50 | def forward(self, pos): 51 | # pos: (B L C), (B H W C), (B H W T C) 52 | pos_enc = torch.matmul(pos, self.B.to(pos.device)) 53 | pos_enc = torch.cat([torch.sin(pos_enc), torch.cos(pos_enc)], dim=-1) 54 | return pos_enc 55 | 56 | 57 | # Attention is All You Need 58 | class PositionalEncoding(nn.Module): 59 | def __init__(self, pos_dim, enc_dim): 60 | super(PositionalEncoding, self).__init__() 61 | assert enc_dim % (pos_dim * 2) == 0, 'dimension of positional encoding must be equal to dim * 2.' 62 | enc_dim = int(enc_dim / 2) 63 | div_term = torch.exp(torch.arange(0., enc_dim, 2) * -(np.log(10000.0) / enc_dim)) 64 | freqs = torch.zeros([pos_dim, enc_dim]) 65 | for i in range(pos_dim): 66 | freqs[i, : enc_dim // 2] = div_term 67 | freqs[i, enc_dim // 2:] = div_term 68 | self.freqs = freqs 69 | 70 | def forward(self, pos): 71 | # pos: (B L C), (B H W C), (B H W T C) 72 | pos_enc = torch.matmul(pos, self.freqs.to(pos.device)) 73 | pos_enc = torch.cat([torch.sin(pos_enc), torch.cos(pos_enc)], dim=-1) 74 | return pos_enc 75 | 76 | 77 | if __name__ == '__main__': 78 | """ 79 | example usage of LearnableFourierFeatures 80 | 81 | let 82 | positional dimension: 2 (2d spatial positions) 83 | fourier feature dimension: 128 84 | hidden dimension: 256 85 | positional encoding dimension: 64 86 | number of positional groups: 1 87 | 88 | batch size: 4 89 | sequence length: 1024 (== 32x32 in 2d spatial resolution) 90 | number of positional groups: 1 91 | positional dimension: 2 92 | """ 93 | lff = LearnableFourierFeatures(pos_dim=2, f_dim=128, h_dim=256, d_dim=64, g_dim=1).cuda() 94 | pos = torch.randn([4, 1024, 1, 2]).cuda() 95 | pe = lff(pos) 96 | print(pe) 97 | print(pe.shape) 98 | 99 | 100 | 101 | """ 102 | example usage of FourierFeatures 103 | 104 | let 105 | positional dimension: 2 (2d spatial positions) 106 | fourier feature dimension: 256 107 | 108 | batch size: 4 109 | sequence length: 32x32 110 | positional dimension: 2 111 | """ 112 | ff = FourierFeatures(pos_dim=2, f_dim=256).cuda() 113 | pos = torch.randn([4, 32, 32, 2]).cuda() 114 | pe = ff(pos) 115 | print(pe) 116 | print(pe.shape) 117 | 118 | 119 | 120 | """ 121 | example usage of PositionalEncoding 122 | 123 | let 124 | positional dimension: 2 (2d spatial positions) 125 | encoding dimension: 256 126 | 127 | batch size: 4 128 | sequence length: 1024 129 | positional dimension: 2 130 | """ 131 | PE = PositionalEncoding(pos_dim=2, enc_dim=256).cuda() 132 | pos = torch.randn([4, 1024, 2]).cuda() 133 | pe = PE(pos) 134 | print(pe) 135 | print(pe.shape) 136 | 137 | 138 | 139 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Mobility-LLM: Learning Visiting Intentions and Travel Preferences from Human Mobility Data with Large Language Models 2 | ### Datasets 3 | To demonstrate the superiority of our proposed model, our experiements are carried out on four real-world datasets derived from Gowalla (GOW), WeePlace (WEE), Brightkite (BKC) and FourSquare (TKY) check-in data. 4 | 5 | In order to facilitate the training of our model, Our model undergoes a filtering process that selects high-quality check-in sequences for training. To ensure data consistency, we set a maximum historical time limit of 120 days and filter out users with fewer than 10 records and places visited fewer than 10 times. 6 | 7 | The below table shows the statistics of three datasets. 8 | | | Gowalla | WeePlace | Brightkite | FourSquare | 9 | | ---------- |:---------:|:----------:|:------------:|:------------:| 10 | | # users | 5,853 | 1,028 | 431 | 703 | 11 | | # POIs | 52,032 | 9,295 | 3,554 | 11,117 | 12 | | # Samples | 413,563 | 104,762 | 44,716 | 60,734 | 13 | 14 | - Download processed datasets from following sources: 15 | - One Drive 16 | - https://1drv.ms/f/c/401a59cded375360/EuI9Lfh3qhVPrMdO22wDbc0BjJ_-5M3YIPEOLyrGoEMD3A?e=jBGTCC 17 | - BaiduNetDisk code: `x24z` 18 | - https://pan.baidu.com/s/1nWWMzS1yQaGdaiVd-njJxQ 19 | - Copy all files and directories to `MobilityLLM/data/new_datasets` 20 | 21 | ### Large Lanugage Models 22 | We compare eight representative backbones with varying capacities, including TinyLlama, TinyLlama-Chat, LiteLlama, phi-2, pythia-70M, pythia-1B, pythia-2.8B and GPT-2. 23 | - Download models from following sources: 24 | - https://huggingface.co/models 25 | - Copy all files and directories to `MobilityLLM/params/*/` 26 | 27 | ### Requirements 28 | - python >= 3.6 29 | - PyTorch >= 1.8 30 | 31 | ### Usage : 32 | Enter directory `MobilityLLM`. 33 | Downstream tasks: 34 | Location Prediction (LP), Trajectory User Link (TUL), or Time Prediction (TP). 35 | model class: 36 | TinyLlama-1_1B (TinyLlama), TinyLlama-Chat (TinyLlama-Chat), phi-2 (phi-2), pythia-70M (pythia-70M), pythia-2_8B (pythia-2.8B), pythia-1B (pythia-1B), LiteLlama (LiteLlama), gpt2 (GPT-2). 37 | - Train model on WEE of LP task: 38 | `python train_MobilityLLM.py --config config/MobilityLLM_wee_POI.conf --dataroot data/ --model_class` 39 |
40 | - Train model on TKY of LP task: 41 | `python train_MobilityLLM.py --config config/MobilityLLM_tky_POI.conf --dataroot data/ --model_class` 42 |
43 | - Train model on GOW of LP task: 44 | `python train_MobilityLLM.py --config config/MobilityLLM_gow_POI.conf --dataroot data/ --model_class` 45 |
46 | - Train model on BKC of LP task: 47 | `python train_MobilityLLM.py --config config/MobilityLLM_bkc_POI.conf --dataroot data/ --model_class` 48 |
49 | - Train model on WEE of TUL task: 50 | `python train_MobilityLLM.py --config config/MobilityLLM_wee_TUL.conf --dataroot data/ --model_class` 51 |
52 | - Train model on TKY of TUL task: 53 | `python train_MobilityLLM.py --config config/MobilityLLM_tky_TUL.conf --dataroot data/ --model_class` 54 |
55 | - Train model on GOW of TUL task: 56 | `python train_MobilityLLM.py --config config/MobilityLLM_gow_TUL.conf --dataroot data/ --model_class` 57 |
58 | - Train model on BKC of TUL task: 59 | `python train_MobilityLLM.py --config config/MobilityLLM_bkc_TUL.conf --dataroot data/ --model_class` 60 |
61 | - Train model on WEE of TP task: 62 | `python train_MobilityLLM.py --config config/MobilityLLM_wee_TPP.conf --dataroot data/ --model_class` 63 |
64 | - Train model on TKY of TP task: 65 | `python train_MobilityLLM.py --config config/MobilityLLM_tky_TPP.conf --dataroot data/ --model_class` 66 |
67 | - Train model on GOW of TP task: 68 | `python train_MobilityLLM.py --config config/MobilityLLM_gow_TPP.conf --dataroot data/ --model_class` 69 |
70 | - Train model on BKC of TP task: 71 | `python train_MobilityLLM.py --config config/MobilityLLM_bkc_TPP.conf --dataroot data/ --model_class` 72 | ### Configuration 73 | The configuration file `MobilityLLM_*.conf` contains three parts: Data, Training and Model: 74 | 75 | #### Data 76 | - dataset_name: The name of the datasets, represents www_GOW, www_BKC, www_TKY or www_WEE. 77 | - max_his_period_days: The max history time. 78 | - max_merge_seconds_limit: To judge whether two identical locations are the same event. 79 | - max_delta_mins: To limit the prediction range. 80 | - least_disuser_count: To filter locations, keep locations which have at least * users. 81 | - least_checkins_count: To filter users, keep users who have at least * checkins. 82 | - split_save: 1 or 0, representing whether datasets are split saved. 83 | 84 | #### Training 85 | - mode: train for default, 86 | - ctx: cuda index, 0 for default 87 | - regularization: float, regularization factor. 88 | - learning_rate: float 89 | - max_epochs: int 90 | - display_step: int 91 | - patience: int, for early stopping. 92 | - train_batch: int 93 | - val_batch: int 94 | - test_batch: int 95 | - batch_size: int 96 | - save_results: bool 97 | 98 | #### Model 99 | - adv: 0 or 1, enable adversarial or not. 100 | - downstream: POI, TUL or TPP, representing Location Prediction, Trajectory User Link, and Time Prediction respestively. 101 | 102 | The remaining parameters are the best parameters of the model. 103 | -------------------------------------------------------------------------------- /model/bert.py: -------------------------------------------------------------------------------- 1 | from transformers import BertForMaskedLM 2 | 3 | from .base import * 4 | from .layers import PositionalEncoding, ContinuousEncoding 5 | 6 | 7 | def get_batch_mask(B, L, valid_len): 8 | mask = repeat(torch.arange(end=L, device=valid_len.device), 9 | 'L -> B L', B=B) < repeat(valid_len, 'B -> B L', L=L) # (B, L) 10 | return mask 11 | 12 | 13 | class BertLM(Encoder): 14 | def __init__(self, d_model, output_size, bert_path, 15 | dis_feats=[], num_embeds=[], con_feats=[], 16 | mlp_grad=True, pooling='mean'): 17 | super().__init__('BertLM-d{}-o{}-p{}'.format(d_model, output_size, pooling)) 18 | 19 | self.output_size = output_size 20 | self.d_model = d_model 21 | self.pooling = pooling 22 | self.dis_feats = dis_feats 23 | self.con_feats = con_feats 24 | 25 | self.pos_encode = PositionalEncoding(d_model, max_len=2001) 26 | if len(dis_feats): 27 | assert len(dis_feats) == len(num_embeds), \ 28 | 'length of num_embeds list should be equal to the number of discrete features.' 29 | self.dis_embeds = nn.ModuleList([nn.Embedding(num_embed, d_model) for num_embed in num_embeds]) 30 | else: 31 | self.dis_embeds = None 32 | 33 | if len(con_feats): 34 | self.con_embeds = nn.ModuleList([ContinuousEncoding(d_model) for _ in con_feats]) 35 | else: 36 | self.con_embeds = None 37 | 38 | self.bert = BertForMaskedLM.from_pretrained(bert_path) 39 | emb_size = self.bert.config.emb_size 40 | self.emb_size = emb_size 41 | self.hidden_size = self.bert.config.hidden_size 42 | 43 | # Froze or Fine-tune the parameters of BERT. 44 | for i, (name, param) in enumerate(self.bert.named_parameters()): 45 | if 'ln' in name or 'wpe' in name: # or 'mlp' in name: 46 | param.requires_grad = True 47 | elif 'mlp' in name and mlp_grad: 48 | param.requires_grad = True 49 | else: 50 | param.requires_grad = False 51 | 52 | self.seq_projector = nn.Sequential(nn.Linear(d_model, emb_size, bias=False), 53 | nn.LayerNorm(emb_size), 54 | nn.ReLU(inplace=True), 55 | nn.Linear(emb_size, emb_size)) 56 | self.poi_projector = nn.Sequential(nn.Linear(emb_size, emb_size, bias=False), 57 | nn.LayerNorm(emb_size), 58 | nn.ReLU(inplace=True), 59 | nn.Linear(emb_size, emb_size)) 60 | 61 | self.out_linear = nn.Sequential(nn.Linear(self.bert.config.hidden_size, output_size, bias=False), 62 | nn.LayerNorm(output_size), 63 | nn.ReLU(inplace=True), 64 | nn.Linear(output_size, output_size)) 65 | 66 | self.trip_mask_token = nn.Parameter(torch.zeros(emb_size).float(), requires_grad=True) 67 | self.poi_mask_token = nn.Parameter(torch.zeros(emb_size).float(), requires_grad=True) 68 | 69 | def forward(self, trip, valid_len, o_poi_embeddings, d_poi_embeddings, 70 | trip_mask=None, poi_mask=None, **kwargs): 71 | B, L, E_in = trip.shape 72 | 73 | trip_batch_mask = get_batch_mask(B, L+2, valid_len+1) 74 | src_batch_mask = get_batch_mask(B, L+2, valid_len+2) 75 | 76 | h = torch.zeros(B, trip.size(1), self.d_model).to(trip.device) 77 | if self.dis_embeds is not None: 78 | for dis_embed, dis_feat in zip(self.dis_embeds, self.dis_feats): 79 | h += dis_embed(trip[..., dis_feat].long()) # (B, L, E) 80 | if self.con_embeds is not None: 81 | for con_embed, con_feat in zip(self.con_embeds, self.con_feats): 82 | h += con_embed(trip[..., con_feat].float()) 83 | h += self.pos_encode(h) 84 | h = self.seq_projector(h) 85 | 86 | input_seq = torch.zeros(B, L+2, self.emb_size).to(trip.device) 87 | input_seq[:, 0] += self.poi_projector(o_poi_embeddings) 88 | input_seq[:, 1:-1] += h 89 | input_seq[src_batch_mask.long() - trip_batch_mask.long() == 1] += self.poi_projector(d_poi_embeddings) 90 | 91 | if trip_mask is not None: 92 | input_seq[:, 1:-1][trip_mask] = self.trip_mask_token 93 | if poi_mask is not None: 94 | o_placeholder = torch.zeros(B, L+2).to(trip.device) 95 | o_placeholder[:, 0] = 1 96 | o_placeholder = o_placeholder.bool() 97 | o_poi_mask = repeat(poi_mask[:, 0], 'b -> b l', l=L+2) 98 | d_poi_mask = repeat(poi_mask[:, 1], 'b -> b l', l=L+2) 99 | input_seq[o_placeholder & o_poi_mask] = self.poi_mask_token 100 | input_seq[(src_batch_mask.long() - trip_batch_mask.long() == 1) & d_poi_mask] = self.poi_mask_token 101 | 102 | memory = self.bert(inputs_embeds=input_seq, attention_mask=src_batch_mask, output_hidden_states=True).hidden_states[-1] 103 | memory = torch.nan_to_num(memory) 104 | 105 | memory = self.out_linear(memory) # (B, E_out) or (B, L, E_out) 106 | 107 | if bool(kwargs.get('pretrain', False)): 108 | return memory 109 | 110 | if self.pooling == 'mean': 111 | mask_expanded = repeat(src_batch_mask.logical_not(), 'B L -> B L E', E=memory.size(2)) # (B, L, E) 112 | memory = memory.masked_fill(mask_expanded, 0) # (B, L, E) 113 | memory = torch.sum(memory, 1) / valid_len.unsqueeze(-1) 114 | elif self.pooling == 'cls': 115 | memory = memory[:, 0] 116 | 117 | return memory 118 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import math 4 | import nni 5 | #import seaborn as sns 6 | #import matplotlib.pyplot as plt 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | class DotDict(dict): 11 | __getattr__ = dict.__getitem__ 12 | __setattr__ = dict.__setitem__ 13 | __delattr__ = dict.__delitem__ 14 | 15 | 16 | def evaluate_location(y, top_k_pred, K=20): 17 | ''' 18 | get hit ratio, mrr 19 | :param y: (batch,) 20 | :param top_k_pred: (batch, num_class) 21 | :param K 22 | :return: 23 | ''' 24 | total_num = top_k_pred.shape[0] 25 | hit_ratio = np.zeros(K) 26 | mrr = [] 27 | for i in range(total_num): 28 | rank = np.where(top_k_pred[i] == y[i])[0] + 1 29 | mrr.append(rank) 30 | for j in range(1, K+1): 31 | if y[i] in set(top_k_pred[i, :j]): 32 | hit_ratio[j-1] = hit_ratio[j-1] + 1 33 | hit_ratio = hit_ratio/total_num 34 | # print('mrr:',mrr) 35 | mrr = (1/np.array(mrr)).mean() 36 | return hit_ratio, mrr 37 | 38 | 39 | def get_total_prob_c(loader, model, gts, gss, time_info, week_info, feature_category, feature_lat, feature_lng, feature_lat_ori, feature_lng_ori, save_filename=None, params_path=None, distance=None): 40 | ''' 41 | calculates the loss, mae and mape for the entire data loader 42 | :param loader: 43 | :param save: 44 | :return: 45 | ''' 46 | all_topic = [] 47 | all_label = [] 48 | for input in loader: 49 | topic, gamma_c = model.get_gammac(input, gss, feature_category, feature_lat, feature_lng, feature_lat_ori, feature_lng_ori, gts, time_info, week_info, distance=distance) # (batch_size,), (batch_size,) 50 | all_topic.append(topic.detach().cpu().numpy()) 51 | all_label.append(gamma_c.detach().cpu().numpy()) 52 | all_topic = np.concatenate(all_topic) 53 | all_label = np.concatenate(all_label) 54 | all_label_index = np.argmax(all_label, axis=1) 55 | print('all_label:', all_label[:10]) 56 | print('all_label_index:', all_label_index[:10]) 57 | if save_filename is not None: 58 | filename = os.path.join(params_path, save_filename + '_gammac.npz') 59 | np.savez(filename, all_topic=all_topic, all_label=all_label, all_label_index=all_label_index) 60 | return all_topic, all_label, all_label_index 61 | 62 | 63 | def density_visualization(density, ground_truth, batch_cnt): 64 | ''' 65 | plot the probability density function 66 | :param density: 67 | :param ground_truth: 68 | :return: 69 | ''' 70 | n_samples = 1000 71 | hours = 48 72 | x = np.linspace(0, hours, n_samples) 73 | t = ground_truth 74 | cnt = 0 75 | length = len(density) 76 | ''' 77 | for i in range(length): 78 | y = density[i] 79 | plt.plot(x, y, "r-", label="STDGN") 80 | plt.legend() 81 | plt.xlabel(r"$\tau$", fontdict={'family': 'Times New Roman', 'size':16}) 82 | plt.ylabel(r"p($\tau$)", fontdict={'family': 'Times New Roman', 'size':16}) 83 | plt.yticks(fontproperties = 'Times New Roman', size = 14) 84 | plt.xticks(fontproperties = 'Times New Roman', size = 14) 85 | plt.grid() 86 | # plt.title("the probability density function in JKT dataset", fontdict={'family': 'Times New Roman', 'size':16}) 87 | true_value = round(t[i], 2) 88 | plt.axvline(x=true_value, ls=":", c="black") 89 | plt.text(x=true_value + 1, y=1/2*np.max(y), s=r"$\tau_{n+1}$=" + str(true_value), size=16, alpha=0.8) 90 | plt.legend(prop={'family' : 'Times New Roman', 'size' : 16}) 91 | plt.show() 92 | pic_name = str(batch_cnt) + '_' + str(cnt) 93 | cnt += 1 94 | plt.savefig(f'./data/density/jkt_{pic_name}.png') 95 | plt.savefig(f'./data/density/jkt_{pic_name}.eps',format='eps', dpi=10000) 96 | plt.close() 97 | ''' 98 | 99 | 100 | def softmax(x): 101 | ''' 102 | self-define softmax operation 103 | :param x: 104 | :return: 105 | ''' 106 | # print("before: ", x) 107 | x -= np.max(x, axis=1, keepdims=True) # for stationary computation 108 | x = np.exp(x) / np.sum(np.exp(x), axis=1, keepdims=True) # formula 109 | # print("after: ", x) 110 | return x 111 | 112 | 113 | def rad(d): 114 | ''' 115 | rad the latitude and longitude 116 | :param d: latitude or longitude 117 | :return rad: 118 | ''' 119 | return d * math.pi / 180.0 120 | 121 | 122 | def getDistance(lat1, lng1, lat2, lng2): 123 | ''' 124 | get the distance between two location using their latitude and longitude 125 | :param lat1: 126 | :param lng1: 127 | :param lat2: 128 | :param lng2: 129 | :return s: 130 | ''' 131 | EARTH_REDIUS = 6378.137 132 | radLat1 = rad(lat1) 133 | radLat2 = rad(lat2) 134 | a = radLat1 - radLat2 135 | b = rad(lng1) - rad(lng2) 136 | s = 2 * math.asin(math.sqrt(math.pow(math.sin(a/2), 2) + math.cos(radLat1) * math.cos(radLat2) * math.pow(math.sin(b/2), 2))) 137 | s = s * EARTH_REDIUS 138 | return s 139 | 140 | 141 | def get_s_baselines_total_loss_s_for_MobilityLLM_RNN(loader, model, save_filename=None, params_path=None): 142 | all_loss_s = [] 143 | all_ground_truth_location = [] 144 | all_predicted_topK = [] 145 | 146 | for input in loader: 147 | s_loss_score, top_k_pred = model(input) 148 | all_loss_s.append(s_loss_score.detach().cpu().numpy()) 149 | all_ground_truth_location.append(input.Y_location.cpu().numpy()) 150 | all_predicted_topK.append(top_k_pred.cpu().numpy()) 151 | 152 | all_loss_s = np.array(all_loss_s) 153 | all_loss_s = np.mean(all_loss_s) 154 | 155 | all_ground_truth_location = np.concatenate(all_ground_truth_location) 156 | all_predicted_topK = np.concatenate(all_predicted_topK) 157 | 158 | hit_ratio, mrr = evaluate_location(all_ground_truth_location, all_predicted_topK) 159 | 160 | if save_filename is not None: 161 | filename = os.path.join(params_path, save_filename + '_results.npz') 162 | np.savez(filename, all_ground_truth_location=all_ground_truth_location, all_predicted_topK=all_predicted_topK) 163 | 164 | return all_loss_s, hit_ratio, mrr 165 | 166 | 167 | # for downstream validation bencnmark 168 | def get_s_baselines_total_loss_s_for_MobilityLLM_DEMO_DOWN(loader, model, downstream='POI', save_filename=None, params_path=None): 169 | all_loss_s = [] 170 | all_ground_truth_users = [] 171 | all_predicted_topK = [] 172 | for input in loader: 173 | s_loss_score, top_k_pred = model(input, mode='downstream', downstream=downstream) 174 | all_loss_s.append(s_loss_score.detach().cpu().numpy()) 175 | if downstream == 'POI': 176 | all_ground_truth_users.append(input.Y_location.cpu().numpy()) 177 | # all_ground_truth_users.append(torch.index_select(torch.tensor(input.Y_location), dim=0, index=indice).cpu().numpy()) 178 | elif downstream == 'TUL': 179 | all_ground_truth_users.append(input.X_users.cpu().numpy()) 180 | else: 181 | raise ValueError('downstream is not in [POI, TUL]') 182 | 183 | all_predicted_topK.append(top_k_pred.cpu().numpy()) 184 | 185 | all_loss_s = np.array(all_loss_s) 186 | all_loss_s = np.mean(all_loss_s) 187 | 188 | all_ground_truth_users = np.concatenate(all_ground_truth_users) 189 | all_predicted_topK = np.concatenate(all_predicted_topK) 190 | 191 | hit_ratio, mrr = evaluate_location(all_ground_truth_users, all_predicted_topK) 192 | 193 | if save_filename is not None: 194 | filename = os.path.join(params_path, save_filename + '_results.npz') 195 | np.savez(filename, all_ground_truth_users=all_ground_truth_users, all_predicted_topK=all_predicted_topK) 196 | 197 | return all_loss_s, hit_ratio, mrr 198 | 199 | 200 | # for downstream validation bencnmark 201 | def get_s_baselines_total_loss_s_for_MobilityLLM_DOWN(loader, model, downstream='POI', save_filename=None, params_path=None): 202 | all_loss_s = [] 203 | all_ground_truth_users = [] 204 | all_predicted_topK = [] 205 | for input in loader: 206 | if input.X_all_loc.shape[1] >= 700: 207 | continue 208 | s_loss_score, top_k_pred, _ = model(input, mode='downstream', downstream=downstream) 209 | all_loss_s.append(s_loss_score.detach().cpu().numpy()) 210 | if downstream == 'POI': 211 | all_ground_truth_users.append(input.Y_location.cpu().numpy()) 212 | elif downstream == 'TUL': 213 | all_ground_truth_users.append(input.X_users.cpu().numpy()) 214 | else: 215 | raise ValueError('downstream is not in [POI, TUL]') 216 | 217 | all_predicted_topK.append(top_k_pred.cpu().numpy()) 218 | 219 | all_loss_s = np.array(all_loss_s) 220 | all_loss_s = np.mean(all_loss_s) 221 | 222 | all_ground_truth_users = np.concatenate(all_ground_truth_users) 223 | all_predicted_topK = np.concatenate(all_predicted_topK) 224 | 225 | hit_ratio, mrr = evaluate_location(all_ground_truth_users, all_predicted_topK) 226 | 227 | if save_filename is not None: 228 | filename = os.path.join(params_path, save_filename + '_results.npz') 229 | np.savez(filename, all_ground_truth_users=all_ground_truth_users, all_predicted_topK=all_predicted_topK) 230 | 231 | return all_loss_s, hit_ratio, mrr 232 | 233 | 234 | def get_t_for_IFLTPP(loader, model, save_filename=None, params_path=None, experiment_base_dir=None, use_nni=False): 235 | ''' 236 | calculates the loss, mae and mape for the entire data loader 237 | :param loader: 238 | :param save: 239 | :return: 240 | ''' 241 | ground_truth_Y_tau = [] 242 | predicted_Y_tau = [] 243 | all_tnll_t = [] 244 | all_nll = [] 245 | 246 | for input in loader: 247 | tnll, mean, _ = model(input, cont_conf=None, mode='downstream', downstream='TPP') # (batch_size,), (batch_size,) 248 | ground_truth_Y_tau.append(model.truth_Y_tau.detach().cpu().numpy()) 249 | predicted_Y_tau.append(mean.detach().cpu().numpy()) 250 | all_tnll_t.append(tnll.detach().cpu().numpy()) 251 | 252 | all_tnll_t = np.array(all_tnll_t) 253 | 254 | ground_truth_Y_tau = np.concatenate(ground_truth_Y_tau).flatten() 255 | predicted_Y_tau = np.concatenate(predicted_Y_tau).flatten() 256 | 257 | mae = np.mean(abs(ground_truth_Y_tau - predicted_Y_tau)) 258 | rmse = np.sqrt(((ground_truth_Y_tau - predicted_Y_tau) ** 2).mean()) 259 | mape = np.mean(abs(ground_truth_Y_tau - predicted_Y_tau) / np.maximum(np.mean(ground_truth_Y_tau), ground_truth_Y_tau)) 260 | tnll_t = np.mean(all_tnll_t) 261 | 262 | if (save_filename is not None) and (not use_nni): 263 | filename = os.path.join(params_path, save_filename+'_results.npz') 264 | np.savez(filename, ground_truth_Y_tau=ground_truth_Y_tau, predicted_Y_tau=predicted_Y_tau) 265 | 266 | return mae, mape, rmse, tnll_t 267 | 268 | 269 | def get_semantic_information(cnt2category, data_root): 270 | import pickle 271 | vecpath = data_root + "glove.twitter.27B.50d.pkl" 272 | pkl_data = open(vecpath, "rb") 273 | word_vec = pickle.load(pkl_data) 274 | for word in word_vec.keys(): 275 | word_vec[word] = word_vec[word] 276 | pkl_data.close() 277 | 278 | word_id = 0 279 | dataset_word_vec = [] 280 | dataset_word_index = {} 281 | categories = cnt2category.values() 282 | for category in categories: 283 | words = category.split(" ") 284 | # print(words) 285 | for word in words: 286 | word = word.lower() 287 | if (word in word_vec) and (word not in dataset_word_index): 288 | dataset_word_index[word] = word_id 289 | word_id += 1 290 | dataset_word_vec.append(word_vec[word]) 291 | print("word_index: ", dataset_word_index) 292 | return dataset_word_vec, dataset_word_index, word_id 293 | 294 | 295 | -------------------------------------------------------------------------------- /preprocess/load_data.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data_utils 2 | import os 3 | import itertools 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def tid_list_48(tm): 9 | if tm.weekday() in [0, 1, 2, 3, 4]: 10 | tid = int(tm.hour) 11 | else: 12 | tid = int(tm.hour) + 24 13 | return tid 14 | 15 | 16 | def load_data_from_dataset(set_name, loader, device, user_cnt, venue_cnt, save_split, name, data_root): 17 | X_target_lengths = loader[f'{set_name}X_target_lengths'] 18 | X_arrival_times = loader[f'{set_name}X_arrival_times'] 19 | X_users = loader[f'{set_name}X_users'] 20 | X_locations = loader[f'{set_name}X_locations'] 21 | X_locations_category = loader[f'{set_name}X_locations_category'] 22 | X_lats = loader[f'{set_name}X_lats'] 23 | X_lons = loader[f'{set_name}X_lons'] 24 | X_geohash = loader[f'{set_name}X_geohash'] 25 | Y_location = loader[f'{set_name}Y_locations'] 26 | Y_location_category = loader[f'{set_name}Y_locations_category'] 27 | Y_lat = loader[f'{set_name}Y_lats'] 28 | Y_lon = loader[f'{set_name}Y_lons'] 29 | Y_geohash = loader[f'{set_name}Y_geohash'] 30 | 31 | X_tau = pad_1d(loader[f'{set_name}X_delta_times']) 32 | Y_tau = pad_1d(loader[f'{set_name}Y_delta_times']) 33 | 34 | X_all_loc = [] 35 | X_all_loc_category = [] 36 | X_all_lat = [] 37 | X_all_lon = [] 38 | X_all_geohash = [] 39 | X_all_tim = [] 40 | X_lengths = [] 41 | for i in range(len(X_arrival_times)): 42 | tim = X_arrival_times[i] 43 | loc = X_locations[i] 44 | loc_category = X_locations_category[i] 45 | lat = X_lats[i] 46 | lon = X_lons[i] 47 | geohash = X_geohash[i] 48 | 49 | len_ = len(tim) 50 | for j in range(len_): 51 | tim[j] = tid_list_48(tim[j]) 52 | 53 | X_all_loc.append(loc) 54 | X_all_loc_category.append(loc_category) 55 | X_all_lat.append(lat) 56 | X_all_lon.append(lon) 57 | X_all_geohash.append(geohash) 58 | X_all_tim.append(tim) 59 | X_lengths.append(len_) 60 | 61 | print("X_all_loc: ", len(X_all_loc), X_all_loc[0]) 62 | print("X_all_loc_category: ", len(X_all_loc), X_all_loc_category[0]) 63 | print("X_all_lat: ", len(X_all_loc), X_all_lat[0]) 64 | print("X_all_lon: ", len(X_all_loc), X_all_lon[0]) 65 | print("X_all_geohash: ", len(X_all_loc), X_all_geohash[0]) 66 | print("X_all_tim: ", len(X_all_tim), X_all_tim[0]) 67 | print("X_target_lengths: ", len(X_target_lengths), X_target_lengths[0]) 68 | print("X_lengths: ", len(X_lengths), X_lengths[0]) 69 | print("X_users:", len(X_users), X_users) 70 | print("Y_location:", len(Y_location), Y_location[0]) 71 | print("Y_location_category:", len(Y_location), Y_location_category[0]) 72 | print("Y_lat:", len(Y_location), Y_lat[0]) 73 | print("Y_lon:", len(Y_location), Y_lon[0]) 74 | print("Y_geohash:", len(Y_location), Y_geohash[0]) 75 | 76 | dataset = SessionBasedSequenceDataset(device, user_cnt, venue_cnt, X_users, X_all_loc, X_all_loc_category, 77 | X_all_lat, X_all_lon, X_all_geohash, 78 | X_all_tim, Y_location, Y_location_category, Y_lat, Y_lon, Y_geohash, 79 | X_target_lengths, X_lengths, None, X_tau, Y_tau) 80 | print(f'samples cnt of data_{set_name}:', dataset.real_length()) 81 | 82 | return dataset 83 | 84 | 85 | def load_dataset_for_MobilityLLM(name, data_root, save_split=False, device=None): 86 | ''' 87 | 1. load data and construct train/val/test dataset 88 | 2. construct temporal graphs gts and spatial graphs gss 89 | 3. construct SessionBasedSequenceDataset 90 | :param name: file name 91 | :param log_mode: whether log(X_taus), default True 92 | :return: 93 | ''' 94 | if not name.endswith('.npz'): 95 | name += '.npz' 96 | 97 | if save_split: 98 | train_loader = dict(np.load(os.path.join(data_root + 'new_datasets', 'train_' + name), allow_pickle=True)) 99 | val_loader = dict(np.load(os.path.join(data_root + 'new_datasets', 'val_' + name), allow_pickle=True)) 100 | loader = dict(np.load(os.path.join(data_root + 'new_datasets', 'test_' + name), allow_pickle=True)) 101 | category_vector = np.load(os.path.join(data_root + 'new_datasets', 'category_' + name)) 102 | 103 | else: 104 | loader = dict(np.load(os.path.join(data_root, 'test_' + name), allow_pickle=True)) 105 | train_loader = loader 106 | val_loader = loader 107 | category_vector = None 108 | 109 | user_cnt = loader['user_cnt'] 110 | venue_cnt = loader['venue_cnt'] 111 | print('user_cnt:', user_cnt) 112 | print('venue_cnt:', venue_cnt) 113 | 114 | feature_category = loader['feature_category'] 115 | feature_lat = loader['feature_lat'] # index 116 | feature_lng = loader['feature_lng'] # index 117 | 118 | # put spatial point features into tensor 119 | feature_category = torch.LongTensor(feature_category) 120 | feature_lat = torch.LongTensor(feature_lat) 121 | feature_lng = torch.LongTensor(feature_lng) 122 | 123 | latN, lngN = loader['latN'], loader['lngN'] 124 | category_cnt = loader['category_cnt'] 125 | 126 | # ----- load train / val / test to get dataset ----- 127 | data_train = load_data_from_dataset('train', train_loader, device, user_cnt, venue_cnt, save_split, name, data_root) 128 | data_val = load_data_from_dataset('val', val_loader, device, user_cnt, venue_cnt, save_split, name, data_root) 129 | data_test = load_data_from_dataset('test', loader, device, user_cnt, venue_cnt, save_split, name, data_root) 130 | 131 | return data_train, data_val, data_test, feature_category, feature_lat, feature_lng, latN, lngN, category_cnt, \ 132 | category_vector['categories'] 133 | 134 | 135 | class SessionBasedSequenceDataset(data_utils.Dataset): 136 | """Dataset class containing variable length sequences. 137 | """ 138 | 139 | def __init__(self, device, user_cnt, venue_cnt, X_users, X_all_loc, X_all_loc_category, X_all_lat, X_all_lon, 140 | X_all_geohash, 141 | X_all_tim, Y_location, Y_location_category, Y_lat, Y_lon, Y_geohash, target_lengths, X_lengths, 142 | X_all_text, 143 | X_tau, Y_tau): 144 | # torch.set_default_tensor_type(torch.cuda.FloatTensor) 145 | self.device = device 146 | self.user_cnt = user_cnt 147 | self.venue_cnt = venue_cnt 148 | self.X_users = X_users 149 | self.X_all_loc = X_all_loc 150 | self.X_all_loc_category = X_all_loc_category 151 | self.X_all_lat = X_all_lat 152 | self.X_all_lon = X_all_lon 153 | self.X_all_geohash = X_all_geohash 154 | self.X_all_tim = X_all_tim 155 | self.target_lengths = target_lengths 156 | self.X_lengths = X_lengths 157 | self.Y_location = Y_location 158 | self.Y_location_category = Y_location_category 159 | self.Y_lat = Y_lat 160 | self.Y_lon = Y_lon 161 | self.Y_geohash = Y_geohash 162 | self.X_all_text = X_all_text 163 | 164 | self.X_tau = [torch.Tensor(_) for _ in X_tau] 165 | # self.Y_tau = torch.Tensor(Y_tau.astype('float64')) / 60 # mins->hour 166 | self.Y_tau = torch.Tensor(Y_tau.astype('float64')) # mins 167 | 168 | self.validate_data() 169 | 170 | @property 171 | def num_series(self): 172 | return len(self.Y_location) 173 | 174 | def real_length(self): 175 | res = 0 176 | n = len(self.Y_location) 177 | for i in range(n): 178 | res += len(self.Y_location[i]) 179 | return res 180 | 181 | def validate_data(self): 182 | if len(self.X_all_loc) != len(self.Y_location) or len(self.X_all_tim) != len(self.Y_location): 183 | raise ValueError("Length of X_all_loc, X_all_tim, Y_location should match") 184 | 185 | def get_tau_log_mean_std_Y(self): 186 | """Get mean and std of Y_taus.""" 187 | y = torch.flatten(self.Y_tau) 188 | logy = y[y != 0].log() 189 | return logy.mean(), logy.std() 190 | 191 | def get_mean_std_Y_tau(self): 192 | """Get mean and std of Y_tau.""" 193 | y = torch.flatten(self.Y_tau) 194 | y = y[y != 0] 195 | return y.mean(), y.std() 196 | 197 | def normalize_Y_tau(self, mean=None, std=None): 198 | self.Y_tau = (self.Y_tau - mean)/std 199 | return self 200 | 201 | def __getitem__(self, key): 202 | ''' 203 | the outputs are feed into collate() 204 | :param key: 205 | :return: 206 | ''' 207 | return self.X_all_loc[key], self.X_all_tim[key], None, self.Y_location[key], self.target_lengths[key], \ 208 | self.X_lengths[key], self.X_users[key], self.X_all_loc_category[key], self.X_all_lat[key], \ 209 | self.X_all_lon[key], self.Y_location_category[key], self.Y_lat[key], self.Y_lon[key], \ 210 | self.X_all_geohash[key], self.Y_geohash[key], self.X_tau[key], self.Y_tau[key], self.device 211 | 212 | def __len__(self): 213 | return self.num_series 214 | 215 | def __repr__(self): 216 | pass 217 | 218 | 219 | def pad_session_data_one(data): 220 | fillvalue = 0 221 | # zip_longest 222 | data = list(zip(*itertools.zip_longest(*data, fillvalue=fillvalue))) 223 | res = [] 224 | res.extend([list(data[i]) for i in range(len(data))]) 225 | return res 226 | 227 | 228 | def pad_session_data_geohash(data): 229 | fillvalue = [0 for i in range(12)] 230 | # zip_longest 231 | data = list(zip(*itertools.zip_longest(*data, fillvalue=fillvalue))) 232 | res = [] 233 | res.extend([list(data[i]) for i in range(len(data))]) 234 | return res 235 | 236 | 237 | def collate_session_based(batch): 238 | ''' 239 | get the output of dataset.__getitem__, and perform padding 240 | :param batch: 241 | :return: 242 | ''' 243 | device = batch[0][-1] 244 | batch = sorted(batch, key=lambda x: len(x[0]), reverse=True) 245 | 246 | X_all_loc = [item[0] for item in batch] 247 | 248 | X_all_loc_category = [item[7] for item in batch] 249 | X_all_lat = [item[8] for item in batch] 250 | X_all_lon = [item[9] for item in batch] 251 | X_all_geohash = [item[13] for item in batch] 252 | 253 | X_all_tim = [item[1] for item in batch] 254 | X_all_text = [item[2] for item in batch] 255 | Y_location = [lid for item in batch for lid in item[3]] 256 | 257 | Y_location_category = [lid for item in batch for lid in item[10]] 258 | Y_lat = [lid for item in batch for lid in item[11]] 259 | Y_lon = [lid for item in batch for lid in item[12]] 260 | Y_geohash = [lid for item in batch for lid in item[14]] 261 | 262 | target_lengths = [item[4] for item in batch] 263 | X_lengths = [item[5] for item in batch] 264 | X_users = [item[6] for item in batch] 265 | 266 | X_tau = torch.stack([item[15] for item in batch]) 267 | torch.where(X_tau<5, torch.mean(X_tau), X_tau) 268 | Y_tau = torch.stack([item[16] for item in batch]) 269 | torch.where(Y_tau<5, torch.mean(Y_tau), Y_tau) 270 | 271 | padded_X_all_loc = pad_session_data_one(X_all_loc) 272 | padded_X_all_loc_category = pad_session_data_one(X_all_loc_category) 273 | padded_X_all_lat = pad_session_data_one(X_all_lat) 274 | padded_X_all_lon = pad_session_data_one(X_all_lon) 275 | padded_X_all_geohash = pad_session_data_geohash(X_all_geohash) 276 | 277 | padded_X_all_tim = pad_session_data_one(X_all_tim) 278 | padded_X_all_loc = torch.tensor(padded_X_all_loc).long() 279 | padded_X_all_tim = torch.tensor(padded_X_all_tim).long() 280 | 281 | return session_Batch(padded_X_all_loc, padded_X_all_tim, X_all_text, Y_location, target_lengths, X_lengths, X_users, 282 | padded_X_all_loc_category, padded_X_all_lat, padded_X_all_lon, Y_location_category, Y_lat, 283 | Y_lon, padded_X_all_geohash, Y_geohash, X_tau, Y_tau, device) 284 | 285 | 286 | class session_Batch(): 287 | def __init__(self, padded_X_all_loc, padded_X_all_tim, X_all_text, Y_location, target_lengths, X_lengths, X_users, 288 | padded_X_all_loc_category, padded_X_all_lat, padded_X_all_lon, Y_location_category, Y_lat, Y_lon, 289 | padded_X_all_geohash, Y_geohash, X_tau, Y_tau, device): 290 | self.X_all_loc = torch.LongTensor(padded_X_all_loc).to(device) # (batch, max_all_length) 291 | self.X_all_tim = torch.LongTensor(padded_X_all_tim).to(device) # (batch, max_all_length) 292 | self.X_all_text = X_all_text 293 | self.X_all_loc_category = padded_X_all_loc_category 294 | self.X_all_lat = padded_X_all_lat 295 | self.X_all_lon = padded_X_all_lon 296 | self.X_all_geohash = torch.FloatTensor(padded_X_all_geohash).to(device) 297 | self.Y_location = torch.Tensor(Y_location).long().to(device) # (Batch,) 298 | self.Y_location_category = Y_location_category 299 | self.Y_lat = Y_lat 300 | self.Y_lon = Y_lon 301 | self.Y_geohash = Y_geohash 302 | self.target_lengths = target_lengths 303 | self.X_lengths = X_lengths 304 | self.X_users = torch.Tensor(X_users).long().to(device) 305 | 306 | self.X_tau = X_tau 307 | self.Y_tau = Y_tau 308 | 309 | 310 | def pad_1d(inputs, pad=0): 311 | def pad_data(x, length, pad): 312 | x_padded = np.pad( 313 | x, (0, length - len(x)), mode="constant", constant_values=pad 314 | ) 315 | return x_padded 316 | 317 | max_len = max((len(x) for x in inputs)) 318 | padded = np.stack([pad_data(x, max_len, pad) for x in inputs]) 319 | return padded 320 | -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import copy 5 | import math 6 | from torch.distributions import Uniform 7 | 8 | class DotDict(dict): 9 | __getattr__ = dict.__getitem__ 10 | __setattr__ = dict.__setitem__ 11 | __delattr__ = dict.__delitem__ 12 | 13 | def clamp_preserve_gradients(x, min, max): 14 | """Clamp the tensor while preserving gradients in the clamped region.""" 15 | return x + (x.clamp(min, max) - x).detach() 16 | 17 | 18 | def KLD(mu, logvar): 19 | ''' 20 | the KL divergency of Gaussian distribution with a standard normal distribution 21 | https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence#Multivariate_normal_distributions 22 | :param mu: the mean (batch, dim) 23 | :param logvar: the log of variance (batch, dim) 24 | :return: 25 | ''' 26 | return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / mu.shape[0] 27 | 28 | def KLD_category(w): 29 | ''' 30 | the KL divergency of category distribution with param=w and the uniform category distribution 31 | :param w: (batch, nClass) 32 | :return: 33 | ''' 34 | nClass = w.shape[1] 35 | p = torch.ones_like(w)/nClass # (batch, nClass) 36 | # print(p[0]) 37 | return torch.sum(w * torch.log(w/p)) / w.shape[0] 38 | 39 | 40 | class MLP2(nn.Module): 41 | """ 42 | MLP with two outputs, one for mu, one for log(var) 43 | """ 44 | def __init__(self, input_size, output_size, 45 | dropout=.0, hidden_size=128, use_selu=True): 46 | super(MLP2, self).__init__() 47 | self.hidden_size = hidden_size 48 | if self.hidden_size > 0: 49 | self.fc1 = nn.Linear(input_size, hidden_size) 50 | self.fc21 = nn.Linear(hidden_size, output_size) 51 | self.fc22 = nn.Linear(hidden_size, output_size) 52 | self.nonlinear_f = F.selu if use_selu else F.relu 53 | self.dropout = nn.Dropout(dropout) 54 | else: 55 | self.fc21 = nn.Linear(input_size, output_size) 56 | self.fc22 = nn.Linear(input_size, output_size) 57 | self.nonlinear_f = F.selu if use_selu else F.relu 58 | self.dropout = nn.Dropout(dropout) 59 | 60 | def forward(self, x): 61 | ''' 62 | 63 | :param x: (batch, dim) 64 | :return: 65 | ''' 66 | # print('mlp x:', x[:3,:]) 67 | # print('mpl self.fc1(x):', self.fc1(x)[:3, :]) 68 | # print('mpl self.nonlinear_f(self.fc1(x)):', self.nonlinear_f(self.fc1(x))[:3, :]) 69 | if self.hidden_size > 0: 70 | h1 = self.dropout(self.nonlinear_f(self.fc1(x))) 71 | return self.fc21(h1), self.fc22(h1) 72 | else: 73 | return self.fc21(x), self.fc22(x) 74 | 75 | 76 | class MLP(nn.Module): 77 | """ 78 | MLP with one output (not normalized) for multinomial distribution 79 | """ 80 | def __init__(self, input_size, hidden_size=64, output_size=1, dropout=0.0, use_selu=True): 81 | ''' 82 | 83 | :param input_size: 84 | :param hidden_size: 85 | :param output_size: the num of cluster 86 | :param dropout: 87 | :param use_selu: 88 | ''' 89 | super(MLP, self).__init__() 90 | self.fc1 = nn.Linear(input_size, hidden_size) 91 | self.fc2 = nn.Linear(hidden_size, output_size) 92 | self.nonlinear_f = F.selu if use_selu else F.leaky_relu 93 | self.dropout = nn.Dropout(dropout) 94 | 95 | def forward(self, x): 96 | h1 = self.dropout(self.nonlinear_f(self.fc1(x))) 97 | return self.fc2(h1) 98 | 99 | 100 | def clones(module, N): 101 | "Produce N identical layers." 102 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) 103 | 104 | 105 | class SublayerConnection(nn.Module): 106 | """ 107 | A residual connection followed by a layer norm. 108 | Note for code simplicity the norm is first as opposed to last. 109 | """ 110 | def __init__(self, size, dropout): 111 | super(SublayerConnection, self).__init__() 112 | self.norm = nn.LayerNorm(size) 113 | self.dropout = nn.Dropout(dropout) 114 | 115 | def forward(self, x, sublayer): 116 | "Apply residual connection to any sublayer with the same size." 117 | return x + self.dropout(sublayer(self.norm(x))) 118 | 119 | 120 | def attention(query, key, value, mask=None, dropout=None): 121 | ''' 122 | 123 | :param query: (B, , max_length, d_k) 124 | :param key: (B, , max_length, d_k) 125 | :param value: (B, , max_length, d_k) 126 | :param mask: (B, <1>, max_length, max_length), true/false matrix, and true means paddings 127 | :param dropout: 128 | :return: outputs:(B, , max_length, d_k), att_scores:(B, , max_length, max_length) 129 | ''' 130 | "Compute 'Scaled Dot Product Attention'" 131 | # print('query:', query.shape) 132 | # print('key:', key.shape) 133 | # print('value:', value.shape) 134 | d_k = query.size(-1) 135 | # print('start 4 query:', query[-1,0]) 136 | # print('start 4 key:', key[-1,0]) 137 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) 138 | # print('start 4 scores:', scores.shape, scores[-1, 0, :, :]) 139 | if mask is not None: 140 | scores = scores.masked_fill(mask, -1e9) # true->-1e9 141 | # print('mask:', mask.shape, mask[-1,0,0,:]) 142 | # print('start 5 scores:', scores.shape, scores[-1,0,:,:]) 143 | p_attn = F.softmax(scores, dim=-1) # 每行和为1 144 | # print('start 5 p_attn:', p_attn.shape, p_attn[-1, 0, :, :]) 145 | # print('----') 146 | if dropout is not None: 147 | p_attn = dropout(p_attn) 148 | return torch.matmul(p_attn, value), p_attn 149 | 150 | 151 | class MultiHeadedAttention(nn.Module): 152 | def __init__(self, h, d_model, dropout=0.1): 153 | "Take in model size and number of heads." 154 | super(MultiHeadedAttention, self).__init__() 155 | assert d_model % h == 0 156 | # We assume d_v always equals d_k 157 | self.d_k = d_model // h 158 | self.h = h 159 | self.linears = clones(nn.Linear(d_model, d_model), 4) # for query, key, value, output 160 | self.attn = None 161 | self.dropout = nn.Dropout(p=dropout) 162 | 163 | def forward(self, query, key, value, mask=None, distance=None): 164 | ''' 165 | 166 | :param query: (B, max_length, d_model) 167 | :param key: (B, max_length, d_model) 168 | :param value: (B, max_length, d_model) 169 | :param mask: (B, max_length, max_length) 170 | :return: (B, max_length, d_model) 171 | ''' 172 | # print('start 3 MHA query:', query[0]) 173 | # print('start 3 MHA key:', key[0]) 174 | if mask is not None: 175 | # Same mask applied to all h heads. 176 | mask = mask.unsqueeze(1) # mask (B, 1, max_length)->(B, 1, 1, max_length) 177 | nbatches = query.size(0) 178 | 179 | # 1) Do all the linear projections in batch from d_model => h x d_k 180 | query, key, value = \ 181 | [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) 182 | for l, x in zip(self.linears, (query, key, value))] 183 | 184 | # 2) Apply attention on all the projected vectors in batch. 185 | x, self.attn = attention(query, key, value, mask=mask, 186 | dropout=self.dropout) 187 | 188 | # 3) "Concat" using a view and apply a final linear. 189 | x = x.transpose(1, 2).contiguous() \ 190 | .view(nbatches, -1, self.h * self.d_k) 191 | return self.linears[-1](x) 192 | 193 | 194 | def spatial_aware_attention(query, key, value, distance, mask=None, dropout=None): 195 | ''' 196 | 197 | :param query: (B, h, max_length, d_k) 198 | :param key: (B, h, max_length, d_k) 199 | :param value: (B, h, max_length, d_k) 200 | :param distance: (B, h, max_length, max_length) 201 | :param mask: (B, 1, max_length, max_length), true/false matrix, and true means paddings 202 | :param dropout: 203 | :return: outputs:(B, h, max_length, d_k), att_scores:(B, h, max_length, max_length) 204 | ''' 205 | "Compute 'Scaled Dot Product Attention'" 206 | # print('query:', query.shape) 207 | # print('key:', key.shape) 208 | # print('value:', value.shape) 209 | d_k = query.size(-1) 210 | # print('start 4 query:', query[-1,0]) 211 | # print('start 4 key:', key[-1,0]) 212 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) # (B, h, max_length, max_length) 213 | # print('start 4 scores:', scores.shape, scores[-1, 0, :, :]) 214 | scores = scores - distance # (B, h, max_length, max_length) 215 | if mask is not None: 216 | scores = scores.masked_fill(mask, -1e9) # true->-1e9 217 | # print('mask:', mask.shape, mask[-1,0,0,:]) 218 | # print('start 5 scores:', scores.shape, scores[-1,0,:,:]) 219 | p_attn = F.softmax(scores, dim=-1) # 每行和为1 220 | # print('start 5 p_attn:', p_attn.shape, p_attn[-1, 0, :, :]) 221 | # print('----') 222 | if dropout is not None: 223 | p_attn = dropout(p_attn) 224 | return torch.matmul(p_attn, value), p_attn 225 | 226 | 227 | class SpatialAwareMultiHeadedAttention(nn.Module): 228 | def __init__(self, h, d_model, dropout=0.1): 229 | "Take in model size and number of heads." 230 | super(SpatialAwareMultiHeadedAttention, self).__init__() 231 | assert d_model % h == 0 232 | # We assume d_v always equals d_k 233 | self.d_k = d_model // h 234 | self.h = h 235 | self.linears = clones(nn.Linear(d_model, d_model), 4) # for query, key, value, output 236 | self.logwb = nn.Parameter(Uniform(0.0, 1.0).sample((self.h,))) # (h,) 237 | self.attn = None 238 | self.dropout = nn.Dropout(p=dropout) 239 | 240 | def forward(self, query, key, value, mask=None, distance=None): 241 | ''' 242 | 243 | :param query: (B, max_length, d_model) 244 | :param key: (B, max_length, d_model) 245 | :param value: (B, max_length, d_model) 246 | :param distance: (B, max_length, max_length) 247 | :param mask: (B, max_length, max_length) 248 | :return: (B, max_length, d_model) 249 | ''' 250 | # print('start 3 MHA query:', query[0]) 251 | # print('start 3 MHA key:', key[0]) 252 | if mask is not None: 253 | # Same mask applied to all h heads. 254 | mask = mask.unsqueeze(1) # mask (B, 1, max_length)->(B, 1, 1, max_length) 255 | nbatches = query.size(0) 256 | 257 | # 1) Do all the linear projections in batch from d_model => h x d_k 258 | query, key, value = \ 259 | [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) 260 | for l, x in zip(self.linears, (query, key, value))] 261 | # distance to multi-head: (B, max_length, max_length) --> (B, h, max_length, max_length) 262 | wb = torch.exp(self.logwb).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) # (1,h,1,1) 263 | mh_distance = distance.unsqueeze(1).repeat(1, self.h, 1, 1) * wb # (B, h, max_length, max_length)*(1,h,1,1) 264 | # 2) Apply attention on all the projected vectors in batch. 265 | x, self.attn = spatial_aware_attention(query, key, value, mh_distance, mask=mask, 266 | dropout=self.dropout) 267 | 268 | # 3) "Concat" using a view and apply a final linear. 269 | x = x.transpose(1, 2).contiguous() \ 270 | .view(nbatches, -1, self.h * self.d_k) 271 | return self.linears[-1](x) 272 | 273 | 274 | class PositionwiseFeedForward(nn.Module): 275 | "Implements FFN equation." 276 | def __init__(self, d_model, d_ff, dropout=0.1): 277 | super(PositionwiseFeedForward, self).__init__() 278 | self.w_1 = nn.Linear(d_model, d_ff) 279 | self.w_2 = nn.Linear(d_ff, d_model) 280 | self.dropout = nn.Dropout(dropout) 281 | 282 | def forward(self, x): 283 | ''' 284 | 285 | :param x: (B, max_length, d_model) 286 | :return: (B, max_length, d_model) 287 | ''' 288 | return self.w_2(self.dropout(F.relu(self.w_1(x)))) 289 | 290 | 291 | class TransformerEncoderLayer(nn.Module): 292 | "Encoder is made up of self-attn and feed forward (defined below)" 293 | def __init__(self, size, self_attn, feed_forward, dropout): 294 | super(TransformerEncoderLayer, self).__init__() 295 | self.self_attn = self_attn 296 | self.feed_forward = feed_forward 297 | self.sublayer = clones(SublayerConnection(size, dropout), 2) 298 | self.size = size 299 | 300 | def forward(self, x, mask=None, distance=None): 301 | ''' 302 | 303 | :param x: (B, max_length, d_model) 304 | :param mask: (B, 1, max_length) 305 | :return: (B, max_length, d_model) 306 | ''' 307 | # print('start 2 x:', x[0]) 308 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask=mask, distance=distance)) 309 | return self.sublayer[1](x, self.feed_forward) 310 | 311 | 312 | class TransformerEncoder(nn.Module): 313 | "Core encoder is a stack of N layers" 314 | 315 | def __init__(self, layer, N): 316 | super(TransformerEncoder, self).__init__() 317 | self.layers = clones(layer, N) 318 | self.norm = nn.LayerNorm(layer.size) 319 | 320 | def forward(self, x, padding_mask=None, session_mask=None, subsequent_mask=None, distance=None): 321 | mask = torch.zeros(x.size(0), x.size(1), x.size(1), device=x.device).bool() # .type(torch.uint8) # (B, max_length, max_length) 322 | # torch.set_printoptions(threshold=1000000) 323 | if padding_mask is not None: 324 | padding_mask = padding_mask.repeat(1, x.size(1), 1).bool() # (B, max_length, max_length) 325 | # print('in padding_mask:', padding_mask) 326 | mask = mask | padding_mask 327 | if session_mask is not None: 328 | # print('in session_mask:', session_mask) 329 | mask = mask | session_mask 330 | if subsequent_mask is not None: 331 | # print('in subsequent_mask:', subsequent_mask) 332 | mask = mask | subsequent_mask 333 | # print('in mask', mask) 334 | for layer in self.layers: 335 | x = layer(x, mask=mask, distance=distance) 336 | return self.norm(x) 337 | 338 | 339 | class Hypernet(nn.Module): 340 | """ 341 | Hypernetwork deals with decoder input and generates params for mu, sigma, w 342 | 343 | Args: 344 | config: Model configuration. 345 | hidden_sizes: Sizes of the hidden layers. [] corresponds to a linear layer. 346 | param_sizes: Sizes of the output parameters. [n_components, n_components, n_components] 分别指定w,mu,s的维度/components 347 | activation: Activation function. 348 | """ 349 | def __init__(self, config, hidden_sizes=None, param_sizes=None, activation=nn.Tanh()): 350 | super().__init__() 351 | if param_sizes is None: 352 | param_sizes = [1, 1, 1] 353 | if hidden_sizes is None: 354 | hidden_sizes = [] 355 | self.decoder_input_size = config.decoder_input_size 356 | self.activation = activation 357 | 358 | # print("hidden_sizes:", hidden_sizes) # [] 359 | # print("param_sizes:", param_sizes) # [64, 64, 64] 360 | # Indices for unpacking parameters 361 | ends = torch.cumsum(torch.tensor(param_sizes), dim=0) 362 | starts = torch.cat((torch.zeros(1).type_as(ends), ends[:-1])) 363 | self.param_slices = [slice(s.item(), e.item()) for s, e in zip(starts, ends)] 364 | # self.param_slices.shape = [slice(0, 64, None), slice(64, 128, None), slice(128, 192, None)] 365 | 366 | self.output_size = sum(param_sizes) 367 | layer_sizes = list(hidden_sizes) + [self.output_size] 368 | # print("Hypernet layer_sizes:", layer_sizes) # [192] 369 | # Bias used in the first linear layer 370 | self.first_bias = nn.Parameter(torch.empty(layer_sizes[0]).uniform_(-0.1, 0.1)) 371 | self.first_linear = nn.Linear(self.decoder_input_size, layer_sizes[0], bias=False) 372 | 373 | # Remaining linear layers 374 | self.linear_layers = nn.ModuleList() 375 | for idx, size in enumerate(layer_sizes[:-1]): 376 | self.linear_layers.append(nn.Linear(size, layer_sizes[idx + 1])) 377 | 378 | def reset_parameters(self): 379 | self.first_bias.data.fill_(0.0) 380 | self.first_linear.reset_parameters() 381 | nn.init.orthogonal_(self.first_linear.weight) 382 | for layer in self.linear_layers: 383 | layer.reset_parameters() 384 | nn.init.orthogonal_(layer.weight) 385 | 386 | def forward(self, decoder_input): 387 | """Generate model parameters from the embeddings. 388 | 389 | Args: 390 | input: decoder input, shape (batch, decoder_input_size) 391 | 392 | Returns: 393 | params: Tuple of model parameters. 394 | """ 395 | # Generate the output based on the input 396 | hidden = self.first_bias 397 | hidden = hidden + self.first_linear(decoder_input) 398 | for layer in self.linear_layers: 399 | hidden = layer(self.activation(hidden)) 400 | 401 | # # Partition the output 402 | # if len(self.param_slices) == 1: 403 | # return hidden 404 | # else: 405 | return tuple([hidden[..., s] for s in self.param_slices]) 406 | -------------------------------------------------------------------------------- /model/llm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForMaskedLM 3 | from peft import LoraModel, LoraConfig 4 | 5 | from .base import * 6 | from .bert import get_batch_mask 7 | from .phi_model import PhiModel 8 | from .positional_encoding import LearnableFourierFeatures as LFF 9 | 10 | 11 | def get_encoder(model_path, model_class): 12 | if model_class == 'tinybert': 13 | lora_config = LoraConfig( 14 | task_type="SEQ_2_SEQ_LM", 15 | r=8, # Lora attention dimension. 16 | lora_alpha=32, # The alpha parameter for Lora scaling. 17 | target_modules=["query", "value"], # The names of the modules to apply Lora to. 18 | lora_dropout=0.01, # The dropout probability for Lora layers. 19 | ) 20 | model = AutoModelForMaskedLM.from_pretrained(model_path) 21 | tokenizer = AutoTokenizer.from_pretrained(model_path) 22 | emb_size = model.config.emb_size 23 | hidden_size = model.config.hidden_size 24 | 25 | elif model_class == 'phi-2': 26 | lora_config = LoraConfig( 27 | task_type="CAUSAL_LM", 28 | r=8, # Lora attention dimension. 29 | lora_alpha=32, # The alpha parameter for Lora scaling. 30 | target_modules=["Wqkv"], # The names of the modules to apply Lora to. 31 | lora_dropout=0.01, # The dropout probability for Lora layers. 32 | ) 33 | model = CustomPhiModel.from_pretrained( 34 | model_path, 35 | torch_dtype=torch.float32, 36 | # device_map="cuda", 37 | trust_remote_code=True 38 | ) 39 | tokenizer = AutoTokenizer.from_pretrained(model_path) 40 | emb_size = model.config.n_embd 41 | hidden_size = model.config.n_embd 42 | 43 | elif model_class == 'gpt2': 44 | lora_config = LoraConfig( 45 | task_type="CAUSAL_LM", 46 | r=8, # Lora attention dimension. 47 | lora_alpha=32, # The alpha parameter for Lora scaling. 48 | target_modules=["c_attn"], # The names of the modules to apply Lora to. 49 | lora_dropout=0.02, # The dropout probability for Lora layers. 50 | ) 51 | model = AutoModelForCausalLM.from_pretrained(model_path) 52 | tokenizer = AutoTokenizer.from_pretrained(model_path) 53 | emb_size = model.config.n_embd 54 | hidden_size = model.config.n_embd 55 | 56 | elif model_class == 'TinyLlama-1_1B': 57 | lora_config = LoraConfig( 58 | task_type="CAUSAL_LM", 59 | r=8, # Lora attention dimension. 60 | lora_alpha=32, # The alpha parameter for Lora scaling. 61 | target_modules=["q_proj","k_proj"], # The names of the modules to apply Lora to. 62 | lora_dropout=0.02, # The dropout probability for Lora layers. 63 | ) 64 | model = CustomLlamaModel(model_path='params/TinyLlama-1.1B') 65 | tokenizer = AutoTokenizer.from_pretrained('params/TinyLlama-1.1B') 66 | emb_size = 2048 67 | hidden_size = 2048 68 | 69 | elif model_class == 'TinyLlama-Chat': 70 | lora_config = LoraConfig( 71 | task_type="CAUSAL_LM", 72 | r=8, # Lora attention dimension. 73 | lora_alpha=32, # The alpha parameter for Lora scaling. 74 | target_modules=["q_proj","k_proj"], # The names of the modules to apply Lora to. 75 | lora_dropout=0.02, # The dropout probability for Lora layers. 76 | ) 77 | model = CustomLlamaModel(model_path= 'params/TinyLlama-Chat') 78 | tokenizer = AutoTokenizer.from_pretrained('params/TinyLlama-Chat') 79 | emb_size = 2048 80 | hidden_size = 2048 81 | 82 | elif model_class == 'pythia-70M': 83 | lora_config = LoraConfig( 84 | task_type="CAUSAL_LM", 85 | r=8, # Lora attention dimension. 86 | lora_alpha=32, # The alpha parameter for Lora scaling. 87 | target_modules=["query_key_value"], # The names of the modules to apply Lora to. 88 | lora_dropout=0.02, # The dropout probability for Lora layers. 89 | ) 90 | model = CustomPythiaModel(model_path= 'params/pythia-70M') 91 | tokenizer = AutoTokenizer.from_pretrained('params/pythia-70M') 92 | emb_size = 512 93 | hidden_size = 512 94 | 95 | elif model_class == 'pythia-2_8B': 96 | lora_config = LoraConfig( 97 | task_type="CAUSAL_LM", 98 | r=8, # Lora attention dimension. 99 | lora_alpha=32, # The alpha parameter for Lora scaling. 100 | target_modules=["query_key_value"], # The names of the modules to apply Lora to. 101 | lora_dropout=0.02, # The dropout probability for Lora layers. 102 | ) 103 | model = CustomPythiaModel(model_path= 'params/pythia-2.8B') 104 | tokenizer = AutoTokenizer.from_pretrained('params/pythia-2.8B') 105 | emb_size = 2560 106 | hidden_size = 2560 107 | 108 | elif model_class == 'pythia-1B': 109 | lora_config = LoraConfig( 110 | task_type="CAUSAL_LM", 111 | r=8, # Lora attention dimension. 112 | lora_alpha=32, # The alpha parameter for Lora scaling. 113 | target_modules=["query_key_value"], # The names of the modules to apply Lora to. 114 | lora_dropout=0.02, # The dropout probability for Lora layers. 115 | ) 116 | model = CustomPythiaModel(model_path= 'params/pythia-1B') 117 | tokenizer = AutoTokenizer.from_pretrained('params/pythia-1B') 118 | emb_size = 2048 119 | hidden_size = 2048 120 | 121 | elif model_class == 'LiteLlama': 122 | lora_config = LoraConfig( 123 | task_type="CAUSAL_LM", 124 | r=8, # Lora attention dimension. 125 | lora_alpha=32, # The alpha parameter for Lora scaling. 126 | target_modules=["q_proj","k_proj"], # The names of the modules to apply Lora to. 127 | lora_dropout=0.02, # The dropout probability for Lora layers. 128 | ) 129 | model = CustomLlamaModel(model_path= 'params/LiteLlama') 130 | tokenizer = AutoTokenizer.from_pretrained('params/LiteLlama') 131 | emb_size = 1024 132 | hidden_size = 1024 133 | 134 | else: 135 | raise NotImplementedError("model_class should be one of ['tinybert', 'phi-2']") 136 | 137 | return LoraModel(model, lora_config, model_class), tokenizer, emb_size, hidden_size 138 | # return model, emb_size, hidden_size 139 | 140 | 141 | class CustomPhiModel(PhiModel): 142 | """ Phi for traj modeling """ 143 | 144 | _keys_to_ignore_on_load_missing = ["ladder_side_nets", "up_net"] 145 | # _keys_to_ignore_on_load_unexpected = [r"transformer\.h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"] 146 | 147 | def __init__(self, config, r=32): 148 | super().__init__(config) 149 | 150 | assert config.n_embd % r == 0, f"n_embd should be divisible by r, got {config.n_embd} and {r}" 151 | side_dim = config.n_embd // r 152 | 153 | self.side_dim = side_dim 154 | self.ladder_side_nets = nn.ModuleList([nn.Linear(config.n_embd, side_dim) for _ in range(config.n_layer)]) 155 | self.up_net = nn.Linear(side_dim, config.n_embd) 156 | 157 | def forward(self, input_ids=None, inputs_embeds=None, past_key_values=None, attention_mask=None): 158 | if inputs_embeds is None: 159 | assert input_ids is not None, "You have to specify either input_ids or inputs_embeds" 160 | hidden_states = self.embd(input_ids) 161 | else: 162 | hidden_states = inputs_embeds 163 | 164 | #hidden_states_backbone = hidden_states.detach() 165 | #side_states = torch.zeros(*hidden_states.shape[:-1], self.side_dim, 166 | # dtype=hidden_states.dtype).to(hidden_states.device) 167 | for i, layer in enumerate(self.h): 168 | hidden_states = layer( 169 | hidden_states, 170 | past_key_values=past_key_values, 171 | attention_mask=attention_mask, 172 | ) 173 | 174 | #hidden_states = self.up_net(side_states) + hidden_states 175 | return hidden_states 176 | 177 | class CustomLlamaModel(nn.Module): 178 | """ Phi for traj modeling """ 179 | 180 | 181 | def __init__(self, model_path): 182 | super().__init__() 183 | 184 | self.model = AutoModelForCausalLM.from_pretrained(model_path).model 185 | self.embd = self.model.embed_tokens 186 | 187 | def forward(self, input_ids=None, inputs_embeds=None, past_key_values=None, attention_mask=None): 188 | past_key_values_length = 0 189 | 190 | device = input_ids.device if input_ids is not None else inputs_embeds.device 191 | position_ids = torch.arange(0, inputs_embeds.shape[1], dtype=torch.long, device=device) 192 | position_ids = position_ids.unsqueeze(0) 193 | 194 | hidden_states = inputs_embeds 195 | 196 | # decoder layers 197 | all_hidden_states = () 198 | 199 | 200 | for decoder_layer in self.model.layers: 201 | 202 | all_hidden_states += (hidden_states,) 203 | 204 | layer_outputs = decoder_layer( 205 | hidden_states, 206 | attention_mask=attention_mask, 207 | position_ids=position_ids, 208 | past_key_value=past_key_values, 209 | 210 | ) 211 | 212 | hidden_states = layer_outputs[0] 213 | 214 | hidden_states = self.model.norm(hidden_states) 215 | # add hidden states from the last decoder layer 216 | all_hidden_states += (hidden_states,) 217 | return all_hidden_states 218 | 219 | class CustomPythiaModel(nn.Module): 220 | """ Phi for traj modeling """ 221 | 222 | 223 | def __init__(self, model_path): 224 | super().__init__() 225 | 226 | self.model = AutoModelForCausalLM.from_pretrained(model_path).gpt_neox 227 | self.embd = self.model.embed_in 228 | # 模型中有self.emb_dropout,是否也要加? 229 | # forword函数相关代码为: 230 | # if inputs_embeds is None: 231 | # inputs_embeds = self.embed_in(input_ids) 232 | # hidden_states = self.emb_dropout(inputs_embeds) 233 | 234 | 235 | def forward(self, input_ids=None, inputs_embeds=None, past_key_values=None, attention_mask=None): 236 | past_key_values_length = 0 237 | 238 | device = input_ids.device if input_ids is not None else inputs_embeds.device 239 | position_ids = torch.arange(0, inputs_embeds.shape[1], dtype=torch.long, device=device) 240 | position_ids = position_ids.unsqueeze(0) 241 | 242 | hidden_states = inputs_embeds 243 | 244 | # decoder layers 245 | all_hidden_states = () 246 | past_length = 0 247 | past_key_values = tuple([None] * self.model.config.num_hidden_layers) 248 | 249 | for i, (layer, layer_past) in enumerate(zip(self.model.layers, past_key_values)): 250 | 251 | all_hidden_states = all_hidden_states + (hidden_states,) 252 | 253 | 254 | 255 | outputs = layer( 256 | hidden_states, 257 | attention_mask=attention_mask, 258 | position_ids=position_ids, 259 | 260 | layer_past=layer_past, 261 | 262 | ) 263 | hidden_states = outputs[0] 264 | 265 | hidden_states = self.model.final_layer_norm(hidden_states) 266 | # Add last hidden state 267 | 268 | all_hidden_states = all_hidden_states + (hidden_states,) 269 | return all_hidden_states 270 | 271 | class LLMModel(Encoder): 272 | def __init__(self, model_path,loc_size,device,output_size=256, 273 | 274 | model_class='gpt2', learnable_param_size=1,): 275 | super().__init__('LLM-gpt2') 276 | 277 | self.output_size = output_size 278 | self.loc_size=loc_size 279 | self.model_class = model_class 280 | self.device = device 281 | self.learnable_param_size=learnable_param_size 282 | 283 | self.encoder, self.tokenizer, self.emb_size, self.hidden_size = get_encoder(model_path, model_class) 284 | self.tokenizer.pad_token = self.tokenizer.eos_token 285 | 286 | # Froze the parameters. 287 | for i, (name, param) in enumerate(self.encoder.named_parameters()): 288 | #param.requires_grad = False 289 | 290 | # if 'side_net' in name or 'up_net' in name: 291 | # param.requires_grad = True 292 | # else: 293 | # param.requires_grad = False 294 | 295 | # if 'ln' in name or 'wpe' in name: # or 'mlp' in name: 296 | # param.requires_grad = True 297 | # elif 'mlp' in name: 298 | # param.requires_grad = False 299 | # else: 300 | # param.requires_grad = False 301 | print(name, param.requires_grad) 302 | 303 | self.embedder = nn.Embedding(self.loc_size,self.emb_size) 304 | self.out_linear = nn.Sequential(nn.Linear(self.hidden_size, output_size, bias=False), 305 | nn.LayerNorm(output_size), 306 | nn.ReLU(inplace=True), 307 | nn.Linear(output_size, output_size)) 308 | # # f_dim,h_dim的取值问题(暂时) 309 | # self.lff = LFF(pos_dim=1, f_dim=128, h_dim=256, d_dim=self.emb_size) # learnable fourier features module 310 | self.time_linear = nn.Sequential(nn.Linear(1, 16),nn.LayerNorm(16), 311 | nn.ReLU(inplace=True),nn.Linear(16, 16)) 312 | self.cat_linear = nn.Sequential(nn.Linear(self.emb_size+16, self.emb_size),nn.LayerNorm(self.emb_size), 313 | nn.ReLU(inplace=True), 314 | nn.Linear(self.emb_size, self.emb_size)) 315 | 316 | # TEMPO 317 | self.cls_token = nn.Parameter(torch.zeros(self.learnable_param_size, self.emb_size).float(), requires_grad=True) 318 | 319 | def forward(self, x, valid_len, time, category, geohash_, **kwargs): 320 | return self.forward_suffix(x, valid_len, time, category, geohash_, self.cls_token) 321 | 322 | def forward_suffix(self, x, valid_len, time, category, geohash_, tokens): 323 | """ P-tuning-like suffix forward. """ 324 | B, L, E_in = x.unsqueeze(-1).shape # time.shape=[B,L], category.shape=[B,L] 325 | 326 | # TEMPO改进方案 327 | trip_batch_mask = get_batch_mask(B, L+self.learnable_param_size, valid_len) 328 | batch_mask = get_batch_mask(B, L+self.learnable_param_size, valid_len+self.learnable_param_size) 329 | 330 | masked_values = torch.zeros_like(x) 331 | 332 | x = torch.where(get_batch_mask(B, L, valid_len).unsqueeze(-1), x.unsqueeze(-1), masked_values.unsqueeze(-1)) 333 | x_embeddings = self.embedder(x).squeeze(-2) # (B,L,self.emb_size) 334 | # TEST 335 | ''' 336 | # Basic usage of Learnable-Fourier-Features 337 | lff = LFF(pos_dim=2, f_dim=128, h_dim=256, d_dim=64) # learnable fourier features module 338 | pos = torch.randn([4, 1024, 1, 2]) # random positional coordinates 339 | pe = lff(pos) # forward 340 | ''' #128*173 341 | # time_embeddings = self.lff(time.unsqueeze(-1).unsqueeze(-1)) # (B L G M) → (B L D) 342 | # x_embeddings += time_embeddings * 0.01 343 | ## x_embeddings = self.cat_linear(torch.cat((x_embeddings, time_embeddings), dim=-1)) 344 | time_embeddings = self.time_linear(time.unsqueeze(-1)) 345 | 346 | category_vocabIds = [] 347 | for i, seq in enumerate(category): # 对batch数据应该还是得用for循环 348 | input_seq = seq[:valid_len[i]] # seq中的pad符号为数字0(不是str,麻烦),先获取有效长度的seq 349 | if len(set(input_seq))==1 and '' in input_seq: # seq存在['','',……,0,0,0,……]的情况(最麻烦),替换''(暂) 350 | input_seq = [self.tokenizer.eos_token] * valid_len[i] 351 | # 对每个序列取category对应的vocabId的最后一个 352 | # 不用for循环实现,需要先对tokens做padding,再取最后一个非padding的——可能不如普通的for循环实现快 353 | category_vocabId = self.tokenizer(input_seq, padding=True, return_tensors="pt")['input_ids'] 354 | last_index = torch.where(category_vocabId != self.tokenizer.eos_token_id, torch.full_like(category_vocabId, 1), 0).sum(dim=1) 355 | category_vocabId = [int(category_vocabId[index][int(x.item()) - 1].item()) for index, x in enumerate(last_index)] 356 | ''' 357 | category_vocabId = [] 358 | for i in seq: 359 | if i == '': # test时,若出现''要特殊处理一下,否则报错 360 | category_vocabId.append(torch.tensor(self.tokenizer.eos_token_id)) 361 | else: 362 | category_vocabId.append(self.tokenizer(i, return_tensors="pt")['input_ids'][-1][-1]) 363 | category_vocabId = torch.tensor(category_vocabId) 364 | ''' 365 | category_vocabIds.append(torch.tensor(category_vocabId+[self.tokenizer.eos_token_id]*(L-valid_len[i]))) 366 | category_vocabIds = torch.stack(category_vocabIds,dim=0).to(self.device) 367 | if self.model_class == 'gpt2': 368 | category_embeddings = self.encoder.model.transformer.wte(category_vocabIds) 369 | else: 370 | category_embeddings = self.encoder.model.embd(category_vocabIds) 371 | #x_embeddings += category_embeddings 372 | 373 | #x_embeddings = self.cat_linear(torch.cat((x_embeddings, time_embeddings, category_embeddings), dim=-1)) 374 | x_embeddings = self.cat_linear(torch.cat((x_embeddings, time_embeddings), dim=-1)) 375 | 376 | h = torch.zeros(B, L+self.learnable_param_size, self.emb_size).to(x.device) 377 | 378 | h[:, :-self.learnable_param_size] = x_embeddings 379 | # position = (batch_mask.long() - trip_batch_mask.long()) == 1 380 | # temp1 = h[(batch_mask.long() - trip_batch_mask.long()) == 1] 381 | # temp2 = tokens.repeat(B,1) 382 | h[(batch_mask.long() - trip_batch_mask.long()) == 1] = tokens.repeat(B,1) 383 | # h[(batch_mask.long() - trip_batch_mask.long()) == 1] = token 384 | if self.model_class == 'tinybert': 385 | h = self.encoder(inputs_embeds=h, attention_mask=batch_mask.unsqueeze(-1), output_hidden_states=True).hidden_states[-1] 386 | elif self.model_class == 'phi-2': 387 | h = self.encoder(inputs_embeds=h, attention_mask=batch_mask) 388 | elif self.model_class == 'gpt2': 389 | h = self.encoder(inputs_embeds=h, attention_mask=batch_mask.unsqueeze(-1), output_hidden_states=True).hidden_states[-1] 390 | elif self.model_class == 'LiteLlama': 391 | h = self.encoder(inputs_embeds=h, attention_mask=batch_mask.unsqueeze(-1).unsqueeze(1).expand(-1,-1,-1,batch_mask.shape[-1]))[0] 392 | elif self.model_class == 'TinyLlama-Chat': 393 | h = self.encoder(inputs_embeds=h, attention_mask=batch_mask.unsqueeze(-1).unsqueeze(1).expand(-1,-1,-1,batch_mask.shape[-1]))[0] 394 | elif self.model_class == 'TinyLlama-1_1B': 395 | h = self.encoder(inputs_embeds=h, attention_mask=batch_mask.unsqueeze(-1).unsqueeze(1).expand(-1,-1,-1,batch_mask.shape[-1]))[0] 396 | elif self.model_class == 'pythia-70M': 397 | h = self.encoder(inputs_embeds=h, attention_mask=batch_mask.unsqueeze(-1).unsqueeze(1).expand(-1,-1,-1,batch_mask.shape[-1]))[0] 398 | elif self.model_class == 'pythia-2_8B': 399 | h = self.encoder(inputs_embeds=h, attention_mask=batch_mask.unsqueeze(-1).unsqueeze(1).expand(-1,-1,-1,batch_mask.shape[-1]))[0] 400 | elif self.model_class == 'pythia-1B': 401 | h = self.encoder(inputs_embeds=h, attention_mask=batch_mask.unsqueeze(-1).unsqueeze(1).expand(-1,-1,-1,batch_mask.shape[-1]))[0] 402 | h = torch.nan_to_num(h) 403 | #output = self.out_linear(h[:, -1]) 404 | output = self.out_linear(h) 405 | 406 | return output -------------------------------------------------------------------------------- /train_MobilityLLM.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import configparser 3 | # from tensorboardX import SummaryWriter 4 | import preprocess.load_data as preprocess 5 | from model.MobilityLLM import * 6 | from copy import deepcopy 7 | from utils import * 8 | from torch.nn.utils import clip_grad_norm_ 9 | import torch 10 | import torch.nn as nn 11 | import time 12 | from tqdm import tqdm 13 | import numpy as np 14 | 15 | #设置随机种子 16 | # 设置随机种子 17 | randomSeed = 202408 18 | torch.manual_seed(randomSeed) 19 | torch.cuda.manual_seed(randomSeed) 20 | torch.cuda.manual_seed_all(randomSeed) 21 | torch.backends.cudnn.deterministic = True 22 | torch.backends.cudnn.benchmark = False 23 | np.random.seed(randomSeed) 24 | 25 | # read hyper-param settings 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument("--config", default='/data/ZhangXinyue/MobilityLLM/config/MobilityLLM_tky_TUL.conf', type=str, 28 | help="configuration file path") 29 | parser.add_argument("--dataroot", default='/data/ZhangXinyue/MobilityLLM/data/', type=str, 30 | help="data root directory") 31 | parser.add_argument("--model_class", default='pythia-70M', type=str, help="configuration file path") 32 | parser.add_argument("--device", default='1', type=str, help="configuration file path") 33 | parser.add_argument("--data_hist", default='1', type=str, help="configuration file path") 34 | args = parser.parse_args() 35 | config_file = args.config 36 | data_root = args.dataroot 37 | model_class = args.model_class 38 | ctx = args.device 39 | data_hist=float(args.data_hist) 40 | config = configparser.ConfigParser() 41 | print('Read configuration file: %s' % (args.config)) 42 | print('>>>>>>> configuration <<<<<<<') 43 | with open(config_file, 'r') as f: 44 | print(f.read()) 45 | print('\n') 46 | config.read(args.config) 47 | data_config = config['Data'] 48 | training_config = config['Training'] 49 | model_config = config['Model'] 50 | 51 | # Data config 52 | dataset_name = data_config['dataset_name'] 53 | max_his_period_days = data_config['max_his_period_days'] 54 | max_merge_seconds_limit = data_config['max_merge_seconds_limit'] 55 | max_delta_mins = data_config['max_delta_mins'] 56 | min_session_mins = data_config['min_session_mins'] 57 | least_disuser_count = data_config['least_disuser_count'] 58 | least_checkins_count = data_config['least_checkins_count'] 59 | latN = data_config['latN'] 60 | lngN = data_config['lngN'] 61 | split_save = bool(int(data_config['split_save'])) 62 | dataset_name = dataset_name + '_' + max_his_period_days + 'H' + max_merge_seconds_limit + 'M' + max_delta_mins + 'd' + min_session_mins + 's' + least_disuser_count + 'P' + least_checkins_count + 'U' 63 | 64 | # Training config 65 | mode = training_config['mode'].strip() 66 | # ctx = training_config['ctx'] 67 | # os.environ["CUDA_VISIBLE_DEVICES"] = ctx # 调试时报错,暂且注释 68 | USE_CUDA = torch.cuda.is_available() 69 | print("CUDA:", USE_CUDA, ctx) 70 | device = torch.device("cuda:"+ctx if USE_CUDA else "cpu") 71 | print('device:', device) 72 | use_nni = bool(int(training_config['use_nni'])) 73 | regularization = float(training_config['regularization']) 74 | learning_rate = float(training_config['learning_rate']) 75 | max_epochs = int(training_config['max_epochs']) 76 | display_step = int(training_config['display_step']) 77 | patience = int(training_config['patience']) 78 | train_batch = int(training_config['train_batch']) 79 | val_batch = int(training_config['val_batch']) 80 | test_batch = int(training_config['test_batch']) 81 | batch_size = int(training_config['batch_size']) 82 | save_results = bool(int(training_config['save_results'])) 83 | 84 | specific_config = 'MobilityLLM' 85 | 86 | # Model Setting 87 | loc_emb_size = int(model_config['loc_emb_size']) 88 | tim_emb_size = int(model_config['tim_emb_size']) 89 | user_emb_size = int(model_config['user_emb_size']) 90 | hidden_size = int(model_config['hidden_size']) 91 | 92 | category_size = int(model_config['category_size']) 93 | geohash_size = int(model_config['geohash_size']) 94 | 95 | if 'learnable_param_size' in model_config: 96 | learnable_param_size = int(model_config['learnable_param_size']) 97 | else: 98 | learnable_param_size = 1 99 | 100 | adv = int(model_config['adv']) 101 | rnn_type = model_config['rnn_type'] 102 | num_layers = int(model_config['num_layers']) 103 | downstream = model_config['downstream'] 104 | 105 | loc_noise_mean = float(model_config['loc_noise_mean']) 106 | loc_noise_sigma = float(model_config['loc_noise_sigma']) 107 | tim_noise_mean = float(model_config['tim_noise_mean']) 108 | tim_noise_sigma = float(model_config['tim_noise_sigma']) 109 | user_noise_mean = float(model_config['user_noise_mean']) 110 | user_noise_sigma = float(model_config['user_noise_sigma']) 111 | tau = float(model_config['tau']) 112 | pos_eps = float(model_config['pos_eps']) 113 | neg_eps = float(model_config['neg_eps']) 114 | dropout_rate_1 = float(model_config['dropout_rate_1']) 115 | dropout_rate_2 = dropout_rate_1 116 | 117 | momentum = float(model_config['momentum']) 118 | theta = float(model_config['theta']) 119 | temperature = float(model_config['temperature']) 120 | k = int(model_config['k']) 121 | self_weight_s = float(model_config['self_weight_s']) 122 | self_weight_t = float(model_config['self_weight_t']) 123 | self_weight_st = float(model_config['self_weight_st']) 124 | 125 | dump_path = 'checkpoints' 126 | rank = model_config['rank'] 127 | epoch_queue_starts = int(model_config['epoch_queue_starts']) 128 | crops_for_assign = [0,1] 129 | feat_dim = int(model_config['feat_dim']) 130 | queue_length = int(model_config['queue_length']) 131 | world_size = int(model_config['world_size']) 132 | loss = model_config['loss'] 133 | tpp = model_config['tpp'] 134 | epsilon = float(model_config['epsilon']) 135 | dropout_spatial = float(model_config['dropout_spatial']) 136 | 137 | if use_nni: 138 | import nni 139 | param = nni.get_next_parameter() 140 | # multi-dataset 141 | batch_size = int(param['batch_size']) 142 | hidden_size = int(param['hidden_size']) 143 | user_emb_size = int(param['user_emb_size']) 144 | category_size = int(param['category_size']) 145 | geohash_size = int(param['geohash_size']) 146 | num_layers = int(param['num_layers']) 147 | momentum = float(param['momentum']) 148 | theta = float(param['theta']) 149 | temperature = float(param['temperature']) 150 | k = int(param['k']) 151 | self_weight_s = float(param['self_weight_s']) 152 | self_weight_t = float(param['self_weight_t']) 153 | self_weight_st = float(param['self_weight_st']) 154 | epsilon = float(param['epsilon']) 155 | dropout_spatial = float(param['dropout_spatial']) 156 | 157 | train_batch = batch_size 158 | val_batch = batch_size 159 | test_batch = batch_size 160 | 161 | print('load dataset:', dataset_name) 162 | print('split_save:', split_save) 163 | 164 | # Data 165 | if data_config['dataset_name'] == "www_NYC" or data_config['dataset_name'] == "TSMC_www_NYC": 166 | data = np.load(data_root + "nyc_cnt2category2cnt.npz", allow_pickle=True) 167 | elif data_config['dataset_name'] == "www_JKT": 168 | data = np.load(data_root + "jkt_cnt2category2cnt.npz", allow_pickle=True) 169 | elif data_config['dataset_name'] == "www_IST": 170 | data = np.load(data_root + "ist_cnt2category2cnt.npz", allow_pickle=True) 171 | elif data_config['dataset_name'] == "www_TKY" or data_config['dataset_name'] == "TSMC_www_TKY": 172 | data = np.load(data_root + "tky_cnt2category2cnt.npz", allow_pickle=True) 173 | else: 174 | data = np.load(data_root + "nyc_cnt2category2cnt.npz", allow_pickle=True) 175 | 176 | 177 | # cnt2category = data['cnt2category'] 178 | # print("cnt2category: ", type(cnt2category), cnt2category) # numpy.ndarry 'dict' 179 | # print("cnt2category: ", cnt2category.shape) # () 180 | # print("cnt2category: ", cnt2category.size) # size=1 181 | # assert(1==0) 182 | 183 | cnt2category = data['cnt2category'].item() # numpy.ndarray.item() category'index->category 184 | 185 | # redundant output 186 | # print("cnt2category: ", type(cnt2category), cnt2category) 187 | # word_vec, word_index, text_size = get_semantic_information(cnt2category, data_root) 188 | 189 | print('Loading data & Category vector...') 190 | data_train, data_val, data_test, feature_category, feature_lat, feature_lng, latN, lngN, category_cnt, category_vector = preprocess.load_dataset_for_MobilityLLM( 191 | dataset_name, save_split=split_save, data_root=data_root, device=device) 192 | print("feature_category: ", feature_category.shape, 193 | feature_category) # feature_category[venue's index] -> venue's category's index 194 | 195 | # Set the parameters for affine normalization layer depending on the decoder () 196 | trainY_tau_mean, trainY_tau_std = data_train.get_tau_log_mean_std_Y() 197 | print('trainY_tau_mean:', trainY_tau_mean, flush=True) 198 | print('trainY_tau_std:', trainY_tau_std, flush=True) 199 | 200 | collate = preprocess.collate_session_based # padding sequence with variable len 201 | 202 | dl_train = torch.utils.data.DataLoader(data_train, batch_size=train_batch, shuffle=True, 203 | collate_fn=collate, drop_last=True) 204 | dl_val = torch.utils.data.DataLoader(data_val, batch_size=val_batch, shuffle=False, collate_fn=collate, drop_last=True) 205 | dl_test = torch.utils.data.DataLoader(data_test, batch_size=test_batch, shuffle=False, collate_fn=collate, drop_last=True) 206 | #训练集比例裁剪 207 | all_train_length=len(dl_train) 208 | new_train_length=int(data_hist*all_train_length) 209 | 210 | # Model setup 211 | print('Building model...', flush=True) 212 | # General model config 213 | tim_size = 48 214 | fill_value = 0 215 | use_semantic = True 216 | if use_semantic: 217 | print("Use semantic information from venue name!") 218 | else: 219 | print("Don't use semantic information from venue name!") 220 | general_config = MobilityLLM_ModelConfig(loc_size=int(data_train.venue_cnt), tim_size=tim_size, 221 | uid_size=int(data_train.user_cnt), tim_emb_size=tim_emb_size, 222 | loc_emb_size=loc_emb_size, hidden_size=hidden_size, user_emb_size=user_emb_size, 223 | model_class=model_class, device=device, 224 | geohash_size=geohash_size, category_size=category_size, 225 | loc_noise_mean=loc_noise_mean, loc_noise_sigma=loc_noise_sigma, 226 | tim_noise_mean=tim_noise_mean, tim_noise_sigma=tim_noise_sigma, 227 | user_noise_mean=user_noise_mean, user_noise_sigma=user_noise_sigma, tau=tau, 228 | momentum=momentum, k=k, theta=theta, temperature=temperature, 229 | pos_eps=pos_eps, neg_eps=neg_eps, dropout_rate_1=dropout_rate_1, 230 | dropout_rate_2=dropout_rate_2, category_vector=category_vector, rnn_type=rnn_type, num_layers=num_layers, downstream=downstream, 231 | scale_init=trainY_tau_std, shift_init=trainY_tau_mean, max_delta_mins=max_delta_mins, loss=loss, tpp=tpp, dropout_spatial = dropout_spatial, 232 | epsilon=epsilon, learnable_param_size=learnable_param_size) 233 | # Define model 234 | model = MobilityLLM(general_config).to(device) 235 | print(model, flush=True) 236 | 237 | params_path = os.path.join('experiments', dataset_name.replace('(', '').replace(')', ''), specific_config) 238 | print('params_path:', params_path) 239 | 240 | if use_nni: 241 | exp_id = nni.get_experiment_id() 242 | trail_id = nni.get_trial_id() 243 | best_name = str(exp_id) + '.' + str(trail_id) + downstream + 'best.params' 244 | params_filename = os.path.join(params_path, best_name) 245 | else: 246 | best_name = downstream + '_best_'+model_class+'.params' 247 | params_filename = os.path.join(params_path, best_name) 248 | 249 | if mode == 'train': 250 | for p in model.parameters(): 251 | if p.dim() > 1: 252 | nn.init.xavier_uniform_(p) 253 | 254 | # # 没有temporal_encoder,注释掉 255 | # for p_t in model.temporal_encoder_momentum.parameters(): 256 | # p_t.requires_grad = False 257 | 258 | total_params = sum(p.numel() for p in model.parameters()) 259 | total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 260 | print("total_params:", total_params, flush=True) 261 | print("total_trainable_params:", total_trainable_params, flush=True) 262 | 263 | if os.path.exists(params_path): 264 | # shutil.rmtree(params_path) 265 | # os.makedirs(params_path) 266 | # print('delete the old one and create params directory %s' % (params_path), flush=True) 267 | print('already exist %s' % (params_path), flush=True) 268 | else: 269 | os.makedirs(params_path) 270 | print('create params directory %s' % (params_path), flush=True) 271 | 272 | print('Starting training...', flush=True) 273 | 274 | # build the queue 275 | queue = None 276 | queue_path = os.path.join(dump_path, "queue" + rank + ".pth") 277 | # if os.path.isfile(queue_path): 278 | # queue = torch.load(queue_path)["queue"] 279 | # # the queue needs to be divisible by the batch size 280 | # queue_length -= queue_length % (batch_size * world_size) 281 | 282 | impatient = 0 283 | best_hit20 = -np.inf 284 | best_tnll = np.inf 285 | best_model = deepcopy(model.state_dict()) 286 | global_step = 0 287 | best_epoch = -1 288 | # sw = SummaryWriter(logdir=params_path, flush_secs=5) 289 | opt = torch.optim.Adam(model.parameters(), weight_decay=regularization, lr=learning_rate, amsgrad=True) 290 | # opt = Adafactor(model.parameters()) 291 | start = time.time() 292 | for epoch in range(0, max_epochs): 293 | model.train() 294 | batch_cnt = 0 295 | # optionally starts a queue 296 | if queue_length > 0 and epoch >= epoch_queue_starts and queue is None: 297 | queue = torch.zeros( 298 | len(crops_for_assign), 299 | -queue_length // world_size, 300 | feat_dim, 301 | ).cuda() 302 | train_count=0 303 | for input in tqdm(dl_train): 304 | train_count+=1 305 | if train_count<=new_train_length: 306 | opt.zero_grad() 307 | # cosine similarity matrix can be big (up to 5GiB), consider cutting. 308 | batch_cnt += 1 309 | if input.X_all_loc.shape[1] >= 700: 310 | print(f'batch: {batch_cnt}, length: {input.X_all_loc.shape[1]}') 311 | continue 312 | if adv == 1: 313 | s_loss_score, top_k_pred, queue = model(input, mode='train', downstream=downstream, cont_conf=[1, 1, 1, 1], queue = queue) 314 | # print(top_k_pred, y) 315 | loss_total = (1 - self_weight_s - self_weight_t - self_weight_st) * s_loss_score 316 | else: 317 | s_loss_score, top_k_pred, queue = model(input, mode='train', downstream=downstream, queue = queue) 318 | loss_total = s_loss_score 319 | # print(s_loss_score) 320 | loss_total.backward() 321 | clip_grad_norm_(model.parameters(), max_norm=1.0) 322 | opt.step() 323 | 324 | # momentum update 325 | if model.momentum > 0: 326 | pass 327 | # # 没有temporal_encoder,注释掉 328 | # for param_momentum, param in zip(model.temporal_encoder_momentum.parameters(), model.temporal_encoder.parameters()): 329 | # param_momentum.data = param_momentum.data * model.momentum + (1. - model.momentum) * param.data 330 | 331 | global_step += 1 332 | if downstream == 'POI': 333 | ys = input.Y_location 334 | elif downstream == 'TUL': 335 | ys = input.X_users # (batch,) 336 | elif downstream == 'TPP': 337 | pass 338 | else: 339 | raise ValueError('downstream is not in [POI, TUL, TPP]') 340 | 341 | # print(f'batch: {batch_cnt}, total loss: {loss_total}, loss_ls: {[i for i in cont_loss_ls]}') 342 | loss_total = loss_total.item() 343 | # sw.add_scalar('training_loss_s', loss_total, global_step) 344 | if downstream != 'TPP': 345 | hit_ratio, mrr = evaluate_location(ys.cpu().numpy(), top_k_pred.cpu().numpy()) # [k] 346 | # sw.add_scalar('training_mrr', mrr, global_step) 347 | # sw.add_scalar('training_hit_1', hit_ratio[0], global_step) 348 | # sw.add_scalar('training_hit_20', hit_ratio[19], global_step) 349 | if queue is not None: 350 | torch.save({"queue": queue}, queue_path) 351 | 352 | model.eval() 353 | with torch.no_grad(): 354 | if downstream != 'TPP': 355 | all_loss_s_val, hit_ratio_val, mrr_val = get_s_baselines_total_loss_s_for_MobilityLLM_DOWN(dl_val, model, downstream=downstream) 356 | if (hit_ratio_val[19] - best_hit20) < 1e-4: 357 | impatient += 1 358 | if best_hit20 < hit_ratio_val[19]: 359 | best_hit20 = hit_ratio_val[19] 360 | best_model = deepcopy(model.state_dict()) 361 | best_epoch = epoch 362 | else: 363 | best_hit20 = hit_ratio_val[19] 364 | best_model = deepcopy(model.state_dict()) 365 | best_epoch = epoch 366 | impatient = 0 367 | else: 368 | mae_val, mape_val, rmse_val, nll_t_val = get_t_for_IFLTPP(dl_val, model) 369 | if (best_tnll - nll_t_val) < 1e-4: 370 | impatient += 1 371 | if nll_t_val < best_tnll: 372 | best_tnll = nll_t_val 373 | best_model = deepcopy(model.state_dict()) 374 | best_epoch = epoch 375 | else: 376 | best_tnll = nll_t_val 377 | best_model = deepcopy(model.state_dict()) 378 | best_epoch = epoch 379 | impatient = 0 380 | 381 | if impatient >= patience: 382 | print('Breaking due to early stopping at epoch %d,best epoch at %d' % (epoch, best_epoch), flush=True) 383 | break 384 | 385 | if epoch % display_step == 0: 386 | if adv == 1: 387 | if downstream != 'TPP': 388 | print('Epoch %4d, train_loss=%.4f, val_loss=%.4f, val_mrr=%.4f, val_hit_1=%.4f, val_hit_20=%.4f' % ( 389 | epoch, loss_total, all_loss_s_val, mrr_val, hit_ratio_val[0], hit_ratio_val[19]), flush=True) 390 | else: 391 | print('Epoch %4d, train_tnll=%.4f, val_tnll=%.4f, val_mae=%.4f, val_rmse=%.4f, val_mape=%.4f' % ( 392 | epoch, loss_total, nll_t_val, mae_val, rmse_val, mape_val), flush=True) 393 | else: 394 | if downstream != 'TPP': 395 | print('Epoch %4d, train_loss=%.4f, val_loss=%.4f, val_mrr=%.4f, val_hit_1=%.4f, val_hit_20=%.4f' % ( 396 | epoch, loss_total, all_loss_s_val, mrr_val, hit_ratio_val[0], hit_ratio_val[19]), flush=True) 397 | else: 398 | print('Epoch %4d, train_tnll=%.4f, val_tnll=%.4f, val_mae=%.4f, val_rmse=%.4f, val_mape=%.4f' % ( 399 | epoch, loss_total, nll_t_val, mae_val, rmse_val, mape_val), flush=True) 400 | 401 | if use_nni: 402 | if downstream != 'TPP': 403 | nni.report_intermediate_result(hit_ratio_val[19]) 404 | else: 405 | nni.report_intermediate_result(mae_val) 406 | 407 | torch.save(best_model, params_filename) 408 | 409 | print("best epoch at %d" % best_epoch, flush=True) 410 | print('save parameters to file: %s' % params_filename, flush=True) 411 | print("training time: ", time.time() - start) 412 | 413 | ### Evaluation 414 | print('----- test ----') 415 | model.load_state_dict(torch.load(params_filename)) 416 | model.eval() 417 | with torch.no_grad(): 418 | if downstream != 'TPP': 419 | train_all_loss_s, train_hit_ratio, train_mrr = get_s_baselines_total_loss_s_for_MobilityLLM_DOWN(dl_train, model, downstream=downstream) 420 | val_all_loss_s, val_hit_ratio, val_mrr = get_s_baselines_total_loss_s_for_MobilityLLM_DOWN(dl_val, model, downstream=downstream) 421 | test_all_loss_s, test_hit_ratio, test_mrr = get_s_baselines_total_loss_s_for_MobilityLLM_DOWN(dl_test, model, downstream=downstream) 422 | 423 | print('Dataset\t loss\t hit_1\t hit_3\t hit_5\t hit_7\t hit_10\t hit_15\t hit_20\t MRR\t\n' + 424 | 'Train:\t %.4f\t %.4f\t %.4f\t %.4f\t %.4f\t %.4f\t %.4f\t %.4f\t %.4f\t\n' % ( 425 | train_all_loss_s, train_hit_ratio[0], train_hit_ratio[2], train_hit_ratio[4], train_hit_ratio[6], 426 | train_hit_ratio[9], train_hit_ratio[14], train_hit_ratio[19], train_mrr) + 427 | 'Val:\t %.4f\t %.4f\t %.4f\t %.4f\t %.4f\t %.4f\t %.4f\t %.4f\t %.4f\t\n' % ( 428 | val_all_loss_s, val_hit_ratio[0], val_hit_ratio[2], val_hit_ratio[4], val_hit_ratio[6], val_hit_ratio[9], 429 | val_hit_ratio[14], val_hit_ratio[19], val_mrr) + 430 | 'Test:\t %.4f\t %.4f\t %.4f\t %.4f\t %.4f\t %.4f\t %.4f\t %.4f\t %.4f\t\n' % ( 431 | test_all_loss_s, test_hit_ratio[0], test_hit_ratio[2], test_hit_ratio[4], test_hit_ratio[6], 432 | test_hit_ratio[9], test_hit_ratio[14], test_hit_ratio[19], test_mrr), flush=True) 433 | else: 434 | train_mae, train_mape, train_rmse, train_nll_t = get_t_for_IFLTPP(dl_train, model, save_filename='train', 435 | params_path=params_path, 436 | use_nni=use_nni) 437 | val_mae, val_mape, val_rmse, val_nll_t = get_t_for_IFLTPP(dl_val, model, save_filename='val', 438 | params_path=params_path, 439 | use_nni=use_nni) 440 | test_mae, test_mape, test_rmse, test_nll_t = get_t_for_IFLTPP(dl_test, model, save_filename='test', 441 | params_path=params_path, 442 | use_nni=use_nni) 443 | 444 | print('Dataset\t MAE\t RMSE\t MAPE\t TNll\t\n' + 445 | 'Train:\t %.4f\t %.4f\t %.4f\t %.4f\t\n' % (train_mae, train_rmse, train_mape, train_nll_t) + 446 | 'Val:\t %.4f\t %.4f\t %.4f\t %.4f\t\n' % (val_mae, val_rmse, val_mape, val_nll_t) + 447 | 'Test:\t %.4f\t %.4f\t %.4f\t %.4f\t\n' % (test_mae, test_rmse, test_mape, test_nll_t), flush=True) 448 | 449 | if use_nni: 450 | if downstream != 'TPP': 451 | nni.report_final_result(val_hit_ratio[19]) 452 | else: 453 | nni.report_final_result(val_mae) 454 | -------------------------------------------------------------------------------- /model/MobilityLLM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence 5 | from torch.nn.utils.rnn import pad_packed_sequence 6 | from utils import DotDict 7 | from model.utils import * 8 | import torch.nn.functional as F 9 | from torch import nn 10 | #import faiss 11 | from transformers import AutoModelForCausalLM,AutoTokenizer 12 | from .llm import LLMModel 13 | 14 | class MobilityLLM_ModelConfig(DotDict): 15 | ''' 16 | configuration of the MobilityLLM 17 | ''' 18 | 19 | def __init__(self, loc_size=None, tim_size=None, uid_size=None, geohash_size=None, category_size=None, tim_emb_size=None, loc_emb_size=None, 20 | hidden_size=None, user_emb_size=None, model_class=None, device=None, 21 | loc_noise_mean=None, loc_noise_sigma=None, tim_noise_mean=None, tim_noise_sigma=None, 22 | user_noise_mean=None, user_noise_sigma=None, tau=None, 23 | pos_eps=None, neg_eps=None, dropout_rate_1=None, dropout_rate_2=None, category_vector=None, rnn_type='BiLSTM', 24 | num_layers=3, k=8, momentum=0.95, temperature=0.1, theta=0.18, 25 | n_components=4, shift_init=0.0, scale_init=0.0, min_clip=-5., max_clip=3., hypernet_hidden_sizes=None, max_delta_mins=1440, 26 | downstream='POI',tpp='pdf',loss='pdf', dropout_spatial = None, epsilon = None, learnable_param_size=1): 27 | super().__init__() 28 | self.max_delta_mins = max_delta_mins 29 | 30 | self.loc_size = loc_size # 31 | self.uid_size = uid_size # 32 | self.tim_size = tim_size # 33 | self.geohash_size = geohash_size # 34 | self.category_size = category_size # 35 | self.loc_emb_size = loc_emb_size 36 | self.tim_emb_size = tim_emb_size 37 | self.user_emb_size = user_emb_size 38 | self.hidden_size = hidden_size # RNN hidden_size 39 | self.model_class = model_class 40 | self.device = device 41 | self.rnn_type = rnn_type 42 | self.num_layers = num_layers 43 | 44 | self.loc_noise_mean = loc_noise_mean 45 | self.loc_noise_sigma = loc_noise_sigma 46 | self.tim_noise_mean = tim_noise_mean 47 | self.tim_noise_sigma = tim_noise_sigma 48 | self.user_noise_mean = user_noise_mean 49 | self.user_noise_sigma = user_noise_sigma 50 | self.tau = tau 51 | self.pos_eps = pos_eps 52 | self.neg_eps = neg_eps 53 | self.dropout_rate_1 = dropout_rate_1 54 | self.dropout_rate_2 = dropout_rate_2 55 | self.downstream = downstream 56 | self.category_vector = category_vector 57 | self.learnable_param_size=learnable_param_size 58 | 59 | self.k = k 60 | self.momentum = momentum 61 | self.theta = theta 62 | self.temperature = temperature 63 | 64 | self.n_components = n_components #需要 65 | self.min_clip = min_clip #需要 无输入 66 | self.max_clip = max_clip #需要 无输入 67 | self.shift_init = shift_init 68 | self.scale_init = scale_init 69 | self.hypernet_hidden_sizes = hypernet_hidden_sizes #需要 无输入 70 | self.decoder_input_size = user_emb_size + hidden_size * 2 #需要 71 | self.loss = loss 72 | self.tpp = tpp 73 | self.dropout_spatial = dropout_spatial 74 | self.epsilon = epsilon 75 | 76 | class MobilityLLM(nn.Module): 77 | def __init__(self, config): 78 | super(MobilityLLM, self).__init__() 79 | # initialize parameters 80 | self.max_delta_mins = config['max_delta_mins'] 81 | self.truth_Y_tau = None 82 | self.loc_size = config['loc_size'] 83 | self.loc_emb_size = config['loc_emb_size'] 84 | self.tim_size = config['tim_size'] 85 | self.tim_emb_size = config['tim_emb_size'] 86 | self.user_size = config['uid_size'] 87 | self.user_emb_size = config['user_emb_size'] 88 | 89 | self.category_size = config['category_size'] 90 | self.geohash_size = config['geohash_size'] 91 | self.category_vector = config['category_vector'] 92 | self.learnable_param_size = config['learnable_param_size'] # 可学习参数个数 93 | 94 | self.hidden_size = config['hidden_size'] 95 | self.rnn_type = config['rnn_type'] 96 | self.num_layers = config['num_layers'] 97 | self.device = config['device'] 98 | self.model_class = config['model_class'] 99 | self.downstream = config['downstream'] 100 | 101 | # parameters for cluster contrastive learning 102 | self.k = config['k'] 103 | 104 | # parameters for time contrastive learning (Angle & Momentum based) 105 | # momentum 106 | self.momentum = config['momentum'] 107 | # angle 108 | self.theta = config['theta'] 109 | # self.theta = 0.05 110 | self.temperature = config['temperature'] 111 | 112 | # spatial 113 | self.user_centroids = None 114 | self.user_2cluster = None 115 | self.item_centroids = None 116 | self.item_2cluster = None 117 | self.softmax = nn.Softmax() 118 | self.epsilon = config['epsilon'] 119 | self.sinkhorn_iterations = 3 120 | self.crops_for_assign = [0, 1] 121 | self.nmb_crops = [2] 122 | self.world_size = -1 123 | self.dropout = nn.Dropout(0.3) 124 | # todo 这里的0.1用超参调 125 | self.l2norm = True 126 | # location all size (embedding + geohash + category) 127 | self.rnn_input_size = self.loc_emb_size + self.geohash_size 128 | if self.rnn_type == 'BiLSTM': 129 | self.bi = 2 130 | else: 131 | self.bi = 1 132 | 133 | # parameters for social contrastive learning (4 group of parameters) 134 | self.para0 = nn.Parameter(torch.randn(1, 6)) 135 | self.para1 = nn.Parameter(torch.randn(1, 4)) 136 | self.para2 = nn.Parameter(torch.randn(1, 24)) 137 | self.para3 = nn.Parameter(torch.randn(1, 16)) 138 | 139 | # parameters for TPP 140 | self.shift_init = config['shift_init'] 141 | self.scale_init = config['scale_init'] 142 | self.min_clip = config['min_clip'] 143 | self.max_clip = config['max_clip'] 144 | 145 | ############################################## 146 | self.loc_noise_mean = config['loc_noise_mean'] 147 | self.loc_noise_sigma = config['loc_noise_sigma'] 148 | self.tim_noise_mean = config['tim_noise_mean'] 149 | self.tim_noise_sigma = config['tim_noise_sigma'] 150 | self.user_noise_mean = config['user_noise_mean'] 151 | self.user_noise_sigma = config['user_noise_sigma'] 152 | 153 | self.tau = config['tau'] 154 | self.pos_eps = config['pos_eps'] 155 | self.neg_eps = config['neg_eps'] 156 | self.dropout_rate_1 = config['dropout_rate_1'] 157 | self.dropout_rate_2 = config['dropout_rate_2'] 158 | 159 | self.dropout_1 = nn.Dropout(self.dropout_rate_1) 160 | self.dropout_2 = nn.Dropout(self.dropout_rate_2) 161 | ################################################ 162 | self.tpp = config['tpp'] 163 | self.loss = config['loss'] 164 | self.mae = torch.nn.L1Loss() 165 | # Embedding layer 166 | self.emb_loc = nn.Embedding(num_embeddings=self.loc_size, embedding_dim=self.loc_emb_size) 167 | self.emb_tim = nn.Embedding(num_embeddings=self.tim_size, embedding_dim=self.tim_emb_size) 168 | self.emb_user = nn.Embedding(num_embeddings=self.user_size, embedding_dim=self.user_emb_size) 169 | 170 | # Category dense layer 171 | self.category_dense = nn.Linear(768, self.category_size) 172 | # Geohash dense layer 173 | self.geohash_dense = nn.Linear(12, self.geohash_size) 174 | 175 | # rnn layer 176 | self.spatial_encoder = LLMModel(model_path = "/data/ZhangXinyue/MobilityLLM/params/"+ self.model_class, model_class= self.model_class, loc_size = self.loc_size, learnable_param_size = self.learnable_param_size, device = self.device) 177 | # if self.rnn_type == 'GRU': 178 | # self.spatial_encoder = nn.GRU(self.rnn_input_size, self.hidden_size, num_layers=self.num_layers, 179 | # batch_first=False) 180 | # self.temporal_encoder = nn.GRU(self.tim_emb_size + 1, self.hidden_size, num_layers=self.num_layers, 181 | # batch_first=False) 182 | # self.temporal_encoder_momentum = nn.GRU(self.tim_emb_size + 1, self.hidden_size, num_layers=self.num_layers, 183 | # batch_first=False) 184 | # elif self.rnn_type == 'LSTM': 185 | # self.spatial_encoder = nn.LSTM(self.rnn_input_size, self.hidden_size, num_layers=self.num_layers, 186 | # batch_first=False) 187 | # elif self.rnn_type == 'BiLSTM': 188 | # self.spatial_encoder = nn.LSTM(self.rnn_input_size, self.hidden_size, num_layers=self.num_layers, 189 | # batch_first=False, bidirectional=True) 190 | # else: 191 | # raise ValueError("rnn_type should be ['GRU', 'LSTM', 'BiLSTM']") 192 | 193 | #spatial_adv 194 | # prototype layer 195 | self.prototypes = None 196 | if isinstance(self.k, list): 197 | self.prototypes = MultiPrototypes(self.hidden_size, self.k) 198 | elif self.k > 0: 199 | self.prototypes = nn.Linear(self.hidden_size, self.k, bias=False) 200 | 201 | # projection head 202 | self.projection_head = nn.Sequential( 203 | nn.Linear(self.hidden_size, 2048), 204 | nn.BatchNorm1d(2048), 205 | nn.ReLU(inplace=True), 206 | nn.Linear(2048, self.hidden_size), 207 | ) 208 | 209 | # Hypernet module for TPP 210 | self.hypernet = Hypernet(config, hidden_sizes=config.hypernet_hidden_sizes, 211 | param_sizes=[config.n_components, config.n_components, config.n_components]) 212 | # linear for TPP 213 | self.linear_p1 = nn.Linear(self.hidden_size + self.user_emb_size,1024) 214 | self.linear_p2 = nn.Linear(1024 , 4) 215 | self.linear_m1 = nn.Linear(self.hidden_size + self.user_emb_size,986) 216 | self.linear_m2 = nn.Linear(986 , 4) 217 | self.linear_l1 = nn.Linear(self.hidden_size + self.user_emb_size,1560) 218 | self.linear_l2 = nn.Linear(1560 , 4) 219 | self.Tanh = nn.Tanh() 220 | self.sigmoid = nn.Sigmoid() 221 | 222 | # dense layer 223 | self.s2st_projection = nn.Linear(self.hidden_size * self.bi, self.hidden_size * self.bi) 224 | if self.downstream == 'TUL': 225 | self.dense = nn.Linear(in_features=self.hidden_size * self.bi, out_features=self.user_size) 226 | self.t2st_projection = nn.Linear(self.hidden_size * self.bi, self.hidden_size * self.bi) 227 | self.projection = nn.Sequential(nn.Linear(self.hidden_size * self.bi, self.hidden_size * self.bi), nn.ReLU()) 228 | elif self.downstream == 'POI': 229 | self.projection = nn.Sequential( 230 | nn.Linear(self.hidden_size * self.bi + self.user_emb_size, self.hidden_size * self.bi + self.user_emb_size), 231 | nn.ReLU()) 232 | self.dense = nn.Linear(in_features=self.hidden_size * self.bi + self.user_emb_size, out_features=self.loc_size) 233 | self.t2st_projection = nn.Linear(self.hidden_size * self.bi + self.user_emb_size, self.hidden_size * self.bi) 234 | elif self.downstream == 'TPP': 235 | self.projection = nn.Sequential( 236 | nn.Linear(self.hidden_size * self.bi + self.user_emb_size, 237 | self.hidden_size * self.bi + self.user_emb_size), 238 | nn.ReLU()) 239 | self.t2st_projection = nn.Linear(self.hidden_size * self.bi + self.user_emb_size, 240 | self.hidden_size * self.bi) 241 | final_in_size = self.hidden_size * self.bi + self.user_emb_size 242 | self.dense_s = nn.Linear(in_features=self.hidden_size, 243 | out_features=self.hidden_size) 244 | self.dense_t = nn.Linear(in_features=self.hidden_size * self.bi + self.user_emb_size, 245 | out_features=self.hidden_size * self.bi + self.user_emb_size) 246 | self.dense_st = nn.Linear(in_features=self.hidden_size * self.bi + self.user_emb_size, 247 | out_features=self.hidden_size * self.bi + self.user_emb_size) 248 | self.dense = nn.Sequential( 249 | nn.Linear(in_features=final_in_size, out_features=final_in_size // 4), 250 | nn.LeakyReLU(), 251 | nn.Linear(in_features=final_in_size // 4, out_features=final_in_size // 16), 252 | nn.LeakyReLU(), 253 | nn.Linear(in_features=final_in_size // 16, out_features=1), 254 | ) 255 | self.linear1 = nn.Linear(self.hidden_size + self.user_emb_size, (self.hidden_size + self.user_emb_size) // 4) 256 | self.linear2 = nn.Linear((self.hidden_size + self.user_emb_size) // 4, 257 | (self.hidden_size + self.user_emb_size) // 16) 258 | self.linear3 = nn.Linear((self.hidden_size + self.user_emb_size) // 16, 1) 259 | self.linear0 = nn.Linear((self.hidden_size) * 2, self.hidden_size) 260 | else: 261 | raise ValueError('downstream should in [TUL, POI, TPP]!') 262 | 263 | self.apply(self._init_weight) 264 | 265 | def _init_weight(self, module): 266 | if isinstance(module, nn.Embedding): 267 | nn.init.xavier_normal_(module.weight) 268 | elif isinstance(module, nn.Linear): 269 | nn.init.xavier_uniform_(module.weight) 270 | elif isinstance(module, nn.LSTM): 271 | for name, param in module.named_parameters(): 272 | if 'weight_ih' in name: 273 | nn.init.xavier_uniform_(param.data) 274 | elif 'weight_hh' in name: 275 | nn.init.orthogonal_(param.data) 276 | elif 'bias' in name: 277 | nn.init.constant_(param.data, 0) 278 | 279 | def spatial_encode(self, x, time, category, geohash_, all_len, cur_len, batch_size, momentum=False, downstream='POI'): 280 | if momentum == True: 281 | f_encoder = self.spatial_encoder_momentum 282 | else: 283 | f_encoder = self.spatial_encoder 284 | # self-attention (mask) 285 | spatial_out = self.spatial_encoder(x, torch.tensor(all_len).to(self.device), time, category, geohash_) # +time, category 286 | # if self.rnn_type == 'GRU': 287 | # spatial_out, h_n = f_encoder(packed_stuff) # max_len*batch*hidden_size 288 | # elif self.rnn_type == 'LSTM': 289 | # spatial_out, (h_n, c_n) = f_encoder(packed_stuff) # max_len*batch*hidden_size 290 | # elif self.rnn_type == 'BiLSTM': 291 | # spatial_out, (h_n, c_n) = f_encoder(packed_stuff) # max_len*batch*hidden_size 292 | # else : 293 | # raise ValueError('rnn type is not in GRU, LSTM, BiLSTM! ') 294 | 295 | # # unpack 296 | # spatial_out, out_len = pad_packed_sequence(spatial_out, batch_first=False) 297 | # spatial_out = spatial_out.permute(1, 0, 2) 298 | 299 | # out_len即all_len batch*max_len*hidden_size 300 | # concatenate 301 | if downstream == 'POI': 302 | final_out = spatial_out[0, (all_len[0] - cur_len[0]): all_len[0], :] 303 | for i in range(1, batch_size): 304 | final_out = torch.cat([final_out, spatial_out[i, (all_len[i] - cur_len[i]): all_len[i], :]], dim=0) 305 | # No longer concate user embedding 306 | # final_out = torch.cat([final_out, all_user_emb], 1) 307 | elif downstream == 'TPP': 308 | if all_len[0] == cur_len[0]: 309 | left = all_len[0] - cur_len[0] 310 | right = all_len[0] 311 | else: 312 | left = all_len[0] - cur_len[0] - 1 313 | right = all_len[0] - 1 314 | final_out = spatial_out[0, left: right, :] 315 | for i in range(1, batch_size): 316 | if all_len[i] == cur_len[i]: 317 | left = all_len[i] - cur_len[i] 318 | right = all_len[i] 319 | else: 320 | left = all_len[i] - cur_len[i] - 1 321 | right = all_len[i] - 1 322 | final_out = torch.cat([final_out, spatial_out[i, left: right, :]], dim=0) 323 | elif downstream == 'TUL': 324 | final_out = spatial_out[0, (all_len[0] - 1): all_len[0], :] 325 | slice_tensor = torch.mean(spatial_out[0, : all_len[0], :],dim=0,keepdim=True) 326 | #print("final_out_shape:",final_out.shape) 327 | #print("slice_tensor_shape:",slice_tensor.shape) 328 | for i in range(1, batch_size): 329 | final_out = torch.cat([final_out, torch.mean(spatial_out[i, : all_len[i], :],dim=0,keepdim=True)], dim=0) 330 | else: 331 | raise ValueError('downstream is not in [POI, TUL, TPP]') 332 | return final_out 333 | 334 | ''' 335 | def temporal_encode(self, packed_stuff, all_len, cur_len, batch_size, momentum=False, downstream='POI'): 336 | if momentum == True: 337 | f_encoder = self.temporal_encoder_momentum 338 | else: 339 | f_encoder = self.temporal_encoder 340 | if self.rnn_type == 'GRU': 341 | temporal_out, h_n = f_encoder(packed_stuff) # max_len*batch*hidden_size 342 | elif self.rnn_type == 'LSTM': 343 | temporal_out, (h_n, c_n) = f_encoder(packed_stuff) # max_len*batch*hidden_size 344 | elif self.rnn_type == 'BiLSTM': 345 | temporal_out, (h_n, c_n) = f_encoder(packed_stuff) # max_len*batch*hidden_size 346 | else : 347 | raise ValueError('rnn type is not in GRU, LSTM, BiLSTM! ') 348 | 349 | # unpack 350 | temporal_out, out_len = pad_packed_sequence(temporal_out, batch_first=False) 351 | temporal_out = temporal_out.permute(1, 0, 2) 352 | 353 | # out_len即all_len batch*max_len*hidden_size 354 | # concatenate 355 | if downstream == 'POI': 356 | final_out = temporal_out[0, (all_len[0] - cur_len[0]): all_len[0], :] 357 | for i in range(1, batch_size): 358 | final_out = torch.cat([final_out, temporal_out[i, (all_len[i] - cur_len[i]): all_len[i], :]], dim=0) 359 | 360 | elif downstream == 'TPP': 361 | if all_len[0] == cur_len[0]: 362 | left = all_len[0] - cur_len[0] 363 | right = all_len[0] 364 | else: 365 | left = all_len[0] - cur_len[0] - 1 366 | right = all_len[0] - 1 367 | final_out = temporal_out[0, left: right, :] 368 | for i in range(1, batch_size): 369 | if all_len[i] == cur_len[i]: 370 | left = all_len[i] - cur_len[i] 371 | right = all_len[i] 372 | else: 373 | left = all_len[i] - cur_len[i] - 1 374 | right = all_len[i] - 1 375 | final_out = torch.cat([final_out, temporal_out[i, left: right, :]], dim=0) 376 | elif downstream == 'TUL': 377 | final_out = temporal_out[0, (all_len[0] - 1): all_len[0], :] 378 | for i in range(1, batch_size): 379 | final_out = torch.cat([final_out, temporal_out[i, (all_len[i] - 1): all_len[i], :]], dim=0) 380 | else: 381 | raise ValueError('downstream is not in [POI, TUL, TPP]') 382 | 383 | return final_out 384 | ''' 385 | 386 | 387 | def normal_logpdf(self, x, mean, log_scale): 388 | ''' 389 | log pdf of the normal distribution with mean and log_sigma 390 | ''' 391 | # z = (x - mean[:,0:1]) * torch.exp(-log_scale[:,0:1]) 392 | z = (x - mean) * torch.exp(-log_scale) 393 | return -log_scale - 0.5 * z.pow(2.0) - 0.5 * np.log(2 * np.pi) 394 | 395 | def mixnormal_logpdf(self, x, log_prior, means, log_scales): 396 | ''' 397 | :param x: ground truth 398 | :param log_prior: 归一化后的权重系数,(batch,max_length/actual_length,n_components)=(64, 128, 64), 399 | :param means: (batch,max_length/actual_length,n_components)=(64, 128, 64) 400 | :param log_scales: (batch,max_length/actual_length,n_components)=(64, 128, 64), scales对应论文中的s 401 | :return: 402 | ''' 403 | return torch.logsumexp( 404 | log_prior + self.normal_logpdf(x, means, log_scales), 405 | dim=-1 406 | ) 407 | 408 | def get_params(self, decoder_input): 409 | """ 410 | Generate model parameters based on the inputs 411 | Args: 412 | input: decoder input [batch, decoder_input_size] 413 | 414 | Returns: 415 | prior_logits: shape [batch, n_components] 416 | means: shape [batch, n_components] 417 | log_scales: shape [batch, n_components] 418 | """ 419 | prior_logits, means, log_scales = self.hypernet(decoder_input) 420 | 421 | # Clamp values that go through exp for numerical stability 422 | prior_logits = clamp_preserve_gradients(prior_logits, self.min_clip, self.max_clip) 423 | log_scales = clamp_preserve_gradients(log_scales, self.min_clip, self.max_clip) 424 | 425 | # normalize prior_logits 426 | prior_logits = F.log_softmax(prior_logits, dim=-1) # 这里进行了权重w的归一化,而且之后进行了log 427 | return prior_logits, means, log_scales 428 | 429 | 430 | 431 | def forward(self, batch, mode='test', cont_conf=None, downstream='POI', queue = None, use_the_queue = False): 432 | 433 | 434 | loc = batch.X_all_loc 435 | 436 | 437 | tim = batch.X_all_tim 438 | user = batch.X_users 439 | geohash_ = batch.X_all_geohash 440 | cur_len = batch.target_lengths 441 | all_len = batch.X_lengths 442 | loc_cat = batch.X_all_loc_category 443 | 444 | batch_size = loc.shape[0] 445 | loc_emb = self.emb_loc(loc) 446 | tim_emb = self.emb_tim(tim) 447 | user_emb = self.emb_user(user) 448 | geohash_ = self.geohash_dense(geohash_) 449 | 450 | # concatenate 451 | #x = torch.cat([loc_emb, tim_emb], dim=2) 452 | x = loc.to(self.device) 453 | time = tim.float().to(self.device) 454 | # x = torch.cat([loc_emb, geohash_], dim=2).permute(1, 0, 2) 455 | # 456 | # x_temporal = tim_emb.permute(1, 0, 2) 457 | 458 | # cat locs & taus 459 | all_tau = [torch.cat((batch.X_tau[0, :all_len[0] - cur_len[0]], batch.Y_tau[0, :cur_len[0]]), dim=-1)] 460 | self.truth_Y_tau = all_tau[0][all_len[0] - cur_len[0]:all_len[0]] 461 | 462 | for i in range(1, batch_size): 463 | # taus 464 | cur_tau = torch.cat((batch.X_tau[i, :all_len[i] - cur_len[i]], batch.Y_tau[i, :cur_len[i]]), dim=-1) 465 | all_tau.append(cur_tau) 466 | 467 | self.truth_Y_tau = torch.cat((self.truth_Y_tau, all_tau[i][all_len[i] - cur_len[i]:all_len[i]]), dim=0) 468 | 469 | all_tau = pad_sequence(all_tau, batch_first=False).to(self.device) 470 | # x_temporal = torch.cat((all_tau.unsqueeze(-1), x_temporal), dim=-1) 471 | 472 | # # pack 473 | # pack_x = pack_padded_sequence(x, lengths=all_len, enforce_sorted=False) 474 | # pack_x_temporal = pack_padded_sequence(x_temporal, lengths=all_len, enforce_sorted=False) 475 | 476 | final_out = self.spatial_encode(x, all_tau.transpose(0,1), loc_cat, geohash_, all_len, cur_len, batch_size, downstream=downstream) 477 | # final_temporal_out = self.temporal_encode(pack_x_temporal, all_len, cur_len, batch_size, downstream=downstream) 478 | 479 | all_user_emb = user_emb[0].unsqueeze(dim=0).repeat(cur_len[0], 1) 480 | for i in range(1, batch_size): 481 | all_user_emb = torch.cat([all_user_emb, user_emb[i].unsqueeze(dim=0).repeat(cur_len[i], 1)], dim=0) 482 | 483 | #todo 去掉时间,试一下结果 484 | if downstream == 'POI': 485 | prediction_out = torch.cat([final_out, all_user_emb], 1) 486 | dense = self.dense(prediction_out) # Batch * loc_size 487 | pred = nn.LogSoftmax(dim=1)(dense) # result 488 | elif downstream == 'TUL': 489 | # prediction_out = torch.cat([final_spatial_out, final_temporal_out], 1) 490 | dense = self.dense(self.dropout(final_out)) 491 | pred = nn.LogSoftmax(dim=1)(dense) # result 492 | elif downstream == 'TPP': 493 | final_out = torch.cat([final_out, all_user_emb], 1) 494 | prediction_out = self.dense(final_out) # Batch * loc_size 495 | else: 496 | raise ValueError('downstream is not in [POI, TUL, TPP]') 497 | 498 | 499 | criterion = nn.NLLLoss().to(self.device) 500 | criterion1 = nn.L1Loss().to(self.device) 501 | 502 | 503 | if downstream == 'POI': 504 | s_loss_score = criterion(pred, batch.Y_location).requires_grad_(True) 505 | _, top_k_pred = torch.topk(pred, k=self.loc_size) # (batch, K)=(batch, num_class) 506 | elif downstream == 'TUL': 507 | s_loss_score = criterion(pred, batch.X_users).requires_grad_(True) 508 | _, top_k_pred = torch.topk(pred, k=self.user_size) # (batch, K)=(batch, num_class) 509 | elif downstream == 'TPP': 510 | # y = torch.log(self.truth_Y_tau + 1e-2).unsqueeze(-1) 511 | 512 | mean_time = self.linear3(self.sigmoid(self.linear2(self.sigmoid(self.linear1(final_out))))) 513 | # loss = self.mae(mean_time, y.to(mean_time.device)) 514 | 515 | # mean_time = torch.exp(mean_time) 516 | # y = torch.log(self.truth_Y_tau + 1e-2).unsqueeze(-1) 517 | # y = (y - self.shift_init.to(y.device)) / self.scale_init.to(y.device) 518 | # if self.tpp == 'pdf': 519 | # prior_logits, means, log_scales = self.get_params(prediction_out) 520 | # #todo 用可学习参数*pred_out代替以上三个 521 | # elif self.tpp == 'linear': 522 | # prior_logits = self.linear_p2(self.Tanh(self.linear_p1(prediction_out))) 523 | # means = self.linear_m2(self.Tanh(self.linear_m1(prediction_out))) 524 | # log_scales = self.linear_l2(self.Tanh(self.linear_l1(prediction_out))) 525 | # # 使用线性层替代 526 | # log_p = self.mixnormal_logpdf(y.to(self.device), prior_logits, means, log_scales) 527 | # prior = prior_logits.exp() # (batch, n_components=64) 528 | # scales_squared = (log_scales * 2).exp() 529 | # a = self.scale_init.to(y.device) 530 | # b = self.shift_init.to(y.device) 531 | # mean_time = (prior * torch.exp(a * means + b + 0.5 * a ** 2 * scales_squared)).sum(-1) 532 | # mean_time = torch.clip(mean_time, max=float(self.max_delta_mins), min=0.0) 533 | # if self.loss == 'pdf': 534 | # s_loss_score = -log_p.mean() 535 | if self.loss == 'mae': 536 | # s_loss_score = criterion1(prediction_out.to(y.device),(self.truth_Y_tau).to(y.device)) 537 | s_loss_score = criterion1(prediction_out.squeeze().to('cpu'),self.truth_Y_tau.to('cpu')) 538 | top_k_pred = prediction_out 539 | else: 540 | raise ValueError('downstream is not in [POI, TUL, TPP]') 541 | 542 | if mode == 'train' and sum(cont_conf) != 0: 543 | return s_loss_score, top_k_pred, queue 544 | else: 545 | return s_loss_score, top_k_pred, queue 546 | 547 | class MultiPrototypes(nn.Module): 548 | def __init__(self, output_dim, nmb_prototypes): 549 | super(MultiPrototypes, self).__init__() 550 | self.nmb_heads = len(nmb_prototypes) 551 | for i, k in enumerate(nmb_prototypes): 552 | self.add_module("prototypes" + str(i), nn.Linear(output_dim, k, bias=False)) 553 | 554 | def forward(self, x): 555 | out = [] 556 | for i in range(self.nmb_heads): 557 | out.append(getattr(self, "prototypes" + str(i))(x)) 558 | return out -------------------------------------------------------------------------------- /model/phi_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | # 4 | # Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu. 5 | # Licensed under the BSD 3-Clause License. 6 | 7 | from __future__ import annotations 8 | 9 | import math 10 | from dataclasses import dataclass, field 11 | from typing import Any, Dict, Optional, Tuple, Union 12 | 13 | import torch 14 | import torch.nn as nn 15 | from einops import rearrange, repeat 16 | from transformers import PretrainedConfig, PreTrainedModel 17 | from transformers.activations import ACT2FN 18 | from transformers.modeling_outputs import CausalLMOutputWithPast 19 | 20 | from .configuration_phi import PhiConfig 21 | 22 | try: 23 | from flash_attn.bert_padding import pad_input, unpad_input 24 | from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding 25 | from flash_attn.modules.mha import FlashCrossAttention, FlashSelfAttention 26 | from flash_attn.ops.fused_dense import FusedDense 27 | except: 28 | pad_input, unpad_input = None, None 29 | FlashRotaryEmbedding = None 30 | FlashSelfAttention, FlashCrossAttention = None, None 31 | FusedDense = None 32 | 33 | 34 | @dataclass 35 | class InferenceParams: 36 | """Inference parameters passed to model to efficiently calculate 37 | and store context during inference. 38 | 39 | Reference: 40 | https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py. 41 | 42 | Args: 43 | max_seqlen: Maximum sequence length. 44 | max_batch_size: Maximum batch size. 45 | seqlen_offset: Sequence length offset. 46 | batch_size_offset: Batch size offset. 47 | key_value_memory_dict: Key value memory dictionary. 48 | lengths_per_sample: Lengths per sample. 49 | 50 | """ 51 | 52 | max_seqlen: int = field(metadata={"help": "Maximum sequence length."}) 53 | 54 | max_batch_size: int = field(metadata={"help": "Maximum batch size."}) 55 | 56 | seqlen_offset: int = field(default=0, metadata={"help": "Sequence length offset."}) 57 | 58 | batch_size_offset: int = field(default=0, metadata={"help": "Batch size offset."}) 59 | 60 | key_value_memory_dict: Dict[str, Any] = field( 61 | default_factory=dict, metadata={"help": "Key value memory dictionary."} 62 | ) 63 | 64 | lengths_per_sample: torch.Tensor = field(default=None, metadata={"help": "Lengths per sample."}) 65 | 66 | 67 | class Embedding(nn.Module): 68 | """Token embedding with dropout.""" 69 | 70 | def __init__(self, config: PretrainedConfig) -> None: 71 | super().__init__() 72 | 73 | self.wte = nn.Embedding(config.vocab_size, config.n_embd) 74 | self.drop = nn.Dropout(config.embd_pdrop) 75 | 76 | def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor: 77 | input_shape = input_ids.size() 78 | input_ids = input_ids.view(-1, input_shape[-1]) 79 | 80 | hidden_states = self.wte(input_ids) 81 | hidden_states = self.drop(hidden_states) 82 | 83 | return hidden_states 84 | 85 | 86 | def _apply_rotary_emb( 87 | x: torch.FloatTensor, 88 | cos: torch.FloatTensor, 89 | sin: torch.FloatTensor, 90 | ) -> torch.FloatTensor: 91 | _, seqlen, _, _ = x.shape 92 | _, rotary_dim = cos.shape 93 | rotary_dim *= 2 94 | 95 | x_rot = x[:, :, :, :rotary_dim] 96 | x_pass = x[:, :, :, rotary_dim:] 97 | 98 | x1, x2 = x_rot.chunk(2, dim=-1) 99 | c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d") 100 | x1, x2, c, s = [t.to(dtype=torch.float32) for t in [x1, x2, c, s]] 101 | 102 | x_rot = torch.cat([x1 * c - x2 * s, x1 * s + x2 * c], axis=-1).to(x.dtype) 103 | 104 | return torch.cat([x_rot, x_pass], axis=-1) 105 | 106 | 107 | def _apply_rotary_emb_kv( 108 | kv: torch.FloatTensor, 109 | cos: torch.FloatTensor, 110 | sin: torch.FloatTensor, 111 | cos_k: Optional[torch.FloatTensor] = None, 112 | sin_k: Optional[torch.FloatTensor] = None, 113 | ) -> torch.FloatTensor: 114 | _, seqlen, _, _, _ = kv.shape 115 | _, rotary_dim = cos.shape 116 | rotary_dim *= 2 117 | 118 | k_rot = kv[:, :, 0, :, :rotary_dim] 119 | k_pass = kv[:, :, 0, :, rotary_dim:] 120 | 121 | k1, k2 = k_rot.chunk(2, dim=-1) 122 | c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d") 123 | k1, k2, c, s = [t.to(dtype=torch.float32) for t in [k1, k2, c, s]] 124 | 125 | k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(kv.dtype) 126 | 127 | return torch.cat( 128 | [ 129 | torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2), 130 | kv[:, :, 1:2, :, :], 131 | ], 132 | axis=2, 133 | ) 134 | 135 | 136 | def _apply_rotary_emb_qkv( 137 | qkv: torch.FloatTensor, 138 | cos: torch.FloatTensor, 139 | sin: torch.FloatTensor, 140 | cos_k: Optional[torch.FloatTensor] = None, 141 | sin_k: Optional[torch.FloatTensor] = None, 142 | ) -> torch.FloatTensor: 143 | _, seqlen, _, _, _ = qkv.shape 144 | _, rotary_dim = cos.shape 145 | rotary_dim *= 2 146 | 147 | q_rot = qkv[:, :, 0, :, :rotary_dim] 148 | q_pass = qkv[:, :, 0, :, rotary_dim:] 149 | 150 | k_rot = qkv[:, :, 1, :, :rotary_dim] 151 | k_pass = qkv[:, :, 1, :, rotary_dim:] 152 | 153 | q1, q2 = q_rot.chunk(2, dim=-1) 154 | k1, k2 = k_rot.chunk(2, dim=-1) 155 | c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d") 156 | q1, q2, k1, k2, c, s = [t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s]] 157 | 158 | q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype) 159 | k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype) 160 | 161 | return torch.cat( 162 | [ 163 | torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2), 164 | torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2), 165 | qkv[:, :, 2:3, :, :], 166 | ], 167 | axis=2, 168 | ) 169 | 170 | 171 | class RotaryEmbedding(nn.Module): 172 | """Rotary positional embedding (RoPE). 173 | 174 | Reference: 175 | RoFormer: Enhanced Transformer with Rotary Position Embedding. 176 | https://arxiv.org/pdf/2104.09864.pdf. 177 | 178 | """ 179 | 180 | def __init__( 181 | self, 182 | dim: int, 183 | base: int = 10000, 184 | scale_base: Optional[float] = None, 185 | pos_idx_in_fp32: bool = True, 186 | max_position_embeddings: int = 2048, 187 | device: Optional[str] = None, 188 | **kwargs, 189 | ) -> None: 190 | super().__init__() 191 | 192 | if scale_base is not None: 193 | raise NotImplementedError 194 | 195 | self.dim = dim 196 | self.base = float(base) 197 | self.scale_base = scale_base 198 | self.pos_idx_in_fp32 = pos_idx_in_fp32 199 | self.max_position_embeddings = max_position_embeddings 200 | self.device = device 201 | 202 | # Generate and save the inverse frequency buffer (non-trainable) 203 | inv_freq = self._compute_inv_freq(device) 204 | self.register_buffer("inv_freq", inv_freq, persistent=False) 205 | 206 | # Generate and save the scale buffer (non-trainable) 207 | scale = ( 208 | (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim) 209 | if scale_base is not None 210 | else None 211 | ) 212 | self.register_buffer("scale", scale, persistent=False) 213 | 214 | # Initialize cached attributes since ONNX can't rely on dynamic initialization 215 | self._update_cos_sin_cache(max_position_embeddings, device=device, dtype=torch.float32) 216 | 217 | def _compute_inv_freq(self, device: Optional[str] = None) -> torch.FloatTensor: 218 | return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)) 219 | 220 | def _update_cos_sin_cache( 221 | self, 222 | seqlen: int, 223 | device: Optional[str] = None, 224 | dtype: Optional[torch.dtype] = None, 225 | ) -> None: 226 | self._seq_len_cached = seqlen 227 | 228 | # fp32 is preferred since the output of `torch.arange` can be quite large 229 | # and bf16 would lose a lot of precision 230 | if self.pos_idx_in_fp32: 231 | t = torch.arange(seqlen, device=device, dtype=torch.float32) 232 | if self.inv_freq.dtype != torch.float32: 233 | inv_freq = self._compute_inv_freq(device=device) 234 | else: 235 | inv_freq = self.inv_freq 236 | else: 237 | t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) 238 | inv_freq = self.inv_freq 239 | 240 | # `torch.outer` is preferred since `torch.einsum` converts from fp32 to fp16 if used with AMP 241 | freqs = torch.outer(t, inv_freq) 242 | if self.scale is None: 243 | self._cos_cached = torch.cos(freqs).to(dtype) 244 | self._sin_cached = torch.sin(freqs).to(dtype) 245 | else: 246 | power = ( 247 | torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2 248 | ) / self.scale_base 249 | scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1") 250 | 251 | # Force the scale multiplication to happen in fp32 252 | self._cos_cached = (torch.cos(freqs) * scale).to(dtype) 253 | self._sin_cached = (torch.sin(freqs) * scale).to(dtype) 254 | self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) 255 | self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) 256 | 257 | def forward( 258 | self, 259 | qkv: torch.Tensor, 260 | kv: Optional[torch.Tensor] = None, 261 | seqlen_offset: int = 0, 262 | **kwargs, 263 | ) -> Tuple[torch.Tensor, torch.Tensor]: 264 | if ( 265 | self._seq_len_cached < qkv.shape[1] + seqlen_offset 266 | or self._cos_cached.device != qkv.device 267 | or self._cos_cached.dtype != qkv.dtype 268 | or (self.training and self._cos_cached.is_inference()) 269 | ): 270 | self._update_cos_sin_cache(qkv.shape[1] + seqlen_offset, device=qkv.device, dtype=qkv.dtype) 271 | 272 | if kv is None: 273 | return _apply_rotary_emb_qkv( 274 | qkv, 275 | self._cos_cached[seqlen_offset:], 276 | self._sin_cached[seqlen_offset:], 277 | ) 278 | else: 279 | q = _apply_rotary_emb( 280 | qkv, 281 | self._cos_cached[seqlen_offset:], 282 | self._sin_cached[seqlen_offset:], 283 | ) 284 | kv = _apply_rotary_emb_kv( 285 | kv, 286 | self._cos_cached[seqlen_offset:], 287 | self._sin_cached[seqlen_offset:], 288 | ) 289 | 290 | return q, kv 291 | 292 | 293 | class MLP(nn.Module): 294 | """Multi-Layer Perceptron. 295 | 296 | Reference: 297 | Attention Is All You Need. 298 | https://arxiv.org/pdf/1706.03762.pdf. 299 | 300 | """ 301 | 302 | def __init__( 303 | self, 304 | config: PretrainedConfig, 305 | n_inner: Optional[int] = None, 306 | act_fn: Optional[str] = None, 307 | ) -> None: 308 | super().__init__() 309 | 310 | act_fn = config.activation_function if act_fn is None else act_fn 311 | 312 | n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner 313 | n_inner = n_inner if n_inner is not None else 4 * config.n_embd 314 | 315 | self.fc1 = nn.Linear(config.n_embd, n_inner) 316 | self.fc2 = nn.Linear(n_inner, config.n_embd) 317 | self.act = ACT2FN[act_fn] 318 | 319 | def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: 320 | hidden_states = self.fc1(hidden_states) 321 | hidden_states = self.act(hidden_states) 322 | hidden_states = self.fc2(hidden_states) 323 | 324 | return hidden_states 325 | 326 | 327 | class SelfAttention(nn.Module): 328 | """Self-attention layer (compatible with PyTorch). 329 | 330 | Reference: 331 | https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py. 332 | 333 | """ 334 | 335 | def __init__( 336 | self, 337 | causal: bool = True, 338 | softmax_scale: Optional[float] = None, 339 | attention_dropout: float = 0.0, 340 | ) -> None: 341 | super().__init__() 342 | 343 | self.causal = causal 344 | self.softmax_scale = softmax_scale 345 | self.drop = nn.Dropout(attention_dropout) 346 | 347 | @torch.autocast("cpu", enabled=False) 348 | @torch.autocast("cuda", enabled=False) 349 | def forward( 350 | self, 351 | qkv: torch.FloatTensor, 352 | causal: bool = None, 353 | key_padding_mask: Optional[torch.BoolTensor] = None, 354 | **kwargs, 355 | ) -> torch.FloatTensor: 356 | batch_size, seqlen = qkv.shape[0], qkv.shape[1] 357 | q, k, v = qkv.unbind(dim=2) 358 | 359 | q = q.to(torch.float32) 360 | k = k.to(torch.float32) 361 | 362 | causal = self.causal if causal is None else causal 363 | softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) 364 | 365 | # Autocast is manually disabled to avoid `torch.einsum` performing the operation 366 | # using float16, which might lead to overflow 367 | scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) 368 | 369 | if key_padding_mask is not None: 370 | padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device) 371 | padding_mask.masked_fill_(key_padding_mask, 0.0) 372 | 373 | scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") 374 | 375 | if causal: 376 | causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1) 377 | scores = scores + causal_mask.to(dtype=scores.dtype) 378 | 379 | attention = torch.softmax(scores, dim=-1).to(v.dtype) 380 | attention = self.drop(attention) 381 | 382 | output = torch.einsum("bhts,bshd->bthd", attention, v) 383 | 384 | return output 385 | 386 | 387 | class CrossAttention(nn.Module): 388 | """Cross-attention layer (compatible with PyTorch). 389 | 390 | Reference: 391 | https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py. 392 | 393 | """ 394 | 395 | def __init__( 396 | self, 397 | causal: bool = True, 398 | softmax_scale: Optional[float] = None, 399 | attention_dropout: float = 0.0, 400 | ) -> None: 401 | super().__init__() 402 | 403 | self.causal = causal 404 | self.softmax_scale = softmax_scale 405 | self.drop = nn.Dropout(attention_dropout) 406 | 407 | @torch.autocast("cpu", enabled=False) 408 | @torch.autocast("cuda", enabled=False) 409 | def forward( 410 | self, 411 | q: torch.FloatTensor, 412 | kv: torch.FloatTensor, 413 | causal: bool = None, 414 | key_padding_mask: Optional[torch.BoolTensor] = None, 415 | **kwargs, 416 | ) -> torch.FloatTensor: 417 | batch_size, seqlen_q = q.shape[0], q.shape[1] 418 | seqlen_k = kv.shape[1] 419 | 420 | if kv.shape[3] != q.shape[2]: 421 | kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3]) 422 | k, v = kv.unbind(dim=2) 423 | 424 | q = q.to(torch.float32) 425 | k = k.to(torch.float32) 426 | 427 | causal = self.causal if causal is None else causal 428 | softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) 429 | 430 | # Autocast is manually disabled to avoid `torch.einsum` performing the operation 431 | # using float16, which might lead to overflow 432 | scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) 433 | 434 | if key_padding_mask is not None: 435 | padding_mask = torch.full( 436 | (batch_size, seqlen_k), 437 | -10000.0, 438 | dtype=scores.dtype, 439 | device=scores.device, 440 | ) 441 | padding_mask.masked_fill_(key_padding_mask, 0.0) 442 | 443 | scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") 444 | 445 | if causal: 446 | rows = rearrange(torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1") 447 | cols = torch.arange(seqlen_k, device=k.device, dtype=torch.long) 448 | causal_mask = cols > rows + seqlen_k - seqlen_q 449 | 450 | scores = scores.masked_fill(causal_mask, -10000.0) 451 | 452 | attention = torch.softmax(scores, dim=-1).to(v.dtype) 453 | attention = self.drop(attention) 454 | 455 | output = torch.einsum("bhts,bshd->bthd", attention, v) 456 | 457 | return output 458 | 459 | 460 | def _find_mha_dims( 461 | config: PretrainedConfig, 462 | n_head: Optional[int] = None, 463 | n_head_kv: Optional[int] = None, 464 | head_dim: Optional[int] = None, 465 | ) -> Tuple[int, int]: 466 | if n_head is None and head_dim is None: 467 | head_dim = config.n_embd // config.n_head 468 | n_head = config.n_head 469 | elif n_head is None or head_dim is None: 470 | raise ValueError("`n_head` and `head_dim` must be both specified or `None`.") 471 | 472 | if n_head_kv is None: 473 | n_head_kv = getattr(config, "n_head_kv", None) or n_head 474 | 475 | return n_head, n_head_kv, head_dim 476 | 477 | 478 | def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, layer_idx: int) -> torch.FloatTensor: 479 | num_heads, head_dim = kv.shape[-2:] 480 | 481 | if layer_idx not in inference_params.key_value_memory_dict: 482 | inference_params.key_value_memory_dict[layer_idx] = torch.empty( 483 | inference_params.max_batch_size, 484 | inference_params.max_seqlen, 485 | 2, 486 | num_heads, 487 | head_dim, 488 | dtype=kv.dtype, 489 | device=kv.device, 490 | ) 491 | 492 | batch_start = inference_params.batch_size_offset 493 | batch_end = batch_start + kv.shape[0] 494 | 495 | sequence_start = inference_params.seqlen_offset 496 | sequence_end = sequence_start + kv.shape[1] 497 | 498 | # When the current sequence length is equal to or larger than the maximum sequence length, 499 | # we need to concatenate the current `kv` with the cached `kv` to expand its length 500 | if sequence_end >= inference_params.max_seqlen: 501 | inference_params.key_value_memory_dict[layer_idx] = torch.concatenate( 502 | (inference_params.key_value_memory_dict[layer_idx], kv), dim=1) 503 | 504 | inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, sequence_start:sequence_end, ...] = kv 505 | kv = inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, :sequence_end, ...] 506 | 507 | return kv 508 | 509 | 510 | class MHA(nn.Module): 511 | """Multi-head attention layer.""" 512 | 513 | def __init__( 514 | self, 515 | config: PretrainedConfig, 516 | dtype: Optional[torch.dtype] = None, 517 | device: Optional[str] = None, 518 | rotary_dim: Optional[int] = None, 519 | rotary_base: float = 10000.0, 520 | rotary_scale_base: Optional[float] = None, 521 | n_head: Optional[int] = None, 522 | n_head_kv: Optional[int] = None, 523 | head_dim: Optional[int] = None, 524 | bias: bool = True, 525 | causal: bool = True, 526 | softmax_scale: Optional[float] = None, 527 | layer_idx: Optional[int] = None, 528 | return_residual: bool = False, 529 | checkpointing: bool = False, 530 | ) -> None: 531 | super().__init__() 532 | 533 | # Rotary embedding 534 | self.rotary_dim = rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0) 535 | if self.rotary_dim > 0: 536 | rotary_cls = FlashRotaryEmbedding if config.flash_rotary else RotaryEmbedding 537 | if rotary_cls is None: 538 | rotary_cls = RotaryEmbedding 539 | 540 | rotary_kwargs = {} 541 | if rotary_cls is RotaryEmbedding: 542 | rotary_kwargs["max_position_embeddings"] = config.n_positions 543 | 544 | self.rotary_emb = rotary_cls( 545 | self.rotary_dim, 546 | base=rotary_base, 547 | scale_base=rotary_scale_base, 548 | device=device, 549 | **rotary_kwargs, 550 | ) 551 | 552 | # MLP 553 | self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims( 554 | config, n_head=n_head, n_head_kv=n_head_kv, head_dim=head_dim 555 | ) 556 | op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv) 557 | hidden_size = config.n_embd 558 | 559 | linear_cls = FusedDense if config.fused_dense else nn.Linear 560 | if linear_cls is None: 561 | linear_cls = nn.Linear 562 | 563 | self.Wqkv = linear_cls(hidden_size, op_size, bias=bias, device=device, dtype=dtype) 564 | self.out_proj = linear_cls(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype) 565 | 566 | # Attention 567 | attn_cls = FlashSelfAttention if config.flash_attn else SelfAttention 568 | if attn_cls is None: 569 | attn_cls = SelfAttention 570 | 571 | cross_attn_cls = FlashCrossAttention if config.flash_attn else CrossAttention 572 | if cross_attn_cls is None: 573 | cross_attn_cls = CrossAttention 574 | 575 | self.inner_attn = attn_cls( 576 | causal=causal, 577 | softmax_scale=softmax_scale, 578 | attention_dropout=config.attn_pdrop, 579 | ) 580 | self.inner_cross_attn = cross_attn_cls( 581 | causal=causal, 582 | softmax_scale=softmax_scale, 583 | attention_dropout=config.attn_pdrop, 584 | ) 585 | 586 | self.flash_attn = config.flash_attn and attn_cls is FlashSelfAttention 587 | self.layer_idx = layer_idx 588 | self.return_residual = return_residual 589 | self.checkpointing = checkpointing 590 | 591 | def _forward_self_attn( 592 | self, x: torch.FloatTensor, key_padding_mask: Optional[torch.BoolTensor] 593 | ) -> torch.FloatTensor: 594 | qkv = self.Wqkv(x) 595 | qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim) 596 | 597 | if self.rotary_dim > 0: 598 | qkv = self.rotary_emb(qkv) 599 | 600 | if self.flash_attn: 601 | batch_size, seqlen = qkv.shape[0], qkv.shape[1] 602 | 603 | cu_seqlens, max_seqlen = None, None 604 | if key_padding_mask is not None: 605 | # If `key_padding_mask` is supplied, we need to unpad the input and retrieve 606 | # the `cu_seqlens` and `max_seqlen` to be used by `flash-attn` 607 | qkv, indices, cu_seqlens, max_seqlen = unpad_input(qkv, key_padding_mask) 608 | 609 | if self.checkpointing: 610 | attn_output = torch.utils.checkpoint.checkpoint( 611 | self.inner_attn, qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen 612 | ) 613 | else: 614 | attn_output = self.inner_attn(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen).to(qkv.device) 615 | 616 | # If `key_padding_mask` is supplied, we need to pad the output back to the original shape 617 | return pad_input(attn_output, indices, batch_size, seqlen) if key_padding_mask is not None else attn_output 618 | 619 | if self.checkpointing: 620 | return torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, key_padding_mask=key_padding_mask) 621 | 622 | return self.inner_attn(qkv, key_padding_mask=key_padding_mask) 623 | 624 | def _forward_cross_attn( 625 | self, 626 | x: torch.FloatTensor, 627 | past_key_values: Optional[InferenceParams], 628 | key_padding_mask: Optional[torch.BoolTensor], 629 | ) -> torch.FloatTensor: 630 | batch_size = x.shape[0] 631 | 632 | qkv = self.Wqkv(x) 633 | 634 | q = qkv[..., : self.n_head * self.head_dim] 635 | q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim) 636 | 637 | kv = qkv[..., self.n_head * self.head_dim:] 638 | kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim) 639 | 640 | seqlen_offset = past_key_values.seqlen_offset if past_key_values is not None else 0 641 | causal = None if seqlen_offset == 0 else False 642 | if self.rotary_dim > 0: 643 | q, kv = self.rotary_emb(q, kv=kv, seqlen_offset=seqlen_offset) 644 | 645 | if past_key_values is not None: 646 | kv = _update_kv_cache(kv, past_key_values, self.layer_idx) 647 | 648 | if self.flash_attn: 649 | batch_size, seqlen_q = q.shape[0], q.shape[1] 650 | seqlen_k = kv.shape[1] 651 | 652 | cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = ( 653 | None, 654 | None, 655 | None, 656 | None, 657 | ) 658 | if key_padding_mask is not None: 659 | kv, _, cu_seqlens_k, max_seqlen_k = unpad_input(kv, key_padding_mask) 660 | 661 | if seqlen_q == 1: 662 | key_padding_mask = torch.ones(batch_size, 1, device=q.device) 663 | elif seqlen_q != seqlen_k: 664 | key_padding_mask = key_padding_mask[:, -seqlen_q:] 665 | 666 | q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, key_padding_mask) 667 | 668 | if self.checkpointing: 669 | attn_output = torch.utils.checkpoint.checkpoint( 670 | self.inner_cross_attn, 671 | q, 672 | kv, 673 | causal=causal, 674 | cu_seqlens=cu_seqlens_q, 675 | max_seqlen=max_seqlen_q, 676 | cu_seqlens_k=cu_seqlens_k, 677 | max_seqlen_k=max_seqlen_k, 678 | ) 679 | else: 680 | attn_output = self.inner_cross_attn( 681 | q, 682 | kv, 683 | causal=causal, 684 | cu_seqlens=cu_seqlens_q, 685 | max_seqlen=max_seqlen_q, 686 | cu_seqlens_k=cu_seqlens_k, 687 | max_seqlen_k=max_seqlen_k, 688 | ) 689 | 690 | return ( 691 | pad_input(attn_output, indices_q, batch_size, max_seqlen_q) 692 | if key_padding_mask is not None 693 | else attn_output 694 | ) 695 | 696 | if self.checkpointing: 697 | return torch.utils.checkpoint.checkpoint( 698 | self.inner_cross_attn, 699 | q, 700 | kv, 701 | key_padding_mask=key_padding_mask, 702 | causal=causal, 703 | ) 704 | 705 | return self.inner_cross_attn(q, kv, key_padding_mask=key_padding_mask, causal=causal) 706 | 707 | def forward( 708 | self, 709 | x: torch.FloatTensor, 710 | past_key_values: Optional[InferenceParams] = None, 711 | attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None, 712 | **kwargs, 713 | ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: 714 | if attention_mask is not None: 715 | attention_mask = attention_mask.bool() 716 | else: 717 | attention_mask = None 718 | 719 | # MHA 720 | if self.n_head == self.n_head_kv: 721 | if past_key_values is None: 722 | # If `past_key_values` are not supplied, we run self-attention 723 | attn_output = self._forward_self_attn(x, attention_mask) 724 | else: 725 | # If `past_key_values` are supplied, it means that we might have cached values and 726 | # could take advantage of cross-attention 727 | attn_output = self._forward_cross_attn(x, past_key_values, attention_mask) 728 | # MQA / GQA 729 | else: 730 | # Regardless of `past_key_values` being supplied or not, it always use cross-attention 731 | # because `q` and `kv` lengths might be different 732 | attn_output = self._forward_cross_attn(x, past_key_values, attention_mask) 733 | 734 | output = rearrange(attn_output, "... h d -> ... (h d)") 735 | output = self.out_proj(output) 736 | 737 | return output if not self.return_residual else (output, x) 738 | 739 | 740 | class ParallelBlock(nn.Module): 741 | """Parallel block. 742 | 743 | This block applies parallel mixer and MLP layers to the input (used in GPT-J and CodeGen). 744 | 745 | """ 746 | 747 | def __init__( 748 | self, 749 | config: PretrainedConfig, 750 | block_idx: Optional[int] = None, 751 | ) -> None: 752 | super().__init__() 753 | 754 | self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) 755 | self.resid_dropout = nn.Dropout(config.resid_pdrop) 756 | self.block_idx = block_idx 757 | 758 | self.mixer = MHA(config, layer_idx=block_idx) 759 | self.mlp = MLP(config) 760 | 761 | def forward( 762 | self, 763 | hidden_states: torch.FloatTensor, 764 | past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None, 765 | attention_mask: Optional[torch.BoolTensor] = None, 766 | **kwargs, 767 | ) -> torch.FloatTensor: 768 | residual = hidden_states 769 | hidden_states = self.ln(hidden_states) 770 | 771 | attn_outputs = self.mixer( 772 | hidden_states, 773 | past_key_values=past_key_values, 774 | attention_mask=attention_mask, 775 | ) 776 | if isinstance(attn_outputs, tuple): 777 | attn_outputs = attn_outputs[0] 778 | 779 | attn_outputs = self.resid_dropout(attn_outputs) 780 | feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states)) 781 | 782 | hidden_states = attn_outputs + feed_forward_hidden_states + residual 783 | 784 | return hidden_states 785 | 786 | 787 | class CausalLMHead(nn.Module): 788 | """Causal Language Modeling head. 789 | 790 | Reference: 791 | Improving Language Understanding by Generative Pre-Training. 792 | https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf. 793 | 794 | """ 795 | 796 | def __init__(self, config: PretrainedConfig) -> None: 797 | super().__init__() 798 | 799 | self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) 800 | self.linear = nn.Linear(config.n_embd, config.vocab_size) 801 | 802 | def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: 803 | hidden_states = self.ln(hidden_states) 804 | logits = self.linear(hidden_states).to(torch.float32) 805 | 806 | return logits 807 | 808 | 809 | class CausalLMLoss(nn.Module): 810 | """Causal Language Modeling loss. 811 | 812 | Reference: 813 | Improving Language Understanding by Generative Pre-Training. 814 | https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf. 815 | 816 | """ 817 | 818 | def __init__(self, shift_labels: bool = True) -> None: 819 | super().__init__() 820 | 821 | self.shift_labels = shift_labels 822 | self.loss_fct = nn.CrossEntropyLoss() 823 | 824 | def forward(self, logits: torch.FloatTensor, labels: torch.LongTensor) -> torch.FloatTensor: 825 | if self.shift_labels: 826 | logits = logits[..., :-1, :].contiguous() 827 | labels = labels[..., 1:].contiguous() 828 | 829 | loss = self.loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) 830 | 831 | return loss 832 | 833 | 834 | class PhiPreTrainedModel(PreTrainedModel): 835 | """Phi pre-trained model.""" 836 | 837 | config_class = PhiConfig 838 | base_model_prefix = "transformer" 839 | supports_gradient_checkpointing = False 840 | _no_split_modules = ["ParallelBlock"] 841 | 842 | def __init__(self, *inputs, **kwargs) -> None: 843 | super().__init__(*inputs, **kwargs) 844 | 845 | def _init_weights(self, module: nn.Module) -> None: 846 | if isinstance(module, (nn.Linear,)): 847 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 848 | if module.bias is not None: 849 | module.bias.data.zero_() 850 | elif isinstance(module, nn.Embedding): 851 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 852 | if module.padding_idx is not None: 853 | module.weight.data[module.padding_idx].zero_() 854 | elif isinstance(module, nn.LayerNorm): 855 | if module.bias is not None: 856 | module.bias.data.zero_() 857 | module.weight.data.fill_(1.0) 858 | 859 | def prepare_inputs_for_generation( 860 | self, 861 | input_ids: torch.LongTensor, 862 | past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None, 863 | attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None, 864 | **kwargs, 865 | ) -> Dict[str, Any]: 866 | if past_key_values is None or not (isinstance(past_key_values, InferenceParams)): 867 | past_key_values = InferenceParams( 868 | max_seqlen=self.config.n_positions, 869 | max_batch_size=input_ids.shape[0], 870 | seqlen_offset=0, 871 | batch_size_offset=0, 872 | key_value_memory_dict={}, 873 | lengths_per_sample=None, 874 | ) 875 | else: 876 | # Assume that `past_key_values` has cached all tokens up to the last token in `input_ids` 877 | past_key_values.seqlen_offset = input_ids.shape[1] - 1 878 | input_ids = input_ids[:, -1].unsqueeze(-1) 879 | 880 | return { 881 | "input_ids": input_ids, 882 | "past_key_values": past_key_values, 883 | "attention_mask": attention_mask, 884 | } 885 | 886 | 887 | class PhiModel(PhiPreTrainedModel): 888 | """Phi model.""" 889 | 890 | _keys_to_ignore_on_load_missing = [""] 891 | _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"] 892 | 893 | def __init__(self, config: PhiConfig) -> None: 894 | super().__init__(config) 895 | 896 | self.embd = Embedding(config) 897 | self.h = nn.ModuleList([ParallelBlock(config, block_idx=i) for i in range(config.n_layer)]) 898 | self.gradient_checkpointing = False 899 | self.post_init() 900 | 901 | def get_input_embeddings(self) -> nn.Embedding: 902 | return self.embd.wte 903 | 904 | def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None: 905 | self.embd.wte = new_embeddings 906 | 907 | def forward( 908 | self, 909 | input_ids: torch.LongTensor, 910 | past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None, 911 | attention_mask: Optional[torch.BoolTensor] = None, 912 | ) -> torch.FloatTensor: 913 | hidden_states = self.embd(input_ids) 914 | 915 | for layer in self.h: 916 | hidden_states = layer( 917 | hidden_states, 918 | past_key_values=past_key_values, 919 | attention_mask=attention_mask, 920 | ) 921 | 922 | return hidden_states 923 | 924 | 925 | class PhiForCausalLM(PhiPreTrainedModel): 926 | """Phi for Causal Language Modeling.""" 927 | 928 | _keys_to_ignore_on_load_missing = [""] 929 | _keys_to_ignore_on_load_unexpected = [r"transformer\.h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"] 930 | 931 | def __init__(self, config: PhiConfig) -> None: 932 | super().__init__(config) 933 | 934 | self.transformer = PhiModel(config) 935 | self.lm_head = CausalLMHead(config) 936 | self.loss = CausalLMLoss() 937 | 938 | self.post_init() 939 | 940 | def get_output_embeddings(self) -> nn.Linear: 941 | return self.lm_head.linear 942 | 943 | def set_output_embeddings(self, new_embeddings: nn.Linear) -> None: 944 | self.lm_head.linear = new_embeddings 945 | 946 | def forward( 947 | self, 948 | input_ids: torch.LongTensor, 949 | past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None, 950 | attention_mask: Optional[torch.BoolTensor] = None, 951 | labels: Optional[torch.LongTensor] = None, 952 | **kwargs, 953 | ) -> CausalLMOutputWithPast: 954 | hidden_states = self.transformer(input_ids, past_key_values=past_key_values, attention_mask=attention_mask) 955 | lm_logits = self.lm_head(hidden_states) 956 | 957 | loss = None 958 | if labels is not None: 959 | loss = self.loss(lm_logits, labels) 960 | 961 | return CausalLMOutputWithPast(loss=loss, logits=lm_logits, past_key_values=past_key_values) 962 | --------------------------------------------------------------------------------