├── Figure_1.pdf
├── Related_Work_A.1.pdf
├── data
├── ist_cnt2category2cnt.npz
├── jkt_cnt2category2cnt.npz
├── nyc_cnt2category2cnt.npz
└── tky_cnt2category2cnt.npz
├── model
├── base.py
├── configuration_phi.py
├── layers.py
├── positional_encoding.py
├── bert.py
├── utils.py
├── llm.py
├── MobilityLLM.py
└── phi_model.py
├── config
├── MobilityLLM_wee_POI.conf
├── MobilityLLM_wee_TPP.conf
├── MobilityLLM_wee_TUL.conf
├── MobilityLLM_bkc_POI.conf
├── MobilityLLM_bkc_TPP.conf
├── MobilityLLM_bkc_TUL.conf
├── MobilityLLM_gow_POI.conf
├── MobilityLLM_gow_TPP.conf
├── MobilityLLM_gow_TUL.conf
├── MobilityLLM_tky_POI.conf
├── MobilityLLM_tky_TPP.conf
└── MobilityLLM_tky_TUL.conf
├── README.md
├── utils.py
├── preprocess
└── load_data.py
└── train_MobilityLLM.py
/Figure_1.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LetianGong/Mobility-LLM/HEAD/Figure_1.pdf
--------------------------------------------------------------------------------
/Related_Work_A.1.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LetianGong/Mobility-LLM/HEAD/Related_Work_A.1.pdf
--------------------------------------------------------------------------------
/data/ist_cnt2category2cnt.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LetianGong/Mobility-LLM/HEAD/data/ist_cnt2category2cnt.npz
--------------------------------------------------------------------------------
/data/jkt_cnt2category2cnt.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LetianGong/Mobility-LLM/HEAD/data/jkt_cnt2category2cnt.npz
--------------------------------------------------------------------------------
/data/nyc_cnt2category2cnt.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LetianGong/Mobility-LLM/HEAD/data/nyc_cnt2category2cnt.npz
--------------------------------------------------------------------------------
/data/tky_cnt2category2cnt.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LetianGong/Mobility-LLM/HEAD/data/tky_cnt2category2cnt.npz
--------------------------------------------------------------------------------
/model/base.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | from torch import nn
5 | import torch.nn.functional as F
6 | import numpy as np
7 | from einops import repeat, rearrange
8 |
9 |
10 | class Encoder(nn.Module):
11 | def __init__(self, name):
12 | super().__init__()
13 |
14 | self.name = name
15 |
16 |
17 | class Decoder(nn.Module):
18 | def __init__(self, name):
19 | super().__init__()
20 |
21 | # self.denormalizer = denormalizer
22 | self.name = name
23 |
24 |
25 | class Denoiser(nn.Module):
26 | def __init__(self, name):
27 | super().__init__()
28 | self.name = name
29 |
30 |
31 | class GNN(nn.Module):
32 | def __init__(self, graph, name):
33 | super().__init__()
34 |
35 | self.graph = graph
36 |
37 | self.name = name
38 |
--------------------------------------------------------------------------------
/config/MobilityLLM_wee_POI.conf:
--------------------------------------------------------------------------------
1 | [Data]
2 | dataset_name = www_WEE
3 | country_name = US
4 | max_his_period_days = 30
5 | max_merge_seconds_limit = 3600
6 | max_delta_mins = 1440
7 | min_session_mins = 120
8 | least_disuser_count = 10
9 | least_checkins_count = 10
10 |
11 | latN = 50
12 | lngN = 40
13 | split_save = 1
14 |
15 | [Training]
16 | use_nni = 0
17 | mode = train
18 | ctx = 0
19 | regularization = 1e-5
20 | learning_rate = 1e-3
21 | max_epochs = 100
22 | display_step = 1
23 | patience = 6
24 | train_batch = 8
25 | val_batch = 8
26 | test_batch = 8
27 | batch_size = 8
28 | save_results = 0
29 |
30 |
31 | [Model]
32 | loc_emb_size = 256
33 | geohash_size = 128
34 | category_size = 128
35 | tim_emb_size = 256
36 | user_emb_size = 256
37 | hidden_size = 256
38 | loc_noise_mean = 0
39 | loc_noise_sigma = 0.01
40 | tim_noise_mean = 0
41 | tim_noise_sigma = 0.01
42 | user_noise_mean = 0
43 | user_noise_sigma = 0.01
44 | tau = 4
45 | pos_eps = 0.5
46 | neg_eps = 0.5
47 | dropout_rate_1 = 0.5
48 | dropout_rate_2 = 0.5
49 | adv = 1
50 | self_weight = 0.05
51 | self_weight_s = 0.05
52 | self_weight_t = 0.05
53 | self_weight_st = 0.05
54 | k = 8
55 | momentum = 0.95
56 | theta = 0.18
57 | temperature = 0.1
58 | rnn_type = GRU
59 | num_layers = 3
60 | downstream = POI
61 | dump_path = checkpoints
62 | rank = 0
63 | queue_length = 1024
64 | world_size = -1
65 | epoch_queue_starts = 0
66 | crops_for_assign = 01
67 | feat_dim = 256
68 | loss = mae
69 | tpp = linear
70 | epsilon = 0.05
71 | dropout_spatial = 0.3
72 | learnable_param_size = 5
--------------------------------------------------------------------------------
/config/MobilityLLM_wee_TPP.conf:
--------------------------------------------------------------------------------
1 | [Data]
2 | dataset_name = www_WEE
3 | country_name = US
4 | max_his_period_days = 30
5 | max_merge_seconds_limit = 3600
6 | max_delta_mins = 1440
7 | min_session_mins = 120
8 | least_disuser_count = 10
9 | least_checkins_count = 10
10 |
11 | latN = 50
12 | lngN = 40
13 | split_save = 1
14 |
15 | [Training]
16 | use_nni = 0
17 | mode = train
18 | ctx = 0
19 | regularization = 1e-5
20 | learning_rate = 1e-3
21 | max_epochs = 100
22 | display_step = 1
23 | patience = 6
24 | train_batch = 8
25 | val_batch = 8
26 | test_batch = 8
27 | batch_size = 8
28 | save_results = 0
29 |
30 |
31 | [Model]
32 | loc_emb_size = 256
33 | geohash_size = 128
34 | category_size = 128
35 | tim_emb_size = 256
36 | user_emb_size = 256
37 | hidden_size = 256
38 | loc_noise_mean = 0
39 | loc_noise_sigma = 0.01
40 | tim_noise_mean = 0
41 | tim_noise_sigma = 0.01
42 | user_noise_mean = 0
43 | user_noise_sigma = 0.01
44 | tau = 4
45 | pos_eps = 0.5
46 | neg_eps = 0.5
47 | dropout_rate_1 = 0.5
48 | dropout_rate_2 = 0.5
49 | adv = 1
50 | self_weight = 0.05
51 | self_weight_s = 0.05
52 | self_weight_t = 0.05
53 | self_weight_st = 0.05
54 | k = 8
55 | momentum = 0.95
56 | theta = 0.18
57 | temperature = 0.1
58 | rnn_type = GRU
59 | num_layers = 3
60 | downstream = TPP
61 | dump_path = checkpoints
62 | rank = 0
63 | queue_length = 1024
64 | world_size = -1
65 | epoch_queue_starts = 0
66 | crops_for_assign = 01
67 | feat_dim = 256
68 | loss = mae
69 | tpp = linear
70 | epsilon = 0.05
71 | dropout_spatial = 0.3
72 | learnable_param_size = 3
--------------------------------------------------------------------------------
/config/MobilityLLM_wee_TUL.conf:
--------------------------------------------------------------------------------
1 | [Data]
2 | dataset_name = www_WEE
3 | country_name = US
4 | max_his_period_days = 30
5 | max_merge_seconds_limit = 3600
6 | max_delta_mins = 1440
7 | min_session_mins = 120
8 | least_disuser_count = 10
9 | least_checkins_count = 10
10 |
11 | latN = 50
12 | lngN = 40
13 | split_save = 1
14 |
15 | [Training]
16 | use_nni = 0
17 | mode = train
18 | ctx = 0
19 | regularization = 1e-5
20 | learning_rate = 1e-3
21 | max_epochs = 100
22 | display_step = 1
23 | patience = 6
24 | train_batch = 8
25 | val_batch = 8
26 | test_batch = 8
27 | batch_size = 8
28 | save_results = 0
29 |
30 |
31 | [Model]
32 | loc_emb_size = 256
33 | geohash_size = 128
34 | category_size = 128
35 | tim_emb_size = 256
36 | user_emb_size = 256
37 | hidden_size = 256
38 | loc_noise_mean = 0
39 | loc_noise_sigma = 0.01
40 | tim_noise_mean = 0
41 | tim_noise_sigma = 0.01
42 | user_noise_mean = 0
43 | user_noise_sigma = 0.01
44 | tau = 4
45 | pos_eps = 0.5
46 | neg_eps = 0.5
47 | dropout_rate_1 = 0.5
48 | dropout_rate_2 = 0.5
49 | adv = 1
50 | self_weight = 0.05
51 | self_weight_s = 0.05
52 | self_weight_t = 0.05
53 | self_weight_st = 0.05
54 | k = 8
55 | momentum = 0.95
56 | theta = 0.18
57 | temperature = 0.1
58 | rnn_type = GRU
59 | num_layers = 3
60 | downstream = TUL
61 | dump_path = checkpoints
62 | rank = 0
63 | queue_length = 1024
64 | world_size = -1
65 | epoch_queue_starts = 0
66 | crops_for_assign = 01
67 | feat_dim = 256
68 | loss = mae
69 | tpp = linear
70 | epsilon = 0.05
71 | dropout_spatial = 0.3
72 | learnable_param_size = 5
--------------------------------------------------------------------------------
/config/MobilityLLM_bkc_POI.conf:
--------------------------------------------------------------------------------
1 | [Data]
2 | dataset_name = www_BKC
3 | country_name = US
4 | max_his_period_days = 120
5 | max_merge_seconds_limit = 10800
6 | max_delta_mins = 1440
7 | min_session_mins = 1440
8 | least_disuser_count = 10
9 | least_checkins_count = 10
10 |
11 | latN = 50
12 | lngN = 40
13 | split_save = 1
14 |
15 | [Training]
16 | use_nni = 0
17 | mode = train
18 | ctx = 0
19 | regularization = 1e-5
20 | learning_rate = 1e-3
21 | max_epochs = 100
22 | # max_epochs = 10
23 | display_step = 1
24 | patience = 6
25 | train_batch = 8
26 | val_batch = 8
27 | test_batch = 8
28 | batch_size = 8
29 | save_results = 0
30 |
31 |
32 | [Model]
33 | loc_emb_size = 256
34 | geohash_size = 128
35 | category_size = 128
36 | tim_emb_size = 256
37 | user_emb_size = 256
38 | hidden_size = 256
39 | loc_noise_mean = 0
40 | loc_noise_sigma = 0.01
41 | tim_noise_mean = 0
42 | tim_noise_sigma = 0.01
43 | user_noise_mean = 0
44 | user_noise_sigma = 0.01
45 | tau = 4
46 | pos_eps = 0.5
47 | neg_eps = 0.5
48 | dropout_rate_1 = 0.5
49 | dropout_rate_2 = 0.5
50 | adv = 1
51 | self_weight = 0.05
52 | self_weight_s = 0.05
53 | self_weight_t = 0.05
54 | self_weight_st = 0.05
55 | k = 8
56 | momentum = 0.95
57 | theta = 0.18
58 | temperature = 0.1
59 | rnn_type = GRU
60 | num_layers = 3
61 | downstream = POI
62 | # downstream = TPP
63 | dump_path = checkpoints
64 | rank = 0
65 | queue_length = 1024
66 | world_size = -1
67 | epoch_queue_starts = 0
68 | crops_for_assign = 01
69 | feat_dim = 256
70 | loss = mae
71 | tpp = linear
72 | epsilon = 0.05
73 | dropout_spatial = 0.3
74 | learnable_param_size = 1
--------------------------------------------------------------------------------
/config/MobilityLLM_bkc_TPP.conf:
--------------------------------------------------------------------------------
1 | [Data]
2 | dataset_name = www_BKC
3 | country_name = US
4 | max_his_period_days = 120
5 | max_merge_seconds_limit = 10800
6 | max_delta_mins = 1440
7 | min_session_mins = 1440
8 | least_disuser_count = 10
9 | least_checkins_count = 10
10 |
11 | latN = 50
12 | lngN = 40
13 | split_save = 1
14 |
15 | [Training]
16 | use_nni = 0
17 | mode = train
18 | ctx = 0
19 | regularization = 1e-5
20 | learning_rate = 1e-3
21 | max_epochs = 100
22 | # max_epochs = 10
23 | display_step = 1
24 | patience = 6
25 | train_batch = 8
26 | val_batch = 8
27 | test_batch = 8
28 | batch_size = 8
29 | save_results = 0
30 |
31 |
32 | [Model]
33 | loc_emb_size = 256
34 | geohash_size = 128
35 | category_size = 128
36 | tim_emb_size = 256
37 | user_emb_size = 256
38 | hidden_size = 256
39 | loc_noise_mean = 0
40 | loc_noise_sigma = 0.01
41 | tim_noise_mean = 0
42 | tim_noise_sigma = 0.01
43 | user_noise_mean = 0
44 | user_noise_sigma = 0.01
45 | tau = 4
46 | pos_eps = 0.5
47 | neg_eps = 0.5
48 | dropout_rate_1 = 0.5
49 | dropout_rate_2 = 0.5
50 | adv = 1
51 | self_weight = 0.05
52 | self_weight_s = 0.05
53 | self_weight_t = 0.05
54 | self_weight_st = 0.05
55 | k = 8
56 | momentum = 0.95
57 | theta = 0.18
58 | temperature = 0.1
59 | rnn_type = GRU
60 | num_layers = 3
61 | # downstream = POI
62 | downstream = TPP
63 | dump_path = checkpoints
64 | rank = 0
65 | queue_length = 1024
66 | world_size = -1
67 | epoch_queue_starts = 0
68 | crops_for_assign = 01
69 | feat_dim = 256
70 | loss = mae
71 | tpp = linear
72 | epsilon = 0.05
73 | dropout_spatial = 0.3
74 | learnable_param_size = 3
--------------------------------------------------------------------------------
/config/MobilityLLM_bkc_TUL.conf:
--------------------------------------------------------------------------------
1 | [Data]
2 | dataset_name = www_BKC
3 | country_name = US
4 | max_his_period_days = 120
5 | max_merge_seconds_limit = 10800
6 | max_delta_mins = 1440
7 | min_session_mins = 1440
8 | least_disuser_count = 10
9 | least_checkins_count = 10
10 |
11 | latN = 50
12 | lngN = 40
13 | split_save = 1
14 |
15 | [Training]
16 | use_nni = 0
17 | mode = train
18 | ctx = 0
19 | regularization = 1e-5
20 | learning_rate = 1e-3
21 | max_epochs = 100
22 | # max_epochs = 10
23 | display_step = 1
24 | patience = 6
25 | train_batch = 8
26 | val_batch = 8
27 | test_batch = 8
28 | batch_size = 8
29 | save_results = 0
30 |
31 |
32 | [Model]
33 | loc_emb_size = 256
34 | geohash_size = 128
35 | category_size = 128
36 | tim_emb_size = 256
37 | user_emb_size = 256
38 | hidden_size = 256
39 | loc_noise_mean = 0
40 | loc_noise_sigma = 0.01
41 | tim_noise_mean = 0
42 | tim_noise_sigma = 0.01
43 | user_noise_mean = 0
44 | user_noise_sigma = 0.01
45 | tau = 4
46 | pos_eps = 0.5
47 | neg_eps = 0.5
48 | dropout_rate_1 = 0.5
49 | dropout_rate_2 = 0.5
50 | adv = 1
51 | self_weight = 0.05
52 | self_weight_s = 0.05
53 | self_weight_t = 0.05
54 | self_weight_st = 0.05
55 | k = 8
56 | momentum = 0.95
57 | theta = 0.18
58 | temperature = 0.1
59 | rnn_type = GRU
60 | num_layers = 3
61 | downstream = TUL
62 | # downstream = TPP
63 | dump_path = checkpoints
64 | rank = 0
65 | queue_length = 1024
66 | world_size = -1
67 | epoch_queue_starts = 0
68 | crops_for_assign = 01
69 | feat_dim = 256
70 | loss = mae
71 | tpp = linear
72 | epsilon = 0.05
73 | dropout_spatial = 0.3
74 | learnable_param_size = 1
--------------------------------------------------------------------------------
/config/MobilityLLM_gow_POI.conf:
--------------------------------------------------------------------------------
1 | [Data]
2 | dataset_name = www_GOW
3 | country_name = US
4 | max_his_period_days = 120
5 | max_merge_seconds_limit = 10800
6 | max_delta_mins = 1440
7 | min_session_mins = 1440
8 | least_disuser_count = 10
9 | least_checkins_count = 10
10 |
11 | latN = 50
12 | lngN = 40
13 | split_save = 1
14 |
15 | [Training]
16 | use_nni = 0
17 | mode = train
18 | ctx = 0
19 | regularization = 1e-5
20 | learning_rate = 1e-3
21 | max_epochs = 100
22 | # max_epochs = 10
23 | display_step = 1
24 | patience = 6
25 | train_batch = 8
26 | val_batch = 8
27 | test_batch = 8
28 | batch_size = 8
29 | save_results = 0
30 |
31 |
32 | [Model]
33 | loc_emb_size = 256
34 | geohash_size = 128
35 | category_size = 128
36 | tim_emb_size = 256
37 | user_emb_size = 256
38 | hidden_size = 256
39 | loc_noise_mean = 0
40 | loc_noise_sigma = 0.01
41 | tim_noise_mean = 0
42 | tim_noise_sigma = 0.01
43 | user_noise_mean = 0
44 | user_noise_sigma = 0.01
45 | tau = 4
46 | pos_eps = 0.5
47 | neg_eps = 0.5
48 | dropout_rate_1 = 0.5
49 | dropout_rate_2 = 0.5
50 | adv = 1
51 | self_weight = 0.05
52 | self_weight_s = 0.05
53 | self_weight_t = 0.05
54 | self_weight_st = 0.05
55 | k = 8
56 | momentum = 0.95
57 | theta = 0.18
58 | temperature = 0.1
59 | rnn_type = GRU
60 | num_layers = 3
61 | downstream = POI
62 | # downstream = TPP
63 | dump_path = checkpoints
64 | rank = 0
65 | queue_length = 1024
66 | world_size = -1
67 | epoch_queue_starts = 0
68 | crops_for_assign = 01
69 | feat_dim = 256
70 | loss = mae
71 | tpp = linear
72 | epsilon = 0.05
73 | dropout_spatial = 0.3
74 | learnable_param_size = 1
--------------------------------------------------------------------------------
/config/MobilityLLM_gow_TPP.conf:
--------------------------------------------------------------------------------
1 | [Data]
2 | dataset_name = www_GOW
3 | country_name = US
4 | max_his_period_days = 120
5 | max_merge_seconds_limit = 10800
6 | max_delta_mins = 1440
7 | min_session_mins = 1440
8 | least_disuser_count = 10
9 | least_checkins_count = 10
10 |
11 | latN = 50
12 | lngN = 40
13 | split_save = 1
14 |
15 | [Training]
16 | use_nni = 0
17 | mode = train
18 | ctx = 0
19 | regularization = 1e-5
20 | learning_rate = 1e-3
21 | max_epochs = 100
22 | # max_epochs = 10
23 | display_step = 1
24 | patience = 6
25 | train_batch = 8
26 | val_batch = 8
27 | test_batch = 8
28 | batch_size = 8
29 | save_results = 0
30 |
31 |
32 | [Model]
33 | loc_emb_size = 256
34 | geohash_size = 128
35 | category_size = 128
36 | tim_emb_size = 256
37 | user_emb_size = 256
38 | hidden_size = 256
39 | loc_noise_mean = 0
40 | loc_noise_sigma = 0.01
41 | tim_noise_mean = 0
42 | tim_noise_sigma = 0.01
43 | user_noise_mean = 0
44 | user_noise_sigma = 0.01
45 | tau = 4
46 | pos_eps = 0.5
47 | neg_eps = 0.5
48 | dropout_rate_1 = 0.5
49 | dropout_rate_2 = 0.5
50 | adv = 1
51 | self_weight = 0.05
52 | self_weight_s = 0.05
53 | self_weight_t = 0.05
54 | self_weight_st = 0.05
55 | k = 8
56 | momentum = 0.95
57 | theta = 0.18
58 | temperature = 0.1
59 | rnn_type = GRU
60 | num_layers = 3
61 | # downstream = POI
62 | downstream = TPP
63 | dump_path = checkpoints
64 | rank = 0
65 | queue_length = 1024
66 | world_size = -1
67 | epoch_queue_starts = 0
68 | crops_for_assign = 01
69 | feat_dim = 256
70 | loss = mae
71 | tpp = linear
72 | epsilon = 0.05
73 | dropout_spatial = 0.3
74 | learnable_param_size = 3
--------------------------------------------------------------------------------
/config/MobilityLLM_gow_TUL.conf:
--------------------------------------------------------------------------------
1 | [Data]
2 | dataset_name = www_GOW
3 | country_name = US
4 | max_his_period_days = 120
5 | max_merge_seconds_limit = 10800
6 | max_delta_mins = 1440
7 | min_session_mins = 1440
8 | least_disuser_count = 10
9 | least_checkins_count = 10
10 |
11 | latN = 50
12 | lngN = 40
13 | split_save = 1
14 |
15 | [Training]
16 | use_nni = 0
17 | mode = train
18 | ctx = 0
19 | regularization = 1e-5
20 | learning_rate = 1e-3
21 | max_epochs = 100
22 | # max_epochs = 10
23 | display_step = 1
24 | patience = 6
25 | train_batch = 8
26 | val_batch = 8
27 | test_batch = 8
28 | batch_size = 8
29 | save_results = 0
30 |
31 |
32 | [Model]
33 | loc_emb_size = 256
34 | geohash_size = 128
35 | category_size = 128
36 | tim_emb_size = 256
37 | user_emb_size = 256
38 | hidden_size = 256
39 | loc_noise_mean = 0
40 | loc_noise_sigma = 0.01
41 | tim_noise_mean = 0
42 | tim_noise_sigma = 0.01
43 | user_noise_mean = 0
44 | user_noise_sigma = 0.01
45 | tau = 4
46 | pos_eps = 0.5
47 | neg_eps = 0.5
48 | dropout_rate_1 = 0.5
49 | dropout_rate_2 = 0.5
50 | adv = 1
51 | self_weight = 0.05
52 | self_weight_s = 0.05
53 | self_weight_t = 0.05
54 | self_weight_st = 0.05
55 | k = 8
56 | momentum = 0.95
57 | theta = 0.18
58 | temperature = 0.1
59 | rnn_type = GRU
60 | num_layers = 3
61 | downstream = TUL
62 | # downstream = TPP
63 | dump_path = checkpoints
64 | rank = 0
65 | queue_length = 1024
66 | world_size = -1
67 | epoch_queue_starts = 0
68 | crops_for_assign = 01
69 | feat_dim = 256
70 | loss = mae
71 | tpp = linear
72 | epsilon = 0.05
73 | dropout_spatial = 0.3
74 | learnable_param_size = 1
--------------------------------------------------------------------------------
/config/MobilityLLM_tky_POI.conf:
--------------------------------------------------------------------------------
1 | [Data]
2 | dataset_name = www_TKY
3 | country_name = JP
4 | max_his_period_days = 120
5 | max_merge_seconds_limit = 10800
6 | max_delta_mins = 1440
7 | min_session_mins = 1440
8 | least_disuser_count = 10
9 | least_checkins_count = 10
10 |
11 | latN = 40
12 | lngN = 40
13 | split_save = 1
14 |
15 | [Training]
16 | use_nni = 0
17 | mode = train
18 | ctx = 0
19 | regularization = 1e-5
20 | learning_rate = 1e-3
21 | max_epochs = 100
22 | # max_epochs = 10
23 | display_step = 1
24 | patience = 6
25 | train_batch = 8
26 | val_batch = 8
27 | test_batch = 8
28 | batch_size = 8
29 | save_results = 0
30 |
31 |
32 | [Model]
33 | loc_emb_size = 256
34 | geohash_size = 128
35 | category_size = 128
36 | tim_emb_size = 256
37 | user_emb_size = 256
38 | hidden_size = 256
39 | loc_noise_mean = 0
40 | loc_noise_sigma = 0.01
41 | tim_noise_mean = 0
42 | tim_noise_sigma = 0.01
43 | user_noise_mean = 0
44 | user_noise_sigma = 0.01
45 | tau = 4
46 | pos_eps = 0.5
47 | neg_eps = 0.5
48 | dropout_rate_1 = 0.5
49 | dropout_rate_2 = 0.5
50 | adv = 1
51 | self_weight = 0.05
52 | self_weight_s = 0.05
53 | self_weight_t = 0.05
54 | self_weight_st = 0.05
55 | k = 8
56 | momentum = 0.95
57 | theta = 0.18
58 | temperature = 0.1
59 | rnn_type = GRU
60 | num_layers = 3
61 | downstream = POI
62 | # downstream = TPP
63 | dump_path = checkpoints
64 | rank = 0
65 | queue_length = 1024
66 | world_size = -1
67 | epoch_queue_starts = 0
68 | crops_for_assign = 01
69 | feat_dim = 256
70 | loss = mae
71 | tpp = linear
72 | epsilon = 0.05
73 | dropout_spatial = 0.3
74 | learnable_param_size = 3
--------------------------------------------------------------------------------
/config/MobilityLLM_tky_TPP.conf:
--------------------------------------------------------------------------------
1 | [Data]
2 | dataset_name = www_TKY
3 | country_name = JP
4 | max_his_period_days = 120
5 | max_merge_seconds_limit = 10800
6 | max_delta_mins = 1440
7 | min_session_mins = 1440
8 | least_disuser_count = 10
9 | least_checkins_count = 10
10 |
11 | latN = 40
12 | lngN = 40
13 | split_save = 1
14 |
15 | [Training]
16 | use_nni = 0
17 | mode = train
18 | ctx = 0
19 | regularization = 1e-5
20 | learning_rate = 1e-3
21 | max_epochs = 100
22 | # max_epochs = 10
23 | display_step = 1
24 | patience = 6
25 | train_batch = 8
26 | val_batch = 8
27 | test_batch = 8
28 | batch_size = 8
29 | save_results = 0
30 |
31 |
32 | [Model]
33 | loc_emb_size = 256
34 | geohash_size = 128
35 | category_size = 128
36 | tim_emb_size = 256
37 | user_emb_size = 256
38 | hidden_size = 256
39 | loc_noise_mean = 0
40 | loc_noise_sigma = 0.01
41 | tim_noise_mean = 0
42 | tim_noise_sigma = 0.01
43 | user_noise_mean = 0
44 | user_noise_sigma = 0.01
45 | tau = 4
46 | pos_eps = 0.5
47 | neg_eps = 0.5
48 | dropout_rate_1 = 0.5
49 | dropout_rate_2 = 0.5
50 | adv = 1
51 | self_weight = 0.05
52 | self_weight_s = 0.05
53 | self_weight_t = 0.05
54 | self_weight_st = 0.05
55 | k = 8
56 | momentum = 0.95
57 | theta = 0.18
58 | temperature = 0.1
59 | rnn_type = GRU
60 | num_layers = 3
61 | # downstream = POI
62 | downstream = TPP
63 | dump_path = checkpoints
64 | rank = 0
65 | queue_length = 1024
66 | world_size = -1
67 | epoch_queue_starts = 0
68 | crops_for_assign = 01
69 | feat_dim = 256
70 | loss = mae
71 | tpp = linear
72 | epsilon = 0.05
73 | dropout_spatial = 0.3
74 | learnable_param_size = 3
--------------------------------------------------------------------------------
/config/MobilityLLM_tky_TUL.conf:
--------------------------------------------------------------------------------
1 | [Data]
2 | dataset_name = www_TKY
3 | country_name = JP
4 | max_his_period_days = 120
5 | max_merge_seconds_limit = 10800
6 | max_delta_mins = 1440
7 | min_session_mins = 1440
8 | least_disuser_count = 10
9 | least_checkins_count = 10
10 |
11 | latN = 40
12 | lngN = 40
13 | split_save = 1
14 |
15 | [Training]
16 | use_nni = 0
17 | mode = train
18 | ctx = 0
19 | regularization = 1e-5
20 | learning_rate = 1e-3
21 | max_epochs = 100
22 | # max_epochs = 10
23 | display_step = 1
24 | patience = 6
25 | train_batch = 8
26 | val_batch = 8
27 | test_batch = 8
28 | batch_size = 8
29 | save_results = 0
30 |
31 |
32 | [Model]
33 | loc_emb_size = 256
34 | geohash_size = 128
35 | category_size = 128
36 | tim_emb_size = 256
37 | user_emb_size = 256
38 | hidden_size = 256
39 | loc_noise_mean = 0
40 | loc_noise_sigma = 0.01
41 | tim_noise_mean = 0
42 | tim_noise_sigma = 0.01
43 | user_noise_mean = 0
44 | user_noise_sigma = 0.01
45 | tau = 4
46 | pos_eps = 0.5
47 | neg_eps = 0.5
48 | dropout_rate_1 = 0.5
49 | dropout_rate_2 = 0.5
50 | adv = 1
51 | self_weight = 0.05
52 | self_weight_s = 0.05
53 | self_weight_t = 0.05
54 | self_weight_st = 0.05
55 | k = 8
56 | momentum = 0.95
57 | theta = 0.18
58 | temperature = 0.1
59 | rnn_type = GRU
60 | num_layers = 3
61 | downstream = TUL
62 | # downstream = TPP
63 | dump_path = checkpoints
64 | rank = 0
65 | queue_length = 1024
66 | world_size = -1
67 | epoch_queue_starts = 0
68 | crops_for_assign = 01
69 | feat_dim = 256
70 | loss = mae
71 | tpp = linear
72 | epsilon = 0.05
73 | dropout_spatial = 0.3
74 | learnable_param_size = 1
--------------------------------------------------------------------------------
/model/configuration_phi.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | import math
5 | from typing import Optional
6 |
7 | from transformers import PretrainedConfig
8 |
9 |
10 | class PhiConfig(PretrainedConfig):
11 | """Phi configuration."""
12 |
13 | model_type = "phi-msft"
14 | attribute_map = {
15 | "max_position_embeddings": "n_positions",
16 | "hidden_size": "n_embd",
17 | "num_attention_heads": "n_head",
18 | "num_hidden_layers": "n_layer",
19 | }
20 |
21 | def __init__(
22 | self,
23 | vocab_size: int = 50304,
24 | n_positions: int = 2048,
25 | n_embd: int = 1024,
26 | n_layer: int = 20,
27 | n_inner: Optional[int] = None,
28 | n_head: int = 16,
29 | n_head_kv: Optional[int] = None,
30 | rotary_dim: Optional[int] = 32,
31 | activation_function: Optional[str] = "gelu_new",
32 | flash_attn: bool = False,
33 | flash_rotary: bool = False,
34 | fused_dense: bool = False,
35 | attn_pdrop: float = 0.0,
36 | embd_pdrop: float = 0.0,
37 | resid_pdrop: float = 0.0,
38 | layer_norm_epsilon: float = 1e-5,
39 | initializer_range: float = 0.02,
40 | tie_word_embeddings: bool = False,
41 | pad_vocab_size_multiple: int = 64,
42 | **kwargs
43 | ) -> None:
44 | self.vocab_size = int(math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
45 | self.n_positions = n_positions
46 | self.n_embd = n_embd
47 | self.n_layer = n_layer
48 | self.n_inner = n_inner
49 | self.n_head = n_head
50 | self.n_head_kv = n_head_kv
51 | self.rotary_dim = min(rotary_dim, n_embd // n_head)
52 | self.activation_function = activation_function
53 | self.flash_attn = flash_attn
54 | self.flash_rotary = flash_rotary
55 | self.fused_dense = fused_dense
56 | self.attn_pdrop = attn_pdrop
57 | self.embd_pdrop = embd_pdrop
58 | self.resid_pdrop = resid_pdrop
59 | self.layer_norm_epsilon = layer_norm_epsilon
60 | self.initializer_range = initializer_range
61 |
62 | super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
63 |
--------------------------------------------------------------------------------
/model/layers.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import numpy as np
4 | import torch
5 | from torch import nn
6 |
7 |
8 | class ContinuousEncoding(nn.Module):
9 | """
10 | A type of trigonometric encoding for encode continuous values into distance-sensitive vectors.
11 | """
12 |
13 | def __init__(self, embed_size):
14 | super().__init__()
15 | self.omega = nn.Parameter((torch.from_numpy(1 / 10 ** np.linspace(0, 9, embed_size))).float(),
16 | requires_grad=True)
17 | self.bias = nn.Parameter(torch.zeros(embed_size).float(), requires_grad=True)
18 | self.div_term = math.sqrt(1. / embed_size)
19 |
20 | def forward(self, x):
21 | """
22 | :param x: input sequence for encoding, (batch_size, seq_len)
23 | :return: encoded sequence, shape (batch_size, seq_len, embed_size)
24 | """
25 | encode = x.unsqueeze(-1) * self.omega.reshape(1, 1, -1) + self.bias.reshape(1, 1, -1)
26 | encode = torch.cos(encode)
27 | return self.div_term * encode
28 |
29 |
30 | class PositionalEncoding(nn.Module):
31 | """
32 | A type of trigonometric encoding for indicating items' positions in sequences.
33 | """
34 |
35 | def __init__(self, embed_size, max_len):
36 | super().__init__()
37 |
38 | pe = torch.zeros(max_len, embed_size).float()
39 | pe.requires_grad = False
40 |
41 | position = torch.arange(0, max_len).float().unsqueeze(1)
42 | div_term = (torch.arange(0, embed_size, 2).float() * -(math.log(10000.0) / embed_size)).exp()
43 |
44 | pe[:, 0::2] = torch.sin(position * div_term)
45 | pe[:, 1::2] = torch.cos(position * div_term)
46 |
47 | pe = pe.unsqueeze(0)
48 | self.register_buffer('pe', pe)
49 |
50 | def forward(self, x, position_ids=None):
51 | """
52 | Args:
53 | x: (B, T, d_model)
54 | position_ids: (B, T) or None
55 |
56 | Returns:
57 | (1, T, d_model) / (B, T, d_model)
58 | """
59 | if position_ids is None:
60 | return self.pe[:, :x.size(1)]
61 | else:
62 | batch_size, seq_len = position_ids.shape
63 | pe = self.pe[:, :seq_len, :] # (1, T, d_model)
64 | pe = pe.expand((position_ids.shape[0], -1, -1)) # (B, T, d_model)
65 | pe = pe.reshape(-1, self.d_model) # (B * T, d_model)
66 | position_ids = position_ids.reshape(-1, 1).squeeze(1) # (B * T,)
67 | output_pe = pe[position_ids].reshape(batch_size, seq_len, self.d_model).detach()
68 | return output_pe
69 |
70 |
71 | class SinusoidalPositionEmbeddings(nn.Module):
72 | """
73 | Sinusoidal-based function used for encoding timestamps.
74 | """
75 |
76 | def __init__(self, dim):
77 | super().__init__()
78 | self.dim = dim
79 |
80 | def forward(self, time):
81 | device = time.device
82 | half_dim = self.dim // 2
83 | embeddings = math.log(10000) / (half_dim - 1)
84 | embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
85 | embeddings = time[:, None] * embeddings[None, :]
86 | embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
87 | return embeddings
88 |
89 |
90 | class TimeEmbed(nn.Module):
91 | def __init__(self, input_dim, output_dim):
92 | super().__init__()
93 |
94 | self.time_mlp = nn.Sequential(
95 | SinusoidalPositionEmbeddings(input_dim),
96 | nn.Linear(input_dim, output_dim),
97 | nn.SiLU(),
98 | nn.Linear(output_dim, output_dim)
99 | )
100 |
101 | def forward(self, time):
102 | return self.time_mlp(time)
103 |
--------------------------------------------------------------------------------
/model/positional_encoding.py:
--------------------------------------------------------------------------------
1 | '''
2 | Unofficial pytorch implementation of the paper "Learnable Fourier Features for Multi-Dimensional Spatial Positional Encoding", NeurIPS 2021.
3 | '''
4 | import torch
5 | import torch.nn as nn
6 | import numpy as np
7 | from einops import rearrange
8 |
9 |
10 | # Learnable Fourier Features for Multi-Dimensional Spatial Positional Encoding
11 | class LearnableFourierFeatures(nn.Module):
12 | def __init__(self, pos_dim, f_dim, h_dim, d_dim, g_dim=1, gamma=1.0):
13 | super(LearnableFourierFeatures, self).__init__()
14 | assert f_dim % 2 == 0, 'number of fourier feature dimensions must be divisible by 2.'
15 | assert d_dim % g_dim == 0, 'number of D dimension must be divisible by the number of G dimension.'
16 | enc_f_dim = int(f_dim / 2)
17 | dg_dim = int(d_dim / g_dim)
18 | self.Wr = nn.Parameter(torch.randn([enc_f_dim, pos_dim]) * (gamma ** 2))
19 | self.mlp = nn.Sequential(
20 | nn.Linear(f_dim, h_dim),
21 | nn.GELU(),
22 | nn.Linear(h_dim, dg_dim)
23 | )
24 | self.div_term = np.sqrt(f_dim)
25 |
26 | def forward(self, pos):
27 | # input pos dim: (B L G M)
28 | # output dim: (B L D)
29 | # L stands for sequence length. all dimensions must be flattened to a single dimension.
30 | XWr = torch.matmul(pos, self.Wr.T)
31 | F = torch.cat([torch.cos(XWr), torch.sin(XWr)], dim=-1) / self.div_term
32 | Y = self.mlp(F)
33 | pos_enc = rearrange(Y, 'b l g d -> b l (g d)')
34 |
35 | #return pos_enc
36 | return Y.squeeze()
37 |
38 |
39 |
40 | # Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains
41 | class FourierFeatures(nn.Module):
42 | def __init__(self, pos_dim, f_dim, sigma=10, train=False):
43 | super(FourierFeatures, self).__init__()
44 | assert f_dim % 2 == 0, 'number of channels must be divisible by 2.'
45 | enc_dim = int(f_dim / 2)
46 | self.B = torch.randn([pos_dim, enc_dim]) * sigma
47 | if train:
48 | self.B = nn.Parameter(self.B)
49 |
50 | def forward(self, pos):
51 | # pos: (B L C), (B H W C), (B H W T C)
52 | pos_enc = torch.matmul(pos, self.B.to(pos.device))
53 | pos_enc = torch.cat([torch.sin(pos_enc), torch.cos(pos_enc)], dim=-1)
54 | return pos_enc
55 |
56 |
57 | # Attention is All You Need
58 | class PositionalEncoding(nn.Module):
59 | def __init__(self, pos_dim, enc_dim):
60 | super(PositionalEncoding, self).__init__()
61 | assert enc_dim % (pos_dim * 2) == 0, 'dimension of positional encoding must be equal to dim * 2.'
62 | enc_dim = int(enc_dim / 2)
63 | div_term = torch.exp(torch.arange(0., enc_dim, 2) * -(np.log(10000.0) / enc_dim))
64 | freqs = torch.zeros([pos_dim, enc_dim])
65 | for i in range(pos_dim):
66 | freqs[i, : enc_dim // 2] = div_term
67 | freqs[i, enc_dim // 2:] = div_term
68 | self.freqs = freqs
69 |
70 | def forward(self, pos):
71 | # pos: (B L C), (B H W C), (B H W T C)
72 | pos_enc = torch.matmul(pos, self.freqs.to(pos.device))
73 | pos_enc = torch.cat([torch.sin(pos_enc), torch.cos(pos_enc)], dim=-1)
74 | return pos_enc
75 |
76 |
77 | if __name__ == '__main__':
78 | """
79 | example usage of LearnableFourierFeatures
80 |
81 | let
82 | positional dimension: 2 (2d spatial positions)
83 | fourier feature dimension: 128
84 | hidden dimension: 256
85 | positional encoding dimension: 64
86 | number of positional groups: 1
87 |
88 | batch size: 4
89 | sequence length: 1024 (== 32x32 in 2d spatial resolution)
90 | number of positional groups: 1
91 | positional dimension: 2
92 | """
93 | lff = LearnableFourierFeatures(pos_dim=2, f_dim=128, h_dim=256, d_dim=64, g_dim=1).cuda()
94 | pos = torch.randn([4, 1024, 1, 2]).cuda()
95 | pe = lff(pos)
96 | print(pe)
97 | print(pe.shape)
98 |
99 |
100 |
101 | """
102 | example usage of FourierFeatures
103 |
104 | let
105 | positional dimension: 2 (2d spatial positions)
106 | fourier feature dimension: 256
107 |
108 | batch size: 4
109 | sequence length: 32x32
110 | positional dimension: 2
111 | """
112 | ff = FourierFeatures(pos_dim=2, f_dim=256).cuda()
113 | pos = torch.randn([4, 32, 32, 2]).cuda()
114 | pe = ff(pos)
115 | print(pe)
116 | print(pe.shape)
117 |
118 |
119 |
120 | """
121 | example usage of PositionalEncoding
122 |
123 | let
124 | positional dimension: 2 (2d spatial positions)
125 | encoding dimension: 256
126 |
127 | batch size: 4
128 | sequence length: 1024
129 | positional dimension: 2
130 | """
131 | PE = PositionalEncoding(pos_dim=2, enc_dim=256).cuda()
132 | pos = torch.randn([4, 1024, 2]).cuda()
133 | pe = PE(pos)
134 | print(pe)
135 | print(pe.shape)
136 |
137 |
138 |
139 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Mobility-LLM: Learning Visiting Intentions and Travel Preferences from Human Mobility Data with Large Language Models
2 | ### Datasets
3 | To demonstrate the superiority of our proposed model, our experiements are carried out on four real-world datasets derived from Gowalla (GOW), WeePlace (WEE), Brightkite (BKC) and FourSquare (TKY) check-in data.
4 |
5 | In order to facilitate the training of our model, Our model undergoes a filtering process that selects high-quality check-in sequences for training. To ensure data consistency, we set a maximum historical time limit of 120 days and filter out users with fewer than 10 records and places visited fewer than 10 times.
6 |
7 | The below table shows the statistics of three datasets.
8 | | | Gowalla | WeePlace | Brightkite | FourSquare |
9 | | ---------- |:---------:|:----------:|:------------:|:------------:|
10 | | # users | 5,853 | 1,028 | 431 | 703 |
11 | | # POIs | 52,032 | 9,295 | 3,554 | 11,117 |
12 | | # Samples | 413,563 | 104,762 | 44,716 | 60,734 |
13 |
14 | - Download processed datasets from following sources:
15 | - One Drive
16 | - https://1drv.ms/f/c/401a59cded375360/EuI9Lfh3qhVPrMdO22wDbc0BjJ_-5M3YIPEOLyrGoEMD3A?e=jBGTCC
17 | - BaiduNetDisk code: `x24z`
18 | - https://pan.baidu.com/s/1nWWMzS1yQaGdaiVd-njJxQ
19 | - Copy all files and directories to `MobilityLLM/data/new_datasets`
20 |
21 | ### Large Lanugage Models
22 | We compare eight representative backbones with varying capacities, including TinyLlama, TinyLlama-Chat, LiteLlama, phi-2, pythia-70M, pythia-1B, pythia-2.8B and GPT-2.
23 | - Download models from following sources:
24 | - https://huggingface.co/models
25 | - Copy all files and directories to `MobilityLLM/params/*/`
26 |
27 | ### Requirements
28 | - python >= 3.6
29 | - PyTorch >= 1.8
30 |
31 | ### Usage :
32 | Enter directory `MobilityLLM`.
33 | Downstream tasks:
34 | Location Prediction (LP), Trajectory User Link (TUL), or Time Prediction (TP).
35 | model class:
36 | TinyLlama-1_1B (TinyLlama), TinyLlama-Chat (TinyLlama-Chat), phi-2 (phi-2), pythia-70M (pythia-70M), pythia-2_8B (pythia-2.8B), pythia-1B (pythia-1B), LiteLlama (LiteLlama), gpt2 (GPT-2).
37 | - Train model on WEE of LP task:
38 | `python train_MobilityLLM.py --config config/MobilityLLM_wee_POI.conf --dataroot data/ --model_class`
39 |
40 | - Train model on TKY of LP task:
41 | `python train_MobilityLLM.py --config config/MobilityLLM_tky_POI.conf --dataroot data/ --model_class`
42 |
43 | - Train model on GOW of LP task:
44 | `python train_MobilityLLM.py --config config/MobilityLLM_gow_POI.conf --dataroot data/ --model_class`
45 |
46 | - Train model on BKC of LP task:
47 | `python train_MobilityLLM.py --config config/MobilityLLM_bkc_POI.conf --dataroot data/ --model_class`
48 |
49 | - Train model on WEE of TUL task:
50 | `python train_MobilityLLM.py --config config/MobilityLLM_wee_TUL.conf --dataroot data/ --model_class`
51 |
52 | - Train model on TKY of TUL task:
53 | `python train_MobilityLLM.py --config config/MobilityLLM_tky_TUL.conf --dataroot data/ --model_class`
54 |
55 | - Train model on GOW of TUL task:
56 | `python train_MobilityLLM.py --config config/MobilityLLM_gow_TUL.conf --dataroot data/ --model_class`
57 |
58 | - Train model on BKC of TUL task:
59 | `python train_MobilityLLM.py --config config/MobilityLLM_bkc_TUL.conf --dataroot data/ --model_class`
60 |
61 | - Train model on WEE of TP task:
62 | `python train_MobilityLLM.py --config config/MobilityLLM_wee_TPP.conf --dataroot data/ --model_class`
63 |
64 | - Train model on TKY of TP task:
65 | `python train_MobilityLLM.py --config config/MobilityLLM_tky_TPP.conf --dataroot data/ --model_class`
66 |
67 | - Train model on GOW of TP task:
68 | `python train_MobilityLLM.py --config config/MobilityLLM_gow_TPP.conf --dataroot data/ --model_class`
69 |
70 | - Train model on BKC of TP task:
71 | `python train_MobilityLLM.py --config config/MobilityLLM_bkc_TPP.conf --dataroot data/ --model_class`
72 | ### Configuration
73 | The configuration file `MobilityLLM_*.conf` contains three parts: Data, Training and Model:
74 |
75 | #### Data
76 | - dataset_name: The name of the datasets, represents www_GOW, www_BKC, www_TKY or www_WEE.
77 | - max_his_period_days: The max history time.
78 | - max_merge_seconds_limit: To judge whether two identical locations are the same event.
79 | - max_delta_mins: To limit the prediction range.
80 | - least_disuser_count: To filter locations, keep locations which have at least * users.
81 | - least_checkins_count: To filter users, keep users who have at least * checkins.
82 | - split_save: 1 or 0, representing whether datasets are split saved.
83 |
84 | #### Training
85 | - mode: train for default,
86 | - ctx: cuda index, 0 for default
87 | - regularization: float, regularization factor.
88 | - learning_rate: float
89 | - max_epochs: int
90 | - display_step: int
91 | - patience: int, for early stopping.
92 | - train_batch: int
93 | - val_batch: int
94 | - test_batch: int
95 | - batch_size: int
96 | - save_results: bool
97 |
98 | #### Model
99 | - adv: 0 or 1, enable adversarial or not.
100 | - downstream: POI, TUL or TPP, representing Location Prediction, Trajectory User Link, and Time Prediction respestively.
101 |
102 | The remaining parameters are the best parameters of the model.
103 |
--------------------------------------------------------------------------------
/model/bert.py:
--------------------------------------------------------------------------------
1 | from transformers import BertForMaskedLM
2 |
3 | from .base import *
4 | from .layers import PositionalEncoding, ContinuousEncoding
5 |
6 |
7 | def get_batch_mask(B, L, valid_len):
8 | mask = repeat(torch.arange(end=L, device=valid_len.device),
9 | 'L -> B L', B=B) < repeat(valid_len, 'B -> B L', L=L) # (B, L)
10 | return mask
11 |
12 |
13 | class BertLM(Encoder):
14 | def __init__(self, d_model, output_size, bert_path,
15 | dis_feats=[], num_embeds=[], con_feats=[],
16 | mlp_grad=True, pooling='mean'):
17 | super().__init__('BertLM-d{}-o{}-p{}'.format(d_model, output_size, pooling))
18 |
19 | self.output_size = output_size
20 | self.d_model = d_model
21 | self.pooling = pooling
22 | self.dis_feats = dis_feats
23 | self.con_feats = con_feats
24 |
25 | self.pos_encode = PositionalEncoding(d_model, max_len=2001)
26 | if len(dis_feats):
27 | assert len(dis_feats) == len(num_embeds), \
28 | 'length of num_embeds list should be equal to the number of discrete features.'
29 | self.dis_embeds = nn.ModuleList([nn.Embedding(num_embed, d_model) for num_embed in num_embeds])
30 | else:
31 | self.dis_embeds = None
32 |
33 | if len(con_feats):
34 | self.con_embeds = nn.ModuleList([ContinuousEncoding(d_model) for _ in con_feats])
35 | else:
36 | self.con_embeds = None
37 |
38 | self.bert = BertForMaskedLM.from_pretrained(bert_path)
39 | emb_size = self.bert.config.emb_size
40 | self.emb_size = emb_size
41 | self.hidden_size = self.bert.config.hidden_size
42 |
43 | # Froze or Fine-tune the parameters of BERT.
44 | for i, (name, param) in enumerate(self.bert.named_parameters()):
45 | if 'ln' in name or 'wpe' in name: # or 'mlp' in name:
46 | param.requires_grad = True
47 | elif 'mlp' in name and mlp_grad:
48 | param.requires_grad = True
49 | else:
50 | param.requires_grad = False
51 |
52 | self.seq_projector = nn.Sequential(nn.Linear(d_model, emb_size, bias=False),
53 | nn.LayerNorm(emb_size),
54 | nn.ReLU(inplace=True),
55 | nn.Linear(emb_size, emb_size))
56 | self.poi_projector = nn.Sequential(nn.Linear(emb_size, emb_size, bias=False),
57 | nn.LayerNorm(emb_size),
58 | nn.ReLU(inplace=True),
59 | nn.Linear(emb_size, emb_size))
60 |
61 | self.out_linear = nn.Sequential(nn.Linear(self.bert.config.hidden_size, output_size, bias=False),
62 | nn.LayerNorm(output_size),
63 | nn.ReLU(inplace=True),
64 | nn.Linear(output_size, output_size))
65 |
66 | self.trip_mask_token = nn.Parameter(torch.zeros(emb_size).float(), requires_grad=True)
67 | self.poi_mask_token = nn.Parameter(torch.zeros(emb_size).float(), requires_grad=True)
68 |
69 | def forward(self, trip, valid_len, o_poi_embeddings, d_poi_embeddings,
70 | trip_mask=None, poi_mask=None, **kwargs):
71 | B, L, E_in = trip.shape
72 |
73 | trip_batch_mask = get_batch_mask(B, L+2, valid_len+1)
74 | src_batch_mask = get_batch_mask(B, L+2, valid_len+2)
75 |
76 | h = torch.zeros(B, trip.size(1), self.d_model).to(trip.device)
77 | if self.dis_embeds is not None:
78 | for dis_embed, dis_feat in zip(self.dis_embeds, self.dis_feats):
79 | h += dis_embed(trip[..., dis_feat].long()) # (B, L, E)
80 | if self.con_embeds is not None:
81 | for con_embed, con_feat in zip(self.con_embeds, self.con_feats):
82 | h += con_embed(trip[..., con_feat].float())
83 | h += self.pos_encode(h)
84 | h = self.seq_projector(h)
85 |
86 | input_seq = torch.zeros(B, L+2, self.emb_size).to(trip.device)
87 | input_seq[:, 0] += self.poi_projector(o_poi_embeddings)
88 | input_seq[:, 1:-1] += h
89 | input_seq[src_batch_mask.long() - trip_batch_mask.long() == 1] += self.poi_projector(d_poi_embeddings)
90 |
91 | if trip_mask is not None:
92 | input_seq[:, 1:-1][trip_mask] = self.trip_mask_token
93 | if poi_mask is not None:
94 | o_placeholder = torch.zeros(B, L+2).to(trip.device)
95 | o_placeholder[:, 0] = 1
96 | o_placeholder = o_placeholder.bool()
97 | o_poi_mask = repeat(poi_mask[:, 0], 'b -> b l', l=L+2)
98 | d_poi_mask = repeat(poi_mask[:, 1], 'b -> b l', l=L+2)
99 | input_seq[o_placeholder & o_poi_mask] = self.poi_mask_token
100 | input_seq[(src_batch_mask.long() - trip_batch_mask.long() == 1) & d_poi_mask] = self.poi_mask_token
101 |
102 | memory = self.bert(inputs_embeds=input_seq, attention_mask=src_batch_mask, output_hidden_states=True).hidden_states[-1]
103 | memory = torch.nan_to_num(memory)
104 |
105 | memory = self.out_linear(memory) # (B, E_out) or (B, L, E_out)
106 |
107 | if bool(kwargs.get('pretrain', False)):
108 | return memory
109 |
110 | if self.pooling == 'mean':
111 | mask_expanded = repeat(src_batch_mask.logical_not(), 'B L -> B L E', E=memory.size(2)) # (B, L, E)
112 | memory = memory.masked_fill(mask_expanded, 0) # (B, L, E)
113 | memory = torch.sum(memory, 1) / valid_len.unsqueeze(-1)
114 | elif self.pooling == 'cls':
115 | memory = memory[:, 0]
116 |
117 | return memory
118 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import math
4 | import nni
5 | #import seaborn as sns
6 | #import matplotlib.pyplot as plt
7 | import torch
8 | import torch.nn.functional as F
9 |
10 | class DotDict(dict):
11 | __getattr__ = dict.__getitem__
12 | __setattr__ = dict.__setitem__
13 | __delattr__ = dict.__delitem__
14 |
15 |
16 | def evaluate_location(y, top_k_pred, K=20):
17 | '''
18 | get hit ratio, mrr
19 | :param y: (batch,)
20 | :param top_k_pred: (batch, num_class)
21 | :param K
22 | :return:
23 | '''
24 | total_num = top_k_pred.shape[0]
25 | hit_ratio = np.zeros(K)
26 | mrr = []
27 | for i in range(total_num):
28 | rank = np.where(top_k_pred[i] == y[i])[0] + 1
29 | mrr.append(rank)
30 | for j in range(1, K+1):
31 | if y[i] in set(top_k_pred[i, :j]):
32 | hit_ratio[j-1] = hit_ratio[j-1] + 1
33 | hit_ratio = hit_ratio/total_num
34 | # print('mrr:',mrr)
35 | mrr = (1/np.array(mrr)).mean()
36 | return hit_ratio, mrr
37 |
38 |
39 | def get_total_prob_c(loader, model, gts, gss, time_info, week_info, feature_category, feature_lat, feature_lng, feature_lat_ori, feature_lng_ori, save_filename=None, params_path=None, distance=None):
40 | '''
41 | calculates the loss, mae and mape for the entire data loader
42 | :param loader:
43 | :param save:
44 | :return:
45 | '''
46 | all_topic = []
47 | all_label = []
48 | for input in loader:
49 | topic, gamma_c = model.get_gammac(input, gss, feature_category, feature_lat, feature_lng, feature_lat_ori, feature_lng_ori, gts, time_info, week_info, distance=distance) # (batch_size,), (batch_size,)
50 | all_topic.append(topic.detach().cpu().numpy())
51 | all_label.append(gamma_c.detach().cpu().numpy())
52 | all_topic = np.concatenate(all_topic)
53 | all_label = np.concatenate(all_label)
54 | all_label_index = np.argmax(all_label, axis=1)
55 | print('all_label:', all_label[:10])
56 | print('all_label_index:', all_label_index[:10])
57 | if save_filename is not None:
58 | filename = os.path.join(params_path, save_filename + '_gammac.npz')
59 | np.savez(filename, all_topic=all_topic, all_label=all_label, all_label_index=all_label_index)
60 | return all_topic, all_label, all_label_index
61 |
62 |
63 | def density_visualization(density, ground_truth, batch_cnt):
64 | '''
65 | plot the probability density function
66 | :param density:
67 | :param ground_truth:
68 | :return:
69 | '''
70 | n_samples = 1000
71 | hours = 48
72 | x = np.linspace(0, hours, n_samples)
73 | t = ground_truth
74 | cnt = 0
75 | length = len(density)
76 | '''
77 | for i in range(length):
78 | y = density[i]
79 | plt.plot(x, y, "r-", label="STDGN")
80 | plt.legend()
81 | plt.xlabel(r"$\tau$", fontdict={'family': 'Times New Roman', 'size':16})
82 | plt.ylabel(r"p($\tau$)", fontdict={'family': 'Times New Roman', 'size':16})
83 | plt.yticks(fontproperties = 'Times New Roman', size = 14)
84 | plt.xticks(fontproperties = 'Times New Roman', size = 14)
85 | plt.grid()
86 | # plt.title("the probability density function in JKT dataset", fontdict={'family': 'Times New Roman', 'size':16})
87 | true_value = round(t[i], 2)
88 | plt.axvline(x=true_value, ls=":", c="black")
89 | plt.text(x=true_value + 1, y=1/2*np.max(y), s=r"$\tau_{n+1}$=" + str(true_value), size=16, alpha=0.8)
90 | plt.legend(prop={'family' : 'Times New Roman', 'size' : 16})
91 | plt.show()
92 | pic_name = str(batch_cnt) + '_' + str(cnt)
93 | cnt += 1
94 | plt.savefig(f'./data/density/jkt_{pic_name}.png')
95 | plt.savefig(f'./data/density/jkt_{pic_name}.eps',format='eps', dpi=10000)
96 | plt.close()
97 | '''
98 |
99 |
100 | def softmax(x):
101 | '''
102 | self-define softmax operation
103 | :param x:
104 | :return:
105 | '''
106 | # print("before: ", x)
107 | x -= np.max(x, axis=1, keepdims=True) # for stationary computation
108 | x = np.exp(x) / np.sum(np.exp(x), axis=1, keepdims=True) # formula
109 | # print("after: ", x)
110 | return x
111 |
112 |
113 | def rad(d):
114 | '''
115 | rad the latitude and longitude
116 | :param d: latitude or longitude
117 | :return rad:
118 | '''
119 | return d * math.pi / 180.0
120 |
121 |
122 | def getDistance(lat1, lng1, lat2, lng2):
123 | '''
124 | get the distance between two location using their latitude and longitude
125 | :param lat1:
126 | :param lng1:
127 | :param lat2:
128 | :param lng2:
129 | :return s:
130 | '''
131 | EARTH_REDIUS = 6378.137
132 | radLat1 = rad(lat1)
133 | radLat2 = rad(lat2)
134 | a = radLat1 - radLat2
135 | b = rad(lng1) - rad(lng2)
136 | s = 2 * math.asin(math.sqrt(math.pow(math.sin(a/2), 2) + math.cos(radLat1) * math.cos(radLat2) * math.pow(math.sin(b/2), 2)))
137 | s = s * EARTH_REDIUS
138 | return s
139 |
140 |
141 | def get_s_baselines_total_loss_s_for_MobilityLLM_RNN(loader, model, save_filename=None, params_path=None):
142 | all_loss_s = []
143 | all_ground_truth_location = []
144 | all_predicted_topK = []
145 |
146 | for input in loader:
147 | s_loss_score, top_k_pred = model(input)
148 | all_loss_s.append(s_loss_score.detach().cpu().numpy())
149 | all_ground_truth_location.append(input.Y_location.cpu().numpy())
150 | all_predicted_topK.append(top_k_pred.cpu().numpy())
151 |
152 | all_loss_s = np.array(all_loss_s)
153 | all_loss_s = np.mean(all_loss_s)
154 |
155 | all_ground_truth_location = np.concatenate(all_ground_truth_location)
156 | all_predicted_topK = np.concatenate(all_predicted_topK)
157 |
158 | hit_ratio, mrr = evaluate_location(all_ground_truth_location, all_predicted_topK)
159 |
160 | if save_filename is not None:
161 | filename = os.path.join(params_path, save_filename + '_results.npz')
162 | np.savez(filename, all_ground_truth_location=all_ground_truth_location, all_predicted_topK=all_predicted_topK)
163 |
164 | return all_loss_s, hit_ratio, mrr
165 |
166 |
167 | # for downstream validation bencnmark
168 | def get_s_baselines_total_loss_s_for_MobilityLLM_DEMO_DOWN(loader, model, downstream='POI', save_filename=None, params_path=None):
169 | all_loss_s = []
170 | all_ground_truth_users = []
171 | all_predicted_topK = []
172 | for input in loader:
173 | s_loss_score, top_k_pred = model(input, mode='downstream', downstream=downstream)
174 | all_loss_s.append(s_loss_score.detach().cpu().numpy())
175 | if downstream == 'POI':
176 | all_ground_truth_users.append(input.Y_location.cpu().numpy())
177 | # all_ground_truth_users.append(torch.index_select(torch.tensor(input.Y_location), dim=0, index=indice).cpu().numpy())
178 | elif downstream == 'TUL':
179 | all_ground_truth_users.append(input.X_users.cpu().numpy())
180 | else:
181 | raise ValueError('downstream is not in [POI, TUL]')
182 |
183 | all_predicted_topK.append(top_k_pred.cpu().numpy())
184 |
185 | all_loss_s = np.array(all_loss_s)
186 | all_loss_s = np.mean(all_loss_s)
187 |
188 | all_ground_truth_users = np.concatenate(all_ground_truth_users)
189 | all_predicted_topK = np.concatenate(all_predicted_topK)
190 |
191 | hit_ratio, mrr = evaluate_location(all_ground_truth_users, all_predicted_topK)
192 |
193 | if save_filename is not None:
194 | filename = os.path.join(params_path, save_filename + '_results.npz')
195 | np.savez(filename, all_ground_truth_users=all_ground_truth_users, all_predicted_topK=all_predicted_topK)
196 |
197 | return all_loss_s, hit_ratio, mrr
198 |
199 |
200 | # for downstream validation bencnmark
201 | def get_s_baselines_total_loss_s_for_MobilityLLM_DOWN(loader, model, downstream='POI', save_filename=None, params_path=None):
202 | all_loss_s = []
203 | all_ground_truth_users = []
204 | all_predicted_topK = []
205 | for input in loader:
206 | if input.X_all_loc.shape[1] >= 700:
207 | continue
208 | s_loss_score, top_k_pred, _ = model(input, mode='downstream', downstream=downstream)
209 | all_loss_s.append(s_loss_score.detach().cpu().numpy())
210 | if downstream == 'POI':
211 | all_ground_truth_users.append(input.Y_location.cpu().numpy())
212 | elif downstream == 'TUL':
213 | all_ground_truth_users.append(input.X_users.cpu().numpy())
214 | else:
215 | raise ValueError('downstream is not in [POI, TUL]')
216 |
217 | all_predicted_topK.append(top_k_pred.cpu().numpy())
218 |
219 | all_loss_s = np.array(all_loss_s)
220 | all_loss_s = np.mean(all_loss_s)
221 |
222 | all_ground_truth_users = np.concatenate(all_ground_truth_users)
223 | all_predicted_topK = np.concatenate(all_predicted_topK)
224 |
225 | hit_ratio, mrr = evaluate_location(all_ground_truth_users, all_predicted_topK)
226 |
227 | if save_filename is not None:
228 | filename = os.path.join(params_path, save_filename + '_results.npz')
229 | np.savez(filename, all_ground_truth_users=all_ground_truth_users, all_predicted_topK=all_predicted_topK)
230 |
231 | return all_loss_s, hit_ratio, mrr
232 |
233 |
234 | def get_t_for_IFLTPP(loader, model, save_filename=None, params_path=None, experiment_base_dir=None, use_nni=False):
235 | '''
236 | calculates the loss, mae and mape for the entire data loader
237 | :param loader:
238 | :param save:
239 | :return:
240 | '''
241 | ground_truth_Y_tau = []
242 | predicted_Y_tau = []
243 | all_tnll_t = []
244 | all_nll = []
245 |
246 | for input in loader:
247 | tnll, mean, _ = model(input, cont_conf=None, mode='downstream', downstream='TPP') # (batch_size,), (batch_size,)
248 | ground_truth_Y_tau.append(model.truth_Y_tau.detach().cpu().numpy())
249 | predicted_Y_tau.append(mean.detach().cpu().numpy())
250 | all_tnll_t.append(tnll.detach().cpu().numpy())
251 |
252 | all_tnll_t = np.array(all_tnll_t)
253 |
254 | ground_truth_Y_tau = np.concatenate(ground_truth_Y_tau).flatten()
255 | predicted_Y_tau = np.concatenate(predicted_Y_tau).flatten()
256 |
257 | mae = np.mean(abs(ground_truth_Y_tau - predicted_Y_tau))
258 | rmse = np.sqrt(((ground_truth_Y_tau - predicted_Y_tau) ** 2).mean())
259 | mape = np.mean(abs(ground_truth_Y_tau - predicted_Y_tau) / np.maximum(np.mean(ground_truth_Y_tau), ground_truth_Y_tau))
260 | tnll_t = np.mean(all_tnll_t)
261 |
262 | if (save_filename is not None) and (not use_nni):
263 | filename = os.path.join(params_path, save_filename+'_results.npz')
264 | np.savez(filename, ground_truth_Y_tau=ground_truth_Y_tau, predicted_Y_tau=predicted_Y_tau)
265 |
266 | return mae, mape, rmse, tnll_t
267 |
268 |
269 | def get_semantic_information(cnt2category, data_root):
270 | import pickle
271 | vecpath = data_root + "glove.twitter.27B.50d.pkl"
272 | pkl_data = open(vecpath, "rb")
273 | word_vec = pickle.load(pkl_data)
274 | for word in word_vec.keys():
275 | word_vec[word] = word_vec[word]
276 | pkl_data.close()
277 |
278 | word_id = 0
279 | dataset_word_vec = []
280 | dataset_word_index = {}
281 | categories = cnt2category.values()
282 | for category in categories:
283 | words = category.split(" ")
284 | # print(words)
285 | for word in words:
286 | word = word.lower()
287 | if (word in word_vec) and (word not in dataset_word_index):
288 | dataset_word_index[word] = word_id
289 | word_id += 1
290 | dataset_word_vec.append(word_vec[word])
291 | print("word_index: ", dataset_word_index)
292 | return dataset_word_vec, dataset_word_index, word_id
293 |
294 |
295 |
--------------------------------------------------------------------------------
/preprocess/load_data.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data_utils
2 | import os
3 | import itertools
4 | import numpy as np
5 | import torch
6 |
7 |
8 | def tid_list_48(tm):
9 | if tm.weekday() in [0, 1, 2, 3, 4]:
10 | tid = int(tm.hour)
11 | else:
12 | tid = int(tm.hour) + 24
13 | return tid
14 |
15 |
16 | def load_data_from_dataset(set_name, loader, device, user_cnt, venue_cnt, save_split, name, data_root):
17 | X_target_lengths = loader[f'{set_name}X_target_lengths']
18 | X_arrival_times = loader[f'{set_name}X_arrival_times']
19 | X_users = loader[f'{set_name}X_users']
20 | X_locations = loader[f'{set_name}X_locations']
21 | X_locations_category = loader[f'{set_name}X_locations_category']
22 | X_lats = loader[f'{set_name}X_lats']
23 | X_lons = loader[f'{set_name}X_lons']
24 | X_geohash = loader[f'{set_name}X_geohash']
25 | Y_location = loader[f'{set_name}Y_locations']
26 | Y_location_category = loader[f'{set_name}Y_locations_category']
27 | Y_lat = loader[f'{set_name}Y_lats']
28 | Y_lon = loader[f'{set_name}Y_lons']
29 | Y_geohash = loader[f'{set_name}Y_geohash']
30 |
31 | X_tau = pad_1d(loader[f'{set_name}X_delta_times'])
32 | Y_tau = pad_1d(loader[f'{set_name}Y_delta_times'])
33 |
34 | X_all_loc = []
35 | X_all_loc_category = []
36 | X_all_lat = []
37 | X_all_lon = []
38 | X_all_geohash = []
39 | X_all_tim = []
40 | X_lengths = []
41 | for i in range(len(X_arrival_times)):
42 | tim = X_arrival_times[i]
43 | loc = X_locations[i]
44 | loc_category = X_locations_category[i]
45 | lat = X_lats[i]
46 | lon = X_lons[i]
47 | geohash = X_geohash[i]
48 |
49 | len_ = len(tim)
50 | for j in range(len_):
51 | tim[j] = tid_list_48(tim[j])
52 |
53 | X_all_loc.append(loc)
54 | X_all_loc_category.append(loc_category)
55 | X_all_lat.append(lat)
56 | X_all_lon.append(lon)
57 | X_all_geohash.append(geohash)
58 | X_all_tim.append(tim)
59 | X_lengths.append(len_)
60 |
61 | print("X_all_loc: ", len(X_all_loc), X_all_loc[0])
62 | print("X_all_loc_category: ", len(X_all_loc), X_all_loc_category[0])
63 | print("X_all_lat: ", len(X_all_loc), X_all_lat[0])
64 | print("X_all_lon: ", len(X_all_loc), X_all_lon[0])
65 | print("X_all_geohash: ", len(X_all_loc), X_all_geohash[0])
66 | print("X_all_tim: ", len(X_all_tim), X_all_tim[0])
67 | print("X_target_lengths: ", len(X_target_lengths), X_target_lengths[0])
68 | print("X_lengths: ", len(X_lengths), X_lengths[0])
69 | print("X_users:", len(X_users), X_users)
70 | print("Y_location:", len(Y_location), Y_location[0])
71 | print("Y_location_category:", len(Y_location), Y_location_category[0])
72 | print("Y_lat:", len(Y_location), Y_lat[0])
73 | print("Y_lon:", len(Y_location), Y_lon[0])
74 | print("Y_geohash:", len(Y_location), Y_geohash[0])
75 |
76 | dataset = SessionBasedSequenceDataset(device, user_cnt, venue_cnt, X_users, X_all_loc, X_all_loc_category,
77 | X_all_lat, X_all_lon, X_all_geohash,
78 | X_all_tim, Y_location, Y_location_category, Y_lat, Y_lon, Y_geohash,
79 | X_target_lengths, X_lengths, None, X_tau, Y_tau)
80 | print(f'samples cnt of data_{set_name}:', dataset.real_length())
81 |
82 | return dataset
83 |
84 |
85 | def load_dataset_for_MobilityLLM(name, data_root, save_split=False, device=None):
86 | '''
87 | 1. load data and construct train/val/test dataset
88 | 2. construct temporal graphs gts and spatial graphs gss
89 | 3. construct SessionBasedSequenceDataset
90 | :param name: file name
91 | :param log_mode: whether log(X_taus), default True
92 | :return:
93 | '''
94 | if not name.endswith('.npz'):
95 | name += '.npz'
96 |
97 | if save_split:
98 | train_loader = dict(np.load(os.path.join(data_root + 'new_datasets', 'train_' + name), allow_pickle=True))
99 | val_loader = dict(np.load(os.path.join(data_root + 'new_datasets', 'val_' + name), allow_pickle=True))
100 | loader = dict(np.load(os.path.join(data_root + 'new_datasets', 'test_' + name), allow_pickle=True))
101 | category_vector = np.load(os.path.join(data_root + 'new_datasets', 'category_' + name))
102 |
103 | else:
104 | loader = dict(np.load(os.path.join(data_root, 'test_' + name), allow_pickle=True))
105 | train_loader = loader
106 | val_loader = loader
107 | category_vector = None
108 |
109 | user_cnt = loader['user_cnt']
110 | venue_cnt = loader['venue_cnt']
111 | print('user_cnt:', user_cnt)
112 | print('venue_cnt:', venue_cnt)
113 |
114 | feature_category = loader['feature_category']
115 | feature_lat = loader['feature_lat'] # index
116 | feature_lng = loader['feature_lng'] # index
117 |
118 | # put spatial point features into tensor
119 | feature_category = torch.LongTensor(feature_category)
120 | feature_lat = torch.LongTensor(feature_lat)
121 | feature_lng = torch.LongTensor(feature_lng)
122 |
123 | latN, lngN = loader['latN'], loader['lngN']
124 | category_cnt = loader['category_cnt']
125 |
126 | # ----- load train / val / test to get dataset -----
127 | data_train = load_data_from_dataset('train', train_loader, device, user_cnt, venue_cnt, save_split, name, data_root)
128 | data_val = load_data_from_dataset('val', val_loader, device, user_cnt, venue_cnt, save_split, name, data_root)
129 | data_test = load_data_from_dataset('test', loader, device, user_cnt, venue_cnt, save_split, name, data_root)
130 |
131 | return data_train, data_val, data_test, feature_category, feature_lat, feature_lng, latN, lngN, category_cnt, \
132 | category_vector['categories']
133 |
134 |
135 | class SessionBasedSequenceDataset(data_utils.Dataset):
136 | """Dataset class containing variable length sequences.
137 | """
138 |
139 | def __init__(self, device, user_cnt, venue_cnt, X_users, X_all_loc, X_all_loc_category, X_all_lat, X_all_lon,
140 | X_all_geohash,
141 | X_all_tim, Y_location, Y_location_category, Y_lat, Y_lon, Y_geohash, target_lengths, X_lengths,
142 | X_all_text,
143 | X_tau, Y_tau):
144 | # torch.set_default_tensor_type(torch.cuda.FloatTensor)
145 | self.device = device
146 | self.user_cnt = user_cnt
147 | self.venue_cnt = venue_cnt
148 | self.X_users = X_users
149 | self.X_all_loc = X_all_loc
150 | self.X_all_loc_category = X_all_loc_category
151 | self.X_all_lat = X_all_lat
152 | self.X_all_lon = X_all_lon
153 | self.X_all_geohash = X_all_geohash
154 | self.X_all_tim = X_all_tim
155 | self.target_lengths = target_lengths
156 | self.X_lengths = X_lengths
157 | self.Y_location = Y_location
158 | self.Y_location_category = Y_location_category
159 | self.Y_lat = Y_lat
160 | self.Y_lon = Y_lon
161 | self.Y_geohash = Y_geohash
162 | self.X_all_text = X_all_text
163 |
164 | self.X_tau = [torch.Tensor(_) for _ in X_tau]
165 | # self.Y_tau = torch.Tensor(Y_tau.astype('float64')) / 60 # mins->hour
166 | self.Y_tau = torch.Tensor(Y_tau.astype('float64')) # mins
167 |
168 | self.validate_data()
169 |
170 | @property
171 | def num_series(self):
172 | return len(self.Y_location)
173 |
174 | def real_length(self):
175 | res = 0
176 | n = len(self.Y_location)
177 | for i in range(n):
178 | res += len(self.Y_location[i])
179 | return res
180 |
181 | def validate_data(self):
182 | if len(self.X_all_loc) != len(self.Y_location) or len(self.X_all_tim) != len(self.Y_location):
183 | raise ValueError("Length of X_all_loc, X_all_tim, Y_location should match")
184 |
185 | def get_tau_log_mean_std_Y(self):
186 | """Get mean and std of Y_taus."""
187 | y = torch.flatten(self.Y_tau)
188 | logy = y[y != 0].log()
189 | return logy.mean(), logy.std()
190 |
191 | def get_mean_std_Y_tau(self):
192 | """Get mean and std of Y_tau."""
193 | y = torch.flatten(self.Y_tau)
194 | y = y[y != 0]
195 | return y.mean(), y.std()
196 |
197 | def normalize_Y_tau(self, mean=None, std=None):
198 | self.Y_tau = (self.Y_tau - mean)/std
199 | return self
200 |
201 | def __getitem__(self, key):
202 | '''
203 | the outputs are feed into collate()
204 | :param key:
205 | :return:
206 | '''
207 | return self.X_all_loc[key], self.X_all_tim[key], None, self.Y_location[key], self.target_lengths[key], \
208 | self.X_lengths[key], self.X_users[key], self.X_all_loc_category[key], self.X_all_lat[key], \
209 | self.X_all_lon[key], self.Y_location_category[key], self.Y_lat[key], self.Y_lon[key], \
210 | self.X_all_geohash[key], self.Y_geohash[key], self.X_tau[key], self.Y_tau[key], self.device
211 |
212 | def __len__(self):
213 | return self.num_series
214 |
215 | def __repr__(self):
216 | pass
217 |
218 |
219 | def pad_session_data_one(data):
220 | fillvalue = 0
221 | # zip_longest
222 | data = list(zip(*itertools.zip_longest(*data, fillvalue=fillvalue)))
223 | res = []
224 | res.extend([list(data[i]) for i in range(len(data))])
225 | return res
226 |
227 |
228 | def pad_session_data_geohash(data):
229 | fillvalue = [0 for i in range(12)]
230 | # zip_longest
231 | data = list(zip(*itertools.zip_longest(*data, fillvalue=fillvalue)))
232 | res = []
233 | res.extend([list(data[i]) for i in range(len(data))])
234 | return res
235 |
236 |
237 | def collate_session_based(batch):
238 | '''
239 | get the output of dataset.__getitem__, and perform padding
240 | :param batch:
241 | :return:
242 | '''
243 | device = batch[0][-1]
244 | batch = sorted(batch, key=lambda x: len(x[0]), reverse=True)
245 |
246 | X_all_loc = [item[0] for item in batch]
247 |
248 | X_all_loc_category = [item[7] for item in batch]
249 | X_all_lat = [item[8] for item in batch]
250 | X_all_lon = [item[9] for item in batch]
251 | X_all_geohash = [item[13] for item in batch]
252 |
253 | X_all_tim = [item[1] for item in batch]
254 | X_all_text = [item[2] for item in batch]
255 | Y_location = [lid for item in batch for lid in item[3]]
256 |
257 | Y_location_category = [lid for item in batch for lid in item[10]]
258 | Y_lat = [lid for item in batch for lid in item[11]]
259 | Y_lon = [lid for item in batch for lid in item[12]]
260 | Y_geohash = [lid for item in batch for lid in item[14]]
261 |
262 | target_lengths = [item[4] for item in batch]
263 | X_lengths = [item[5] for item in batch]
264 | X_users = [item[6] for item in batch]
265 |
266 | X_tau = torch.stack([item[15] for item in batch])
267 | torch.where(X_tau<5, torch.mean(X_tau), X_tau)
268 | Y_tau = torch.stack([item[16] for item in batch])
269 | torch.where(Y_tau<5, torch.mean(Y_tau), Y_tau)
270 |
271 | padded_X_all_loc = pad_session_data_one(X_all_loc)
272 | padded_X_all_loc_category = pad_session_data_one(X_all_loc_category)
273 | padded_X_all_lat = pad_session_data_one(X_all_lat)
274 | padded_X_all_lon = pad_session_data_one(X_all_lon)
275 | padded_X_all_geohash = pad_session_data_geohash(X_all_geohash)
276 |
277 | padded_X_all_tim = pad_session_data_one(X_all_tim)
278 | padded_X_all_loc = torch.tensor(padded_X_all_loc).long()
279 | padded_X_all_tim = torch.tensor(padded_X_all_tim).long()
280 |
281 | return session_Batch(padded_X_all_loc, padded_X_all_tim, X_all_text, Y_location, target_lengths, X_lengths, X_users,
282 | padded_X_all_loc_category, padded_X_all_lat, padded_X_all_lon, Y_location_category, Y_lat,
283 | Y_lon, padded_X_all_geohash, Y_geohash, X_tau, Y_tau, device)
284 |
285 |
286 | class session_Batch():
287 | def __init__(self, padded_X_all_loc, padded_X_all_tim, X_all_text, Y_location, target_lengths, X_lengths, X_users,
288 | padded_X_all_loc_category, padded_X_all_lat, padded_X_all_lon, Y_location_category, Y_lat, Y_lon,
289 | padded_X_all_geohash, Y_geohash, X_tau, Y_tau, device):
290 | self.X_all_loc = torch.LongTensor(padded_X_all_loc).to(device) # (batch, max_all_length)
291 | self.X_all_tim = torch.LongTensor(padded_X_all_tim).to(device) # (batch, max_all_length)
292 | self.X_all_text = X_all_text
293 | self.X_all_loc_category = padded_X_all_loc_category
294 | self.X_all_lat = padded_X_all_lat
295 | self.X_all_lon = padded_X_all_lon
296 | self.X_all_geohash = torch.FloatTensor(padded_X_all_geohash).to(device)
297 | self.Y_location = torch.Tensor(Y_location).long().to(device) # (Batch,)
298 | self.Y_location_category = Y_location_category
299 | self.Y_lat = Y_lat
300 | self.Y_lon = Y_lon
301 | self.Y_geohash = Y_geohash
302 | self.target_lengths = target_lengths
303 | self.X_lengths = X_lengths
304 | self.X_users = torch.Tensor(X_users).long().to(device)
305 |
306 | self.X_tau = X_tau
307 | self.Y_tau = Y_tau
308 |
309 |
310 | def pad_1d(inputs, pad=0):
311 | def pad_data(x, length, pad):
312 | x_padded = np.pad(
313 | x, (0, length - len(x)), mode="constant", constant_values=pad
314 | )
315 | return x_padded
316 |
317 | max_len = max((len(x) for x in inputs))
318 | padded = np.stack([pad_data(x, max_len, pad) for x in inputs])
319 | return padded
320 |
--------------------------------------------------------------------------------
/model/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import copy
5 | import math
6 | from torch.distributions import Uniform
7 |
8 | class DotDict(dict):
9 | __getattr__ = dict.__getitem__
10 | __setattr__ = dict.__setitem__
11 | __delattr__ = dict.__delitem__
12 |
13 | def clamp_preserve_gradients(x, min, max):
14 | """Clamp the tensor while preserving gradients in the clamped region."""
15 | return x + (x.clamp(min, max) - x).detach()
16 |
17 |
18 | def KLD(mu, logvar):
19 | '''
20 | the KL divergency of Gaussian distribution with a standard normal distribution
21 | https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence#Multivariate_normal_distributions
22 | :param mu: the mean (batch, dim)
23 | :param logvar: the log of variance (batch, dim)
24 | :return:
25 | '''
26 | return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / mu.shape[0]
27 |
28 | def KLD_category(w):
29 | '''
30 | the KL divergency of category distribution with param=w and the uniform category distribution
31 | :param w: (batch, nClass)
32 | :return:
33 | '''
34 | nClass = w.shape[1]
35 | p = torch.ones_like(w)/nClass # (batch, nClass)
36 | # print(p[0])
37 | return torch.sum(w * torch.log(w/p)) / w.shape[0]
38 |
39 |
40 | class MLP2(nn.Module):
41 | """
42 | MLP with two outputs, one for mu, one for log(var)
43 | """
44 | def __init__(self, input_size, output_size,
45 | dropout=.0, hidden_size=128, use_selu=True):
46 | super(MLP2, self).__init__()
47 | self.hidden_size = hidden_size
48 | if self.hidden_size > 0:
49 | self.fc1 = nn.Linear(input_size, hidden_size)
50 | self.fc21 = nn.Linear(hidden_size, output_size)
51 | self.fc22 = nn.Linear(hidden_size, output_size)
52 | self.nonlinear_f = F.selu if use_selu else F.relu
53 | self.dropout = nn.Dropout(dropout)
54 | else:
55 | self.fc21 = nn.Linear(input_size, output_size)
56 | self.fc22 = nn.Linear(input_size, output_size)
57 | self.nonlinear_f = F.selu if use_selu else F.relu
58 | self.dropout = nn.Dropout(dropout)
59 |
60 | def forward(self, x):
61 | '''
62 |
63 | :param x: (batch, dim)
64 | :return:
65 | '''
66 | # print('mlp x:', x[:3,:])
67 | # print('mpl self.fc1(x):', self.fc1(x)[:3, :])
68 | # print('mpl self.nonlinear_f(self.fc1(x)):', self.nonlinear_f(self.fc1(x))[:3, :])
69 | if self.hidden_size > 0:
70 | h1 = self.dropout(self.nonlinear_f(self.fc1(x)))
71 | return self.fc21(h1), self.fc22(h1)
72 | else:
73 | return self.fc21(x), self.fc22(x)
74 |
75 |
76 | class MLP(nn.Module):
77 | """
78 | MLP with one output (not normalized) for multinomial distribution
79 | """
80 | def __init__(self, input_size, hidden_size=64, output_size=1, dropout=0.0, use_selu=True):
81 | '''
82 |
83 | :param input_size:
84 | :param hidden_size:
85 | :param output_size: the num of cluster
86 | :param dropout:
87 | :param use_selu:
88 | '''
89 | super(MLP, self).__init__()
90 | self.fc1 = nn.Linear(input_size, hidden_size)
91 | self.fc2 = nn.Linear(hidden_size, output_size)
92 | self.nonlinear_f = F.selu if use_selu else F.leaky_relu
93 | self.dropout = nn.Dropout(dropout)
94 |
95 | def forward(self, x):
96 | h1 = self.dropout(self.nonlinear_f(self.fc1(x)))
97 | return self.fc2(h1)
98 |
99 |
100 | def clones(module, N):
101 | "Produce N identical layers."
102 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
103 |
104 |
105 | class SublayerConnection(nn.Module):
106 | """
107 | A residual connection followed by a layer norm.
108 | Note for code simplicity the norm is first as opposed to last.
109 | """
110 | def __init__(self, size, dropout):
111 | super(SublayerConnection, self).__init__()
112 | self.norm = nn.LayerNorm(size)
113 | self.dropout = nn.Dropout(dropout)
114 |
115 | def forward(self, x, sublayer):
116 | "Apply residual connection to any sublayer with the same size."
117 | return x + self.dropout(sublayer(self.norm(x)))
118 |
119 |
120 | def attention(query, key, value, mask=None, dropout=None):
121 | '''
122 |
123 | :param query: (B, , max_length, d_k)
124 | :param key: (B, , max_length, d_k)
125 | :param value: (B, , max_length, d_k)
126 | :param mask: (B, <1>, max_length, max_length), true/false matrix, and true means paddings
127 | :param dropout:
128 | :return: outputs:(B, , max_length, d_k), att_scores:(B, , max_length, max_length)
129 | '''
130 | "Compute 'Scaled Dot Product Attention'"
131 | # print('query:', query.shape)
132 | # print('key:', key.shape)
133 | # print('value:', value.shape)
134 | d_k = query.size(-1)
135 | # print('start 4 query:', query[-1,0])
136 | # print('start 4 key:', key[-1,0])
137 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
138 | # print('start 4 scores:', scores.shape, scores[-1, 0, :, :])
139 | if mask is not None:
140 | scores = scores.masked_fill(mask, -1e9) # true->-1e9
141 | # print('mask:', mask.shape, mask[-1,0,0,:])
142 | # print('start 5 scores:', scores.shape, scores[-1,0,:,:])
143 | p_attn = F.softmax(scores, dim=-1) # 每行和为1
144 | # print('start 5 p_attn:', p_attn.shape, p_attn[-1, 0, :, :])
145 | # print('----')
146 | if dropout is not None:
147 | p_attn = dropout(p_attn)
148 | return torch.matmul(p_attn, value), p_attn
149 |
150 |
151 | class MultiHeadedAttention(nn.Module):
152 | def __init__(self, h, d_model, dropout=0.1):
153 | "Take in model size and number of heads."
154 | super(MultiHeadedAttention, self).__init__()
155 | assert d_model % h == 0
156 | # We assume d_v always equals d_k
157 | self.d_k = d_model // h
158 | self.h = h
159 | self.linears = clones(nn.Linear(d_model, d_model), 4) # for query, key, value, output
160 | self.attn = None
161 | self.dropout = nn.Dropout(p=dropout)
162 |
163 | def forward(self, query, key, value, mask=None, distance=None):
164 | '''
165 |
166 | :param query: (B, max_length, d_model)
167 | :param key: (B, max_length, d_model)
168 | :param value: (B, max_length, d_model)
169 | :param mask: (B, max_length, max_length)
170 | :return: (B, max_length, d_model)
171 | '''
172 | # print('start 3 MHA query:', query[0])
173 | # print('start 3 MHA key:', key[0])
174 | if mask is not None:
175 | # Same mask applied to all h heads.
176 | mask = mask.unsqueeze(1) # mask (B, 1, max_length)->(B, 1, 1, max_length)
177 | nbatches = query.size(0)
178 |
179 | # 1) Do all the linear projections in batch from d_model => h x d_k
180 | query, key, value = \
181 | [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
182 | for l, x in zip(self.linears, (query, key, value))]
183 |
184 | # 2) Apply attention on all the projected vectors in batch.
185 | x, self.attn = attention(query, key, value, mask=mask,
186 | dropout=self.dropout)
187 |
188 | # 3) "Concat" using a view and apply a final linear.
189 | x = x.transpose(1, 2).contiguous() \
190 | .view(nbatches, -1, self.h * self.d_k)
191 | return self.linears[-1](x)
192 |
193 |
194 | def spatial_aware_attention(query, key, value, distance, mask=None, dropout=None):
195 | '''
196 |
197 | :param query: (B, h, max_length, d_k)
198 | :param key: (B, h, max_length, d_k)
199 | :param value: (B, h, max_length, d_k)
200 | :param distance: (B, h, max_length, max_length)
201 | :param mask: (B, 1, max_length, max_length), true/false matrix, and true means paddings
202 | :param dropout:
203 | :return: outputs:(B, h, max_length, d_k), att_scores:(B, h, max_length, max_length)
204 | '''
205 | "Compute 'Scaled Dot Product Attention'"
206 | # print('query:', query.shape)
207 | # print('key:', key.shape)
208 | # print('value:', value.shape)
209 | d_k = query.size(-1)
210 | # print('start 4 query:', query[-1,0])
211 | # print('start 4 key:', key[-1,0])
212 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) # (B, h, max_length, max_length)
213 | # print('start 4 scores:', scores.shape, scores[-1, 0, :, :])
214 | scores = scores - distance # (B, h, max_length, max_length)
215 | if mask is not None:
216 | scores = scores.masked_fill(mask, -1e9) # true->-1e9
217 | # print('mask:', mask.shape, mask[-1,0,0,:])
218 | # print('start 5 scores:', scores.shape, scores[-1,0,:,:])
219 | p_attn = F.softmax(scores, dim=-1) # 每行和为1
220 | # print('start 5 p_attn:', p_attn.shape, p_attn[-1, 0, :, :])
221 | # print('----')
222 | if dropout is not None:
223 | p_attn = dropout(p_attn)
224 | return torch.matmul(p_attn, value), p_attn
225 |
226 |
227 | class SpatialAwareMultiHeadedAttention(nn.Module):
228 | def __init__(self, h, d_model, dropout=0.1):
229 | "Take in model size and number of heads."
230 | super(SpatialAwareMultiHeadedAttention, self).__init__()
231 | assert d_model % h == 0
232 | # We assume d_v always equals d_k
233 | self.d_k = d_model // h
234 | self.h = h
235 | self.linears = clones(nn.Linear(d_model, d_model), 4) # for query, key, value, output
236 | self.logwb = nn.Parameter(Uniform(0.0, 1.0).sample((self.h,))) # (h,)
237 | self.attn = None
238 | self.dropout = nn.Dropout(p=dropout)
239 |
240 | def forward(self, query, key, value, mask=None, distance=None):
241 | '''
242 |
243 | :param query: (B, max_length, d_model)
244 | :param key: (B, max_length, d_model)
245 | :param value: (B, max_length, d_model)
246 | :param distance: (B, max_length, max_length)
247 | :param mask: (B, max_length, max_length)
248 | :return: (B, max_length, d_model)
249 | '''
250 | # print('start 3 MHA query:', query[0])
251 | # print('start 3 MHA key:', key[0])
252 | if mask is not None:
253 | # Same mask applied to all h heads.
254 | mask = mask.unsqueeze(1) # mask (B, 1, max_length)->(B, 1, 1, max_length)
255 | nbatches = query.size(0)
256 |
257 | # 1) Do all the linear projections in batch from d_model => h x d_k
258 | query, key, value = \
259 | [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
260 | for l, x in zip(self.linears, (query, key, value))]
261 | # distance to multi-head: (B, max_length, max_length) --> (B, h, max_length, max_length)
262 | wb = torch.exp(self.logwb).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) # (1,h,1,1)
263 | mh_distance = distance.unsqueeze(1).repeat(1, self.h, 1, 1) * wb # (B, h, max_length, max_length)*(1,h,1,1)
264 | # 2) Apply attention on all the projected vectors in batch.
265 | x, self.attn = spatial_aware_attention(query, key, value, mh_distance, mask=mask,
266 | dropout=self.dropout)
267 |
268 | # 3) "Concat" using a view and apply a final linear.
269 | x = x.transpose(1, 2).contiguous() \
270 | .view(nbatches, -1, self.h * self.d_k)
271 | return self.linears[-1](x)
272 |
273 |
274 | class PositionwiseFeedForward(nn.Module):
275 | "Implements FFN equation."
276 | def __init__(self, d_model, d_ff, dropout=0.1):
277 | super(PositionwiseFeedForward, self).__init__()
278 | self.w_1 = nn.Linear(d_model, d_ff)
279 | self.w_2 = nn.Linear(d_ff, d_model)
280 | self.dropout = nn.Dropout(dropout)
281 |
282 | def forward(self, x):
283 | '''
284 |
285 | :param x: (B, max_length, d_model)
286 | :return: (B, max_length, d_model)
287 | '''
288 | return self.w_2(self.dropout(F.relu(self.w_1(x))))
289 |
290 |
291 | class TransformerEncoderLayer(nn.Module):
292 | "Encoder is made up of self-attn and feed forward (defined below)"
293 | def __init__(self, size, self_attn, feed_forward, dropout):
294 | super(TransformerEncoderLayer, self).__init__()
295 | self.self_attn = self_attn
296 | self.feed_forward = feed_forward
297 | self.sublayer = clones(SublayerConnection(size, dropout), 2)
298 | self.size = size
299 |
300 | def forward(self, x, mask=None, distance=None):
301 | '''
302 |
303 | :param x: (B, max_length, d_model)
304 | :param mask: (B, 1, max_length)
305 | :return: (B, max_length, d_model)
306 | '''
307 | # print('start 2 x:', x[0])
308 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask=mask, distance=distance))
309 | return self.sublayer[1](x, self.feed_forward)
310 |
311 |
312 | class TransformerEncoder(nn.Module):
313 | "Core encoder is a stack of N layers"
314 |
315 | def __init__(self, layer, N):
316 | super(TransformerEncoder, self).__init__()
317 | self.layers = clones(layer, N)
318 | self.norm = nn.LayerNorm(layer.size)
319 |
320 | def forward(self, x, padding_mask=None, session_mask=None, subsequent_mask=None, distance=None):
321 | mask = torch.zeros(x.size(0), x.size(1), x.size(1), device=x.device).bool() # .type(torch.uint8) # (B, max_length, max_length)
322 | # torch.set_printoptions(threshold=1000000)
323 | if padding_mask is not None:
324 | padding_mask = padding_mask.repeat(1, x.size(1), 1).bool() # (B, max_length, max_length)
325 | # print('in padding_mask:', padding_mask)
326 | mask = mask | padding_mask
327 | if session_mask is not None:
328 | # print('in session_mask:', session_mask)
329 | mask = mask | session_mask
330 | if subsequent_mask is not None:
331 | # print('in subsequent_mask:', subsequent_mask)
332 | mask = mask | subsequent_mask
333 | # print('in mask', mask)
334 | for layer in self.layers:
335 | x = layer(x, mask=mask, distance=distance)
336 | return self.norm(x)
337 |
338 |
339 | class Hypernet(nn.Module):
340 | """
341 | Hypernetwork deals with decoder input and generates params for mu, sigma, w
342 |
343 | Args:
344 | config: Model configuration.
345 | hidden_sizes: Sizes of the hidden layers. [] corresponds to a linear layer.
346 | param_sizes: Sizes of the output parameters. [n_components, n_components, n_components] 分别指定w,mu,s的维度/components
347 | activation: Activation function.
348 | """
349 | def __init__(self, config, hidden_sizes=None, param_sizes=None, activation=nn.Tanh()):
350 | super().__init__()
351 | if param_sizes is None:
352 | param_sizes = [1, 1, 1]
353 | if hidden_sizes is None:
354 | hidden_sizes = []
355 | self.decoder_input_size = config.decoder_input_size
356 | self.activation = activation
357 |
358 | # print("hidden_sizes:", hidden_sizes) # []
359 | # print("param_sizes:", param_sizes) # [64, 64, 64]
360 | # Indices for unpacking parameters
361 | ends = torch.cumsum(torch.tensor(param_sizes), dim=0)
362 | starts = torch.cat((torch.zeros(1).type_as(ends), ends[:-1]))
363 | self.param_slices = [slice(s.item(), e.item()) for s, e in zip(starts, ends)]
364 | # self.param_slices.shape = [slice(0, 64, None), slice(64, 128, None), slice(128, 192, None)]
365 |
366 | self.output_size = sum(param_sizes)
367 | layer_sizes = list(hidden_sizes) + [self.output_size]
368 | # print("Hypernet layer_sizes:", layer_sizes) # [192]
369 | # Bias used in the first linear layer
370 | self.first_bias = nn.Parameter(torch.empty(layer_sizes[0]).uniform_(-0.1, 0.1))
371 | self.first_linear = nn.Linear(self.decoder_input_size, layer_sizes[0], bias=False)
372 |
373 | # Remaining linear layers
374 | self.linear_layers = nn.ModuleList()
375 | for idx, size in enumerate(layer_sizes[:-1]):
376 | self.linear_layers.append(nn.Linear(size, layer_sizes[idx + 1]))
377 |
378 | def reset_parameters(self):
379 | self.first_bias.data.fill_(0.0)
380 | self.first_linear.reset_parameters()
381 | nn.init.orthogonal_(self.first_linear.weight)
382 | for layer in self.linear_layers:
383 | layer.reset_parameters()
384 | nn.init.orthogonal_(layer.weight)
385 |
386 | def forward(self, decoder_input):
387 | """Generate model parameters from the embeddings.
388 |
389 | Args:
390 | input: decoder input, shape (batch, decoder_input_size)
391 |
392 | Returns:
393 | params: Tuple of model parameters.
394 | """
395 | # Generate the output based on the input
396 | hidden = self.first_bias
397 | hidden = hidden + self.first_linear(decoder_input)
398 | for layer in self.linear_layers:
399 | hidden = layer(self.activation(hidden))
400 |
401 | # # Partition the output
402 | # if len(self.param_slices) == 1:
403 | # return hidden
404 | # else:
405 | return tuple([hidden[..., s] for s in self.param_slices])
406 |
--------------------------------------------------------------------------------
/model/llm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForMaskedLM
3 | from peft import LoraModel, LoraConfig
4 |
5 | from .base import *
6 | from .bert import get_batch_mask
7 | from .phi_model import PhiModel
8 | from .positional_encoding import LearnableFourierFeatures as LFF
9 |
10 |
11 | def get_encoder(model_path, model_class):
12 | if model_class == 'tinybert':
13 | lora_config = LoraConfig(
14 | task_type="SEQ_2_SEQ_LM",
15 | r=8, # Lora attention dimension.
16 | lora_alpha=32, # The alpha parameter for Lora scaling.
17 | target_modules=["query", "value"], # The names of the modules to apply Lora to.
18 | lora_dropout=0.01, # The dropout probability for Lora layers.
19 | )
20 | model = AutoModelForMaskedLM.from_pretrained(model_path)
21 | tokenizer = AutoTokenizer.from_pretrained(model_path)
22 | emb_size = model.config.emb_size
23 | hidden_size = model.config.hidden_size
24 |
25 | elif model_class == 'phi-2':
26 | lora_config = LoraConfig(
27 | task_type="CAUSAL_LM",
28 | r=8, # Lora attention dimension.
29 | lora_alpha=32, # The alpha parameter for Lora scaling.
30 | target_modules=["Wqkv"], # The names of the modules to apply Lora to.
31 | lora_dropout=0.01, # The dropout probability for Lora layers.
32 | )
33 | model = CustomPhiModel.from_pretrained(
34 | model_path,
35 | torch_dtype=torch.float32,
36 | # device_map="cuda",
37 | trust_remote_code=True
38 | )
39 | tokenizer = AutoTokenizer.from_pretrained(model_path)
40 | emb_size = model.config.n_embd
41 | hidden_size = model.config.n_embd
42 |
43 | elif model_class == 'gpt2':
44 | lora_config = LoraConfig(
45 | task_type="CAUSAL_LM",
46 | r=8, # Lora attention dimension.
47 | lora_alpha=32, # The alpha parameter for Lora scaling.
48 | target_modules=["c_attn"], # The names of the modules to apply Lora to.
49 | lora_dropout=0.02, # The dropout probability for Lora layers.
50 | )
51 | model = AutoModelForCausalLM.from_pretrained(model_path)
52 | tokenizer = AutoTokenizer.from_pretrained(model_path)
53 | emb_size = model.config.n_embd
54 | hidden_size = model.config.n_embd
55 |
56 | elif model_class == 'TinyLlama-1_1B':
57 | lora_config = LoraConfig(
58 | task_type="CAUSAL_LM",
59 | r=8, # Lora attention dimension.
60 | lora_alpha=32, # The alpha parameter for Lora scaling.
61 | target_modules=["q_proj","k_proj"], # The names of the modules to apply Lora to.
62 | lora_dropout=0.02, # The dropout probability for Lora layers.
63 | )
64 | model = CustomLlamaModel(model_path='params/TinyLlama-1.1B')
65 | tokenizer = AutoTokenizer.from_pretrained('params/TinyLlama-1.1B')
66 | emb_size = 2048
67 | hidden_size = 2048
68 |
69 | elif model_class == 'TinyLlama-Chat':
70 | lora_config = LoraConfig(
71 | task_type="CAUSAL_LM",
72 | r=8, # Lora attention dimension.
73 | lora_alpha=32, # The alpha parameter for Lora scaling.
74 | target_modules=["q_proj","k_proj"], # The names of the modules to apply Lora to.
75 | lora_dropout=0.02, # The dropout probability for Lora layers.
76 | )
77 | model = CustomLlamaModel(model_path= 'params/TinyLlama-Chat')
78 | tokenizer = AutoTokenizer.from_pretrained('params/TinyLlama-Chat')
79 | emb_size = 2048
80 | hidden_size = 2048
81 |
82 | elif model_class == 'pythia-70M':
83 | lora_config = LoraConfig(
84 | task_type="CAUSAL_LM",
85 | r=8, # Lora attention dimension.
86 | lora_alpha=32, # The alpha parameter for Lora scaling.
87 | target_modules=["query_key_value"], # The names of the modules to apply Lora to.
88 | lora_dropout=0.02, # The dropout probability for Lora layers.
89 | )
90 | model = CustomPythiaModel(model_path= 'params/pythia-70M')
91 | tokenizer = AutoTokenizer.from_pretrained('params/pythia-70M')
92 | emb_size = 512
93 | hidden_size = 512
94 |
95 | elif model_class == 'pythia-2_8B':
96 | lora_config = LoraConfig(
97 | task_type="CAUSAL_LM",
98 | r=8, # Lora attention dimension.
99 | lora_alpha=32, # The alpha parameter for Lora scaling.
100 | target_modules=["query_key_value"], # The names of the modules to apply Lora to.
101 | lora_dropout=0.02, # The dropout probability for Lora layers.
102 | )
103 | model = CustomPythiaModel(model_path= 'params/pythia-2.8B')
104 | tokenizer = AutoTokenizer.from_pretrained('params/pythia-2.8B')
105 | emb_size = 2560
106 | hidden_size = 2560
107 |
108 | elif model_class == 'pythia-1B':
109 | lora_config = LoraConfig(
110 | task_type="CAUSAL_LM",
111 | r=8, # Lora attention dimension.
112 | lora_alpha=32, # The alpha parameter for Lora scaling.
113 | target_modules=["query_key_value"], # The names of the modules to apply Lora to.
114 | lora_dropout=0.02, # The dropout probability for Lora layers.
115 | )
116 | model = CustomPythiaModel(model_path= 'params/pythia-1B')
117 | tokenizer = AutoTokenizer.from_pretrained('params/pythia-1B')
118 | emb_size = 2048
119 | hidden_size = 2048
120 |
121 | elif model_class == 'LiteLlama':
122 | lora_config = LoraConfig(
123 | task_type="CAUSAL_LM",
124 | r=8, # Lora attention dimension.
125 | lora_alpha=32, # The alpha parameter for Lora scaling.
126 | target_modules=["q_proj","k_proj"], # The names of the modules to apply Lora to.
127 | lora_dropout=0.02, # The dropout probability for Lora layers.
128 | )
129 | model = CustomLlamaModel(model_path= 'params/LiteLlama')
130 | tokenizer = AutoTokenizer.from_pretrained('params/LiteLlama')
131 | emb_size = 1024
132 | hidden_size = 1024
133 |
134 | else:
135 | raise NotImplementedError("model_class should be one of ['tinybert', 'phi-2']")
136 |
137 | return LoraModel(model, lora_config, model_class), tokenizer, emb_size, hidden_size
138 | # return model, emb_size, hidden_size
139 |
140 |
141 | class CustomPhiModel(PhiModel):
142 | """ Phi for traj modeling """
143 |
144 | _keys_to_ignore_on_load_missing = ["ladder_side_nets", "up_net"]
145 | # _keys_to_ignore_on_load_unexpected = [r"transformer\.h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"]
146 |
147 | def __init__(self, config, r=32):
148 | super().__init__(config)
149 |
150 | assert config.n_embd % r == 0, f"n_embd should be divisible by r, got {config.n_embd} and {r}"
151 | side_dim = config.n_embd // r
152 |
153 | self.side_dim = side_dim
154 | self.ladder_side_nets = nn.ModuleList([nn.Linear(config.n_embd, side_dim) for _ in range(config.n_layer)])
155 | self.up_net = nn.Linear(side_dim, config.n_embd)
156 |
157 | def forward(self, input_ids=None, inputs_embeds=None, past_key_values=None, attention_mask=None):
158 | if inputs_embeds is None:
159 | assert input_ids is not None, "You have to specify either input_ids or inputs_embeds"
160 | hidden_states = self.embd(input_ids)
161 | else:
162 | hidden_states = inputs_embeds
163 |
164 | #hidden_states_backbone = hidden_states.detach()
165 | #side_states = torch.zeros(*hidden_states.shape[:-1], self.side_dim,
166 | # dtype=hidden_states.dtype).to(hidden_states.device)
167 | for i, layer in enumerate(self.h):
168 | hidden_states = layer(
169 | hidden_states,
170 | past_key_values=past_key_values,
171 | attention_mask=attention_mask,
172 | )
173 |
174 | #hidden_states = self.up_net(side_states) + hidden_states
175 | return hidden_states
176 |
177 | class CustomLlamaModel(nn.Module):
178 | """ Phi for traj modeling """
179 |
180 |
181 | def __init__(self, model_path):
182 | super().__init__()
183 |
184 | self.model = AutoModelForCausalLM.from_pretrained(model_path).model
185 | self.embd = self.model.embed_tokens
186 |
187 | def forward(self, input_ids=None, inputs_embeds=None, past_key_values=None, attention_mask=None):
188 | past_key_values_length = 0
189 |
190 | device = input_ids.device if input_ids is not None else inputs_embeds.device
191 | position_ids = torch.arange(0, inputs_embeds.shape[1], dtype=torch.long, device=device)
192 | position_ids = position_ids.unsqueeze(0)
193 |
194 | hidden_states = inputs_embeds
195 |
196 | # decoder layers
197 | all_hidden_states = ()
198 |
199 |
200 | for decoder_layer in self.model.layers:
201 |
202 | all_hidden_states += (hidden_states,)
203 |
204 | layer_outputs = decoder_layer(
205 | hidden_states,
206 | attention_mask=attention_mask,
207 | position_ids=position_ids,
208 | past_key_value=past_key_values,
209 |
210 | )
211 |
212 | hidden_states = layer_outputs[0]
213 |
214 | hidden_states = self.model.norm(hidden_states)
215 | # add hidden states from the last decoder layer
216 | all_hidden_states += (hidden_states,)
217 | return all_hidden_states
218 |
219 | class CustomPythiaModel(nn.Module):
220 | """ Phi for traj modeling """
221 |
222 |
223 | def __init__(self, model_path):
224 | super().__init__()
225 |
226 | self.model = AutoModelForCausalLM.from_pretrained(model_path).gpt_neox
227 | self.embd = self.model.embed_in
228 | # 模型中有self.emb_dropout,是否也要加?
229 | # forword函数相关代码为:
230 | # if inputs_embeds is None:
231 | # inputs_embeds = self.embed_in(input_ids)
232 | # hidden_states = self.emb_dropout(inputs_embeds)
233 |
234 |
235 | def forward(self, input_ids=None, inputs_embeds=None, past_key_values=None, attention_mask=None):
236 | past_key_values_length = 0
237 |
238 | device = input_ids.device if input_ids is not None else inputs_embeds.device
239 | position_ids = torch.arange(0, inputs_embeds.shape[1], dtype=torch.long, device=device)
240 | position_ids = position_ids.unsqueeze(0)
241 |
242 | hidden_states = inputs_embeds
243 |
244 | # decoder layers
245 | all_hidden_states = ()
246 | past_length = 0
247 | past_key_values = tuple([None] * self.model.config.num_hidden_layers)
248 |
249 | for i, (layer, layer_past) in enumerate(zip(self.model.layers, past_key_values)):
250 |
251 | all_hidden_states = all_hidden_states + (hidden_states,)
252 |
253 |
254 |
255 | outputs = layer(
256 | hidden_states,
257 | attention_mask=attention_mask,
258 | position_ids=position_ids,
259 |
260 | layer_past=layer_past,
261 |
262 | )
263 | hidden_states = outputs[0]
264 |
265 | hidden_states = self.model.final_layer_norm(hidden_states)
266 | # Add last hidden state
267 |
268 | all_hidden_states = all_hidden_states + (hidden_states,)
269 | return all_hidden_states
270 |
271 | class LLMModel(Encoder):
272 | def __init__(self, model_path,loc_size,device,output_size=256,
273 |
274 | model_class='gpt2', learnable_param_size=1,):
275 | super().__init__('LLM-gpt2')
276 |
277 | self.output_size = output_size
278 | self.loc_size=loc_size
279 | self.model_class = model_class
280 | self.device = device
281 | self.learnable_param_size=learnable_param_size
282 |
283 | self.encoder, self.tokenizer, self.emb_size, self.hidden_size = get_encoder(model_path, model_class)
284 | self.tokenizer.pad_token = self.tokenizer.eos_token
285 |
286 | # Froze the parameters.
287 | for i, (name, param) in enumerate(self.encoder.named_parameters()):
288 | #param.requires_grad = False
289 |
290 | # if 'side_net' in name or 'up_net' in name:
291 | # param.requires_grad = True
292 | # else:
293 | # param.requires_grad = False
294 |
295 | # if 'ln' in name or 'wpe' in name: # or 'mlp' in name:
296 | # param.requires_grad = True
297 | # elif 'mlp' in name:
298 | # param.requires_grad = False
299 | # else:
300 | # param.requires_grad = False
301 | print(name, param.requires_grad)
302 |
303 | self.embedder = nn.Embedding(self.loc_size,self.emb_size)
304 | self.out_linear = nn.Sequential(nn.Linear(self.hidden_size, output_size, bias=False),
305 | nn.LayerNorm(output_size),
306 | nn.ReLU(inplace=True),
307 | nn.Linear(output_size, output_size))
308 | # # f_dim,h_dim的取值问题(暂时)
309 | # self.lff = LFF(pos_dim=1, f_dim=128, h_dim=256, d_dim=self.emb_size) # learnable fourier features module
310 | self.time_linear = nn.Sequential(nn.Linear(1, 16),nn.LayerNorm(16),
311 | nn.ReLU(inplace=True),nn.Linear(16, 16))
312 | self.cat_linear = nn.Sequential(nn.Linear(self.emb_size+16, self.emb_size),nn.LayerNorm(self.emb_size),
313 | nn.ReLU(inplace=True),
314 | nn.Linear(self.emb_size, self.emb_size))
315 |
316 | # TEMPO
317 | self.cls_token = nn.Parameter(torch.zeros(self.learnable_param_size, self.emb_size).float(), requires_grad=True)
318 |
319 | def forward(self, x, valid_len, time, category, geohash_, **kwargs):
320 | return self.forward_suffix(x, valid_len, time, category, geohash_, self.cls_token)
321 |
322 | def forward_suffix(self, x, valid_len, time, category, geohash_, tokens):
323 | """ P-tuning-like suffix forward. """
324 | B, L, E_in = x.unsqueeze(-1).shape # time.shape=[B,L], category.shape=[B,L]
325 |
326 | # TEMPO改进方案
327 | trip_batch_mask = get_batch_mask(B, L+self.learnable_param_size, valid_len)
328 | batch_mask = get_batch_mask(B, L+self.learnable_param_size, valid_len+self.learnable_param_size)
329 |
330 | masked_values = torch.zeros_like(x)
331 |
332 | x = torch.where(get_batch_mask(B, L, valid_len).unsqueeze(-1), x.unsqueeze(-1), masked_values.unsqueeze(-1))
333 | x_embeddings = self.embedder(x).squeeze(-2) # (B,L,self.emb_size)
334 | # TEST
335 | '''
336 | # Basic usage of Learnable-Fourier-Features
337 | lff = LFF(pos_dim=2, f_dim=128, h_dim=256, d_dim=64) # learnable fourier features module
338 | pos = torch.randn([4, 1024, 1, 2]) # random positional coordinates
339 | pe = lff(pos) # forward
340 | ''' #128*173
341 | # time_embeddings = self.lff(time.unsqueeze(-1).unsqueeze(-1)) # (B L G M) → (B L D)
342 | # x_embeddings += time_embeddings * 0.01
343 | ## x_embeddings = self.cat_linear(torch.cat((x_embeddings, time_embeddings), dim=-1))
344 | time_embeddings = self.time_linear(time.unsqueeze(-1))
345 |
346 | category_vocabIds = []
347 | for i, seq in enumerate(category): # 对batch数据应该还是得用for循环
348 | input_seq = seq[:valid_len[i]] # seq中的pad符号为数字0(不是str,麻烦),先获取有效长度的seq
349 | if len(set(input_seq))==1 and '' in input_seq: # seq存在['','',……,0,0,0,……]的情况(最麻烦),替换''(暂)
350 | input_seq = [self.tokenizer.eos_token] * valid_len[i]
351 | # 对每个序列取category对应的vocabId的最后一个
352 | # 不用for循环实现,需要先对tokens做padding,再取最后一个非padding的——可能不如普通的for循环实现快
353 | category_vocabId = self.tokenizer(input_seq, padding=True, return_tensors="pt")['input_ids']
354 | last_index = torch.where(category_vocabId != self.tokenizer.eos_token_id, torch.full_like(category_vocabId, 1), 0).sum(dim=1)
355 | category_vocabId = [int(category_vocabId[index][int(x.item()) - 1].item()) for index, x in enumerate(last_index)]
356 | '''
357 | category_vocabId = []
358 | for i in seq:
359 | if i == '': # test时,若出现''要特殊处理一下,否则报错
360 | category_vocabId.append(torch.tensor(self.tokenizer.eos_token_id))
361 | else:
362 | category_vocabId.append(self.tokenizer(i, return_tensors="pt")['input_ids'][-1][-1])
363 | category_vocabId = torch.tensor(category_vocabId)
364 | '''
365 | category_vocabIds.append(torch.tensor(category_vocabId+[self.tokenizer.eos_token_id]*(L-valid_len[i])))
366 | category_vocabIds = torch.stack(category_vocabIds,dim=0).to(self.device)
367 | if self.model_class == 'gpt2':
368 | category_embeddings = self.encoder.model.transformer.wte(category_vocabIds)
369 | else:
370 | category_embeddings = self.encoder.model.embd(category_vocabIds)
371 | #x_embeddings += category_embeddings
372 |
373 | #x_embeddings = self.cat_linear(torch.cat((x_embeddings, time_embeddings, category_embeddings), dim=-1))
374 | x_embeddings = self.cat_linear(torch.cat((x_embeddings, time_embeddings), dim=-1))
375 |
376 | h = torch.zeros(B, L+self.learnable_param_size, self.emb_size).to(x.device)
377 |
378 | h[:, :-self.learnable_param_size] = x_embeddings
379 | # position = (batch_mask.long() - trip_batch_mask.long()) == 1
380 | # temp1 = h[(batch_mask.long() - trip_batch_mask.long()) == 1]
381 | # temp2 = tokens.repeat(B,1)
382 | h[(batch_mask.long() - trip_batch_mask.long()) == 1] = tokens.repeat(B,1)
383 | # h[(batch_mask.long() - trip_batch_mask.long()) == 1] = token
384 | if self.model_class == 'tinybert':
385 | h = self.encoder(inputs_embeds=h, attention_mask=batch_mask.unsqueeze(-1), output_hidden_states=True).hidden_states[-1]
386 | elif self.model_class == 'phi-2':
387 | h = self.encoder(inputs_embeds=h, attention_mask=batch_mask)
388 | elif self.model_class == 'gpt2':
389 | h = self.encoder(inputs_embeds=h, attention_mask=batch_mask.unsqueeze(-1), output_hidden_states=True).hidden_states[-1]
390 | elif self.model_class == 'LiteLlama':
391 | h = self.encoder(inputs_embeds=h, attention_mask=batch_mask.unsqueeze(-1).unsqueeze(1).expand(-1,-1,-1,batch_mask.shape[-1]))[0]
392 | elif self.model_class == 'TinyLlama-Chat':
393 | h = self.encoder(inputs_embeds=h, attention_mask=batch_mask.unsqueeze(-1).unsqueeze(1).expand(-1,-1,-1,batch_mask.shape[-1]))[0]
394 | elif self.model_class == 'TinyLlama-1_1B':
395 | h = self.encoder(inputs_embeds=h, attention_mask=batch_mask.unsqueeze(-1).unsqueeze(1).expand(-1,-1,-1,batch_mask.shape[-1]))[0]
396 | elif self.model_class == 'pythia-70M':
397 | h = self.encoder(inputs_embeds=h, attention_mask=batch_mask.unsqueeze(-1).unsqueeze(1).expand(-1,-1,-1,batch_mask.shape[-1]))[0]
398 | elif self.model_class == 'pythia-2_8B':
399 | h = self.encoder(inputs_embeds=h, attention_mask=batch_mask.unsqueeze(-1).unsqueeze(1).expand(-1,-1,-1,batch_mask.shape[-1]))[0]
400 | elif self.model_class == 'pythia-1B':
401 | h = self.encoder(inputs_embeds=h, attention_mask=batch_mask.unsqueeze(-1).unsqueeze(1).expand(-1,-1,-1,batch_mask.shape[-1]))[0]
402 | h = torch.nan_to_num(h)
403 | #output = self.out_linear(h[:, -1])
404 | output = self.out_linear(h)
405 |
406 | return output
--------------------------------------------------------------------------------
/train_MobilityLLM.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import configparser
3 | # from tensorboardX import SummaryWriter
4 | import preprocess.load_data as preprocess
5 | from model.MobilityLLM import *
6 | from copy import deepcopy
7 | from utils import *
8 | from torch.nn.utils import clip_grad_norm_
9 | import torch
10 | import torch.nn as nn
11 | import time
12 | from tqdm import tqdm
13 | import numpy as np
14 |
15 | #设置随机种子
16 | # 设置随机种子
17 | randomSeed = 202408
18 | torch.manual_seed(randomSeed)
19 | torch.cuda.manual_seed(randomSeed)
20 | torch.cuda.manual_seed_all(randomSeed)
21 | torch.backends.cudnn.deterministic = True
22 | torch.backends.cudnn.benchmark = False
23 | np.random.seed(randomSeed)
24 |
25 | # read hyper-param settings
26 | parser = argparse.ArgumentParser()
27 | parser.add_argument("--config", default='/data/ZhangXinyue/MobilityLLM/config/MobilityLLM_tky_TUL.conf', type=str,
28 | help="configuration file path")
29 | parser.add_argument("--dataroot", default='/data/ZhangXinyue/MobilityLLM/data/', type=str,
30 | help="data root directory")
31 | parser.add_argument("--model_class", default='pythia-70M', type=str, help="configuration file path")
32 | parser.add_argument("--device", default='1', type=str, help="configuration file path")
33 | parser.add_argument("--data_hist", default='1', type=str, help="configuration file path")
34 | args = parser.parse_args()
35 | config_file = args.config
36 | data_root = args.dataroot
37 | model_class = args.model_class
38 | ctx = args.device
39 | data_hist=float(args.data_hist)
40 | config = configparser.ConfigParser()
41 | print('Read configuration file: %s' % (args.config))
42 | print('>>>>>>> configuration <<<<<<<')
43 | with open(config_file, 'r') as f:
44 | print(f.read())
45 | print('\n')
46 | config.read(args.config)
47 | data_config = config['Data']
48 | training_config = config['Training']
49 | model_config = config['Model']
50 |
51 | # Data config
52 | dataset_name = data_config['dataset_name']
53 | max_his_period_days = data_config['max_his_period_days']
54 | max_merge_seconds_limit = data_config['max_merge_seconds_limit']
55 | max_delta_mins = data_config['max_delta_mins']
56 | min_session_mins = data_config['min_session_mins']
57 | least_disuser_count = data_config['least_disuser_count']
58 | least_checkins_count = data_config['least_checkins_count']
59 | latN = data_config['latN']
60 | lngN = data_config['lngN']
61 | split_save = bool(int(data_config['split_save']))
62 | dataset_name = dataset_name + '_' + max_his_period_days + 'H' + max_merge_seconds_limit + 'M' + max_delta_mins + 'd' + min_session_mins + 's' + least_disuser_count + 'P' + least_checkins_count + 'U'
63 |
64 | # Training config
65 | mode = training_config['mode'].strip()
66 | # ctx = training_config['ctx']
67 | # os.environ["CUDA_VISIBLE_DEVICES"] = ctx # 调试时报错,暂且注释
68 | USE_CUDA = torch.cuda.is_available()
69 | print("CUDA:", USE_CUDA, ctx)
70 | device = torch.device("cuda:"+ctx if USE_CUDA else "cpu")
71 | print('device:', device)
72 | use_nni = bool(int(training_config['use_nni']))
73 | regularization = float(training_config['regularization'])
74 | learning_rate = float(training_config['learning_rate'])
75 | max_epochs = int(training_config['max_epochs'])
76 | display_step = int(training_config['display_step'])
77 | patience = int(training_config['patience'])
78 | train_batch = int(training_config['train_batch'])
79 | val_batch = int(training_config['val_batch'])
80 | test_batch = int(training_config['test_batch'])
81 | batch_size = int(training_config['batch_size'])
82 | save_results = bool(int(training_config['save_results']))
83 |
84 | specific_config = 'MobilityLLM'
85 |
86 | # Model Setting
87 | loc_emb_size = int(model_config['loc_emb_size'])
88 | tim_emb_size = int(model_config['tim_emb_size'])
89 | user_emb_size = int(model_config['user_emb_size'])
90 | hidden_size = int(model_config['hidden_size'])
91 |
92 | category_size = int(model_config['category_size'])
93 | geohash_size = int(model_config['geohash_size'])
94 |
95 | if 'learnable_param_size' in model_config:
96 | learnable_param_size = int(model_config['learnable_param_size'])
97 | else:
98 | learnable_param_size = 1
99 |
100 | adv = int(model_config['adv'])
101 | rnn_type = model_config['rnn_type']
102 | num_layers = int(model_config['num_layers'])
103 | downstream = model_config['downstream']
104 |
105 | loc_noise_mean = float(model_config['loc_noise_mean'])
106 | loc_noise_sigma = float(model_config['loc_noise_sigma'])
107 | tim_noise_mean = float(model_config['tim_noise_mean'])
108 | tim_noise_sigma = float(model_config['tim_noise_sigma'])
109 | user_noise_mean = float(model_config['user_noise_mean'])
110 | user_noise_sigma = float(model_config['user_noise_sigma'])
111 | tau = float(model_config['tau'])
112 | pos_eps = float(model_config['pos_eps'])
113 | neg_eps = float(model_config['neg_eps'])
114 | dropout_rate_1 = float(model_config['dropout_rate_1'])
115 | dropout_rate_2 = dropout_rate_1
116 |
117 | momentum = float(model_config['momentum'])
118 | theta = float(model_config['theta'])
119 | temperature = float(model_config['temperature'])
120 | k = int(model_config['k'])
121 | self_weight_s = float(model_config['self_weight_s'])
122 | self_weight_t = float(model_config['self_weight_t'])
123 | self_weight_st = float(model_config['self_weight_st'])
124 |
125 | dump_path = 'checkpoints'
126 | rank = model_config['rank']
127 | epoch_queue_starts = int(model_config['epoch_queue_starts'])
128 | crops_for_assign = [0,1]
129 | feat_dim = int(model_config['feat_dim'])
130 | queue_length = int(model_config['queue_length'])
131 | world_size = int(model_config['world_size'])
132 | loss = model_config['loss']
133 | tpp = model_config['tpp']
134 | epsilon = float(model_config['epsilon'])
135 | dropout_spatial = float(model_config['dropout_spatial'])
136 |
137 | if use_nni:
138 | import nni
139 | param = nni.get_next_parameter()
140 | # multi-dataset
141 | batch_size = int(param['batch_size'])
142 | hidden_size = int(param['hidden_size'])
143 | user_emb_size = int(param['user_emb_size'])
144 | category_size = int(param['category_size'])
145 | geohash_size = int(param['geohash_size'])
146 | num_layers = int(param['num_layers'])
147 | momentum = float(param['momentum'])
148 | theta = float(param['theta'])
149 | temperature = float(param['temperature'])
150 | k = int(param['k'])
151 | self_weight_s = float(param['self_weight_s'])
152 | self_weight_t = float(param['self_weight_t'])
153 | self_weight_st = float(param['self_weight_st'])
154 | epsilon = float(param['epsilon'])
155 | dropout_spatial = float(param['dropout_spatial'])
156 |
157 | train_batch = batch_size
158 | val_batch = batch_size
159 | test_batch = batch_size
160 |
161 | print('load dataset:', dataset_name)
162 | print('split_save:', split_save)
163 |
164 | # Data
165 | if data_config['dataset_name'] == "www_NYC" or data_config['dataset_name'] == "TSMC_www_NYC":
166 | data = np.load(data_root + "nyc_cnt2category2cnt.npz", allow_pickle=True)
167 | elif data_config['dataset_name'] == "www_JKT":
168 | data = np.load(data_root + "jkt_cnt2category2cnt.npz", allow_pickle=True)
169 | elif data_config['dataset_name'] == "www_IST":
170 | data = np.load(data_root + "ist_cnt2category2cnt.npz", allow_pickle=True)
171 | elif data_config['dataset_name'] == "www_TKY" or data_config['dataset_name'] == "TSMC_www_TKY":
172 | data = np.load(data_root + "tky_cnt2category2cnt.npz", allow_pickle=True)
173 | else:
174 | data = np.load(data_root + "nyc_cnt2category2cnt.npz", allow_pickle=True)
175 |
176 |
177 | # cnt2category = data['cnt2category']
178 | # print("cnt2category: ", type(cnt2category), cnt2category) # numpy.ndarry 'dict'
179 | # print("cnt2category: ", cnt2category.shape) # ()
180 | # print("cnt2category: ", cnt2category.size) # size=1
181 | # assert(1==0)
182 |
183 | cnt2category = data['cnt2category'].item() # numpy.ndarray.item() category'index->category
184 |
185 | # redundant output
186 | # print("cnt2category: ", type(cnt2category), cnt2category)
187 | # word_vec, word_index, text_size = get_semantic_information(cnt2category, data_root)
188 |
189 | print('Loading data & Category vector...')
190 | data_train, data_val, data_test, feature_category, feature_lat, feature_lng, latN, lngN, category_cnt, category_vector = preprocess.load_dataset_for_MobilityLLM(
191 | dataset_name, save_split=split_save, data_root=data_root, device=device)
192 | print("feature_category: ", feature_category.shape,
193 | feature_category) # feature_category[venue's index] -> venue's category's index
194 |
195 | # Set the parameters for affine normalization layer depending on the decoder ()
196 | trainY_tau_mean, trainY_tau_std = data_train.get_tau_log_mean_std_Y()
197 | print('trainY_tau_mean:', trainY_tau_mean, flush=True)
198 | print('trainY_tau_std:', trainY_tau_std, flush=True)
199 |
200 | collate = preprocess.collate_session_based # padding sequence with variable len
201 |
202 | dl_train = torch.utils.data.DataLoader(data_train, batch_size=train_batch, shuffle=True,
203 | collate_fn=collate, drop_last=True)
204 | dl_val = torch.utils.data.DataLoader(data_val, batch_size=val_batch, shuffle=False, collate_fn=collate, drop_last=True)
205 | dl_test = torch.utils.data.DataLoader(data_test, batch_size=test_batch, shuffle=False, collate_fn=collate, drop_last=True)
206 | #训练集比例裁剪
207 | all_train_length=len(dl_train)
208 | new_train_length=int(data_hist*all_train_length)
209 |
210 | # Model setup
211 | print('Building model...', flush=True)
212 | # General model config
213 | tim_size = 48
214 | fill_value = 0
215 | use_semantic = True
216 | if use_semantic:
217 | print("Use semantic information from venue name!")
218 | else:
219 | print("Don't use semantic information from venue name!")
220 | general_config = MobilityLLM_ModelConfig(loc_size=int(data_train.venue_cnt), tim_size=tim_size,
221 | uid_size=int(data_train.user_cnt), tim_emb_size=tim_emb_size,
222 | loc_emb_size=loc_emb_size, hidden_size=hidden_size, user_emb_size=user_emb_size,
223 | model_class=model_class, device=device,
224 | geohash_size=geohash_size, category_size=category_size,
225 | loc_noise_mean=loc_noise_mean, loc_noise_sigma=loc_noise_sigma,
226 | tim_noise_mean=tim_noise_mean, tim_noise_sigma=tim_noise_sigma,
227 | user_noise_mean=user_noise_mean, user_noise_sigma=user_noise_sigma, tau=tau,
228 | momentum=momentum, k=k, theta=theta, temperature=temperature,
229 | pos_eps=pos_eps, neg_eps=neg_eps, dropout_rate_1=dropout_rate_1,
230 | dropout_rate_2=dropout_rate_2, category_vector=category_vector, rnn_type=rnn_type, num_layers=num_layers, downstream=downstream,
231 | scale_init=trainY_tau_std, shift_init=trainY_tau_mean, max_delta_mins=max_delta_mins, loss=loss, tpp=tpp, dropout_spatial = dropout_spatial,
232 | epsilon=epsilon, learnable_param_size=learnable_param_size)
233 | # Define model
234 | model = MobilityLLM(general_config).to(device)
235 | print(model, flush=True)
236 |
237 | params_path = os.path.join('experiments', dataset_name.replace('(', '').replace(')', ''), specific_config)
238 | print('params_path:', params_path)
239 |
240 | if use_nni:
241 | exp_id = nni.get_experiment_id()
242 | trail_id = nni.get_trial_id()
243 | best_name = str(exp_id) + '.' + str(trail_id) + downstream + 'best.params'
244 | params_filename = os.path.join(params_path, best_name)
245 | else:
246 | best_name = downstream + '_best_'+model_class+'.params'
247 | params_filename = os.path.join(params_path, best_name)
248 |
249 | if mode == 'train':
250 | for p in model.parameters():
251 | if p.dim() > 1:
252 | nn.init.xavier_uniform_(p)
253 |
254 | # # 没有temporal_encoder,注释掉
255 | # for p_t in model.temporal_encoder_momentum.parameters():
256 | # p_t.requires_grad = False
257 |
258 | total_params = sum(p.numel() for p in model.parameters())
259 | total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
260 | print("total_params:", total_params, flush=True)
261 | print("total_trainable_params:", total_trainable_params, flush=True)
262 |
263 | if os.path.exists(params_path):
264 | # shutil.rmtree(params_path)
265 | # os.makedirs(params_path)
266 | # print('delete the old one and create params directory %s' % (params_path), flush=True)
267 | print('already exist %s' % (params_path), flush=True)
268 | else:
269 | os.makedirs(params_path)
270 | print('create params directory %s' % (params_path), flush=True)
271 |
272 | print('Starting training...', flush=True)
273 |
274 | # build the queue
275 | queue = None
276 | queue_path = os.path.join(dump_path, "queue" + rank + ".pth")
277 | # if os.path.isfile(queue_path):
278 | # queue = torch.load(queue_path)["queue"]
279 | # # the queue needs to be divisible by the batch size
280 | # queue_length -= queue_length % (batch_size * world_size)
281 |
282 | impatient = 0
283 | best_hit20 = -np.inf
284 | best_tnll = np.inf
285 | best_model = deepcopy(model.state_dict())
286 | global_step = 0
287 | best_epoch = -1
288 | # sw = SummaryWriter(logdir=params_path, flush_secs=5)
289 | opt = torch.optim.Adam(model.parameters(), weight_decay=regularization, lr=learning_rate, amsgrad=True)
290 | # opt = Adafactor(model.parameters())
291 | start = time.time()
292 | for epoch in range(0, max_epochs):
293 | model.train()
294 | batch_cnt = 0
295 | # optionally starts a queue
296 | if queue_length > 0 and epoch >= epoch_queue_starts and queue is None:
297 | queue = torch.zeros(
298 | len(crops_for_assign),
299 | -queue_length // world_size,
300 | feat_dim,
301 | ).cuda()
302 | train_count=0
303 | for input in tqdm(dl_train):
304 | train_count+=1
305 | if train_count<=new_train_length:
306 | opt.zero_grad()
307 | # cosine similarity matrix can be big (up to 5GiB), consider cutting.
308 | batch_cnt += 1
309 | if input.X_all_loc.shape[1] >= 700:
310 | print(f'batch: {batch_cnt}, length: {input.X_all_loc.shape[1]}')
311 | continue
312 | if adv == 1:
313 | s_loss_score, top_k_pred, queue = model(input, mode='train', downstream=downstream, cont_conf=[1, 1, 1, 1], queue = queue)
314 | # print(top_k_pred, y)
315 | loss_total = (1 - self_weight_s - self_weight_t - self_weight_st) * s_loss_score
316 | else:
317 | s_loss_score, top_k_pred, queue = model(input, mode='train', downstream=downstream, queue = queue)
318 | loss_total = s_loss_score
319 | # print(s_loss_score)
320 | loss_total.backward()
321 | clip_grad_norm_(model.parameters(), max_norm=1.0)
322 | opt.step()
323 |
324 | # momentum update
325 | if model.momentum > 0:
326 | pass
327 | # # 没有temporal_encoder,注释掉
328 | # for param_momentum, param in zip(model.temporal_encoder_momentum.parameters(), model.temporal_encoder.parameters()):
329 | # param_momentum.data = param_momentum.data * model.momentum + (1. - model.momentum) * param.data
330 |
331 | global_step += 1
332 | if downstream == 'POI':
333 | ys = input.Y_location
334 | elif downstream == 'TUL':
335 | ys = input.X_users # (batch,)
336 | elif downstream == 'TPP':
337 | pass
338 | else:
339 | raise ValueError('downstream is not in [POI, TUL, TPP]')
340 |
341 | # print(f'batch: {batch_cnt}, total loss: {loss_total}, loss_ls: {[i for i in cont_loss_ls]}')
342 | loss_total = loss_total.item()
343 | # sw.add_scalar('training_loss_s', loss_total, global_step)
344 | if downstream != 'TPP':
345 | hit_ratio, mrr = evaluate_location(ys.cpu().numpy(), top_k_pred.cpu().numpy()) # [k]
346 | # sw.add_scalar('training_mrr', mrr, global_step)
347 | # sw.add_scalar('training_hit_1', hit_ratio[0], global_step)
348 | # sw.add_scalar('training_hit_20', hit_ratio[19], global_step)
349 | if queue is not None:
350 | torch.save({"queue": queue}, queue_path)
351 |
352 | model.eval()
353 | with torch.no_grad():
354 | if downstream != 'TPP':
355 | all_loss_s_val, hit_ratio_val, mrr_val = get_s_baselines_total_loss_s_for_MobilityLLM_DOWN(dl_val, model, downstream=downstream)
356 | if (hit_ratio_val[19] - best_hit20) < 1e-4:
357 | impatient += 1
358 | if best_hit20 < hit_ratio_val[19]:
359 | best_hit20 = hit_ratio_val[19]
360 | best_model = deepcopy(model.state_dict())
361 | best_epoch = epoch
362 | else:
363 | best_hit20 = hit_ratio_val[19]
364 | best_model = deepcopy(model.state_dict())
365 | best_epoch = epoch
366 | impatient = 0
367 | else:
368 | mae_val, mape_val, rmse_val, nll_t_val = get_t_for_IFLTPP(dl_val, model)
369 | if (best_tnll - nll_t_val) < 1e-4:
370 | impatient += 1
371 | if nll_t_val < best_tnll:
372 | best_tnll = nll_t_val
373 | best_model = deepcopy(model.state_dict())
374 | best_epoch = epoch
375 | else:
376 | best_tnll = nll_t_val
377 | best_model = deepcopy(model.state_dict())
378 | best_epoch = epoch
379 | impatient = 0
380 |
381 | if impatient >= patience:
382 | print('Breaking due to early stopping at epoch %d,best epoch at %d' % (epoch, best_epoch), flush=True)
383 | break
384 |
385 | if epoch % display_step == 0:
386 | if adv == 1:
387 | if downstream != 'TPP':
388 | print('Epoch %4d, train_loss=%.4f, val_loss=%.4f, val_mrr=%.4f, val_hit_1=%.4f, val_hit_20=%.4f' % (
389 | epoch, loss_total, all_loss_s_val, mrr_val, hit_ratio_val[0], hit_ratio_val[19]), flush=True)
390 | else:
391 | print('Epoch %4d, train_tnll=%.4f, val_tnll=%.4f, val_mae=%.4f, val_rmse=%.4f, val_mape=%.4f' % (
392 | epoch, loss_total, nll_t_val, mae_val, rmse_val, mape_val), flush=True)
393 | else:
394 | if downstream != 'TPP':
395 | print('Epoch %4d, train_loss=%.4f, val_loss=%.4f, val_mrr=%.4f, val_hit_1=%.4f, val_hit_20=%.4f' % (
396 | epoch, loss_total, all_loss_s_val, mrr_val, hit_ratio_val[0], hit_ratio_val[19]), flush=True)
397 | else:
398 | print('Epoch %4d, train_tnll=%.4f, val_tnll=%.4f, val_mae=%.4f, val_rmse=%.4f, val_mape=%.4f' % (
399 | epoch, loss_total, nll_t_val, mae_val, rmse_val, mape_val), flush=True)
400 |
401 | if use_nni:
402 | if downstream != 'TPP':
403 | nni.report_intermediate_result(hit_ratio_val[19])
404 | else:
405 | nni.report_intermediate_result(mae_val)
406 |
407 | torch.save(best_model, params_filename)
408 |
409 | print("best epoch at %d" % best_epoch, flush=True)
410 | print('save parameters to file: %s' % params_filename, flush=True)
411 | print("training time: ", time.time() - start)
412 |
413 | ### Evaluation
414 | print('----- test ----')
415 | model.load_state_dict(torch.load(params_filename))
416 | model.eval()
417 | with torch.no_grad():
418 | if downstream != 'TPP':
419 | train_all_loss_s, train_hit_ratio, train_mrr = get_s_baselines_total_loss_s_for_MobilityLLM_DOWN(dl_train, model, downstream=downstream)
420 | val_all_loss_s, val_hit_ratio, val_mrr = get_s_baselines_total_loss_s_for_MobilityLLM_DOWN(dl_val, model, downstream=downstream)
421 | test_all_loss_s, test_hit_ratio, test_mrr = get_s_baselines_total_loss_s_for_MobilityLLM_DOWN(dl_test, model, downstream=downstream)
422 |
423 | print('Dataset\t loss\t hit_1\t hit_3\t hit_5\t hit_7\t hit_10\t hit_15\t hit_20\t MRR\t\n' +
424 | 'Train:\t %.4f\t %.4f\t %.4f\t %.4f\t %.4f\t %.4f\t %.4f\t %.4f\t %.4f\t\n' % (
425 | train_all_loss_s, train_hit_ratio[0], train_hit_ratio[2], train_hit_ratio[4], train_hit_ratio[6],
426 | train_hit_ratio[9], train_hit_ratio[14], train_hit_ratio[19], train_mrr) +
427 | 'Val:\t %.4f\t %.4f\t %.4f\t %.4f\t %.4f\t %.4f\t %.4f\t %.4f\t %.4f\t\n' % (
428 | val_all_loss_s, val_hit_ratio[0], val_hit_ratio[2], val_hit_ratio[4], val_hit_ratio[6], val_hit_ratio[9],
429 | val_hit_ratio[14], val_hit_ratio[19], val_mrr) +
430 | 'Test:\t %.4f\t %.4f\t %.4f\t %.4f\t %.4f\t %.4f\t %.4f\t %.4f\t %.4f\t\n' % (
431 | test_all_loss_s, test_hit_ratio[0], test_hit_ratio[2], test_hit_ratio[4], test_hit_ratio[6],
432 | test_hit_ratio[9], test_hit_ratio[14], test_hit_ratio[19], test_mrr), flush=True)
433 | else:
434 | train_mae, train_mape, train_rmse, train_nll_t = get_t_for_IFLTPP(dl_train, model, save_filename='train',
435 | params_path=params_path,
436 | use_nni=use_nni)
437 | val_mae, val_mape, val_rmse, val_nll_t = get_t_for_IFLTPP(dl_val, model, save_filename='val',
438 | params_path=params_path,
439 | use_nni=use_nni)
440 | test_mae, test_mape, test_rmse, test_nll_t = get_t_for_IFLTPP(dl_test, model, save_filename='test',
441 | params_path=params_path,
442 | use_nni=use_nni)
443 |
444 | print('Dataset\t MAE\t RMSE\t MAPE\t TNll\t\n' +
445 | 'Train:\t %.4f\t %.4f\t %.4f\t %.4f\t\n' % (train_mae, train_rmse, train_mape, train_nll_t) +
446 | 'Val:\t %.4f\t %.4f\t %.4f\t %.4f\t\n' % (val_mae, val_rmse, val_mape, val_nll_t) +
447 | 'Test:\t %.4f\t %.4f\t %.4f\t %.4f\t\n' % (test_mae, test_rmse, test_mape, test_nll_t), flush=True)
448 |
449 | if use_nni:
450 | if downstream != 'TPP':
451 | nni.report_final_result(val_hit_ratio[19])
452 | else:
453 | nni.report_final_result(val_mae)
454 |
--------------------------------------------------------------------------------
/model/MobilityLLM.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence
5 | from torch.nn.utils.rnn import pad_packed_sequence
6 | from utils import DotDict
7 | from model.utils import *
8 | import torch.nn.functional as F
9 | from torch import nn
10 | #import faiss
11 | from transformers import AutoModelForCausalLM,AutoTokenizer
12 | from .llm import LLMModel
13 |
14 | class MobilityLLM_ModelConfig(DotDict):
15 | '''
16 | configuration of the MobilityLLM
17 | '''
18 |
19 | def __init__(self, loc_size=None, tim_size=None, uid_size=None, geohash_size=None, category_size=None, tim_emb_size=None, loc_emb_size=None,
20 | hidden_size=None, user_emb_size=None, model_class=None, device=None,
21 | loc_noise_mean=None, loc_noise_sigma=None, tim_noise_mean=None, tim_noise_sigma=None,
22 | user_noise_mean=None, user_noise_sigma=None, tau=None,
23 | pos_eps=None, neg_eps=None, dropout_rate_1=None, dropout_rate_2=None, category_vector=None, rnn_type='BiLSTM',
24 | num_layers=3, k=8, momentum=0.95, temperature=0.1, theta=0.18,
25 | n_components=4, shift_init=0.0, scale_init=0.0, min_clip=-5., max_clip=3., hypernet_hidden_sizes=None, max_delta_mins=1440,
26 | downstream='POI',tpp='pdf',loss='pdf', dropout_spatial = None, epsilon = None, learnable_param_size=1):
27 | super().__init__()
28 | self.max_delta_mins = max_delta_mins
29 |
30 | self.loc_size = loc_size #
31 | self.uid_size = uid_size #
32 | self.tim_size = tim_size #
33 | self.geohash_size = geohash_size #
34 | self.category_size = category_size #
35 | self.loc_emb_size = loc_emb_size
36 | self.tim_emb_size = tim_emb_size
37 | self.user_emb_size = user_emb_size
38 | self.hidden_size = hidden_size # RNN hidden_size
39 | self.model_class = model_class
40 | self.device = device
41 | self.rnn_type = rnn_type
42 | self.num_layers = num_layers
43 |
44 | self.loc_noise_mean = loc_noise_mean
45 | self.loc_noise_sigma = loc_noise_sigma
46 | self.tim_noise_mean = tim_noise_mean
47 | self.tim_noise_sigma = tim_noise_sigma
48 | self.user_noise_mean = user_noise_mean
49 | self.user_noise_sigma = user_noise_sigma
50 | self.tau = tau
51 | self.pos_eps = pos_eps
52 | self.neg_eps = neg_eps
53 | self.dropout_rate_1 = dropout_rate_1
54 | self.dropout_rate_2 = dropout_rate_2
55 | self.downstream = downstream
56 | self.category_vector = category_vector
57 | self.learnable_param_size=learnable_param_size
58 |
59 | self.k = k
60 | self.momentum = momentum
61 | self.theta = theta
62 | self.temperature = temperature
63 |
64 | self.n_components = n_components #需要
65 | self.min_clip = min_clip #需要 无输入
66 | self.max_clip = max_clip #需要 无输入
67 | self.shift_init = shift_init
68 | self.scale_init = scale_init
69 | self.hypernet_hidden_sizes = hypernet_hidden_sizes #需要 无输入
70 | self.decoder_input_size = user_emb_size + hidden_size * 2 #需要
71 | self.loss = loss
72 | self.tpp = tpp
73 | self.dropout_spatial = dropout_spatial
74 | self.epsilon = epsilon
75 |
76 | class MobilityLLM(nn.Module):
77 | def __init__(self, config):
78 | super(MobilityLLM, self).__init__()
79 | # initialize parameters
80 | self.max_delta_mins = config['max_delta_mins']
81 | self.truth_Y_tau = None
82 | self.loc_size = config['loc_size']
83 | self.loc_emb_size = config['loc_emb_size']
84 | self.tim_size = config['tim_size']
85 | self.tim_emb_size = config['tim_emb_size']
86 | self.user_size = config['uid_size']
87 | self.user_emb_size = config['user_emb_size']
88 |
89 | self.category_size = config['category_size']
90 | self.geohash_size = config['geohash_size']
91 | self.category_vector = config['category_vector']
92 | self.learnable_param_size = config['learnable_param_size'] # 可学习参数个数
93 |
94 | self.hidden_size = config['hidden_size']
95 | self.rnn_type = config['rnn_type']
96 | self.num_layers = config['num_layers']
97 | self.device = config['device']
98 | self.model_class = config['model_class']
99 | self.downstream = config['downstream']
100 |
101 | # parameters for cluster contrastive learning
102 | self.k = config['k']
103 |
104 | # parameters for time contrastive learning (Angle & Momentum based)
105 | # momentum
106 | self.momentum = config['momentum']
107 | # angle
108 | self.theta = config['theta']
109 | # self.theta = 0.05
110 | self.temperature = config['temperature']
111 |
112 | # spatial
113 | self.user_centroids = None
114 | self.user_2cluster = None
115 | self.item_centroids = None
116 | self.item_2cluster = None
117 | self.softmax = nn.Softmax()
118 | self.epsilon = config['epsilon']
119 | self.sinkhorn_iterations = 3
120 | self.crops_for_assign = [0, 1]
121 | self.nmb_crops = [2]
122 | self.world_size = -1
123 | self.dropout = nn.Dropout(0.3)
124 | # todo 这里的0.1用超参调
125 | self.l2norm = True
126 | # location all size (embedding + geohash + category)
127 | self.rnn_input_size = self.loc_emb_size + self.geohash_size
128 | if self.rnn_type == 'BiLSTM':
129 | self.bi = 2
130 | else:
131 | self.bi = 1
132 |
133 | # parameters for social contrastive learning (4 group of parameters)
134 | self.para0 = nn.Parameter(torch.randn(1, 6))
135 | self.para1 = nn.Parameter(torch.randn(1, 4))
136 | self.para2 = nn.Parameter(torch.randn(1, 24))
137 | self.para3 = nn.Parameter(torch.randn(1, 16))
138 |
139 | # parameters for TPP
140 | self.shift_init = config['shift_init']
141 | self.scale_init = config['scale_init']
142 | self.min_clip = config['min_clip']
143 | self.max_clip = config['max_clip']
144 |
145 | ##############################################
146 | self.loc_noise_mean = config['loc_noise_mean']
147 | self.loc_noise_sigma = config['loc_noise_sigma']
148 | self.tim_noise_mean = config['tim_noise_mean']
149 | self.tim_noise_sigma = config['tim_noise_sigma']
150 | self.user_noise_mean = config['user_noise_mean']
151 | self.user_noise_sigma = config['user_noise_sigma']
152 |
153 | self.tau = config['tau']
154 | self.pos_eps = config['pos_eps']
155 | self.neg_eps = config['neg_eps']
156 | self.dropout_rate_1 = config['dropout_rate_1']
157 | self.dropout_rate_2 = config['dropout_rate_2']
158 |
159 | self.dropout_1 = nn.Dropout(self.dropout_rate_1)
160 | self.dropout_2 = nn.Dropout(self.dropout_rate_2)
161 | ################################################
162 | self.tpp = config['tpp']
163 | self.loss = config['loss']
164 | self.mae = torch.nn.L1Loss()
165 | # Embedding layer
166 | self.emb_loc = nn.Embedding(num_embeddings=self.loc_size, embedding_dim=self.loc_emb_size)
167 | self.emb_tim = nn.Embedding(num_embeddings=self.tim_size, embedding_dim=self.tim_emb_size)
168 | self.emb_user = nn.Embedding(num_embeddings=self.user_size, embedding_dim=self.user_emb_size)
169 |
170 | # Category dense layer
171 | self.category_dense = nn.Linear(768, self.category_size)
172 | # Geohash dense layer
173 | self.geohash_dense = nn.Linear(12, self.geohash_size)
174 |
175 | # rnn layer
176 | self.spatial_encoder = LLMModel(model_path = "/data/ZhangXinyue/MobilityLLM/params/"+ self.model_class, model_class= self.model_class, loc_size = self.loc_size, learnable_param_size = self.learnable_param_size, device = self.device)
177 | # if self.rnn_type == 'GRU':
178 | # self.spatial_encoder = nn.GRU(self.rnn_input_size, self.hidden_size, num_layers=self.num_layers,
179 | # batch_first=False)
180 | # self.temporal_encoder = nn.GRU(self.tim_emb_size + 1, self.hidden_size, num_layers=self.num_layers,
181 | # batch_first=False)
182 | # self.temporal_encoder_momentum = nn.GRU(self.tim_emb_size + 1, self.hidden_size, num_layers=self.num_layers,
183 | # batch_first=False)
184 | # elif self.rnn_type == 'LSTM':
185 | # self.spatial_encoder = nn.LSTM(self.rnn_input_size, self.hidden_size, num_layers=self.num_layers,
186 | # batch_first=False)
187 | # elif self.rnn_type == 'BiLSTM':
188 | # self.spatial_encoder = nn.LSTM(self.rnn_input_size, self.hidden_size, num_layers=self.num_layers,
189 | # batch_first=False, bidirectional=True)
190 | # else:
191 | # raise ValueError("rnn_type should be ['GRU', 'LSTM', 'BiLSTM']")
192 |
193 | #spatial_adv
194 | # prototype layer
195 | self.prototypes = None
196 | if isinstance(self.k, list):
197 | self.prototypes = MultiPrototypes(self.hidden_size, self.k)
198 | elif self.k > 0:
199 | self.prototypes = nn.Linear(self.hidden_size, self.k, bias=False)
200 |
201 | # projection head
202 | self.projection_head = nn.Sequential(
203 | nn.Linear(self.hidden_size, 2048),
204 | nn.BatchNorm1d(2048),
205 | nn.ReLU(inplace=True),
206 | nn.Linear(2048, self.hidden_size),
207 | )
208 |
209 | # Hypernet module for TPP
210 | self.hypernet = Hypernet(config, hidden_sizes=config.hypernet_hidden_sizes,
211 | param_sizes=[config.n_components, config.n_components, config.n_components])
212 | # linear for TPP
213 | self.linear_p1 = nn.Linear(self.hidden_size + self.user_emb_size,1024)
214 | self.linear_p2 = nn.Linear(1024 , 4)
215 | self.linear_m1 = nn.Linear(self.hidden_size + self.user_emb_size,986)
216 | self.linear_m2 = nn.Linear(986 , 4)
217 | self.linear_l1 = nn.Linear(self.hidden_size + self.user_emb_size,1560)
218 | self.linear_l2 = nn.Linear(1560 , 4)
219 | self.Tanh = nn.Tanh()
220 | self.sigmoid = nn.Sigmoid()
221 |
222 | # dense layer
223 | self.s2st_projection = nn.Linear(self.hidden_size * self.bi, self.hidden_size * self.bi)
224 | if self.downstream == 'TUL':
225 | self.dense = nn.Linear(in_features=self.hidden_size * self.bi, out_features=self.user_size)
226 | self.t2st_projection = nn.Linear(self.hidden_size * self.bi, self.hidden_size * self.bi)
227 | self.projection = nn.Sequential(nn.Linear(self.hidden_size * self.bi, self.hidden_size * self.bi), nn.ReLU())
228 | elif self.downstream == 'POI':
229 | self.projection = nn.Sequential(
230 | nn.Linear(self.hidden_size * self.bi + self.user_emb_size, self.hidden_size * self.bi + self.user_emb_size),
231 | nn.ReLU())
232 | self.dense = nn.Linear(in_features=self.hidden_size * self.bi + self.user_emb_size, out_features=self.loc_size)
233 | self.t2st_projection = nn.Linear(self.hidden_size * self.bi + self.user_emb_size, self.hidden_size * self.bi)
234 | elif self.downstream == 'TPP':
235 | self.projection = nn.Sequential(
236 | nn.Linear(self.hidden_size * self.bi + self.user_emb_size,
237 | self.hidden_size * self.bi + self.user_emb_size),
238 | nn.ReLU())
239 | self.t2st_projection = nn.Linear(self.hidden_size * self.bi + self.user_emb_size,
240 | self.hidden_size * self.bi)
241 | final_in_size = self.hidden_size * self.bi + self.user_emb_size
242 | self.dense_s = nn.Linear(in_features=self.hidden_size,
243 | out_features=self.hidden_size)
244 | self.dense_t = nn.Linear(in_features=self.hidden_size * self.bi + self.user_emb_size,
245 | out_features=self.hidden_size * self.bi + self.user_emb_size)
246 | self.dense_st = nn.Linear(in_features=self.hidden_size * self.bi + self.user_emb_size,
247 | out_features=self.hidden_size * self.bi + self.user_emb_size)
248 | self.dense = nn.Sequential(
249 | nn.Linear(in_features=final_in_size, out_features=final_in_size // 4),
250 | nn.LeakyReLU(),
251 | nn.Linear(in_features=final_in_size // 4, out_features=final_in_size // 16),
252 | nn.LeakyReLU(),
253 | nn.Linear(in_features=final_in_size // 16, out_features=1),
254 | )
255 | self.linear1 = nn.Linear(self.hidden_size + self.user_emb_size, (self.hidden_size + self.user_emb_size) // 4)
256 | self.linear2 = nn.Linear((self.hidden_size + self.user_emb_size) // 4,
257 | (self.hidden_size + self.user_emb_size) // 16)
258 | self.linear3 = nn.Linear((self.hidden_size + self.user_emb_size) // 16, 1)
259 | self.linear0 = nn.Linear((self.hidden_size) * 2, self.hidden_size)
260 | else:
261 | raise ValueError('downstream should in [TUL, POI, TPP]!')
262 |
263 | self.apply(self._init_weight)
264 |
265 | def _init_weight(self, module):
266 | if isinstance(module, nn.Embedding):
267 | nn.init.xavier_normal_(module.weight)
268 | elif isinstance(module, nn.Linear):
269 | nn.init.xavier_uniform_(module.weight)
270 | elif isinstance(module, nn.LSTM):
271 | for name, param in module.named_parameters():
272 | if 'weight_ih' in name:
273 | nn.init.xavier_uniform_(param.data)
274 | elif 'weight_hh' in name:
275 | nn.init.orthogonal_(param.data)
276 | elif 'bias' in name:
277 | nn.init.constant_(param.data, 0)
278 |
279 | def spatial_encode(self, x, time, category, geohash_, all_len, cur_len, batch_size, momentum=False, downstream='POI'):
280 | if momentum == True:
281 | f_encoder = self.spatial_encoder_momentum
282 | else:
283 | f_encoder = self.spatial_encoder
284 | # self-attention (mask)
285 | spatial_out = self.spatial_encoder(x, torch.tensor(all_len).to(self.device), time, category, geohash_) # +time, category
286 | # if self.rnn_type == 'GRU':
287 | # spatial_out, h_n = f_encoder(packed_stuff) # max_len*batch*hidden_size
288 | # elif self.rnn_type == 'LSTM':
289 | # spatial_out, (h_n, c_n) = f_encoder(packed_stuff) # max_len*batch*hidden_size
290 | # elif self.rnn_type == 'BiLSTM':
291 | # spatial_out, (h_n, c_n) = f_encoder(packed_stuff) # max_len*batch*hidden_size
292 | # else :
293 | # raise ValueError('rnn type is not in GRU, LSTM, BiLSTM! ')
294 |
295 | # # unpack
296 | # spatial_out, out_len = pad_packed_sequence(spatial_out, batch_first=False)
297 | # spatial_out = spatial_out.permute(1, 0, 2)
298 |
299 | # out_len即all_len batch*max_len*hidden_size
300 | # concatenate
301 | if downstream == 'POI':
302 | final_out = spatial_out[0, (all_len[0] - cur_len[0]): all_len[0], :]
303 | for i in range(1, batch_size):
304 | final_out = torch.cat([final_out, spatial_out[i, (all_len[i] - cur_len[i]): all_len[i], :]], dim=0)
305 | # No longer concate user embedding
306 | # final_out = torch.cat([final_out, all_user_emb], 1)
307 | elif downstream == 'TPP':
308 | if all_len[0] == cur_len[0]:
309 | left = all_len[0] - cur_len[0]
310 | right = all_len[0]
311 | else:
312 | left = all_len[0] - cur_len[0] - 1
313 | right = all_len[0] - 1
314 | final_out = spatial_out[0, left: right, :]
315 | for i in range(1, batch_size):
316 | if all_len[i] == cur_len[i]:
317 | left = all_len[i] - cur_len[i]
318 | right = all_len[i]
319 | else:
320 | left = all_len[i] - cur_len[i] - 1
321 | right = all_len[i] - 1
322 | final_out = torch.cat([final_out, spatial_out[i, left: right, :]], dim=0)
323 | elif downstream == 'TUL':
324 | final_out = spatial_out[0, (all_len[0] - 1): all_len[0], :]
325 | slice_tensor = torch.mean(spatial_out[0, : all_len[0], :],dim=0,keepdim=True)
326 | #print("final_out_shape:",final_out.shape)
327 | #print("slice_tensor_shape:",slice_tensor.shape)
328 | for i in range(1, batch_size):
329 | final_out = torch.cat([final_out, torch.mean(spatial_out[i, : all_len[i], :],dim=0,keepdim=True)], dim=0)
330 | else:
331 | raise ValueError('downstream is not in [POI, TUL, TPP]')
332 | return final_out
333 |
334 | '''
335 | def temporal_encode(self, packed_stuff, all_len, cur_len, batch_size, momentum=False, downstream='POI'):
336 | if momentum == True:
337 | f_encoder = self.temporal_encoder_momentum
338 | else:
339 | f_encoder = self.temporal_encoder
340 | if self.rnn_type == 'GRU':
341 | temporal_out, h_n = f_encoder(packed_stuff) # max_len*batch*hidden_size
342 | elif self.rnn_type == 'LSTM':
343 | temporal_out, (h_n, c_n) = f_encoder(packed_stuff) # max_len*batch*hidden_size
344 | elif self.rnn_type == 'BiLSTM':
345 | temporal_out, (h_n, c_n) = f_encoder(packed_stuff) # max_len*batch*hidden_size
346 | else :
347 | raise ValueError('rnn type is not in GRU, LSTM, BiLSTM! ')
348 |
349 | # unpack
350 | temporal_out, out_len = pad_packed_sequence(temporal_out, batch_first=False)
351 | temporal_out = temporal_out.permute(1, 0, 2)
352 |
353 | # out_len即all_len batch*max_len*hidden_size
354 | # concatenate
355 | if downstream == 'POI':
356 | final_out = temporal_out[0, (all_len[0] - cur_len[0]): all_len[0], :]
357 | for i in range(1, batch_size):
358 | final_out = torch.cat([final_out, temporal_out[i, (all_len[i] - cur_len[i]): all_len[i], :]], dim=0)
359 |
360 | elif downstream == 'TPP':
361 | if all_len[0] == cur_len[0]:
362 | left = all_len[0] - cur_len[0]
363 | right = all_len[0]
364 | else:
365 | left = all_len[0] - cur_len[0] - 1
366 | right = all_len[0] - 1
367 | final_out = temporal_out[0, left: right, :]
368 | for i in range(1, batch_size):
369 | if all_len[i] == cur_len[i]:
370 | left = all_len[i] - cur_len[i]
371 | right = all_len[i]
372 | else:
373 | left = all_len[i] - cur_len[i] - 1
374 | right = all_len[i] - 1
375 | final_out = torch.cat([final_out, temporal_out[i, left: right, :]], dim=0)
376 | elif downstream == 'TUL':
377 | final_out = temporal_out[0, (all_len[0] - 1): all_len[0], :]
378 | for i in range(1, batch_size):
379 | final_out = torch.cat([final_out, temporal_out[i, (all_len[i] - 1): all_len[i], :]], dim=0)
380 | else:
381 | raise ValueError('downstream is not in [POI, TUL, TPP]')
382 |
383 | return final_out
384 | '''
385 |
386 |
387 | def normal_logpdf(self, x, mean, log_scale):
388 | '''
389 | log pdf of the normal distribution with mean and log_sigma
390 | '''
391 | # z = (x - mean[:,0:1]) * torch.exp(-log_scale[:,0:1])
392 | z = (x - mean) * torch.exp(-log_scale)
393 | return -log_scale - 0.5 * z.pow(2.0) - 0.5 * np.log(2 * np.pi)
394 |
395 | def mixnormal_logpdf(self, x, log_prior, means, log_scales):
396 | '''
397 | :param x: ground truth
398 | :param log_prior: 归一化后的权重系数,(batch,max_length/actual_length,n_components)=(64, 128, 64),
399 | :param means: (batch,max_length/actual_length,n_components)=(64, 128, 64)
400 | :param log_scales: (batch,max_length/actual_length,n_components)=(64, 128, 64), scales对应论文中的s
401 | :return:
402 | '''
403 | return torch.logsumexp(
404 | log_prior + self.normal_logpdf(x, means, log_scales),
405 | dim=-1
406 | )
407 |
408 | def get_params(self, decoder_input):
409 | """
410 | Generate model parameters based on the inputs
411 | Args:
412 | input: decoder input [batch, decoder_input_size]
413 |
414 | Returns:
415 | prior_logits: shape [batch, n_components]
416 | means: shape [batch, n_components]
417 | log_scales: shape [batch, n_components]
418 | """
419 | prior_logits, means, log_scales = self.hypernet(decoder_input)
420 |
421 | # Clamp values that go through exp for numerical stability
422 | prior_logits = clamp_preserve_gradients(prior_logits, self.min_clip, self.max_clip)
423 | log_scales = clamp_preserve_gradients(log_scales, self.min_clip, self.max_clip)
424 |
425 | # normalize prior_logits
426 | prior_logits = F.log_softmax(prior_logits, dim=-1) # 这里进行了权重w的归一化,而且之后进行了log
427 | return prior_logits, means, log_scales
428 |
429 |
430 |
431 | def forward(self, batch, mode='test', cont_conf=None, downstream='POI', queue = None, use_the_queue = False):
432 |
433 |
434 | loc = batch.X_all_loc
435 |
436 |
437 | tim = batch.X_all_tim
438 | user = batch.X_users
439 | geohash_ = batch.X_all_geohash
440 | cur_len = batch.target_lengths
441 | all_len = batch.X_lengths
442 | loc_cat = batch.X_all_loc_category
443 |
444 | batch_size = loc.shape[0]
445 | loc_emb = self.emb_loc(loc)
446 | tim_emb = self.emb_tim(tim)
447 | user_emb = self.emb_user(user)
448 | geohash_ = self.geohash_dense(geohash_)
449 |
450 | # concatenate
451 | #x = torch.cat([loc_emb, tim_emb], dim=2)
452 | x = loc.to(self.device)
453 | time = tim.float().to(self.device)
454 | # x = torch.cat([loc_emb, geohash_], dim=2).permute(1, 0, 2)
455 | #
456 | # x_temporal = tim_emb.permute(1, 0, 2)
457 |
458 | # cat locs & taus
459 | all_tau = [torch.cat((batch.X_tau[0, :all_len[0] - cur_len[0]], batch.Y_tau[0, :cur_len[0]]), dim=-1)]
460 | self.truth_Y_tau = all_tau[0][all_len[0] - cur_len[0]:all_len[0]]
461 |
462 | for i in range(1, batch_size):
463 | # taus
464 | cur_tau = torch.cat((batch.X_tau[i, :all_len[i] - cur_len[i]], batch.Y_tau[i, :cur_len[i]]), dim=-1)
465 | all_tau.append(cur_tau)
466 |
467 | self.truth_Y_tau = torch.cat((self.truth_Y_tau, all_tau[i][all_len[i] - cur_len[i]:all_len[i]]), dim=0)
468 |
469 | all_tau = pad_sequence(all_tau, batch_first=False).to(self.device)
470 | # x_temporal = torch.cat((all_tau.unsqueeze(-1), x_temporal), dim=-1)
471 |
472 | # # pack
473 | # pack_x = pack_padded_sequence(x, lengths=all_len, enforce_sorted=False)
474 | # pack_x_temporal = pack_padded_sequence(x_temporal, lengths=all_len, enforce_sorted=False)
475 |
476 | final_out = self.spatial_encode(x, all_tau.transpose(0,1), loc_cat, geohash_, all_len, cur_len, batch_size, downstream=downstream)
477 | # final_temporal_out = self.temporal_encode(pack_x_temporal, all_len, cur_len, batch_size, downstream=downstream)
478 |
479 | all_user_emb = user_emb[0].unsqueeze(dim=0).repeat(cur_len[0], 1)
480 | for i in range(1, batch_size):
481 | all_user_emb = torch.cat([all_user_emb, user_emb[i].unsqueeze(dim=0).repeat(cur_len[i], 1)], dim=0)
482 |
483 | #todo 去掉时间,试一下结果
484 | if downstream == 'POI':
485 | prediction_out = torch.cat([final_out, all_user_emb], 1)
486 | dense = self.dense(prediction_out) # Batch * loc_size
487 | pred = nn.LogSoftmax(dim=1)(dense) # result
488 | elif downstream == 'TUL':
489 | # prediction_out = torch.cat([final_spatial_out, final_temporal_out], 1)
490 | dense = self.dense(self.dropout(final_out))
491 | pred = nn.LogSoftmax(dim=1)(dense) # result
492 | elif downstream == 'TPP':
493 | final_out = torch.cat([final_out, all_user_emb], 1)
494 | prediction_out = self.dense(final_out) # Batch * loc_size
495 | else:
496 | raise ValueError('downstream is not in [POI, TUL, TPP]')
497 |
498 |
499 | criterion = nn.NLLLoss().to(self.device)
500 | criterion1 = nn.L1Loss().to(self.device)
501 |
502 |
503 | if downstream == 'POI':
504 | s_loss_score = criterion(pred, batch.Y_location).requires_grad_(True)
505 | _, top_k_pred = torch.topk(pred, k=self.loc_size) # (batch, K)=(batch, num_class)
506 | elif downstream == 'TUL':
507 | s_loss_score = criterion(pred, batch.X_users).requires_grad_(True)
508 | _, top_k_pred = torch.topk(pred, k=self.user_size) # (batch, K)=(batch, num_class)
509 | elif downstream == 'TPP':
510 | # y = torch.log(self.truth_Y_tau + 1e-2).unsqueeze(-1)
511 |
512 | mean_time = self.linear3(self.sigmoid(self.linear2(self.sigmoid(self.linear1(final_out)))))
513 | # loss = self.mae(mean_time, y.to(mean_time.device))
514 |
515 | # mean_time = torch.exp(mean_time)
516 | # y = torch.log(self.truth_Y_tau + 1e-2).unsqueeze(-1)
517 | # y = (y - self.shift_init.to(y.device)) / self.scale_init.to(y.device)
518 | # if self.tpp == 'pdf':
519 | # prior_logits, means, log_scales = self.get_params(prediction_out)
520 | # #todo 用可学习参数*pred_out代替以上三个
521 | # elif self.tpp == 'linear':
522 | # prior_logits = self.linear_p2(self.Tanh(self.linear_p1(prediction_out)))
523 | # means = self.linear_m2(self.Tanh(self.linear_m1(prediction_out)))
524 | # log_scales = self.linear_l2(self.Tanh(self.linear_l1(prediction_out)))
525 | # # 使用线性层替代
526 | # log_p = self.mixnormal_logpdf(y.to(self.device), prior_logits, means, log_scales)
527 | # prior = prior_logits.exp() # (batch, n_components=64)
528 | # scales_squared = (log_scales * 2).exp()
529 | # a = self.scale_init.to(y.device)
530 | # b = self.shift_init.to(y.device)
531 | # mean_time = (prior * torch.exp(a * means + b + 0.5 * a ** 2 * scales_squared)).sum(-1)
532 | # mean_time = torch.clip(mean_time, max=float(self.max_delta_mins), min=0.0)
533 | # if self.loss == 'pdf':
534 | # s_loss_score = -log_p.mean()
535 | if self.loss == 'mae':
536 | # s_loss_score = criterion1(prediction_out.to(y.device),(self.truth_Y_tau).to(y.device))
537 | s_loss_score = criterion1(prediction_out.squeeze().to('cpu'),self.truth_Y_tau.to('cpu'))
538 | top_k_pred = prediction_out
539 | else:
540 | raise ValueError('downstream is not in [POI, TUL, TPP]')
541 |
542 | if mode == 'train' and sum(cont_conf) != 0:
543 | return s_loss_score, top_k_pred, queue
544 | else:
545 | return s_loss_score, top_k_pred, queue
546 |
547 | class MultiPrototypes(nn.Module):
548 | def __init__(self, output_dim, nmb_prototypes):
549 | super(MultiPrototypes, self).__init__()
550 | self.nmb_heads = len(nmb_prototypes)
551 | for i, k in enumerate(nmb_prototypes):
552 | self.add_module("prototypes" + str(i), nn.Linear(output_dim, k, bias=False))
553 |
554 | def forward(self, x):
555 | out = []
556 | for i in range(self.nmb_heads):
557 | out.append(getattr(self, "prototypes" + str(i))(x))
558 | return out
--------------------------------------------------------------------------------
/model/phi_model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 | #
4 | # Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
5 | # Licensed under the BSD 3-Clause License.
6 |
7 | from __future__ import annotations
8 |
9 | import math
10 | from dataclasses import dataclass, field
11 | from typing import Any, Dict, Optional, Tuple, Union
12 |
13 | import torch
14 | import torch.nn as nn
15 | from einops import rearrange, repeat
16 | from transformers import PretrainedConfig, PreTrainedModel
17 | from transformers.activations import ACT2FN
18 | from transformers.modeling_outputs import CausalLMOutputWithPast
19 |
20 | from .configuration_phi import PhiConfig
21 |
22 | try:
23 | from flash_attn.bert_padding import pad_input, unpad_input
24 | from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
25 | from flash_attn.modules.mha import FlashCrossAttention, FlashSelfAttention
26 | from flash_attn.ops.fused_dense import FusedDense
27 | except:
28 | pad_input, unpad_input = None, None
29 | FlashRotaryEmbedding = None
30 | FlashSelfAttention, FlashCrossAttention = None, None
31 | FusedDense = None
32 |
33 |
34 | @dataclass
35 | class InferenceParams:
36 | """Inference parameters passed to model to efficiently calculate
37 | and store context during inference.
38 |
39 | Reference:
40 | https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py.
41 |
42 | Args:
43 | max_seqlen: Maximum sequence length.
44 | max_batch_size: Maximum batch size.
45 | seqlen_offset: Sequence length offset.
46 | batch_size_offset: Batch size offset.
47 | key_value_memory_dict: Key value memory dictionary.
48 | lengths_per_sample: Lengths per sample.
49 |
50 | """
51 |
52 | max_seqlen: int = field(metadata={"help": "Maximum sequence length."})
53 |
54 | max_batch_size: int = field(metadata={"help": "Maximum batch size."})
55 |
56 | seqlen_offset: int = field(default=0, metadata={"help": "Sequence length offset."})
57 |
58 | batch_size_offset: int = field(default=0, metadata={"help": "Batch size offset."})
59 |
60 | key_value_memory_dict: Dict[str, Any] = field(
61 | default_factory=dict, metadata={"help": "Key value memory dictionary."}
62 | )
63 |
64 | lengths_per_sample: torch.Tensor = field(default=None, metadata={"help": "Lengths per sample."})
65 |
66 |
67 | class Embedding(nn.Module):
68 | """Token embedding with dropout."""
69 |
70 | def __init__(self, config: PretrainedConfig) -> None:
71 | super().__init__()
72 |
73 | self.wte = nn.Embedding(config.vocab_size, config.n_embd)
74 | self.drop = nn.Dropout(config.embd_pdrop)
75 |
76 | def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
77 | input_shape = input_ids.size()
78 | input_ids = input_ids.view(-1, input_shape[-1])
79 |
80 | hidden_states = self.wte(input_ids)
81 | hidden_states = self.drop(hidden_states)
82 |
83 | return hidden_states
84 |
85 |
86 | def _apply_rotary_emb(
87 | x: torch.FloatTensor,
88 | cos: torch.FloatTensor,
89 | sin: torch.FloatTensor,
90 | ) -> torch.FloatTensor:
91 | _, seqlen, _, _ = x.shape
92 | _, rotary_dim = cos.shape
93 | rotary_dim *= 2
94 |
95 | x_rot = x[:, :, :, :rotary_dim]
96 | x_pass = x[:, :, :, rotary_dim:]
97 |
98 | x1, x2 = x_rot.chunk(2, dim=-1)
99 | c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
100 | x1, x2, c, s = [t.to(dtype=torch.float32) for t in [x1, x2, c, s]]
101 |
102 | x_rot = torch.cat([x1 * c - x2 * s, x1 * s + x2 * c], axis=-1).to(x.dtype)
103 |
104 | return torch.cat([x_rot, x_pass], axis=-1)
105 |
106 |
107 | def _apply_rotary_emb_kv(
108 | kv: torch.FloatTensor,
109 | cos: torch.FloatTensor,
110 | sin: torch.FloatTensor,
111 | cos_k: Optional[torch.FloatTensor] = None,
112 | sin_k: Optional[torch.FloatTensor] = None,
113 | ) -> torch.FloatTensor:
114 | _, seqlen, _, _, _ = kv.shape
115 | _, rotary_dim = cos.shape
116 | rotary_dim *= 2
117 |
118 | k_rot = kv[:, :, 0, :, :rotary_dim]
119 | k_pass = kv[:, :, 0, :, rotary_dim:]
120 |
121 | k1, k2 = k_rot.chunk(2, dim=-1)
122 | c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
123 | k1, k2, c, s = [t.to(dtype=torch.float32) for t in [k1, k2, c, s]]
124 |
125 | k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(kv.dtype)
126 |
127 | return torch.cat(
128 | [
129 | torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
130 | kv[:, :, 1:2, :, :],
131 | ],
132 | axis=2,
133 | )
134 |
135 |
136 | def _apply_rotary_emb_qkv(
137 | qkv: torch.FloatTensor,
138 | cos: torch.FloatTensor,
139 | sin: torch.FloatTensor,
140 | cos_k: Optional[torch.FloatTensor] = None,
141 | sin_k: Optional[torch.FloatTensor] = None,
142 | ) -> torch.FloatTensor:
143 | _, seqlen, _, _, _ = qkv.shape
144 | _, rotary_dim = cos.shape
145 | rotary_dim *= 2
146 |
147 | q_rot = qkv[:, :, 0, :, :rotary_dim]
148 | q_pass = qkv[:, :, 0, :, rotary_dim:]
149 |
150 | k_rot = qkv[:, :, 1, :, :rotary_dim]
151 | k_pass = qkv[:, :, 1, :, rotary_dim:]
152 |
153 | q1, q2 = q_rot.chunk(2, dim=-1)
154 | k1, k2 = k_rot.chunk(2, dim=-1)
155 | c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
156 | q1, q2, k1, k2, c, s = [t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s]]
157 |
158 | q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype)
159 | k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype)
160 |
161 | return torch.cat(
162 | [
163 | torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2),
164 | torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
165 | qkv[:, :, 2:3, :, :],
166 | ],
167 | axis=2,
168 | )
169 |
170 |
171 | class RotaryEmbedding(nn.Module):
172 | """Rotary positional embedding (RoPE).
173 |
174 | Reference:
175 | RoFormer: Enhanced Transformer with Rotary Position Embedding.
176 | https://arxiv.org/pdf/2104.09864.pdf.
177 |
178 | """
179 |
180 | def __init__(
181 | self,
182 | dim: int,
183 | base: int = 10000,
184 | scale_base: Optional[float] = None,
185 | pos_idx_in_fp32: bool = True,
186 | max_position_embeddings: int = 2048,
187 | device: Optional[str] = None,
188 | **kwargs,
189 | ) -> None:
190 | super().__init__()
191 |
192 | if scale_base is not None:
193 | raise NotImplementedError
194 |
195 | self.dim = dim
196 | self.base = float(base)
197 | self.scale_base = scale_base
198 | self.pos_idx_in_fp32 = pos_idx_in_fp32
199 | self.max_position_embeddings = max_position_embeddings
200 | self.device = device
201 |
202 | # Generate and save the inverse frequency buffer (non-trainable)
203 | inv_freq = self._compute_inv_freq(device)
204 | self.register_buffer("inv_freq", inv_freq, persistent=False)
205 |
206 | # Generate and save the scale buffer (non-trainable)
207 | scale = (
208 | (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
209 | if scale_base is not None
210 | else None
211 | )
212 | self.register_buffer("scale", scale, persistent=False)
213 |
214 | # Initialize cached attributes since ONNX can't rely on dynamic initialization
215 | self._update_cos_sin_cache(max_position_embeddings, device=device, dtype=torch.float32)
216 |
217 | def _compute_inv_freq(self, device: Optional[str] = None) -> torch.FloatTensor:
218 | return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
219 |
220 | def _update_cos_sin_cache(
221 | self,
222 | seqlen: int,
223 | device: Optional[str] = None,
224 | dtype: Optional[torch.dtype] = None,
225 | ) -> None:
226 | self._seq_len_cached = seqlen
227 |
228 | # fp32 is preferred since the output of `torch.arange` can be quite large
229 | # and bf16 would lose a lot of precision
230 | if self.pos_idx_in_fp32:
231 | t = torch.arange(seqlen, device=device, dtype=torch.float32)
232 | if self.inv_freq.dtype != torch.float32:
233 | inv_freq = self._compute_inv_freq(device=device)
234 | else:
235 | inv_freq = self.inv_freq
236 | else:
237 | t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
238 | inv_freq = self.inv_freq
239 |
240 | # `torch.outer` is preferred since `torch.einsum` converts from fp32 to fp16 if used with AMP
241 | freqs = torch.outer(t, inv_freq)
242 | if self.scale is None:
243 | self._cos_cached = torch.cos(freqs).to(dtype)
244 | self._sin_cached = torch.sin(freqs).to(dtype)
245 | else:
246 | power = (
247 | torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
248 | ) / self.scale_base
249 | scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
250 |
251 | # Force the scale multiplication to happen in fp32
252 | self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
253 | self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
254 | self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
255 | self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
256 |
257 | def forward(
258 | self,
259 | qkv: torch.Tensor,
260 | kv: Optional[torch.Tensor] = None,
261 | seqlen_offset: int = 0,
262 | **kwargs,
263 | ) -> Tuple[torch.Tensor, torch.Tensor]:
264 | if (
265 | self._seq_len_cached < qkv.shape[1] + seqlen_offset
266 | or self._cos_cached.device != qkv.device
267 | or self._cos_cached.dtype != qkv.dtype
268 | or (self.training and self._cos_cached.is_inference())
269 | ):
270 | self._update_cos_sin_cache(qkv.shape[1] + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
271 |
272 | if kv is None:
273 | return _apply_rotary_emb_qkv(
274 | qkv,
275 | self._cos_cached[seqlen_offset:],
276 | self._sin_cached[seqlen_offset:],
277 | )
278 | else:
279 | q = _apply_rotary_emb(
280 | qkv,
281 | self._cos_cached[seqlen_offset:],
282 | self._sin_cached[seqlen_offset:],
283 | )
284 | kv = _apply_rotary_emb_kv(
285 | kv,
286 | self._cos_cached[seqlen_offset:],
287 | self._sin_cached[seqlen_offset:],
288 | )
289 |
290 | return q, kv
291 |
292 |
293 | class MLP(nn.Module):
294 | """Multi-Layer Perceptron.
295 |
296 | Reference:
297 | Attention Is All You Need.
298 | https://arxiv.org/pdf/1706.03762.pdf.
299 |
300 | """
301 |
302 | def __init__(
303 | self,
304 | config: PretrainedConfig,
305 | n_inner: Optional[int] = None,
306 | act_fn: Optional[str] = None,
307 | ) -> None:
308 | super().__init__()
309 |
310 | act_fn = config.activation_function if act_fn is None else act_fn
311 |
312 | n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
313 | n_inner = n_inner if n_inner is not None else 4 * config.n_embd
314 |
315 | self.fc1 = nn.Linear(config.n_embd, n_inner)
316 | self.fc2 = nn.Linear(n_inner, config.n_embd)
317 | self.act = ACT2FN[act_fn]
318 |
319 | def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
320 | hidden_states = self.fc1(hidden_states)
321 | hidden_states = self.act(hidden_states)
322 | hidden_states = self.fc2(hidden_states)
323 |
324 | return hidden_states
325 |
326 |
327 | class SelfAttention(nn.Module):
328 | """Self-attention layer (compatible with PyTorch).
329 |
330 | Reference:
331 | https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
332 |
333 | """
334 |
335 | def __init__(
336 | self,
337 | causal: bool = True,
338 | softmax_scale: Optional[float] = None,
339 | attention_dropout: float = 0.0,
340 | ) -> None:
341 | super().__init__()
342 |
343 | self.causal = causal
344 | self.softmax_scale = softmax_scale
345 | self.drop = nn.Dropout(attention_dropout)
346 |
347 | @torch.autocast("cpu", enabled=False)
348 | @torch.autocast("cuda", enabled=False)
349 | def forward(
350 | self,
351 | qkv: torch.FloatTensor,
352 | causal: bool = None,
353 | key_padding_mask: Optional[torch.BoolTensor] = None,
354 | **kwargs,
355 | ) -> torch.FloatTensor:
356 | batch_size, seqlen = qkv.shape[0], qkv.shape[1]
357 | q, k, v = qkv.unbind(dim=2)
358 |
359 | q = q.to(torch.float32)
360 | k = k.to(torch.float32)
361 |
362 | causal = self.causal if causal is None else causal
363 | softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
364 |
365 | # Autocast is manually disabled to avoid `torch.einsum` performing the operation
366 | # using float16, which might lead to overflow
367 | scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
368 |
369 | if key_padding_mask is not None:
370 | padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device)
371 | padding_mask.masked_fill_(key_padding_mask, 0.0)
372 |
373 | scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
374 |
375 | if causal:
376 | causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
377 | scores = scores + causal_mask.to(dtype=scores.dtype)
378 |
379 | attention = torch.softmax(scores, dim=-1).to(v.dtype)
380 | attention = self.drop(attention)
381 |
382 | output = torch.einsum("bhts,bshd->bthd", attention, v)
383 |
384 | return output
385 |
386 |
387 | class CrossAttention(nn.Module):
388 | """Cross-attention layer (compatible with PyTorch).
389 |
390 | Reference:
391 | https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
392 |
393 | """
394 |
395 | def __init__(
396 | self,
397 | causal: bool = True,
398 | softmax_scale: Optional[float] = None,
399 | attention_dropout: float = 0.0,
400 | ) -> None:
401 | super().__init__()
402 |
403 | self.causal = causal
404 | self.softmax_scale = softmax_scale
405 | self.drop = nn.Dropout(attention_dropout)
406 |
407 | @torch.autocast("cpu", enabled=False)
408 | @torch.autocast("cuda", enabled=False)
409 | def forward(
410 | self,
411 | q: torch.FloatTensor,
412 | kv: torch.FloatTensor,
413 | causal: bool = None,
414 | key_padding_mask: Optional[torch.BoolTensor] = None,
415 | **kwargs,
416 | ) -> torch.FloatTensor:
417 | batch_size, seqlen_q = q.shape[0], q.shape[1]
418 | seqlen_k = kv.shape[1]
419 |
420 | if kv.shape[3] != q.shape[2]:
421 | kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
422 | k, v = kv.unbind(dim=2)
423 |
424 | q = q.to(torch.float32)
425 | k = k.to(torch.float32)
426 |
427 | causal = self.causal if causal is None else causal
428 | softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
429 |
430 | # Autocast is manually disabled to avoid `torch.einsum` performing the operation
431 | # using float16, which might lead to overflow
432 | scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
433 |
434 | if key_padding_mask is not None:
435 | padding_mask = torch.full(
436 | (batch_size, seqlen_k),
437 | -10000.0,
438 | dtype=scores.dtype,
439 | device=scores.device,
440 | )
441 | padding_mask.masked_fill_(key_padding_mask, 0.0)
442 |
443 | scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
444 |
445 | if causal:
446 | rows = rearrange(torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1")
447 | cols = torch.arange(seqlen_k, device=k.device, dtype=torch.long)
448 | causal_mask = cols > rows + seqlen_k - seqlen_q
449 |
450 | scores = scores.masked_fill(causal_mask, -10000.0)
451 |
452 | attention = torch.softmax(scores, dim=-1).to(v.dtype)
453 | attention = self.drop(attention)
454 |
455 | output = torch.einsum("bhts,bshd->bthd", attention, v)
456 |
457 | return output
458 |
459 |
460 | def _find_mha_dims(
461 | config: PretrainedConfig,
462 | n_head: Optional[int] = None,
463 | n_head_kv: Optional[int] = None,
464 | head_dim: Optional[int] = None,
465 | ) -> Tuple[int, int]:
466 | if n_head is None and head_dim is None:
467 | head_dim = config.n_embd // config.n_head
468 | n_head = config.n_head
469 | elif n_head is None or head_dim is None:
470 | raise ValueError("`n_head` and `head_dim` must be both specified or `None`.")
471 |
472 | if n_head_kv is None:
473 | n_head_kv = getattr(config, "n_head_kv", None) or n_head
474 |
475 | return n_head, n_head_kv, head_dim
476 |
477 |
478 | def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, layer_idx: int) -> torch.FloatTensor:
479 | num_heads, head_dim = kv.shape[-2:]
480 |
481 | if layer_idx not in inference_params.key_value_memory_dict:
482 | inference_params.key_value_memory_dict[layer_idx] = torch.empty(
483 | inference_params.max_batch_size,
484 | inference_params.max_seqlen,
485 | 2,
486 | num_heads,
487 | head_dim,
488 | dtype=kv.dtype,
489 | device=kv.device,
490 | )
491 |
492 | batch_start = inference_params.batch_size_offset
493 | batch_end = batch_start + kv.shape[0]
494 |
495 | sequence_start = inference_params.seqlen_offset
496 | sequence_end = sequence_start + kv.shape[1]
497 |
498 | # When the current sequence length is equal to or larger than the maximum sequence length,
499 | # we need to concatenate the current `kv` with the cached `kv` to expand its length
500 | if sequence_end >= inference_params.max_seqlen:
501 | inference_params.key_value_memory_dict[layer_idx] = torch.concatenate(
502 | (inference_params.key_value_memory_dict[layer_idx], kv), dim=1)
503 |
504 | inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, sequence_start:sequence_end, ...] = kv
505 | kv = inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, :sequence_end, ...]
506 |
507 | return kv
508 |
509 |
510 | class MHA(nn.Module):
511 | """Multi-head attention layer."""
512 |
513 | def __init__(
514 | self,
515 | config: PretrainedConfig,
516 | dtype: Optional[torch.dtype] = None,
517 | device: Optional[str] = None,
518 | rotary_dim: Optional[int] = None,
519 | rotary_base: float = 10000.0,
520 | rotary_scale_base: Optional[float] = None,
521 | n_head: Optional[int] = None,
522 | n_head_kv: Optional[int] = None,
523 | head_dim: Optional[int] = None,
524 | bias: bool = True,
525 | causal: bool = True,
526 | softmax_scale: Optional[float] = None,
527 | layer_idx: Optional[int] = None,
528 | return_residual: bool = False,
529 | checkpointing: bool = False,
530 | ) -> None:
531 | super().__init__()
532 |
533 | # Rotary embedding
534 | self.rotary_dim = rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0)
535 | if self.rotary_dim > 0:
536 | rotary_cls = FlashRotaryEmbedding if config.flash_rotary else RotaryEmbedding
537 | if rotary_cls is None:
538 | rotary_cls = RotaryEmbedding
539 |
540 | rotary_kwargs = {}
541 | if rotary_cls is RotaryEmbedding:
542 | rotary_kwargs["max_position_embeddings"] = config.n_positions
543 |
544 | self.rotary_emb = rotary_cls(
545 | self.rotary_dim,
546 | base=rotary_base,
547 | scale_base=rotary_scale_base,
548 | device=device,
549 | **rotary_kwargs,
550 | )
551 |
552 | # MLP
553 | self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(
554 | config, n_head=n_head, n_head_kv=n_head_kv, head_dim=head_dim
555 | )
556 | op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv)
557 | hidden_size = config.n_embd
558 |
559 | linear_cls = FusedDense if config.fused_dense else nn.Linear
560 | if linear_cls is None:
561 | linear_cls = nn.Linear
562 |
563 | self.Wqkv = linear_cls(hidden_size, op_size, bias=bias, device=device, dtype=dtype)
564 | self.out_proj = linear_cls(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype)
565 |
566 | # Attention
567 | attn_cls = FlashSelfAttention if config.flash_attn else SelfAttention
568 | if attn_cls is None:
569 | attn_cls = SelfAttention
570 |
571 | cross_attn_cls = FlashCrossAttention if config.flash_attn else CrossAttention
572 | if cross_attn_cls is None:
573 | cross_attn_cls = CrossAttention
574 |
575 | self.inner_attn = attn_cls(
576 | causal=causal,
577 | softmax_scale=softmax_scale,
578 | attention_dropout=config.attn_pdrop,
579 | )
580 | self.inner_cross_attn = cross_attn_cls(
581 | causal=causal,
582 | softmax_scale=softmax_scale,
583 | attention_dropout=config.attn_pdrop,
584 | )
585 |
586 | self.flash_attn = config.flash_attn and attn_cls is FlashSelfAttention
587 | self.layer_idx = layer_idx
588 | self.return_residual = return_residual
589 | self.checkpointing = checkpointing
590 |
591 | def _forward_self_attn(
592 | self, x: torch.FloatTensor, key_padding_mask: Optional[torch.BoolTensor]
593 | ) -> torch.FloatTensor:
594 | qkv = self.Wqkv(x)
595 | qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
596 |
597 | if self.rotary_dim > 0:
598 | qkv = self.rotary_emb(qkv)
599 |
600 | if self.flash_attn:
601 | batch_size, seqlen = qkv.shape[0], qkv.shape[1]
602 |
603 | cu_seqlens, max_seqlen = None, None
604 | if key_padding_mask is not None:
605 | # If `key_padding_mask` is supplied, we need to unpad the input and retrieve
606 | # the `cu_seqlens` and `max_seqlen` to be used by `flash-attn`
607 | qkv, indices, cu_seqlens, max_seqlen = unpad_input(qkv, key_padding_mask)
608 |
609 | if self.checkpointing:
610 | attn_output = torch.utils.checkpoint.checkpoint(
611 | self.inner_attn, qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
612 | )
613 | else:
614 | attn_output = self.inner_attn(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen).to(qkv.device)
615 |
616 | # If `key_padding_mask` is supplied, we need to pad the output back to the original shape
617 | return pad_input(attn_output, indices, batch_size, seqlen) if key_padding_mask is not None else attn_output
618 |
619 | if self.checkpointing:
620 | return torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, key_padding_mask=key_padding_mask)
621 |
622 | return self.inner_attn(qkv, key_padding_mask=key_padding_mask)
623 |
624 | def _forward_cross_attn(
625 | self,
626 | x: torch.FloatTensor,
627 | past_key_values: Optional[InferenceParams],
628 | key_padding_mask: Optional[torch.BoolTensor],
629 | ) -> torch.FloatTensor:
630 | batch_size = x.shape[0]
631 |
632 | qkv = self.Wqkv(x)
633 |
634 | q = qkv[..., : self.n_head * self.head_dim]
635 | q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
636 |
637 | kv = qkv[..., self.n_head * self.head_dim:]
638 | kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
639 |
640 | seqlen_offset = past_key_values.seqlen_offset if past_key_values is not None else 0
641 | causal = None if seqlen_offset == 0 else False
642 | if self.rotary_dim > 0:
643 | q, kv = self.rotary_emb(q, kv=kv, seqlen_offset=seqlen_offset)
644 |
645 | if past_key_values is not None:
646 | kv = _update_kv_cache(kv, past_key_values, self.layer_idx)
647 |
648 | if self.flash_attn:
649 | batch_size, seqlen_q = q.shape[0], q.shape[1]
650 | seqlen_k = kv.shape[1]
651 |
652 | cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = (
653 | None,
654 | None,
655 | None,
656 | None,
657 | )
658 | if key_padding_mask is not None:
659 | kv, _, cu_seqlens_k, max_seqlen_k = unpad_input(kv, key_padding_mask)
660 |
661 | if seqlen_q == 1:
662 | key_padding_mask = torch.ones(batch_size, 1, device=q.device)
663 | elif seqlen_q != seqlen_k:
664 | key_padding_mask = key_padding_mask[:, -seqlen_q:]
665 |
666 | q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, key_padding_mask)
667 |
668 | if self.checkpointing:
669 | attn_output = torch.utils.checkpoint.checkpoint(
670 | self.inner_cross_attn,
671 | q,
672 | kv,
673 | causal=causal,
674 | cu_seqlens=cu_seqlens_q,
675 | max_seqlen=max_seqlen_q,
676 | cu_seqlens_k=cu_seqlens_k,
677 | max_seqlen_k=max_seqlen_k,
678 | )
679 | else:
680 | attn_output = self.inner_cross_attn(
681 | q,
682 | kv,
683 | causal=causal,
684 | cu_seqlens=cu_seqlens_q,
685 | max_seqlen=max_seqlen_q,
686 | cu_seqlens_k=cu_seqlens_k,
687 | max_seqlen_k=max_seqlen_k,
688 | )
689 |
690 | return (
691 | pad_input(attn_output, indices_q, batch_size, max_seqlen_q)
692 | if key_padding_mask is not None
693 | else attn_output
694 | )
695 |
696 | if self.checkpointing:
697 | return torch.utils.checkpoint.checkpoint(
698 | self.inner_cross_attn,
699 | q,
700 | kv,
701 | key_padding_mask=key_padding_mask,
702 | causal=causal,
703 | )
704 |
705 | return self.inner_cross_attn(q, kv, key_padding_mask=key_padding_mask, causal=causal)
706 |
707 | def forward(
708 | self,
709 | x: torch.FloatTensor,
710 | past_key_values: Optional[InferenceParams] = None,
711 | attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
712 | **kwargs,
713 | ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
714 | if attention_mask is not None:
715 | attention_mask = attention_mask.bool()
716 | else:
717 | attention_mask = None
718 |
719 | # MHA
720 | if self.n_head == self.n_head_kv:
721 | if past_key_values is None:
722 | # If `past_key_values` are not supplied, we run self-attention
723 | attn_output = self._forward_self_attn(x, attention_mask)
724 | else:
725 | # If `past_key_values` are supplied, it means that we might have cached values and
726 | # could take advantage of cross-attention
727 | attn_output = self._forward_cross_attn(x, past_key_values, attention_mask)
728 | # MQA / GQA
729 | else:
730 | # Regardless of `past_key_values` being supplied or not, it always use cross-attention
731 | # because `q` and `kv` lengths might be different
732 | attn_output = self._forward_cross_attn(x, past_key_values, attention_mask)
733 |
734 | output = rearrange(attn_output, "... h d -> ... (h d)")
735 | output = self.out_proj(output)
736 |
737 | return output if not self.return_residual else (output, x)
738 |
739 |
740 | class ParallelBlock(nn.Module):
741 | """Parallel block.
742 |
743 | This block applies parallel mixer and MLP layers to the input (used in GPT-J and CodeGen).
744 |
745 | """
746 |
747 | def __init__(
748 | self,
749 | config: PretrainedConfig,
750 | block_idx: Optional[int] = None,
751 | ) -> None:
752 | super().__init__()
753 |
754 | self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
755 | self.resid_dropout = nn.Dropout(config.resid_pdrop)
756 | self.block_idx = block_idx
757 |
758 | self.mixer = MHA(config, layer_idx=block_idx)
759 | self.mlp = MLP(config)
760 |
761 | def forward(
762 | self,
763 | hidden_states: torch.FloatTensor,
764 | past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
765 | attention_mask: Optional[torch.BoolTensor] = None,
766 | **kwargs,
767 | ) -> torch.FloatTensor:
768 | residual = hidden_states
769 | hidden_states = self.ln(hidden_states)
770 |
771 | attn_outputs = self.mixer(
772 | hidden_states,
773 | past_key_values=past_key_values,
774 | attention_mask=attention_mask,
775 | )
776 | if isinstance(attn_outputs, tuple):
777 | attn_outputs = attn_outputs[0]
778 |
779 | attn_outputs = self.resid_dropout(attn_outputs)
780 | feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
781 |
782 | hidden_states = attn_outputs + feed_forward_hidden_states + residual
783 |
784 | return hidden_states
785 |
786 |
787 | class CausalLMHead(nn.Module):
788 | """Causal Language Modeling head.
789 |
790 | Reference:
791 | Improving Language Understanding by Generative Pre-Training.
792 | https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
793 |
794 | """
795 |
796 | def __init__(self, config: PretrainedConfig) -> None:
797 | super().__init__()
798 |
799 | self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
800 | self.linear = nn.Linear(config.n_embd, config.vocab_size)
801 |
802 | def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
803 | hidden_states = self.ln(hidden_states)
804 | logits = self.linear(hidden_states).to(torch.float32)
805 |
806 | return logits
807 |
808 |
809 | class CausalLMLoss(nn.Module):
810 | """Causal Language Modeling loss.
811 |
812 | Reference:
813 | Improving Language Understanding by Generative Pre-Training.
814 | https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
815 |
816 | """
817 |
818 | def __init__(self, shift_labels: bool = True) -> None:
819 | super().__init__()
820 |
821 | self.shift_labels = shift_labels
822 | self.loss_fct = nn.CrossEntropyLoss()
823 |
824 | def forward(self, logits: torch.FloatTensor, labels: torch.LongTensor) -> torch.FloatTensor:
825 | if self.shift_labels:
826 | logits = logits[..., :-1, :].contiguous()
827 | labels = labels[..., 1:].contiguous()
828 |
829 | loss = self.loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
830 |
831 | return loss
832 |
833 |
834 | class PhiPreTrainedModel(PreTrainedModel):
835 | """Phi pre-trained model."""
836 |
837 | config_class = PhiConfig
838 | base_model_prefix = "transformer"
839 | supports_gradient_checkpointing = False
840 | _no_split_modules = ["ParallelBlock"]
841 |
842 | def __init__(self, *inputs, **kwargs) -> None:
843 | super().__init__(*inputs, **kwargs)
844 |
845 | def _init_weights(self, module: nn.Module) -> None:
846 | if isinstance(module, (nn.Linear,)):
847 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
848 | if module.bias is not None:
849 | module.bias.data.zero_()
850 | elif isinstance(module, nn.Embedding):
851 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
852 | if module.padding_idx is not None:
853 | module.weight.data[module.padding_idx].zero_()
854 | elif isinstance(module, nn.LayerNorm):
855 | if module.bias is not None:
856 | module.bias.data.zero_()
857 | module.weight.data.fill_(1.0)
858 |
859 | def prepare_inputs_for_generation(
860 | self,
861 | input_ids: torch.LongTensor,
862 | past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
863 | attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
864 | **kwargs,
865 | ) -> Dict[str, Any]:
866 | if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
867 | past_key_values = InferenceParams(
868 | max_seqlen=self.config.n_positions,
869 | max_batch_size=input_ids.shape[0],
870 | seqlen_offset=0,
871 | batch_size_offset=0,
872 | key_value_memory_dict={},
873 | lengths_per_sample=None,
874 | )
875 | else:
876 | # Assume that `past_key_values` has cached all tokens up to the last token in `input_ids`
877 | past_key_values.seqlen_offset = input_ids.shape[1] - 1
878 | input_ids = input_ids[:, -1].unsqueeze(-1)
879 |
880 | return {
881 | "input_ids": input_ids,
882 | "past_key_values": past_key_values,
883 | "attention_mask": attention_mask,
884 | }
885 |
886 |
887 | class PhiModel(PhiPreTrainedModel):
888 | """Phi model."""
889 |
890 | _keys_to_ignore_on_load_missing = [""]
891 | _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"]
892 |
893 | def __init__(self, config: PhiConfig) -> None:
894 | super().__init__(config)
895 |
896 | self.embd = Embedding(config)
897 | self.h = nn.ModuleList([ParallelBlock(config, block_idx=i) for i in range(config.n_layer)])
898 | self.gradient_checkpointing = False
899 | self.post_init()
900 |
901 | def get_input_embeddings(self) -> nn.Embedding:
902 | return self.embd.wte
903 |
904 | def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
905 | self.embd.wte = new_embeddings
906 |
907 | def forward(
908 | self,
909 | input_ids: torch.LongTensor,
910 | past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
911 | attention_mask: Optional[torch.BoolTensor] = None,
912 | ) -> torch.FloatTensor:
913 | hidden_states = self.embd(input_ids)
914 |
915 | for layer in self.h:
916 | hidden_states = layer(
917 | hidden_states,
918 | past_key_values=past_key_values,
919 | attention_mask=attention_mask,
920 | )
921 |
922 | return hidden_states
923 |
924 |
925 | class PhiForCausalLM(PhiPreTrainedModel):
926 | """Phi for Causal Language Modeling."""
927 |
928 | _keys_to_ignore_on_load_missing = [""]
929 | _keys_to_ignore_on_load_unexpected = [r"transformer\.h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"]
930 |
931 | def __init__(self, config: PhiConfig) -> None:
932 | super().__init__(config)
933 |
934 | self.transformer = PhiModel(config)
935 | self.lm_head = CausalLMHead(config)
936 | self.loss = CausalLMLoss()
937 |
938 | self.post_init()
939 |
940 | def get_output_embeddings(self) -> nn.Linear:
941 | return self.lm_head.linear
942 |
943 | def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
944 | self.lm_head.linear = new_embeddings
945 |
946 | def forward(
947 | self,
948 | input_ids: torch.LongTensor,
949 | past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
950 | attention_mask: Optional[torch.BoolTensor] = None,
951 | labels: Optional[torch.LongTensor] = None,
952 | **kwargs,
953 | ) -> CausalLMOutputWithPast:
954 | hidden_states = self.transformer(input_ids, past_key_values=past_key_values, attention_mask=attention_mask)
955 | lm_logits = self.lm_head(hidden_states)
956 |
957 | loss = None
958 | if labels is not None:
959 | loss = self.loss(lm_logits, labels)
960 |
961 | return CausalLMOutputWithPast(loss=loss, logits=lm_logits, past_key_values=past_key_values)
962 |
--------------------------------------------------------------------------------