├── .gitignore ├── LICENSE ├── README.md ├── configs ├── amazon_anhp.yaml ├── amazon_anhp_ebm_time.yaml ├── amazon_anhp_ebm_type.yaml ├── ke_anhp_gdelt.yaml ├── ke_anhp_gdelt_ebm_obj.yaml └── ke_anhp_gdelt_ebm_rel.yaml ├── eval.py ├── eval_tpp.py ├── exp_res ├── LAMP-neurips-2023-numerical-results - Amazon.csv ├── LAMP-neurips-2023-numerical-results - GDELT.csv ├── LAMP-neurips-2023-numerical-results - ICEWS.csv └── qualitative_res │ ├── gdelt_obj_pred_10.txt │ ├── gdelt_obj_pred_100.txt │ └── gdelt_predicate_pred_10.txt ├── main_basemodel.py ├── main_ebm.py ├── main_ebm_anhp.py ├── main_train_tpp.py ├── models ├── __init__.py ├── ebm │ ├── __init__.py │ ├── att_nhp_ebm.py │ ├── att_nhp_ebm_amazon.py │ ├── nhp_ebm.py │ └── rnn.py ├── ebm_runner.py ├── ke_anhp │ ├── __init__.py │ ├── _model.py │ └── _modules.py ├── kg_runner.py ├── modules.py ├── tpp │ ├── __init__.py │ ├── att_nhp.py │ ├── base_model.py │ ├── layers.py │ ├── nhp.py │ └── thinning.py └── tpp_runner.py ├── preprocess ├── __init__.py ├── data_loader_tpp.py ├── datasets │ ├── __init__.py │ ├── ebm_dataset.py │ ├── kg_dataset.py │ ├── tpp_dataset.py │ └── tpp_emb_dataset.py ├── kg_data_factory.py └── tpp_data_factory.py ├── requirements.txt ├── scripts ├── amazon │ ├── amazon_prompts │ │ ├── help_msg_1.txt │ │ ├── prompt_assistant_1.txt │ │ ├── prompt_assistant_10.txt │ │ ├── prompt_assistant_11.txt │ │ ├── prompt_assistant_12.txt │ │ ├── prompt_assistant_13.txt │ │ ├── prompt_assistant_14.txt │ │ ├── prompt_assistant_15.txt │ │ ├── prompt_assistant_16.txt │ │ ├── prompt_assistant_2.txt │ │ ├── prompt_assistant_3.txt │ │ ├── prompt_assistant_4.txt │ │ ├── prompt_assistant_5.txt │ │ ├── prompt_assistant_6.txt │ │ ├── prompt_assistant_7.txt │ │ ├── prompt_assistant_8.txt │ │ ├── prompt_assistant_9.txt │ │ ├── prompt_user_1.txt │ │ ├── prompt_user_10.txt │ │ ├── prompt_user_11.txt │ │ ├── prompt_user_12.txt │ │ ├── prompt_user_13.txt │ │ ├── prompt_user_14.txt │ │ ├── prompt_user_15.txt │ │ ├── prompt_user_16.txt │ │ ├── prompt_user_2.txt │ │ ├── prompt_user_3.txt │ │ ├── prompt_user_4.txt │ │ ├── prompt_user_5.txt │ │ ├── prompt_user_6.txt │ │ ├── prompt_user_7.txt │ │ ├── prompt_user_8.txt │ │ └── prompt_user_9.txt │ ├── chatgpt_script.txt │ ├── check_script.txt │ ├── step_1_extract_data.py │ ├── step_2_clean_data.py │ ├── step_3_make_model_input.py │ ├── step_4_query_chatgpt.py │ └── step_5_make_emb_dataset.py └── gdelt │ ├── gdelt_16shot.txt │ ├── gdelt_prompts_simplified │ ├── help_msg_1.txt │ ├── help_msg_2.txt │ ├── prompt_assistant_1.txt │ ├── prompt_assistant_10.txt │ ├── prompt_assistant_11.txt │ ├── prompt_assistant_12.txt │ ├── prompt_assistant_13.txt │ ├── prompt_assistant_14.txt │ ├── prompt_assistant_15.txt │ ├── prompt_assistant_16.txt │ ├── prompt_assistant_2.txt │ ├── prompt_assistant_3.txt │ ├── prompt_assistant_4.txt │ ├── prompt_assistant_5.txt │ ├── prompt_assistant_6.txt │ ├── prompt_assistant_7.txt │ ├── prompt_assistant_8.txt │ ├── prompt_assistant_9.txt │ ├── prompt_user_1.txt │ ├── prompt_user_10.txt │ ├── prompt_user_11.txt │ ├── prompt_user_12.txt │ ├── prompt_user_13.txt │ ├── prompt_user_14.txt │ ├── prompt_user_15.txt │ ├── prompt_user_16.txt │ ├── prompt_user_2.txt │ ├── prompt_user_3.txt │ ├── prompt_user_4.txt │ ├── prompt_user_5.txt │ ├── prompt_user_6.txt │ ├── prompt_user_7.txt │ ├── prompt_user_8.txt │ ├── prompt_user_9.txt │ └── system_msg_1.txt │ ├── step_1_news_scraper.py │ ├── step_2_clean_scraped_news.py │ ├── step_3_5_analyze_gdelt_entity.py │ ├── step_3_make_model_input.py │ ├── step_4_query_chatgpt.py │ ├── step_5_make_emb_dataset.py │ └── step_6_recompute_distance.py └── utils ├── __init__.py ├── bert_utils.py ├── general.py ├── gpt_utils.py ├── log_utils.py ├── metrics.py ├── misc.py ├── misc_amazon.py ├── persist_utils.py └── torch_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | .DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LAMP: LAnguage Model in event Prediction 2 | 3 | 4 | PyTorch code for the paper [Language Models Can Improve Event Prediction by Few-Shot Abductive Reasoning](https://arxiv.org/abs/2305.16646), NeurIPS'23. 5 | 6 | 7 | ## How to Run 8 | 9 | ### Environment Requirements 10 | 11 | First, please make sure you have an environment compatible with the following requirement 12 | 13 | ```bash 14 | torch == 1.9.0 15 | numpy 16 | pandas 17 | ``` 18 | 19 | Lower version of pytorch should also be working but we have not tested it. 20 | 21 | 22 | 23 | ### Data Preparation 24 | 25 | You can obtain the benchmark dataset from [Google Drive](https://drive.google.com/file/d/1XbPiPTNVprKaQwMvk9McaY2USJrEkUj6/view?usp=share_link). All the datasets are well pre-processed and can be used easily. 26 | 27 | **Please unzipped the files and put gdelt folder under the `./data/` directory**, so the directory becomes `./data/gdelt/*.pkl`. 28 | 29 | 30 | ### Training and Evaluation Example 31 | 32 | Assume we are running the task of predicate (relation) prediction on the GDELT data and setup the config files. 33 | 34 | 35 | Step 1: we need to train the chosen TPP, ANHP, with the config `configs/ke_anhp_gdelt.yaml`: 36 | 37 | ``` 38 | python main_basemodel.py 39 | ``` 40 | NOTE: in `configs/ke_anhp_gdelt.yaml`, one needs to setup data and model specs, where we have put default params there. 41 | 42 | After the training is finished, the prediction result of the base model will be saved as `logs/ke_anhp_gdelt_test.pkl`. 43 | 44 | 45 | Step 2: we query the chatgpt to generate the causal events based on the prediction results from the previous step 46 | 47 | ``` 48 | cd scripts/gdelt 49 | python step_4_query_chatgpt.py 50 | ``` 51 | NOTE: in `scripts/gdelt/step_4_query_chatgpt.py`, one needs to setup the personal openai account to query the gpt, along with the params of the prediction task of predicate (relation). 52 | 53 | After the query is finished, a json file `relation.json` will be generated at `scripts/gdelt/ddb_storage/gdelt_chatgpt`. 54 | 55 | 56 | Step 3: we setup the samples to train the ranking model: 57 | 58 | ``` 59 | cd scripts/gdelt 60 | python step_5_make_emb_dataset.py 61 | ``` 62 | After the generation is finished, a json file `relation.json` will be generated at `scripts/gdelt/ddb_storage/ke_anhp_gdelt_bert_ebm_dataset`. 63 | 64 | 65 | Step 4: we train the ranking model, with the config `configs/ke_anhp_gdelt_ebm_rel.yaml` and evaluate on the test set. 66 | 67 | ``` 68 | python main_ebm.py 69 | ``` 70 | 71 | For other tasks, one can simply modify the params in previous steps. 72 | 73 | 74 | ## Reference 75 | 76 | If you use this code as part of any published research, please acknowledge the following paper 77 | ``` 78 | @inproceedings{shi2023language, 79 | title={Language Models Can Improve Event Prediction by Few-Shot Abductive Reasoning}, 80 | author={Xiaoming Shi and Siqiao Xue and Kangrui Wang and Fan Zhou and James Y. Zhang and Jun Zhou and Chenhao Tan and Hongyuan Mei}, 81 | year={2023}, 82 | booktitle = {Advances in Neural Information Processing Systems}, 83 | url={https://arxiv.org/abs/2305.16646} 84 | } 85 | ``` 86 | -------------------------------------------------------------------------------- /configs/amazon_anhp.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | type: tpp 3 | data_name: amazon 4 | data_dir: data/amazon/amazon_v0327.pkl 5 | num_event_types: 24 6 | train_end_date: 2015-06-01 7 | valid_end_date: 2016-01-01 8 | 9 | model: 10 | name: anhp 11 | hidden_size: 32 12 | time_emb_size: 32 13 | num_event_types_no_pad: 24 14 | num_event_types_pad: 25 15 | event_pad_index: 24 16 | num_heads: 1 17 | num_layers: 2 18 | dropout: 0.1 19 | sharing_param_layer: False 20 | use_ln: False 21 | mc_num_sample_per_step: 100 22 | train: 23 | batch_size: 128 24 | num_epoch: 2 25 | thinning_params: 26 | num_seqs: 100 27 | num_samples: 1 28 | num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm 29 | look_ahead_time: 10 30 | patience: 5 # the maximum iteration used in adaptive thinning 31 | -------------------------------------------------------------------------------- /configs/amazon_anhp_ebm_time.yaml: -------------------------------------------------------------------------------- 1 | base_model: anhp 2 | dataset: amazon 3 | pred_type: dtime 4 | 5 | model: 6 | name: anhp 7 | hidden_size: 64 8 | time_emb_size: 8 9 | num_event_types_no_pad: 24 10 | num_event_types_pad: 25 11 | event_pad_index: 24 12 | num_heads: 4 13 | num_layers: 1 14 | dropout: 0.3 15 | sharing_param_layer: False 16 | use_ln: False 17 | mc_num_sample_per_step: 20 18 | train: 19 | num_epochs: 50 20 | lr: 0.0007 21 | batch_size: 8 22 | loss_function: bce 23 | thinning_params: 24 | num_seqs: 100 25 | num_samples: 1 26 | num_exp: 500 27 | look_ahead_time: 10 28 | patience: 5 -------------------------------------------------------------------------------- /configs/amazon_anhp_ebm_type.yaml: -------------------------------------------------------------------------------- 1 | base_model: anhp 2 | dataset: amazon 3 | pred_type: type 4 | 5 | model: 6 | name: anhp 7 | hidden_size: 64 8 | time_emb_size: 8 9 | num_event_types_no_pad: 24 10 | num_event_types_pad: 25 11 | event_pad_index: 24 12 | num_heads: 4 13 | num_layers: 1 14 | dropout: 0.3 15 | sharing_param_layer: False 16 | use_ln: False 17 | mc_num_sample_per_step: 20 18 | train: 19 | num_epochs: 50 20 | lr: 0.0007 21 | batch_size: 8 22 | loss_function: bce 23 | thinning_params: 24 | num_seqs: 100 25 | num_samples: 1 26 | num_exp: 500 27 | look_ahead_time: 10 28 | patience: 5 -------------------------------------------------------------------------------- /configs/ke_anhp_gdelt.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | type: kg 3 | data_dir: data/gdelt/gdelt.pkl 4 | num_relations: 20 5 | num_entities: 2279 6 | context_length: 100 7 | train_end_index_ratio: 0.8 8 | valid_end_index_ratio: 0.9 9 | test_end_index_ratio: 1.0 10 | time_factor: 100.0 11 | 12 | model: 13 | name: ke_anhp 14 | storage_uri: logs/ke_anhp_gdelt.pt 15 | 16 | dim_c: 4 17 | dim_d: 8 18 | dim_l: 8 19 | num_layers: 1 20 | n_heads: 1 21 | dropout_rate: 0.1 22 | 23 | num_epochs: 5 24 | batch_size: 128 25 | lr: 0.001 26 | -------------------------------------------------------------------------------- /configs/ke_anhp_gdelt_ebm_obj.yaml: -------------------------------------------------------------------------------- 1 | base_model: ke_anhp 2 | dataset: gdelt 3 | pred_type: object 4 | 5 | model: 6 | name: anhp 7 | embedding_dim: 40 8 | d_time: 16 9 | d_model: 50 10 | num_heads: 2 11 | num_layers: 3 12 | dropout: 0.083274 13 | 14 | time_emb_size: 24 15 | num_event_types_no_pad: 24 16 | num_event_types_pad: 25 17 | event_pad_index: 24 18 | sharing_param_layer: False 19 | use_ln: False 20 | mc_num_sample_per_step: 20 21 | train: 22 | num_epochs: 50 23 | lr: 0.00066 24 | batch_size: 8 25 | loss_function: mnce 26 | thinning_params: 27 | num_seqs: 100 28 | num_samples: 1 29 | num_exp: 500 30 | look_ahead_time: 10 31 | patience: 5 -------------------------------------------------------------------------------- /configs/ke_anhp_gdelt_ebm_rel.yaml: -------------------------------------------------------------------------------- 1 | base_model: ke_anhp 2 | dataset: gdelt 3 | pred_type: relation 4 | 5 | model: 6 | name: anhp 7 | embedding_dim: 40 8 | d_time: 16 9 | d_model: 50 10 | num_heads: 2 11 | num_layers: 3 12 | dropout: 0.083274 13 | 14 | time_emb_size: 24 15 | num_event_types_no_pad: 24 16 | num_event_types_pad: 25 17 | event_pad_index: 24 18 | sharing_param_layer: False 19 | use_ln: False 20 | mc_num_sample_per_step: 20 21 | train: 22 | num_epochs: 50 23 | lr: 0.00066 24 | batch_size: 8 25 | loss_function: mnce 26 | thinning_params: 27 | num_seqs: 100 28 | num_samples: 1 29 | num_exp: 500 30 | look_ahead_time: 10 31 | patience: 5 -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | from utils.general import file_uri_reader_processor, list_of_dict_to_dict 4 | from utils.metrics import is_hit, rank, group_pred_data_in_time, get_precision_recall 5 | 6 | 7 | def eval_basemodel_precision_recall(pred_fn, source_fn, rel_topk, obj_topk, num_last_eval_points=4000): 8 | pred_data = file_uri_reader_processor(pred_fn)[-num_last_eval_points:] 9 | source_data = file_uri_reader_processor(source_fn)['data'] 10 | group_pred_by_time = group_pred_data_in_time(pred_data, source_data) 11 | out = get_precision_recall(group_pred_by_time, rel_topk=rel_topk, obj_topk=obj_topk) 12 | print(out) 13 | 14 | 15 | def eval_basemodel(pred_fn, num_last_eval_points=4000, rel_hit_top_k=3, obj_hit_top_k=10): 16 | pred_data = file_uri_reader_processor(pred_fn)[-num_last_eval_points:] 17 | pred_dict = list_of_dict_to_dict(pred_data) 18 | 19 | rel_pr_topk = 5 20 | object_pr_topk = 20 21 | 22 | rel_ranks = rank(pred_dict['relation'], pred_dict['pred_relation']) 23 | obj_ranks = rank(pred_dict['object'], pred_dict['pred_object']) 24 | 25 | rel_mask = rel_ranks <= rel_pr_topk 26 | obj_mask = obj_ranks <= object_pr_topk 27 | 28 | rel_mean_rank = np.mean(rel_ranks[rel_ranks <= rel_pr_topk]) 29 | obj_mean_rank = np.mean(obj_ranks[obj_ranks <= object_pr_topk]) 30 | relation_hit_ratio = np.mean( 31 | is_hit(pred_dict['relation'][rel_mask], pred_dict['pred_relation'][rel_mask], top_n=rel_hit_top_k)) 32 | object_hit_ratio = np.mean( 33 | is_hit(pred_dict['object'][obj_mask], pred_dict['pred_object'][obj_mask], top_n=obj_hit_top_k)) 34 | print( 35 | f'Relation top{rel_hit_top_k} ratio: {relation_hit_ratio}\n' 36 | f'Relation mean rank: {rel_mean_rank}\n' 37 | f'Object top{obj_hit_top_k} ratio: {object_hit_ratio}\n' 38 | f'Object mean rank: {obj_mean_rank}\n' 39 | ) 40 | 41 | 42 | def eval_topk_rerank(pred_fn, num_last_eval_points=4000, rel_hit_top_k=3, obj_hit_top_k=10): 43 | pred_data = file_uri_reader_processor(pred_fn)[-num_last_eval_points:] 44 | pred_dict = list_of_dict_to_dict(pred_data) 45 | 46 | rel_ranks = rank(pred_dict['relation'], pred_dict['pred_relation']) 47 | obj_ranks = rank(pred_dict['object'], pred_dict['pred_object']) 48 | 49 | rel_mean_rank = np.mean(rel_ranks[rel_ranks <= rel_hit_top_k]) 50 | obj_mean_rank = np.mean(obj_ranks[obj_ranks <= obj_hit_top_k]) 51 | print( 52 | f'Relation mean rank: {rel_mean_rank}\n' 53 | f'Object mean rank: {obj_mean_rank}\n' 54 | ) 55 | 56 | 57 | def eval_combination_rerank(pred_fn, topk=100, num_last_eval_points=4000): 58 | pred_data = file_uri_reader_processor(pred_fn)[-num_last_eval_points:] 59 | 60 | rank_list = [] 61 | hit_count = 0 62 | for point in pred_data: 63 | # to numerical 64 | rel_obj_label = '{0}_{1}'.format(np.sum(point['relation']), np.sum(point['object'])) 65 | pred_rel_obj = { 66 | '{0}_{1}'.format(np.sum(item['pred_relation']), np.sum(item['pred_object'])): np.sum(item['pred_score']) 67 | for item in point['pred_rel_obj']} 68 | sorted_pred_rel_obj_list = sorted(pred_rel_obj.items(), key=lambda item: item[1], reverse=True)[:topk] 69 | 70 | for i, (pair, _) in enumerate(sorted_pred_rel_obj_list): 71 | if pair == rel_obj_label: 72 | hit_count += 1 73 | rank_list.append(i + 1) 74 | break 75 | 76 | print(f'Hit count: {hit_count}, total: {len(pred_data)}') 77 | print('Hit ratio:', hit_count / len(pred_data)) 78 | print('Mean rank:', np.mean(rank_list)) 79 | 80 | 81 | if __name__ == '__main__': 82 | source_fn = 'data/gdelt/gdelt.pkl' 83 | ke_anhp_gdelt_fn = 'logs/ke_anhp_gdelt_test.pkl' 84 | 85 | eval_basemodel_precision_recall(pred_fn=ke_anhp_gdelt_fn, source_fn=source_fn, rel_topk=5, obj_topk=20) 86 | # eval_topk_rerank(ke_anhp_gdelt_fn, rel_hit_top_k=10, obj_hit_top_k=2) 87 | 88 | -------------------------------------------------------------------------------- /eval_tpp.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | from utils.general import file_uri_reader_processor, get_value_by_key 4 | from utils.metrics import mark, mapk, time_rmse_np, rank 5 | 6 | 7 | def eval_basemodel_precision_recall(pred_fn): 8 | pred_data = file_uri_reader_processor(pred_fn) 9 | 10 | label_dtime = [] 11 | pred_dtime = [] 12 | label_type = [] 13 | pred_type = [] 14 | 15 | 16 | for seq in pred_data: # by default, we do the inference with batch_size = 1 17 | label_type_ = get_value_by_key('label_type', seq) 18 | pred_type_score_ = get_value_by_key('pred_type_score', seq) 19 | label_dtime_ = get_value_by_key('label_dtime', seq) 20 | pred_dtime_ = get_value_by_key('pred_dtime', seq) 21 | 22 | label_dtime.extend(label_dtime_) 23 | pred_dtime.extend(pred_dtime_) 24 | 25 | # each label is a list 26 | label_type.extend([[x] for x in label_type_]) 27 | 28 | # make score in descending order 29 | pred_type_ = [np.argsort(x)[-5:][::-1].tolist() for x in pred_type_score_] 30 | pred_type.extend(pred_type_) 31 | 32 | 33 | # precision - recall 34 | print(f'MAP is {mapk(label_type, pred_type, k=5)}') 35 | print(f'MAR is {mark(label_type, pred_type, k=5)}') 36 | 37 | # compute time rmse 38 | time_rmse = time_rmse_np(pred_dtime, label_dtime) 39 | print(f'Time RMSE {time_rmse}') 40 | return 41 | 42 | 43 | def eval_base_model_mean_rank(pred_fn, target_events): 44 | pred_data = file_uri_reader_processor(pred_fn) 45 | 46 | pred_target_data = [] 47 | pred_type_score = [] 48 | label_type = [] 49 | for event in target_events: 50 | seq_idx, original_idx = eval(event[0]) 51 | pred_event = search_pred_data(pred_data, seq_idx, original_idx) 52 | pred_target_data.append(pred_event) 53 | pred_type_score.append(pred_event['pred_type_score']) 54 | label_type.append(pred_event['label_type']) 55 | 56 | 57 | type_pr_topk = 5 58 | type_ranks = rank(label_type, pred_type_score) 59 | type_mask = type_ranks <= type_pr_topk 60 | 61 | type_mean_ranks = np.mean(type_ranks[type_mask]) 62 | print(type_mean_ranks) 63 | return 64 | 65 | def search_pred_data(pred_data, seq_idx, original_idx): 66 | 67 | for pred_seq in pred_data: 68 | if pred_seq[0]['seq_idx'] == int(seq_idx): 69 | for pred_element in pred_seq: 70 | if pred_element['original_idx'] == int(original_idx): 71 | return pred_element 72 | 73 | print('Error, index not found') 74 | 75 | if __name__ == '__main__': 76 | anhp_amazon_fn = 'logs/tpp_amazon_test.pkl' 77 | 78 | import dictdatabase as DDB 79 | DDB.config.storage_directory = 'scripts/amazon/ddb_storage' 80 | ebm_data = list(DDB.at(f'anhp_amazon_bert_ebm_dataset', 'type').read().values()) 81 | 82 | eval_base_model_mean_rank(anhp_amazon_fn, ebm_data) -------------------------------------------------------------------------------- /exp_res/LAMP-neurips-2023-numerical-results - Amazon.csv: -------------------------------------------------------------------------------- 1 | Mean Rank on type prediction ,Figure2 c upper,Figure 18,Figure 13,,,,, 2 | M,NHP,ANHP,NHP-G3.0,ANHP-G3.0,NHP-G3.5,ANHP-G3.5,NHP-G3.5-edit,ANHP-G3.5-edit 3 | 2,0.947,0.966,0.950,0.936,0.920,0.932,0.942,0.947 4 | 3,1.264,1.254,1.235,1.245,1.190,1.178,1.218,1.203 5 | 4,1.750,1.780,1.753,1.842,1.619,1.642,1.662,1.697 6 | 5,2.180,2.150,2.200,2.170,2.090,2.060,2.111,2.084 7 | ,,,,,,,, 8 | ,,,,,,,, 9 | ,,,,,,,, 10 | MRR on type prediction ,Figure 9 c upper,,,,,,, 11 | M,NHP,ANHP,NHP-G3.0,ANHP-G3.0,NHP-G3.5,ANHP-G3.5,, 12 | 2,1.042,1.050,0.950,1.094,1.092,1.130,, 13 | 3,0.768,0.760,1.235,1.245,0.852,0.846,, 14 | 4,0.512,0.528,1.753,1.842,0.578,0.566,, 15 | 5,0.458,0.476,2.200,2.170,0.530,0.525,, 16 | ,,,,,,,, 17 | ,,,,,,,, 18 | ,,,,,,,, 19 | RMSE on time prediction,Figure 9 c lower,Figure2 c lower,,,,,, 20 | M,NHP,ANHP,NHP-G3.0,ANHP-G3.0,NHP-G3.5,ANHP-G3.5,, 21 | 2,147.365,143.824,n.a,n.a,133.838,132.824,, 22 | 3,146.100,146.900,n.a,n.a,130.100,131.100,, 23 | 4,147.290,145.290,n.a,n.a,130.292,129.292,, 24 | 5,144.960,146.000,n.a,n.a,130.657,129.657,, 25 | -------------------------------------------------------------------------------- /exp_res/LAMP-neurips-2023-numerical-results - ICEWS.csv: -------------------------------------------------------------------------------- 1 | Mean Rank on predicate prediction ,,,,,,,,, 2 | M,KE,DyRep,ANHP,KE-G3.0,DyRep-G3.0,ANHP-G3.0,KE-G3.5,DyRep-G3.5,ANHP-G3.5 3 | 2,1.497,1.480,1.851,1.457,1.477,1.481,1.455,1.451,1.724 4 | 3,1.635,1.696,2.284,1.621,1.651,1.759,1.622,1.632,2.041 5 | 4,2.507,2.460,2.853,2.414,2.435,2.350,2.334,2.394,2.590 6 | 5,2.825,2.718,3.340,2.758,2.650,2.564,2.695,2.615,3.121 7 | ,,,,,,,,, 8 | ,,,,,,,,, 9 | ,,,,,,,,, 10 | Mean Rank on object prediction ,Figure 2 c upper,,,,,,,, 11 | M,KE,DyRep,ANHP,KE-G3.0,DyRep-G3.0,ANHP-G3.0,KE-G3.5,DyRep-G3.5,ANHP-G3.5 12 | 10,9.809,10.010,9.922,n.a.,n.a.,n.a.,7.294,6.994,6.943 13 | 25,23.541,25.554,24.541,n.a.,n.a.,n.a.,10.837,9.370,9.137 14 | 50,48.145,48.050,48.405,n.a.,n.a.,n.a.,15.700,14.437,14.070 15 | 75,73.118,70.172,71.718,n.a.,n.a.,n.a.,21.085,22.880,20.850 16 | 100,94.610,94.176,93.761,n.a.,n.a.,n.a.,28.730,28.830,26.430 17 | ,,,,,,,,, 18 | ,,,,,,,,, 19 | MRR on object prediction ,Figure 9 b upper,,,,,,,, 20 | M,KE,DyRep,ANHP,KE-G3.0,DyRep-G3.0,ANHP-G3.0,KE-G3.5,DyRep-G3.5,ANHP-G3.5 21 | 10,0.110,0.108,0.103,n.a.,n.a.,n.a.,0.241,0.241,0.240 22 | 25,0.054,0.049,0.044,n.a.,n.a.,n.a.,0.203,0.213,0.226 23 | 50,0.027,0.028,0.031,n.a.,n.a.,n.a.,0.189,0.189,0.199 24 | 75,0.020,0.023,0.029,n.a.,n.a.,n.a.,0.167,0.154,0.170 25 | 100,0.017,0.018,0.018,n.a.,n.a.,n.a.,0.160,0.154,0.163 26 | ,,,,,,,,, 27 | ,,,,,,,,, 28 | Time RMSE,Figure 2 b lower,Figure 9 b lower,,,,,,, 29 | M,KE,DyRep,ANHP,KE-G3.0,DyRep-G3.0,ANHP-G3.0,KE-G3.5,DyRep-G3.5,ANHP-G3.5 30 | 2,17.423,17.223,17.324,n.a.,n.a.,n.a.,16.738,16.703,16.816 31 | 3,17.1,16.99,16.900,n.a.,n.a.,n.a.,16.43,16.55,16.490 32 | 4,17,16.69,16.890,n.a.,n.a.,n.a.,16.289,16.289,16.349 33 | 5,16.9,16.85,16.750,n.a.,n.a.,n.a.,16.305,16.295,16.295 34 | -------------------------------------------------------------------------------- /main_basemodel.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | 3 | from models.ke_anhp import KnowEvolveANHP 4 | from models.kg_runner import KGRunner 5 | from models.tpp.att_nhp import AttNHP 6 | from models.tpp_runner import TPPRunner 7 | from preprocess.kg_data_factory import KGDatasetFactory 8 | from preprocess.tpp_data_factory import TPPDatasetFactory 9 | from utils.general import file_uri_writer_processor, setup_seed 10 | 11 | # device = 'cuda:0' 12 | from utils.torch_utils import count_torch_model_params 13 | 14 | device = 'cpu' 15 | 16 | 17 | def build_data_factory_and_runner_from_config_file(config_name): 18 | with open(config_name) as config_file: 19 | config = yaml.safe_load(config_file) 20 | data_config = config['data'] 21 | model_config = config['model'] 22 | 23 | if data_config['type'].lower() == 'kg': 24 | dataset_factory = KGDatasetFactory.build_from_config_dict(data_config) 25 | else: 26 | dataset_factory = TPPDatasetFactory.build_from_config_dict(data_config) 27 | 28 | if model_config['name'].lower() == 'ke_anhp': 29 | model = KnowEvolveANHP( 30 | num_entities=dataset_factory.num_entities, 31 | num_relations=dataset_factory.num_relations, 32 | dim_c=model_config['dim_c'], 33 | dim_l=model_config['dim_l'], 34 | dim_d=model_config['dim_d'], 35 | num_layers=model_config['num_layers'], 36 | n_heads=model_config['n_heads'], 37 | dropout_rate=model_config['dropout_rate'], 38 | ) 39 | runner = KGRunner( 40 | model, 41 | source_data=dataset_factory.data, 42 | lr=model_config.get('lr', 0.001), 43 | num_epochs=model_config.get('num_epochs', 10), 44 | storage_uri=model_config.get('storage_uri'), 45 | device=device 46 | ) 47 | else: 48 | model = AttNHP(model_config) 49 | runner = TPPRunner(model, 50 | lr=model_config.get('lr', 0.001), 51 | num_epochs=model_config.get('num_epochs', 10), 52 | ) 53 | 54 | return dataset_factory, runner 55 | 56 | 57 | if __name__ == '__main__': 58 | setup_seed() 59 | ke_anhp_config_fn = 'configs/ke_anhp_gdelt.yaml' 60 | 61 | dataset_factory, runner = build_data_factory_and_runner_from_config_file(ke_anhp_config_fn) 62 | 63 | runner.train( 64 | train_dl=dataset_factory.get_train_dataloader(), 65 | valid_dl=dataset_factory.get_valid_dataloader(), 66 | verbose=False 67 | ) 68 | 69 | metric, res = runner.evaluate_one_epoch( 70 | dataset_factory.iterate_dataset_with_original_index(dataset_factory.test_dataset), 71 | with_index=True, 72 | warmup_steps=5000, 73 | predict_relation=True, 74 | predict_object=True, 75 | ) 76 | metric, res = runner.evaluate_combination_one_epoch( 77 | dataset_factory.iterate_dataset_with_original_index(dataset_factory.test_dataset), 78 | with_index=True, 79 | warmup_steps=5000, 80 | ) 81 | file_uri_writer_processor(res, 'logs/ke_anhp_gdelt_test.pkl') 82 | -------------------------------------------------------------------------------- /main_ebm.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import yaml 4 | from torch.utils.data import DataLoader 5 | 6 | import dictdatabase as DDB 7 | 8 | from models.ebm.att_nhp_ebm import AttNHPEBM 9 | from models.ebm.rnn import RNNEbm 10 | from models.ebm_runner import EBMRunner 11 | from preprocess.datasets.ebm_dataset import EBMDataset, EBMValidDataset 12 | from utils.general import setup_seed 13 | from utils.metrics import is_hit, rank 14 | from utils.torch_utils import count_torch_model_params 15 | 16 | 17 | def main(dataset_name, model_name, pred_type, model_config, num_test_points: int = 15000, is_load: bool = False): 18 | DDB.config.storage_directory = 'scripts/gdelt/ddb_storage' 19 | ebm_data = list(DDB.at(f'{model_name}_{dataset_name}_bert_ebm_dataset', pred_type).read().values()) 20 | 21 | if pred_type == 'relation': 22 | num_noise_samples = 4 23 | top_n = 3 24 | elif pred_type == 'object': 25 | num_noise_samples = 19 26 | top_n = 10 27 | else: 28 | num_noise_samples = 99 29 | top_n = 50 30 | 31 | if dataset_name == 'gdelt': 32 | num_entities = 2279 33 | num_relations = 20 34 | else: 35 | return 36 | 37 | print('Original length', len(ebm_data)) 38 | train_data = ebm_data[:-num_test_points] 39 | test_data = ebm_data[-num_test_points:] 40 | 41 | train_dataset = EBMDataset(train_data, num_noise_samples=num_noise_samples, time_factor=100.0, max_seq_length=10) 42 | # train_dataset = EBMValidDataset(train_data, num_noise_samples=num_noise_samples, time_factor=100.0) # used when ke_anhp_gdelt_comb 43 | test_dataset = EBMValidDataset(test_data, num_noise_samples=num_noise_samples, time_factor=100.0, max_seq_length=10) 44 | print('train', len(train_dataset)) 45 | print('test', len(test_dataset)) 46 | 47 | train_dl = DataLoader( 48 | train_dataset, 49 | batch_size=model_config['train']['batch_size'], 50 | ) 51 | test_dl = DataLoader( 52 | test_dataset, 53 | batch_size=model_config['train']['batch_size'] * 2, 54 | ) 55 | if model_config['name'].lower() == 'rnn': 56 | model = RNNEbm( 57 | num_entities=num_entities, 58 | num_relations=num_relations, 59 | num_noise_samples=num_noise_samples, 60 | embedding_dim=model_config['embedding_dim'], 61 | num_cells=model_config['num_cells'], 62 | num_layers=model_config['num_layers'], 63 | dropout_rate=model_config['dropout_rate'] 64 | ) 65 | else: 66 | model = AttNHPEBM( 67 | num_entities=num_entities, 68 | num_relations=num_relations, 69 | num_noise_samples=num_noise_samples, 70 | embedding_dim=model_config['embedding_dim'], 71 | d_model=model_config['d_model'], 72 | d_time=model_config['d_time'], 73 | num_layers=model_config['num_layers'], 74 | num_heads=model_config['num_heads'], 75 | dropout=model_config['dropout'], 76 | use_ln=False 77 | ) 78 | 79 | runner = EBMRunner( 80 | model, 81 | loss_function=model_config['train']['loss_function'], 82 | lr=float(model_config['train']['lr']), 83 | log_path=f'logs/{model_name}_{dataset_name}_ebm_{pred_type}.pt', 84 | lr_scheduler_params=model_config.get('lr_scheduler') 85 | ) 86 | 87 | if is_load: 88 | runner.load() 89 | else: 90 | runner.train(train_dl, valid_dataloader=test_dl, num_epochs=model_config['train']['num_epochs'], verbose=False) 91 | runner.load() 92 | 93 | _, (label_score, fake_scores) = runner.evaluate_one_epoch(test_dl) 94 | label = np.zeros_like(label_score, dtype=np.int32) 95 | pred = np.concatenate([label_score[:, None], fake_scores], axis=-1) 96 | 97 | hit_ratio = np.mean(is_hit(label, pred, top_n=top_n)) 98 | mean_rank = np.mean(rank(label, pred)) 99 | 100 | print( 101 | f'Hit ratio: {hit_ratio}\n' 102 | f'Mean rank: {mean_rank}' 103 | ) 104 | return mean_rank 105 | 106 | 107 | if __name__ == '__main__': 108 | setup_seed() 109 | ke_gdelt_ebm_config_fn = 'configs/ke_anhp_gdelt_ebm_rel.yaml' 110 | with open(ke_gdelt_ebm_config_fn) as config_file: 111 | config = yaml.safe_load(config_file) 112 | 113 | main( 114 | dataset_name=config['dataset'], 115 | model_name=config['base_model'], 116 | pred_type=config['pred_type'], 117 | model_config=config['model'], 118 | is_load=False 119 | ) 120 | -------------------------------------------------------------------------------- /main_ebm_anhp.py: -------------------------------------------------------------------------------- 1 | import dictdatabase as DDB 2 | import numpy as np 3 | import yaml 4 | from torch.utils.data import DataLoader 5 | 6 | from models.ebm.att_nhp_ebm_amazon import AttNHPEBMTPP 7 | from models.ebm_runner import EBMRunner 8 | from preprocess.datasets.tpp_emb_dataset import TPPEBMDataset, TPPEBMValidDataset 9 | from utils.general import setup_seed 10 | from utils.metrics import is_hit, rank 11 | 12 | 13 | def main(dataset_name, model_name, pred_type, model_config, num_test_points: int = 400, is_load: bool = False): 14 | DDB.config.storage_directory = 'scripts/amazon/ddb_storage' 15 | ebm_data = list(DDB.at(f'{model_name}_{dataset_name}_ebm_dataset_time_bak', pred_type).read().values()) 16 | 17 | model = AttNHPEBMTPP(model_config) 18 | 19 | if pred_type == 'type': 20 | num_noise_samples = 4 21 | top_n = 3 22 | num_event_types = 24 23 | elif pred_type == 'dtime': 24 | top_n = 5 25 | num_event_types = 24 26 | num_noise_samples = 3 27 | else: 28 | raise RuntimeError(f'Unknown pred_type {pred_type}') 29 | 30 | print('Original length', len(ebm_data)) 31 | train_data = ebm_data[:-num_test_points] 32 | test_data = ebm_data[-num_test_points:] 33 | 34 | train_dataset = TPPEBMDataset(train_data, num_event_types=num_event_types, num_noise_samples=num_noise_samples) 35 | test_dataset = TPPEBMDataset(test_data, num_event_types=num_event_types, num_noise_samples=num_noise_samples) 36 | print('train', len(train_dataset)) 37 | print('test', len(test_dataset)) 38 | 39 | train_dl = DataLoader( 40 | train_dataset, 41 | batch_size=model_config['train']['batch_size'], 42 | ) 43 | test_dl = DataLoader( 44 | test_dataset, 45 | batch_size=model_config['train']['batch_size'] * 2, 46 | ) 47 | 48 | model_log_path = f'logs/{model_name}_{dataset_name}_ebm_{pred_type}.pt' 49 | runner = EBMRunner( 50 | model, 51 | is_tpp_model=True, 52 | loss_function=model_config['train']['loss_function'], 53 | lr=float(model_config['train']['lr']), 54 | log_path=model_log_path, 55 | lr_scheduler_params=model_config['train'].get('lr_scheduler') 56 | ) 57 | 58 | if is_load: 59 | runner.load() 60 | else: 61 | runner.train(train_dl, valid_dataloader=test_dl, num_epochs=model_config['train']['num_epochs'], verbose=False) 62 | runner.load() 63 | 64 | if pred_type == 'type': 65 | _, (label_score, fake_scores) = runner.evaluate_one_epoch(test_dl) 66 | label = np.zeros_like(label_score, dtype=np.int32) 67 | pred = np.concatenate([label_score[:, None], fake_scores], axis=-1) 68 | 69 | hit_ratio = np.mean(is_hit(label, pred, top_n=top_n)) 70 | metric = np.mean(rank(label, pred)) 71 | 72 | print( 73 | f'Hit ratio: {hit_ratio}\n' 74 | f'Mean rank: {metric}' 75 | ) 76 | else: 77 | metric, _ = runner.evaluate_one_epoch_time(test_dl) 78 | 79 | return metric 80 | 81 | 82 | if __name__ == '__main__': 83 | setup_seed() 84 | amazon_ebm_config_fn = 'configs/amazon_anhp_ebm_type.yaml' 85 | with open(amazon_ebm_config_fn) as config_file: 86 | config = yaml.safe_load(config_file) 87 | 88 | main( 89 | dataset_name=config['dataset'], 90 | model_name=config['base_model'], 91 | pred_type=config['pred_type'], 92 | model_config=config['model'], 93 | is_load=False 94 | ) 95 | -------------------------------------------------------------------------------- /main_train_tpp.py: -------------------------------------------------------------------------------- 1 | 2 | import yaml 3 | 4 | from models.know_evolve import KnowEvolve 5 | from models.tpp.att_nhp import AttNHP 6 | from models.kg_runner import KGRunner 7 | from models.tpp_runner import TPPRunner 8 | from preprocess.kg_data_factory import KGDatasetFactory 9 | from preprocess.tpp_data_factory import TPPDatasetFactory 10 | from utils.general import file_uri_writer_processor 11 | 12 | 13 | def build_data_factory_and_runner_from_config_file(config_name): 14 | with open(config_name) as config_file: 15 | config = yaml.safe_load(config_file) 16 | data_config = config['data'] 17 | model_config = config['model'] 18 | 19 | if data_config['type'].lower() == 'kg': 20 | dataset_factory = KGDatasetFactory.build_from_config_dict(data_config) 21 | else: 22 | dataset_factory = TPPDatasetFactory.build_from_config_dict(data_config) 23 | 24 | if model_config['name'].lower() == 'ke': 25 | model = KnowEvolve( 26 | num_entities=dataset_factory.num_entities, 27 | num_relations=dataset_factory.num_relations, 28 | dim_c=model_config['dim_c'], 29 | dim_l=model_config['dim_l'], 30 | dim_d=model_config['dim_d'], 31 | ) 32 | runner = KGRunner( 33 | model, 34 | source_data=dataset_factory.data, 35 | lr=model_config.get('lr', 0.001), 36 | num_epochs=model_config.get('num_epochs', 10), 37 | ) 38 | elif model_config['name'].lower() == 'ke-tpp': 39 | pass 40 | else: 41 | model = AttNHP(model_config) 42 | runner = TPPRunner(model, 43 | lr=model_config.get('lr', 0.001), 44 | num_epochs=model_config.get('num_epochs', 10), 45 | ) 46 | 47 | return dataset_factory, runner, model_config 48 | 49 | 50 | if __name__ == '__main__': 51 | # config_filename = 'configs/ke.yaml' 52 | config_filename = 'configs/amazon_anhp.yaml' # tpp 53 | 54 | dataset_factory, runner, model_config = build_data_factory_and_runner_from_config_file(config_filename) 55 | 56 | runner.train( 57 | train_dl=dataset_factory.get_train_dataloader(batch_size=model_config['train']['batch_size']), 58 | valid_dl=dataset_factory.get_valid_dataloader(), 59 | num_epochs=model_config['train']['num_epoch'] 60 | ) 61 | runner.save() 62 | 63 | metric, res = runner.evaluate_one_epoch( 64 | dataset_factory.get_train_dataloader() 65 | ) 66 | 67 | file_uri_writer_processor(res, 'tpp_amazon_test.pkl') 68 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .tpp.att_nhp import AttNHP 2 | from .tpp.base_model import BaseModel 3 | 4 | __all__ = ['BaseModel', 5 | 'AttNHP'] 6 | -------------------------------------------------------------------------------- /models/ebm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iLampard/lamp/c48312108ef9b0eb54c9e4581b2af248abc09094/models/ebm/__init__.py -------------------------------------------------------------------------------- /models/ebm/att_nhp_ebm_amazon.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..tpp.att_nhp import AttNHP 4 | 5 | 6 | class AttNHPEBMTPP(AttNHP): 7 | def __init__(self, 8 | model_config): 9 | super(AttNHPEBMTPP, self).__init__(model_config) 10 | self.out_ffn_layer = torch.nn.Sequential( 11 | torch.nn.Linear(self.d_model * self.n_head, self.d_model), 12 | torch.nn.GELU(), 13 | torch.nn.Linear(self.d_model, 1) 14 | ) 15 | 16 | def forward( 17 | self, 18 | time_seqs, 19 | time_delta_seqs, 20 | type_seqs, 21 | batch_non_pad_mask, 22 | attention_mask, 23 | type_mask 24 | ): 25 | # shape -> [..., n_seqs, seq_len] 26 | 27 | rnn_seq_output = [] 28 | num_seqs = time_seqs.size(-2) 29 | for i in range(num_seqs): 30 | batch_i = (type_seqs[:, i, :], 31 | time_seqs[:, i, :], 32 | batch_non_pad_mask[:, i, :], 33 | attention_mask[:, i, :, :]) 34 | # [batch_size, seq_len, hidden_size] 35 | rnn_seq_output_i = super().forward_along_seqs(*batch_i) 36 | 37 | # [batch_size, hidden_size] 38 | rnn_seq_output_i = self.get_logits_at_last_step(rnn_seq_output_i, batch_non_pad_mask[:, i, :]) 39 | 40 | rnn_seq_output.append(rnn_seq_output_i) 41 | 42 | # shape -> [batch_size, n_seqs, hidden_dim] 43 | rnn_logit = torch.stack(rnn_seq_output, dim=1) 44 | 45 | # shape -> [..., n_seqs, 1] 46 | logit = self.out_ffn_layer(rnn_logit) 47 | 48 | # shape -> [..., n_seqs] 49 | logit = logit[..., 0] 50 | 51 | return logit 52 | -------------------------------------------------------------------------------- /models/ebm/nhp_ebm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from models.tpp.nhp import NHP 5 | 6 | 7 | class NHPDisc(NHP): 8 | def __init__(self, model_config): 9 | super(NHPDisc, self).__init__(model_config) 10 | 11 | # prediction for discriminator 12 | self.discriminator_prediction_layer = torch.nn.Sequential( 13 | nn.Linear(self.hidden_dim, self.hidden_dim // 2), 14 | # nn.Linear(self.hidden_dim, self.hidden_dim), 15 | nn.Linear(self.hidden_dim // 2, 1), 16 | ) 17 | 18 | self.discriminator_loss = nn.CrossEntropyLoss(reduction='mean') 19 | 20 | def predict_as_discriminator(self, logits): 21 | logits = self.discriminator_prediction_layer(logits) 22 | # use softmax to scale values to probability 23 | # logits = torch.softmax(logits, dim=-1) 24 | 25 | return logits 26 | 27 | @staticmethod 28 | def get_logits_at_last_step(logits, batch_non_pad_mask, sample_len=None): 29 | # ref: https://medium.com/analytics-vidhya/understanding-indexing-with-pytorch-gather-33717a84ebc4 30 | seq_len = batch_non_pad_mask.sum(dim=1) 31 | select_index = seq_len - 1 if sample_len is None else seq_len - 1 - sample_len 32 | # [batch_size, hidden_dim] 33 | select_index = select_index.unsqueeze(1).repeat(1, logits.size(-1)) 34 | # [batch_size, 1, hidden_dim] 35 | select_index = select_index.unsqueeze(1) 36 | # [batch_size, hidden_dim] 37 | last_logits = torch.gather(logits, dim=1, index=select_index).squeeze(1) 38 | return last_logits 39 | -------------------------------------------------------------------------------- /models/ebm/rnn.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | 5 | class RNNEbm(torch.nn.Module): 6 | def __init__( 7 | self, 8 | num_entities: int, 9 | num_relations: int, 10 | num_noise_samples: int, 11 | embedding_dim: int, 12 | num_cells: int, 13 | num_layers: int = 2, 14 | dropout_rate: float = 0.1 15 | ): 16 | super().__init__() 17 | self.num_noise_samples = num_noise_samples 18 | 19 | self.entity_embedding_layer = torch.nn.Embedding(num_entities, embedding_dim=embedding_dim) 20 | self.relation_embedding_layer = torch.nn.Embedding(num_relations, embedding_dim=embedding_dim) 21 | 22 | self.in_ffn_layer = torch.nn.Sequential( 23 | torch.nn.Linear(embedding_dim * 3, num_cells), 24 | torch.nn.GELU(), 25 | torch.nn.Linear(num_cells, num_cells) 26 | ) 27 | self.out_ffn_layer = torch.nn.Sequential( 28 | torch.nn.Linear(num_cells, num_cells), 29 | torch.nn.GELU(), 30 | torch.nn.Linear(num_cells, 1) 31 | ) 32 | 33 | self.rnn_layer = RNN( 34 | feat_dim=num_cells, 35 | num_cells=num_cells, 36 | num_layers=num_layers, 37 | dropout_rate=dropout_rate, 38 | cell_type='lstm', 39 | ) 40 | 41 | def forward( 42 | self, 43 | noise_seq_subject, 44 | noise_seq_object, 45 | noise_seq_relation, 46 | noise_seq_time, 47 | ): 48 | # shape -> [..., n_seqs, seq_len, embedding_dim] 49 | noise_seq_subject_emb = self.entity_embedding_layer(noise_seq_subject) 50 | noise_seq_object_emb = self.entity_embedding_layer(noise_seq_object) 51 | noise_seq_relation_emb = self.relation_embedding_layer(noise_seq_relation) 52 | 53 | mask = noise_seq_time > 0 54 | 55 | rnn_input = torch.cat([ 56 | noise_seq_subject_emb, 57 | noise_seq_object_emb, 58 | noise_seq_relation_emb, 59 | ], dim=-1) 60 | 61 | rnn_input = rnn_input * mask[..., None] 62 | 63 | rnn_input = self.in_ffn_layer(rnn_input) 64 | 65 | rnn_seq_output, _ = self.rnn_layer(rnn_input) 66 | 67 | # shape -> [..., n_seqs, 1] 68 | logit = self.out_ffn_layer(rnn_seq_output[..., :, -1, :]) 69 | 70 | # shape -> [..., n_seqs] 71 | logit = logit[..., 0] 72 | 73 | return logit 74 | 75 | 76 | class RNN(torch.nn.Module): 77 | rnn_type_dict = { 78 | 'rnn': torch.nn.RNN, 79 | 'gru': torch.nn.GRU, 80 | 'lstm': torch.nn.LSTM 81 | } 82 | 83 | def __init__( 84 | self, 85 | feat_dim: int, 86 | num_cells: int, 87 | num_layers: int = 1, 88 | dropout_rate: float = 0.1, 89 | cell_type: str = 'lstm', 90 | bidirectional: bool = False, 91 | bias: bool = True, 92 | ): 93 | """ 94 | 95 | Args: 96 | feat_dim: int 97 | Dimension of input. 98 | num_cells: int 99 | Dimension of hidden state. 100 | num_layers: int 101 | Number of rnn layers. 102 | dropout_rate: float, default 0.1 103 | Dropout rate of rnn layer. 104 | cell_type: str, default 'lstm' 105 | Type of rnn cell, option in ['rnn', 'gru', 'lstm']. 106 | bidirectional: bool, default False. 107 | Identify if the rnn is bidirectional. 108 | bias: bool, default True. 109 | Identify if using bias in rnn. 110 | """ 111 | super(RNN, self).__init__() 112 | # assignment 113 | if num_layers == 1 and dropout_rate > 0: 114 | dropout_rate = 0 115 | self.cell_type = cell_type.lower() 116 | rnn_cls = self.rnn_type_dict.get(self.cell_type) 117 | 118 | self.rnn_layer = rnn_cls( 119 | input_size=feat_dim, 120 | hidden_size=num_cells, 121 | num_layers=num_layers, 122 | dropout=dropout_rate, 123 | bidirectional=bidirectional, 124 | batch_first=True, 125 | bias=bias 126 | ) 127 | 128 | def forward(self, x, initial_state=None) -> torch.Tensor: 129 | """ 130 | 131 | Args: 132 | x: tensor with shape [..., seq_len, feat_dim] 133 | initial_state: tensor or tuple of tensor 134 | Same like rnn, gru or lstm's state. 135 | 136 | Returns: 137 | output tensor: [..., seq_len, num_directions * num_cells] 138 | last state tensor: (h_n, c_n) or h_n with shape [..., num_layers * num_directions, num_cells] 139 | """ 140 | if self.cell_type not in ['lstm'] and type(initial_state) in (tuple, list) and len(initial_state) == 1: 141 | initial_state = initial_state[0] 142 | self.rnn_layer.flatten_parameters() 143 | 144 | if len(x.size()) > 3: 145 | head_shape = x.shape[:-2] 146 | out_tuple = self.rnn_layer(x.flatten(0, -3), initial_state) 147 | seq_out = out_tuple[0].reshape([*head_shape, *out_tuple[0].shape[-2:]]) 148 | return seq_out, out_tuple[1] 149 | else: 150 | return self.rnn_layer(x, initial_state) 151 | -------------------------------------------------------------------------------- /models/ke_anhp/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from ._model import KnowEvolveANHP 3 | -------------------------------------------------------------------------------- /models/modules.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | 5 | class ParameterDot(torch.nn.Module): 6 | def __init__( 7 | self, 8 | num_embeddings: int, 9 | in1_features: int, 10 | in2_features: int, 11 | bias: bool = False, 12 | ): 13 | super(ParameterDot, self).__init__() 14 | self.use_bias = bias 15 | 16 | self.weights = torch.nn.Parameter(torch.Tensor(num_embeddings, in1_features, in2_features)) 17 | torch.nn.init.xavier_uniform_(self.weights, gain=torch.nn.init.calculate_gain('relu')) 18 | if self.use_bias: 19 | self.bias = torch.nn.Parameter(torch.empty((num_embeddings,)), requires_grad=True) 20 | torch.nn.init.constant_(self.bias, 0.1) 21 | 22 | def forward(self, tensor1, tensor2, parameters_index): 23 | """Applies a dot transformation to the incoming data: `y = x_1^T A x_2 + b`. 24 | 25 | Args: 26 | tensor1: tensor with shape [B, ..., in1_features] 27 | tensor2: tensor with shape [B, ..., in2_features] 28 | parameters_index: tensor with shape [B] 29 | 30 | Returns: 31 | Tensor with shape [B]. 32 | """ 33 | # [B, in1_features, in2_features] 34 | weight = self.weights[parameters_index] 35 | 36 | # [..., B, in1_features] 37 | tensor1 = tensor1.transpose(-2, 0) 38 | # [B, ..., in2_features] 39 | x = torch.sum(tensor1[..., None] * weight, dim=-2).transpose(0, -2) 40 | # [B, ...] 41 | x = torch.sum(x * tensor2, dim=-1) 42 | 43 | if self.use_bias: 44 | # [B] 45 | bias = self.bias[parameters_index] 46 | 47 | x = x.transpose(-1, 0) + bias 48 | x = x.transpose(0, -1) 49 | 50 | return x 51 | -------------------------------------------------------------------------------- /models/tpp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iLampard/lamp/c48312108ef9b0eb54c9e4581b2af248abc09094/models/tpp/__init__.py -------------------------------------------------------------------------------- /models/tpp/base_model.py: -------------------------------------------------------------------------------- 1 | """ Base model with common functionality """ 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class BaseModel(nn.Module): 8 | def __init__(self, model_config): 9 | super(BaseModel, self).__init__() 10 | self.num_steps_integral_loss = model_config.get( 11 | 'num_steps_integral_loss', 20) 12 | self.add_bos = model_config.get('add_bos', False) 13 | self.hidden_size = model_config.get('hidden_size', 64) 14 | self.num_event_types_no_pad = model_config['num_event_types_no_pad'] # not include [PAD], [BOS], [EOS] 15 | self.num_event_types_pad = model_config['num_event_types_pad'] # include [PAD], [BOS], [EOS] 16 | self.event_pad_index = model_config['event_pad_index'] 17 | self.eps = torch.finfo(torch.float32).eps 18 | 19 | self.layer_event_emb = nn.Embedding(self.num_event_types_pad, # have padding 20 | self.hidden_size, 21 | padding_idx=self.event_pad_index) 22 | 23 | @staticmethod 24 | def generate_model_from_config(model_config): 25 | model_name = model_config.get('name') 26 | 27 | cls_list = [] 28 | cls_list.extend(BaseModel.__subclasses__()) 29 | idx = 0 30 | while idx < len(cls_list): 31 | subcls = cls_list[idx] 32 | if subcls.__name__ == model_name: 33 | return subcls(model_config) 34 | cls_list.extend(subcls.__subclasses__()) 35 | idx += 1 36 | 37 | raise RuntimeError('No model named ' + model_name) 38 | 39 | def compute_loglik(self, batch): 40 | raise NotImplementedError -------------------------------------------------------------------------------- /models/tpp/layers.py: -------------------------------------------------------------------------------- 1 | # Transformer Components Implementation Adapted from Annotated Transformer: 2 | # https://nlp.seas.harvard.edu/2018/04/03/attention.html 3 | import math 4 | 5 | import torch 6 | from torch import nn 7 | 8 | 9 | # ref: https://github.com/yangalan123/anhp-andtt/blob/master/anhp/model/xfmr.py 10 | 11 | def attention(query, key, value, mask=None, dropout=None): 12 | d_k = query.size(-1) 13 | scores = torch.matmul(query, key.transpose(-2, -1)) \ 14 | / math.sqrt(d_k) 15 | if mask is not None: 16 | # small change here -- we use "1" for masked element 17 | scores = scores.masked_fill(mask > 0, -1e9) 18 | p_attn = torch.softmax(scores, dim=-1) 19 | if dropout is not None: 20 | p_attn = dropout(p_attn) 21 | return torch.matmul(p_attn, value), p_attn 22 | 23 | 24 | class MultiHeadAttention(nn.Module): 25 | def __init__(self, n_head, d_input, d_model, dropout=0.1, output_linear=False): 26 | super(MultiHeadAttention, self).__init__() 27 | assert d_model % n_head == 0 28 | self.n_head = n_head 29 | self.d_k = d_model // n_head 30 | self.d_v = self.d_k 31 | self.d_model = d_model 32 | self.output_linear = output_linear 33 | 34 | if output_linear: 35 | self.linears = nn.ModuleList( 36 | [nn.Linear(d_input, d_model) for _ in range(3)] + [nn.Linear(d_model, d_model), ]) 37 | else: 38 | self.linears = nn.ModuleList([nn.Linear(d_input, d_model) for _ in range(3)]) 39 | # for i in range(len(self.linears)): 40 | # nn.init.xavier_uniform_(self.linears[i].weight) 41 | self.dropout = nn.Dropout(p=dropout) 42 | 43 | def forward(self, query, key, value, mask): 44 | if mask is not None: 45 | mask = mask.unsqueeze(1) 46 | nbatches = query.size(0) 47 | 48 | query, key, value = [ 49 | l(x).view(nbatches, -1, self.n_head, self.d_k).transpose(1, 2) 50 | for l, x in zip(self.linears, (query, key, value)) 51 | ] 52 | x, attn_weight = attention(query, key, value, mask=mask, dropout=self.dropout) 53 | 54 | x = x.transpose(1, 2).contiguous() \ 55 | .view(nbatches, -1, self.n_head * self.d_k) 56 | 57 | if self.output_linear: 58 | return self.linears[-1](x) 59 | else: 60 | return x 61 | 62 | 63 | class SublayerConnection(nn.Module): 64 | # used for residual connnection 65 | def __init__(self, d_model, dropout): 66 | super(SublayerConnection, self).__init__() 67 | self.norm = nn.LayerNorm(d_model) 68 | self.dropout = nn.Dropout(dropout) 69 | 70 | def forward(self, x, sublayer): 71 | return x + self.dropout(sublayer(self.norm(x))) 72 | 73 | 74 | class PositionwiseFeedForward(nn.Module): 75 | """Implements FFN equation.""" 76 | 77 | def __init__(self, d_model, d_ff, dropout=0.1): 78 | super(PositionwiseFeedForward, self).__init__() 79 | self.w_1 = nn.Linear(d_model, d_ff) 80 | self.w_2 = nn.Linear(d_ff, d_model) 81 | self.dropout = nn.Dropout(dropout) 82 | self.relu = nn.ReLU() 83 | 84 | def forward(self, x): 85 | return self.w_2(self.dropout(self.relu(self.w_1(x)))) 86 | 87 | 88 | class EncoderLayer(nn.Module): 89 | def __init__(self, d_model, self_attn, feed_forward=None, use_residual=False, dropout=0.1): 90 | super(EncoderLayer, self).__init__() 91 | self.self_attn = self_attn 92 | self.feed_forward = feed_forward 93 | self.use_residual = use_residual 94 | if use_residual: 95 | self.sublayer = nn.ModuleList([SublayerConnection(d_model, dropout) for _ in range(2)]) 96 | self.d_model = d_model 97 | 98 | def forward(self, x, mask): 99 | if self.use_residual: 100 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) 101 | if self.feed_forward is not None: 102 | return self.sublayer[1](x, self.feed_forward) 103 | else: 104 | return x 105 | else: 106 | return self.self_attn(x, x, x, mask) 107 | 108 | 109 | class XFMREncoder(nn.Module): 110 | def __init__(self, d_model, num_layers, self_attn, feed_forward, use_residual=False, dropout=0.1): 111 | super(XFMREncoder, self).__init__() 112 | self.layers = nn.ModuleList( 113 | [EncoderLayer(d_model, self_attn, feed_forward, use_residual, dropout) 114 | for _ in range(num_layers) 115 | ]) 116 | self.norm = nn.LayerNorm(d_model) 117 | 118 | def forward(self, x, mask): 119 | for layer in self.layers: 120 | x = layer(x, mask) 121 | return self.norm(x) -------------------------------------------------------------------------------- /models/tpp_runner.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from utils.general import get_value_by_key 5 | from utils.metrics import rank, is_hit, time_rmse_np 6 | 7 | 8 | class TPPRunner(): 9 | def __init__( 10 | self, 11 | model, 12 | lr: float = 1e-3, 13 | num_epochs: int = 10, 14 | ): 15 | self.model = model 16 | self.num_epochs = num_epochs 17 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr) 18 | 19 | def train(self, train_dl, valid_dl=None, num_epochs: int = None, verbose: bool = True, autosave: bool = True): 20 | num_epochs = num_epochs or self.num_epochs 21 | 22 | best_metric = float('inf') 23 | for epoch_i in range(num_epochs): 24 | loss = self.train_one_epoch(train_dl, verbose=verbose) 25 | print('Epoch', epoch_i, 'loss:', loss) 26 | if valid_dl is not None: 27 | # metric, _ = self.evaluate_one_epoch(valid_dl) 28 | metric = self.evaluate_one_epoch_by_loss(valid_dl) 29 | if metric < best_metric: 30 | if autosave: 31 | self.save() 32 | print('save model at epoch', epoch_i) 33 | best_metric = metric 34 | print('------------ Best metric:', metric) 35 | 36 | def train_one_epoch( 37 | self, 38 | train_dataloader, 39 | verbose: bool = True 40 | ): 41 | epoch_loss = 0 42 | num_batches = len(train_dataloader) 43 | for i, batch in enumerate(train_dataloader): 44 | self.optimizer.zero_grad() 45 | ret_tuple = self.model( 46 | batch, 47 | return_loss=True, 48 | ) 49 | loss = ret_tuple[0] 50 | 51 | loss.backward() 52 | 53 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=5.0) 54 | self.optimizer.step() 55 | 56 | loss = loss.detach().cpu().numpy() 57 | if verbose: 58 | print(f'--- batch {i} loss:', loss) 59 | 60 | epoch_loss += loss / num_batches 61 | 62 | return epoch_loss 63 | 64 | def evaluate_one_epoch_by_loss(self, dataloader): 65 | epoch_loss = 0 66 | num_batches = len(dataloader) 67 | with torch.no_grad(): 68 | for i, batch in enumerate(dataloader): 69 | 70 | if batch[0].size(1) <= 2: 71 | continue 72 | 73 | ret_tuple = self.model( 74 | batch, 75 | return_loss=True, 76 | ) 77 | loss = ret_tuple[0] 78 | 79 | loss = loss.detach().cpu().numpy() 80 | 81 | epoch_loss += loss / num_batches 82 | 83 | return epoch_loss 84 | 85 | def evaluate_one_epoch(self, dataloader, top_n_hit: int = 5): 86 | self.model.eval() 87 | 88 | metric_list_dict = { 89 | 'type': { 90 | 'mean_rank': [], 91 | 'hit_ratio': [] 92 | }, 93 | 'time': { 94 | 'rmse': [] 95 | }, 96 | } 97 | 98 | res_batches = [] 99 | label_dtime_batch = [] 100 | pred_dtime_batch = [] 101 | with torch.no_grad(): 102 | for i, batch in enumerate(dataloader): 103 | batch_pred = self.model.inference( 104 | batch 105 | ) 106 | 107 | if len(batch_pred): 108 | label_type = np.array(get_value_by_key('label_type', batch_pred)) 109 | pred_type_score = np.array(get_value_by_key('pred_type_score', batch_pred)) 110 | label_dtime = np.array(get_value_by_key('label_dtime', batch_pred)) 111 | pred_dtime = np.array(get_value_by_key('pred_dtime', batch_pred)) 112 | 113 | mean_rank = rank(label=label_type, pred=pred_type_score).tolist() 114 | metric_list_dict['type']['mean_rank'].extend(mean_rank) 115 | 116 | hit_ratio = is_hit(label=label_type, pred=pred_type_score, 117 | top_n=top_n_hit) 118 | 119 | metric_list_dict['type']['hit_ratio'].extend(hit_ratio) 120 | label_dtime_batch.extend(label_dtime) 121 | pred_dtime_batch.extend(pred_dtime) 122 | 123 | res_batches.append(batch_pred) 124 | 125 | for pred_type, metric_dict in metric_list_dict.items(): 126 | for metric_name, metric_list in metric_dict.items(): 127 | if len(metric_list) > 0: 128 | # set the important metric 129 | if pred_type == 'type' and metric_name == 'mean_rank': 130 | metric = np.mean(metric_list) 131 | print(f'--------- {pred_type}-{metric_name}:', np.mean(metric_list)) 132 | 133 | # print(f'--------- RMSE:', time_rmse_np(pred_dtime_batch, label_dtime_batch)) 134 | 135 | return metric, res_batches 136 | 137 | def save(self, path: str = 'anhp.pt'): 138 | torch.save(self.model.state_dict(), path) 139 | 140 | def load(self, path: str = 'anhp.pt'): 141 | self.model.load_state_dict(torch.load(path)) 142 | 143 | def is_empty_pred(self, pred_output): 144 | if len(pred_output) == 0: 145 | return True 146 | elif len(pred_output[0]) == 0: 147 | return True 148 | return False 149 | -------------------------------------------------------------------------------- /preprocess/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_loader_tpp import TPPDataset, create_tpp_dataloader 2 | 3 | __all__ = [ 4 | 'TPPDataset', 5 | 'create_tpp_dataloader' 6 | ] 7 | -------------------------------------------------------------------------------- /preprocess/data_loader_tpp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset, DataLoader 4 | 5 | 6 | # ref: https://github.com/yangalan123/anhp-andtt/blob/master/anhp/data/NHPDataset.py 7 | 8 | 9 | class TPPDataset(Dataset): 10 | def __init__(self, data, event_num, eos_elapse=1, add_bos=True, add_eos=True, eps=np.finfo(float).eps, 11 | concurrent=False, 12 | skip_padding=False): 13 | """ 14 | :param data: list[list[dict{"time_since_last_event"[float], "time_since_start"[float], "type_event"[int]}]] 15 | :param event_num: int, how many events are there in the whole dataset? (not limited to train/test/dev) 16 | each internal list is a event stream sequence 17 | following the data format of GaTech and JHU 18 | event_type(int) : starting from 0, appended [BOS], [EOS] and [PAD] at |E|, |E| + 1, |E| + 2 19 | :param eos_elapse: int, how much time we should wait after the last event to give EOS mark 20 | :param add_bos / add_eos: bool, whether to add [BOS] / [EOS] 21 | :param eps: float, if |x1-x2| < eps, then x1 == x2 (avoid float error in comparison) 22 | :param concurrent: bool, whether to consider concurrent events or not 23 | """ 24 | assert eos_elapse >= 0, "EOS should not appear earlier than last event!" 25 | self.time_seq = [[x["time_since_start"] for x in seq] for seq in data] 26 | self.event_seq = [[x["type_event"] for x in seq] for seq in data] 27 | self.time_delta_seq = [[x["time_since_last_event"] for x in seq] for seq in data] 28 | self.skip_padding = skip_padding 29 | 30 | # starting from 0 31 | self.event_num = event_num 32 | self.pad_index = self.event_num 33 | self.bos_index = self.event_num + 1 34 | self.eos_index = self.event_num + 2 35 | self.eps = eps 36 | self.concurrent = concurrent 37 | 38 | self.add_bos = add_bos 39 | self.add_eos = add_eos 40 | self.data = data 41 | # at least include [PAD] 42 | self.num_types = self.event_num + 1 43 | assert max( 44 | [max(seq) for seq in self.event_seq]) + 1 <= self.event_num, "there are more event than specified?" 45 | if self.add_bos: 46 | self.time_seq = [[0, ] + seq for seq in self.time_seq] 47 | self.event_seq = [[self.bos_index, ] + seq for seq in self.event_seq] 48 | self.time_delta_seq = [[0, ] + seq for seq in self.time_delta_seq] 49 | self.num_types += 1 50 | if self.add_eos: 51 | self.time_seq = [seq + [seq[-1] + eos_elapse, ] for seq in self.time_seq] 52 | self.event_seq = [seq + [self.eos_index, ] for seq in self.event_seq] 53 | self.time_delta_seq = [seq + [eos_elapse, ] for seq in self.time_delta_seq] 54 | self.num_types += 1 55 | 56 | def __len__(self): 57 | assert len(self.time_seq) == len(self.event_seq) and len(self.time_delta_seq) == len(self.event_seq), \ 58 | f"Inconsistent lengths for data! time_seq_len:{len(self.time_seq)}, event_len: {len(self.event_seq)}, time_delta_seq_len: {len(self.time_delta_seq)}" 59 | return len(self.event_seq) 60 | 61 | def __getitem__(self, idx): 62 | return self.time_seq[idx], self.time_delta_seq[idx], self.event_seq[idx] 63 | 64 | def padding(self, seqs, dtype, max_len=None): 65 | # padding to the max_length 66 | if max_len is None: 67 | max_len = max(len(seq) for seq in seqs) 68 | batch_seq = np.array([seq + [self.pad_index] * (max_len - len(seq)) for seq in seqs], dtype=np.float64) 69 | 70 | # by default, return float32 tensor 71 | return torch.tensor(batch_seq, dtype=dtype) 72 | 73 | def createConcurrentMask(self, time_seq): 74 | max_length = max([len(x) for x in time_seq]) 75 | batch_size = len(time_seq) 76 | mask = torch.ones((batch_size, max_length, max_length), dtype=torch.uint8) 77 | for _batch_i, _time_seq in enumerate(time_seq): 78 | # buf[(start, end, can_attend_to_earlier_than)] 79 | # mask[_batch_i, start:end, 0:can_attend_to_earlier_than] = 0 80 | if self.add_bos: 81 | cur_index = 1 82 | buf = [(0, 1, 1)] 83 | else: 84 | cur_index = 0 85 | buf = [] 86 | cur_time = _time_seq[cur_index] 87 | last_index = cur_index 88 | while cur_index < len(_time_seq): 89 | if abs(_time_seq[cur_index] - cur_time) < self.eps: 90 | cur_index += 1 91 | else: 92 | buf.append((last_index, cur_index, last_index)) 93 | last_index = min(cur_index, len(_time_seq) - 1) 94 | cur_time = _time_seq[last_index] 95 | 96 | if buf[-1][1] <= len(_time_seq) - 1: 97 | buf.append((last_index, cur_index, last_index)) 98 | 99 | for item in buf: 100 | beg, end, attend_to = item 101 | mask[_batch_i, beg: end, 0: attend_to] = 0 102 | return mask 103 | 104 | def createPadAttnMask(self, event_seq, concurrent_mask=None): 105 | # 1 -- pad, 0 -- non-pad 106 | batch_size, seq_len = event_seq.size(0), event_seq.size(1) 107 | batch_seq_pad_mask = event_seq.eq(self.pad_index) 108 | attention_key_pad_mask = batch_seq_pad_mask.unsqueeze(1).expand(batch_size, seq_len, -1) 109 | subsequent_mask = torch.triu( 110 | torch.ones((seq_len, seq_len), device=event_seq.device, dtype=torch.uint8), diagonal=0 111 | ).unsqueeze(0).expand(batch_size, -1, -1) 112 | attention_mask = subsequent_mask | attention_key_pad_mask.bool() 113 | if concurrent_mask is None: 114 | # no way to judge concurrent events, simply believe there is no concurrent events 115 | pass 116 | else: 117 | attention_mask |= concurrent_mask.bool() 118 | return ~batch_seq_pad_mask, attention_mask 119 | 120 | def collate_fn(self, batch): 121 | time_seq, time_delta_seq, event_seq = list(zip(*batch)) 122 | if self.concurrent: 123 | concurrent_mask = self.createConcurrentMask(time_seq) 124 | else: 125 | concurrent_mask = None 126 | 127 | time_seq = self.padding(time_seq, torch.float64) 128 | time_delta_seq = self.padding(time_delta_seq, torch.float64) 129 | event_seq = self.padding(event_seq, torch.long) 130 | 131 | batch_non_pad_mask, attention_mask = self.createPadAttnMask(event_seq, concurrent_mask) 132 | 133 | type_mask = torch.zeros([*event_seq.size(), self.event_num]) 134 | for i in range(self.event_num): 135 | type_mask[:, :, i] = event_seq == i 136 | 137 | return time_seq, time_delta_seq, event_seq, batch_non_pad_mask, attention_mask, type_mask 138 | 139 | 140 | def create_tpp_dataloader(dataset, batch_size, shuffle=True): 141 | return DataLoader( 142 | dataset, 143 | # num_workers=num_workers, # comment this line which may cause warnings in ParallelNative.cpp 144 | batch_size=batch_size, 145 | collate_fn=dataset.collate_fn, 146 | shuffle=shuffle 147 | ) 148 | -------------------------------------------------------------------------------- /preprocess/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iLampard/lamp/c48312108ef9b0eb54c9e4581b2af248abc09094/preprocess/datasets/__init__.py -------------------------------------------------------------------------------- /preprocess/datasets/ebm_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | from torch.utils.data.dataset import T_co 5 | 6 | 7 | class EBMDataset(Dataset): 8 | def __init__( 9 | self, 10 | data, 11 | num_noise_samples: int, 12 | time_factor: float = 1.0, 13 | max_seq_length: int = 10, 14 | target_name: str = 'relation' 15 | ): 16 | super().__init__() 17 | self.data = data 18 | self.num_noise_samples = num_noise_samples 19 | self.time_factor = time_factor 20 | self.max_seq_length = max_seq_length 21 | self.target_name = target_name 22 | 23 | def __getitem__(self, index) -> T_co: 24 | point = self.data[index] 25 | original_index = point[0] 26 | label_causal_event_list = point[1] 27 | noise_causal_events_list = point[2] 28 | label_target = point[3] 29 | noise_target = point[4] 30 | 31 | label_dict = list_of_dict_to_dict(label_causal_event_list) 32 | label_dict = padding_lists_in_dict(label_dict, padding_length=self.max_seq_length) 33 | 34 | noise_dicts = [ 35 | padding_lists_in_dict( 36 | list_of_dict_to_dict(noise_causal_event_list), 37 | padding_length=self.max_seq_length 38 | ) for noise_causal_event_list in noise_causal_events_list 39 | ] 40 | 41 | assert len(noise_dicts) >= self.num_noise_samples, 'Length not match' 42 | 43 | noise_dict = list_of_dict_to_dict(noise_dicts) 44 | 45 | return { 46 | 'label_seq_subject': label_dict['subject'], 47 | 'label_seq_object': label_dict['object'], 48 | 'label_seq_relation': label_dict['relation'], 49 | 'label_seq_time': label_dict['time'] / self.time_factor, 50 | 51 | 'noise_seq_subject': noise_dict['subject'][:self.num_noise_samples], 52 | 'noise_seq_object': noise_dict['object'][:self.num_noise_samples], 53 | 'noise_seq_relation': noise_dict['relation'][:self.num_noise_samples], 54 | 'noise_seq_time': noise_dict['time'][:self.num_noise_samples] / self.time_factor, 55 | 56 | # label information 57 | 'label_target': label_target, # real rels 58 | 'pred_target': [label_target] + noise_target[:4], # make it the same length for all events 59 | 'original_index': original_index 60 | } 61 | 62 | def __len__(self): 63 | return len(self.data) 64 | 65 | 66 | class EBMValidDataset(EBMDataset): 67 | def __init__( 68 | self, 69 | data, 70 | num_noise_samples: int, 71 | time_factor: float = 1.0, 72 | max_seq_length: int = 10, 73 | target_name: str = 'relation' 74 | ): 75 | super().__init__( 76 | data, num_noise_samples, time_factor, max_seq_length, target_name 77 | ) 78 | self.data = [ 79 | item for item in self.data if len(item[4]) == num_noise_samples 80 | ] 81 | 82 | 83 | def list_of_dict_to_dict(list_of_dicts): 84 | if not list_of_dicts: 85 | raise ValueError("The list of dicts is empty") 86 | 87 | dict_of_lists = {key: np.array([d[key] for d in list_of_dicts]) for key in list_of_dicts[0]} 88 | 89 | return dict_of_lists 90 | 91 | 92 | def padding_lists_in_dict(a_dict, padding_length=10, padding_values=None): 93 | padding_values = padding_values or {} 94 | 95 | for k in a_dict: 96 | v = a_dict[k] 97 | len_diff = padding_length - len(v) 98 | if len_diff < 0: 99 | a_dict[k] = v[-padding_length:] 100 | elif len_diff > 0: 101 | a_dict[k] = np.concatenate([[padding_values.get(k, 0)] * len_diff, v]) 102 | 103 | return a_dict 104 | -------------------------------------------------------------------------------- /preprocess/datasets/kg_dataset.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class KGDataset(Dataset): 8 | """Return a sequence whose length is context_length + 1, and drop last by default""" 9 | 10 | def __init__( 11 | self, 12 | data, 13 | context_length: int, 14 | time_factor: float = 100.0, 15 | start_ratio: float = 0.0, 16 | end_ratio: float = 1.0, 17 | overlap: bool = True, 18 | ): 19 | # data is the list of tuple 20 | self.data = data 21 | self.context_length = context_length 22 | self.time_factor = time_factor 23 | # start and end ratio of the length of the data 24 | self.start_ratio = start_ratio 25 | self.end_ratio = end_ratio 26 | self.overlap = overlap 27 | 28 | self.window_length = self.context_length + 1 29 | 30 | # set start and end indexes for each sequence (just one sequence in this case) 31 | data_length = len(self.data) 32 | assert data_length >= self.window_length, f'The length of dataset ({data_length}) ' \ 33 | f'is less than context_length ({self.window_length})' 34 | 35 | self.start_idx = int(len(self.data) * self.start_ratio) 36 | self.end_idx = int(len(self.data) * self.end_ratio) 37 | 38 | def __getitem__(self, idx): 39 | """Get a sequence of data whose length is context_length + 1 ([idx - context_length, idx]) 40 | 41 | Args: 42 | idx: 43 | 44 | Returns: 45 | 46 | """ 47 | original_idx = self.get_original_index(idx) 48 | sample = { 49 | 'seq_subject': self._extract_window_list_by_idx(original_idx, tuple_idx=0), 50 | 'seq_object': self._extract_window_list_by_idx(original_idx, tuple_idx=1), 51 | 'seq_relation': self._extract_window_list_by_idx(original_idx, tuple_idx=2), 52 | 'seq_time': (self._extract_window_list_by_idx(original_idx, tuple_idx=3) / self.time_factor).astype( 53 | np.float32), 54 | } 55 | return sample 56 | 57 | def __len__(self): 58 | if self.overlap: 59 | num = self.end_idx - self.start_idx - self.context_length 60 | else: 61 | num = math.floor((self.end_idx - self.start_idx) / self.window_length) 62 | return num 63 | 64 | def get_original_index(self, idx): 65 | if self.overlap: 66 | return self.start_idx + self.context_length + idx 67 | else: 68 | return self.start_idx + self.context_length + idx * self.window_length 69 | 70 | def _extract_window_list_by_idx(self, idx, tuple_idx): 71 | window = [] 72 | for i in range(idx - self.context_length, idx + 1): 73 | window.append(self.data[i][tuple_idx]) 74 | return np.array(window) 75 | -------------------------------------------------------------------------------- /preprocess/datasets/tpp_dataset.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import Dataset 6 | 7 | 8 | # ref: https://github.com/yangalan123/anhp-andtt/blob/master/anhp/data/NHPDataset.py 9 | 10 | 11 | class TPPDataset(Dataset): 12 | def __init__( 13 | self, 14 | data, 15 | num_event_types, 16 | start_date: str = '2000-01-01', 17 | end_date: str = '2030-01-01', 18 | time_factor: float = 100.0 19 | ): 20 | self.data = data 21 | self.time_factor = time_factor 22 | self.start_date = datetime.strptime(start_date, '%Y-%m-%d') if isinstance(start_date, str) else start_date 23 | self.end_date = datetime.strptime(end_date, '%Y-%m-%d') if isinstance(end_date, str) else end_date 24 | 25 | # only use end date to do the truncation 26 | # because we better need the prefix sequence in valid and test set 27 | self.time_seqs_used_in_model = [ 28 | [x["event_time"] for x in seq if self.is_before_time_period(x['event_date'])] for 29 | k, seq in data.items()] 30 | 31 | self.time_seqs_in_period = [ 32 | [x["event_time"] for x in seq if self.is_in_time_period(x['event_date'])] for 33 | k, seq in data.items()] 34 | 35 | self.type_seqs = [ 36 | [x["event_type"] for x in seq if self.is_before_time_period(x['event_date'])] for 37 | k, seq in data.items()] 38 | 39 | self.time_delta_seqs = [[x["event_dtime"] for x in seq if self.is_before_time_period(x['event_date'])] 40 | for k, seq in data.items()] 41 | 42 | # make the first timestamp to zero 43 | self.time_seqs = [] 44 | for seq in self.time_seqs_used_in_model: 45 | seq = [x - seq[0] for x in seq] 46 | self.time_seqs.append(seq) 47 | 48 | # seq idx 49 | self.seq_idx = list(data.keys()) 50 | 51 | # position in the seq 52 | self.original_idx = self.get_original_index() 53 | 54 | self.event_num = num_event_types 55 | self.pad_index = self.event_num 56 | 57 | def get_original_index(self): 58 | res = [] 59 | for time_seq, source_seq in zip(self.time_seqs_in_period, self.time_seqs_used_in_model): 60 | idx_in_seq = [source_seq.index(time_seq[i]) for i in range(len(time_seq))] 61 | res.append(idx_in_seq) 62 | return res 63 | 64 | def is_in_time_period(self, event_date: str, date_format='%Y-%m-%d'): 65 | event_date_ = datetime.strptime(event_date, date_format) 66 | if event_date_ >= self.start_date and event_date_ < self.end_date: 67 | return True 68 | else: 69 | return False 70 | 71 | def is_before_time_period(self, event_date: str, date_format='%Y-%m-%d'): 72 | event_date_ = datetime.strptime(event_date, date_format) 73 | if event_date_ < self.end_date: 74 | return True 75 | else: 76 | return False 77 | 78 | def __len__(self): 79 | """ 80 | 81 | Returns: length of the dataset 82 | 83 | """ 84 | 85 | return len(self.time_seqs) 86 | 87 | def __getitem__(self, idx): 88 | """ 89 | 90 | Args: 91 | idx: iteration index 92 | 93 | Returns: 94 | time_seq, time_delta_seq and event_seq element 95 | 96 | """ 97 | return self.time_seqs[idx], self.time_delta_seqs[idx], self.type_seqs[idx], self.seq_idx[idx], \ 98 | self.original_idx[idx] 99 | 100 | def padding(self, seqs, dtype, max_len=None, pad_index=None): 101 | pad_index = self.pad_index if pad_index is None else pad_index 102 | # padding to the max_length 103 | if max_len is None: 104 | max_len = max(len(seq) for seq in seqs) 105 | batch_seq = np.array([seq + [pad_index] * (max_len - len(seq)) for seq in seqs], dtype=np.float64) 106 | 107 | return torch.tensor(batch_seq, dtype=dtype) 108 | 109 | def createPadAttnMask(self, type_seqs, concurrent_mask=None): 110 | # 1 -- pad, 0 -- non-pad 111 | batch_size, seq_len = type_seqs.size(0), type_seqs.size(1) 112 | batch_seq_pad_mask = type_seqs.eq(self.pad_index) 113 | attention_key_pad_mask = batch_seq_pad_mask.unsqueeze(1).expand(batch_size, seq_len, -1) 114 | subsequent_mask = torch.triu( 115 | torch.ones((seq_len, seq_len), device=type_seqs.device, dtype=torch.uint8), diagonal=0 116 | ).unsqueeze(0).expand(batch_size, -1, -1) 117 | attention_mask = subsequent_mask | attention_key_pad_mask.bool() 118 | if concurrent_mask is None: 119 | # no way to judge concurrent events, simply believe there is no concurrent events 120 | pass 121 | else: 122 | attention_mask |= concurrent_mask.bool() 123 | return ~batch_seq_pad_mask, attention_mask 124 | 125 | def collate_fn(self, batch): 126 | """ 127 | 128 | Args: 129 | batch: batch sequence data 130 | 131 | Returns: 132 | batch tensors of time_seqs, time_delta_seqs, event_seqs, 133 | batch_non_pad_mask, attention_mask, type_mask 134 | 135 | """ 136 | time_seqs, time_delta_seqs, type_seqs, seq_idx, original_idx = list(zip(*batch)) 137 | 138 | # one could use float64 to avoid precision loss during conversion from numpy.array to torch.tensor 139 | # for generality we use float32 for the moment 140 | time_seqs = self.padding(time_seqs, torch.float32) 141 | time_delta_seqs = self.padding(time_delta_seqs, torch.float32) 142 | type_seqs = self.padding(type_seqs, torch.long) 143 | 144 | batch_non_pad_mask, attention_mask = self.createPadAttnMask(type_seqs) 145 | 146 | type_mask = torch.zeros([*type_seqs.size(), self.event_num]) 147 | for i in range(self.event_num): 148 | type_mask[:, :, i] = type_seqs == i 149 | 150 | # an ugly pad, we fix it later 151 | # this pad has no effect in evaluation because batch_size 152 | original_idx = self.padding(original_idx, pad_index=1000, dtype=torch.long) 153 | 154 | seq_idx = np.array(seq_idx)[:, None] 155 | seq_idx = torch.LongTensor(np.tile(seq_idx, [1, original_idx.size(-1)])) 156 | 157 | return time_seqs, time_delta_seqs, type_seqs, batch_non_pad_mask, attention_mask, type_mask, \ 158 | seq_idx, original_idx 159 | -------------------------------------------------------------------------------- /preprocess/kg_data_factory.py: -------------------------------------------------------------------------------- 1 | 2 | import yaml 3 | from torch.utils.data import Sampler, DataLoader, SequentialSampler 4 | from torch.utils.data.dataloader import default_collate 5 | 6 | from preprocess.datasets.kg_dataset import KGDataset 7 | from utils.general import file_uri_reader_processor 8 | 9 | 10 | class KGDatasetFactory: 11 | def __init__( 12 | self, 13 | data, 14 | context_length: int, 15 | num_entities: int, 16 | num_relations: int, 17 | train_end_index_ratio: float = 0.7, 18 | valid_end_index_ratio: float = 0.8, 19 | test_end_index_ratio: float = 1.0, 20 | time_factor: float = 1.0, 21 | batch_size: int = 32, 22 | ): 23 | # data is the list of tuple 24 | self.data = data 25 | self.num_entities = num_entities 26 | self.num_relations = num_relations 27 | self.context_length = context_length 28 | self.train_end_index_ratio = train_end_index_ratio 29 | self.valid_end_index_ratio = valid_end_index_ratio 30 | self.test_end_index_ratio = test_end_index_ratio or 1.0 31 | self.time_factor = time_factor 32 | self.batch_size = batch_size 33 | 34 | # set datasets 35 | self.train_dataset = KGDataset( 36 | data=self.data, 37 | context_length=self.context_length, 38 | time_factor=self.time_factor, 39 | end_ratio=self.train_end_index_ratio, 40 | overlap=False 41 | ) 42 | 43 | if self.valid_end_index_ratio is None or self.train_end_index_ratio >= self.valid_end_index_ratio: 44 | self.valid_dataset = None 45 | else: 46 | self.valid_dataset = KGDataset( 47 | data=self.data, 48 | context_length=self.context_length, 49 | time_factor=self.time_factor, 50 | start_ratio=self.train_end_index_ratio, 51 | end_ratio=self.valid_end_index_ratio, 52 | overlap=True 53 | ) 54 | 55 | if self.valid_end_index_ratio is None: 56 | start_ratio = self.train_end_index_ratio 57 | else: 58 | start_ratio = self.valid_end_index_ratio 59 | if start_ratio >= self.test_end_index_ratio: 60 | self.test_dataset = None 61 | else: 62 | self.test_dataset = KGDataset( 63 | data=self.data, 64 | context_length=self.context_length, 65 | time_factor=self.time_factor, 66 | start_ratio=start_ratio, 67 | end_ratio=self.test_end_index_ratio, 68 | overlap=True 69 | ) 70 | 71 | def get_train_dataloader(self, **kwargs): 72 | return DataLoader( 73 | self.train_dataset, 74 | batch_sampler=KGTrainingBatchSampler(SequentialSampler(self.train_dataset), self.train_dataset), 75 | **kwargs 76 | ) 77 | 78 | def get_valid_dataloader(self, **kwargs): 79 | if kwargs.get('batch_size'): 80 | del kwargs['batch_size'] 81 | return DataLoader( 82 | self.valid_dataset, 83 | # batch_size=self.batch_size, 84 | batch_size=1, 85 | **kwargs 86 | ) 87 | 88 | def get_test_dataloader(self, **kwargs): 89 | return DataLoader( 90 | self.test_dataset, 91 | # batch_size=self.batch_size, 92 | batch_size=1, 93 | **kwargs 94 | ) 95 | 96 | def iterate_dataset_with_original_index(self, dataset: KGDataset): 97 | for idx in range(len(dataset)): 98 | original_idx = dataset.get_original_index(idx) 99 | batch = default_collate([dataset[idx]]) 100 | yield original_idx, batch 101 | 102 | @staticmethod 103 | def build_from_config_dict(config_dict: dict): 104 | data_obj = file_uri_reader_processor(config_dict['data_dir']) 105 | 106 | return KGDatasetFactory( 107 | data=data_obj['data'], 108 | context_length=config_dict['context_length'], 109 | num_relations=config_dict.get('num_relations', data_obj.get('num_rel')), 110 | num_entities=config_dict.get('num_entities', data_obj.get('num_entity')), 111 | train_end_index_ratio=config_dict['train_end_index_ratio'], 112 | valid_end_index_ratio=config_dict['valid_end_index_ratio'], 113 | test_end_index_ratio=config_dict['test_end_index_ratio'], 114 | time_factor=config_dict.get('time_factor', 1.0), 115 | batch_size=config_dict.get('batch_size', 32), 116 | ) 117 | 118 | @staticmethod 119 | def build_from_config_file(config_file: str): 120 | with open(config_file) as config_file: 121 | config = yaml.safe_load(config_file) 122 | return KGDatasetFactory.build_from_config_dict(config['data']) 123 | 124 | 125 | class KGTrainingBatchSampler(Sampler): 126 | 127 | def __init__(self, sampler: Sampler[int], kg_dataset: KGDataset) -> None: 128 | super(KGTrainingBatchSampler, self).__init__(kg_dataset) 129 | 130 | self.sampler = sampler 131 | self.kg_dataset = kg_dataset 132 | 133 | self.batches = list(self._batches_iterator()) 134 | 135 | def __iter__(self): 136 | for batch in self.batches: 137 | yield batch 138 | 139 | def __len__(self): 140 | return len(self.batches) 141 | 142 | def _batches_iterator(self): 143 | batch = [] 144 | existed_entities = set() 145 | for idx in self.sampler: 146 | 147 | sample = self.kg_dataset[idx] 148 | sample_entities = set(sample['seq_subject']).union(set(sample['seq_object'])) 149 | if len(existed_entities) == 0 or len(existed_entities.intersection(sample_entities)) == 0: 150 | # no intersection or first sample 151 | batch.append(idx) 152 | existed_entities = existed_entities.union(sample_entities) 153 | else: 154 | yield batch 155 | batch = [idx] 156 | existed_entities = sample_entities 157 | yield batch 158 | -------------------------------------------------------------------------------- /preprocess/tpp_data_factory.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from torch.utils.data import DataLoader 3 | 4 | from preprocess.datasets.tpp_dataset import TPPDataset 5 | from utils.general import file_uri_reader_processor 6 | 7 | 8 | class TPPDatasetFactory: 9 | def __init__( 10 | self, 11 | data, 12 | num_event_types: int, 13 | train_end_date: str = '2015-08-01', 14 | valid_end_date: str = '2016-02-01', 15 | time_factor: float = 1.0, 16 | ): 17 | # data is the list of tuple 18 | self.data = data 19 | self.num_event_types = num_event_types 20 | self.train_end_date = str(train_end_date) 21 | self.valid_end_date = str(valid_end_date) 22 | self.time_factor = time_factor 23 | 24 | # set datasets 25 | self.train_dataset = TPPDataset( 26 | data=self.data, 27 | time_factor=self.time_factor, 28 | end_date=self.train_end_date, 29 | num_event_types=num_event_types 30 | ) 31 | 32 | if self.valid_end_date is None: 33 | self.valid_dataset = None 34 | else: 35 | self.valid_dataset = TPPDataset( 36 | data=self.data, 37 | time_factor=self.time_factor, 38 | start_date=self.train_end_date, 39 | end_date=self.valid_end_date, 40 | num_event_types=num_event_types 41 | ) 42 | 43 | if self.valid_end_date is None: 44 | self.test_dataset = None 45 | else: 46 | self.test_dataset = TPPDataset( 47 | data=self.data, 48 | time_factor=self.time_factor, 49 | start_date=self.valid_end_date, 50 | num_event_types=num_event_types 51 | ) 52 | 53 | def get_train_dataloader(self, **kwargs): 54 | if 'batch_size' not in kwargs: 55 | kwargs['batch_size'] = 1 56 | return DataLoader( 57 | self.train_dataset, 58 | collate_fn=self.train_dataset.collate_fn, 59 | **kwargs 60 | ) 61 | 62 | def get_valid_dataloader(self, **kwargs): 63 | return DataLoader( 64 | self.valid_dataset, 65 | batch_size=1, 66 | collate_fn=self.valid_dataset.collate_fn, 67 | **kwargs 68 | ) 69 | 70 | def get_test_dataloader(self, **kwargs): 71 | return DataLoader( 72 | self.test_dataset, 73 | batch_size=1, 74 | collate_fn=self.test_dataset.collate_fn, 75 | **kwargs 76 | ) 77 | 78 | @staticmethod 79 | def build_from_config_dict(config_dict: dict): 80 | data_obj = file_uri_reader_processor(config_dict['data_dir']) 81 | 82 | return TPPDatasetFactory( 83 | data=data_obj['user_seqs'], 84 | num_event_types=config_dict.get('num_event_types', data_obj.get('dim_process')), 85 | train_end_date=config_dict['train_end_date'], 86 | valid_end_date=config_dict['valid_end_date'], 87 | time_factor=config_dict.get('time_factor', 1.0), 88 | ) 89 | 90 | @staticmethod 91 | def build_from_config_file(config_file: str): 92 | with open(config_file) as config_file: 93 | config = yaml.safe_load(config_file) 94 | return TPPDatasetFactory.build_from_config_dict(config['data']) 95 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pandas 3 | dill 4 | 5 | # news scraper 6 | feapder 7 | newspaper3k 8 | 9 | # gpt 10 | openai 11 | 12 | torch==1.8.1 13 | scikit-learn 14 | dictdatabase 15 | sentence_transformers 16 | Levenshtein 17 | -------------------------------------------------------------------------------- /scripts/amazon/amazon_prompts/help_msg_1.txt: -------------------------------------------------------------------------------- 1 | I want you to do the abductive reasoning over the events that are extracted from Amazon clothing, shoes and jewelry reviews data. I given you an effect event and you give me two to four cause events that cause the effect event to happen. Each event corresponds to an review submitted by the customer, which consists of an product category(event type), a product title, an event time, summary text and review text from the user that describes the feedback of the shopping event. 2 | 3 | The product categories are restricted to the following set: 4 | 1. Women Shoes, 5 | 2. Men Shoes, 6 | 3. Men Clothing, 7 | 4. Women Clothing, 8 | 5. Novelty & More, 9 | 6. Men Uniforms, Work & Safety, 10 | 7. Women Jewelry, 11 | 8. Costumes & Accessories, 12 | 9. Men Accessories, 13 | 10. Luggage & Travel Gear, 14 | 11. Men Watches, 15 | 12. Women Accessories, 16 | 13. Children Shoes, 17 | 14. Children Clothing 18 | 15. Shoe, Jewelry & Watch Accessories, 19 | 16. Women Watches, 20 | 17. Women Uniforms, Work & Safety, 21 | 18. Men Surf, Skate & Street, 22 | 19. Women Handbags & Wallets 23 | 20. Men Jewelry 24 | 21. Children Accessories 25 | 22. Women Maternity 26 | 23. Women General 27 | 24. Others -------------------------------------------------------------------------------- /scripts/amazon/amazon_prompts/prompt_assistant_1.txt: -------------------------------------------------------------------------------- 1 | cause event 1 2 | product category: Children Shoes 3 | product title: New Balance KE420 Y Lifestyle Running Shoe (Little Kid/Big Kid) 4 | event time: 2013-10-16 5 | summary text: Great Value 6 | review text: New Balance is our favorite tennis shoe brand, and I bought several pairs of these for my kids to wear back to school. They make great athletic and everyday shoes with good foot support for them. 7 | 8 | cause event 2 9 | product category: Women Clothing 10 | product title: Mizuno Select Belted Low Rise Fastpitch Pant 11 | event time: 2013-07-03 12 | summary text: Stylish and Athletic 13 | review text: Great fit and fabric quality. My daughter will get years of use out of these. Color is pretty "bright" too, but that is what we were looking for to match her school color blue. 14 | 15 | cause event 3 16 | product category: Women Shoes 17 | product title: Mizuno Women's Finch Franchise 4 18 | event time: 2013-06-05 19 | summary text: Our 2nd Pair 20 | review text: My daughter swears by these. So much so that it's her second pair in 3 years. Love the vibrant pink color and great on field performance to keep her from slipping or losing traction. 21 | 22 | cause event 4 23 | product category: Children Shoes 24 | product title: New Balance YB4040 Baseball Cleat (Little Kid/Big Kid) 25 | event time: 2013-05-08 26 | summary text: Great value for a great brand 27 | review text: My son claims these make him run faster. Very comfortable as is to be expected with New Balance. I purchased a half size larger to ensure he has room to grow, and the fit is no concern. Very stylish cleat too. -------------------------------------------------------------------------------- /scripts/amazon/amazon_prompts/prompt_assistant_10.txt: -------------------------------------------------------------------------------- 1 | cause event 1 2 | product category: Women Jewelry 3 | product title: 1928 Jewelry Filigree Teardrop Chandelier Earrings 4 | event time: 2014-07-25 5 | summary text: Five Stars 6 | review text: Love these earrings. They are just gorgeous. I like that they are not super big. Love 1928 fashion jewelry. 7 | 8 | cause event 2 9 | product category: Women Jewelry 10 | product title: 1928 Jewelry Gothic Cross Stretch Bracelet 11 | event time: 2014-07-25 12 | summary text: 1928 Gothic cross Blue Twilight Bracelet 13 | review text: I am crazy about bracelets and this one is very beautiful. I like the hematite band and the stones shine in the light with crosses in between. I like 1928 jewelry and have quite a bit of it. Happy to have this gothic bracelet. 14 | 15 | cause event 3 16 | product category: Women Jewelry 17 | product title: Sterling Silver "Sweet 16" Heart Locket Necklace, 18" 18 | event time: 2014-07-25 19 | summary text: Sweet 16 Locket 20 | review text: Very pretty and dainty locket. I got this for my daughter for her birthday in a couple of weeks. I'm sure she will just love it. 21 | 22 | cause event 4 23 | product category: Women Jewelry 24 | product title: 1928 Jewelry Vintage Crystal Drop Earrings 25 | event time: 2014-06-29 26 | summary text: Beautiful blue earrings 27 | review text: Looks so pretty with the necklace. Beautiful set. -------------------------------------------------------------------------------- /scripts/amazon/amazon_prompts/prompt_assistant_11.txt: -------------------------------------------------------------------------------- 1 | cause event 1 2 | product category: Men Clothing 3 | product title: Carhartt All-season Cotton Crew Work Socks 4 | event time: 2015-05-09 5 | summary text: Quality & Comfort 6 | review text: Great quality sock that should last a long time. Very comfortable in all but the most extreme climates. I find that these Carhartt Men's socks keep my feet dry on warm days when hiking, but I would not say that these would keep wicking away the moisture in really hot and humid areas. Great for casual wear also. 7 | Highly recommended. 8 | 9 | cause event 2 10 | product category: Men Uniforms, Work & Safety 11 | product title: Carhartt Men's Cotton Ripstop Relaxed Fit Work Pant 12 | event time: 2015-04-25 13 | summary text: Made Tough for ANY Job 14 | review text: These pants will fill a much needed ability to carry items when we hike the wilderness. With many spacious pockets that can carry the necessities of comfort and survival. There are a few important items I do not trust to a pack or bag that can become separated in a fall or other event. However...outside of the workman looking for a heavy duty pair of pants, or the outdoors-man looking for the same...these pants may not be for you. 15 | The Carhartt Ripstop Relaxed Fit Work Pants are made TOUGH. They will last a very long time and serve you well...they are just heavy, stiff and a bit uncomfortable to wear casually. 16 | Maybe after a bunch of washings they will soften over time, but I fear they are NOT to be confused with a comfy pair of jeans. 17 | 18 | cause event 3 19 | product category: Men Clothing 20 | product title: Carhartt Men's Workwear Pocket Short Sleeve Henley Original Fit Shirt K84 21 | event time: 2015-04-09 22 | summary text: Comfortable and Tough-As-Nails Carhartt Quality 23 | review text: Quality, comfort and style all in a long lasting Carhartt. I may never buy another brand again...it is THAT good. I have received many compliments on this shirt from friends and family who are used to seeing me in just a comfortable pocket Tee- shirt. Now I feel more confident going to a casual restaurant without feeling under dressed. Great workmanship and quality materials that should last a very long time. Love it! 24 | Highly recommended. -------------------------------------------------------------------------------- /scripts/amazon/amazon_prompts/prompt_assistant_12.txt: -------------------------------------------------------------------------------- 1 | cause event 1 2 | product category: Children Shoes 3 | product title: New Balance Kid's KV689 Running Shoe (Little Kid/Big Kid) 4 | event time: 2012-10-22 5 | summary text: little girl shoes 6 | review text: Received the shipment on time. My granddaughter thought these were really beautiful, I thought they were well built and should outgrow her long before they may have to be discarded. I have purchased this brand in the past and have been well satisfied. 7 | 8 | cause event 2 9 | product category: Children Shoes 10 | product title: New Balance KA688 Alternative Closure Running Shoe (Little Kid/Big Kid) 11 | event time: 2012-03-18 12 | summary text: good deal 13 | review text: These are a good brand, a good fit and a good buy. Recommend shoe and company. Grand daughter loves them 14 | 15 | cause event 3 16 | product category: Costumes & Accessories 17 | product title: My First Career Gear - Road Crew Toddler Costume 18 | event time: 2012-02-24 19 | summary text: Impressed 20 | review text: Grand child thinks this has helped her doctoring skills considerably. Its made her happy so as a toy it did its job. -------------------------------------------------------------------------------- /scripts/amazon/amazon_prompts/prompt_assistant_13.txt: -------------------------------------------------------------------------------- 1 | cause event 1 2 | product category: Men Clothing 3 | product title: IZOD Men's Big and Tall Classic Denim Jeans (Regular and Relaxed Fit) 4 | event time: 2014-12-01 5 | summary text: Izod Makes Great Jeans 6 | review text: Izod did it again! This time they made a great pair of jeans in an understated manner. They fit perfectly and are also very roomy with space to spare. The fabric has a nice soft feel but holds its form very well. And they didn't seem to shrink after one washing (and I know I didn't lose any weight!). These jeans can be worn on special occasions as well as for physical activities. I love them and couldn't ask for anything more. Also the price is fair and probably hard to beat. Try a pair. I think you'll be delighted. Lastly they are well packaged and arrive on time. A hard act to beat. 7 | 8 | cause event 2 9 | product category: Men Clothing 10 | product title: Fruit of the Loom Men's Pocket T-Shirt Multipack 11 | event time: 2014-11-23 12 | summary text: A Solid Fairly Priced Tee Shirt 13 | review text: This was a very good purchase. The material is both soft and strong. It fits well with plenty of room. It holds its shape well and even comes with a handy pocket in front. They appear to be made to last because they seem very durable. As if these were not sufficient 14 | reasons to buy these shirts, the price is very good and fair. I recommend these Tee shirts, even for all year round wear (layered). 15 | 16 | cause event 3 17 | product category: Men Clothing 18 | product title: IZOD Men's Big and Tall Pleated Extended Twill Pant 19 | event time: 2014-07-30 20 | summary text: To Iron or Dry Clean is the Question 21 | review text: The company is reputable and they make a good product. The material is soft and the workmanship seems quite good; these are not pants that were "thrown together". They are also very comfortable even at the end of a long day. The only minor fault is they don't seem to be wrinkle-free. I know I could wash them, dry them for a few minutes, and" iron " if needed but I prefer to take them to the dry cleaners. I still highly recommend them particularly for the comfort factor. -------------------------------------------------------------------------------- /scripts/amazon/amazon_prompts/prompt_assistant_14.txt: -------------------------------------------------------------------------------- 1 | cause event 1 2 | product category: Men Clothing 3 | product title: Calvin Klein Men's Straight 4 | event time: 2015-07-05 5 | summary text: Five Stars 6 | review text: Wear comfortable and good style. 7 | 8 | cause event 2 9 | product category: Women Clothing 10 | product title: Levi's Women's 505 Straight Jeans 11 | event time: 2014-12-30 12 | summary text: Nice 13 | review text: To help a friend buy, my friend liked. 14 | 15 | cause event 3 16 | product category: Men Clothing 17 | product title: Levi's Men's 514 Straight Fit Jean 18 | event time: 2014-06-03 19 | summary text: Nice Jeans 20 | review text: Jeans style is very beautiful, color slightly deep than pictures, size is right, the price also is very good. Boyfriend is very like. -------------------------------------------------------------------------------- /scripts/amazon/amazon_prompts/prompt_assistant_15.txt: -------------------------------------------------------------------------------- 1 | cause event 1 2 | product category: Women Jewelry 3 | product title: Sterling Silver Gemstone Flower Pendant Necklace 4 | event time: 2015-02-13 5 | summary text: Perfect for little girls 6 | review text: Bought this for my eight year old daughter. Peridot is her birthstone and she loves it. It's dainty and pretty for any girl. 7 | 8 | cause event 2 9 | product category: Children Shoes 10 | product title: PUMA Court Point V Kids Sneaker (Infant/Toddler) 11 | event time: 2014-08-22 12 | summary text: Great school shoe! 13 | review text: Bought this shoe for my 8 year old and it looks pretty strong. I like that there are no laces too. The laces get pretty dirty fast. 14 | 15 | cause event 3 16 | product category: Children Clothing 17 | product title: AX44 Sunny Fashion Girls Dress Blue Rose Wedding Pageant Kids Boutique Size 9-10 18 | event time: 2013-12-14 19 | summary text: Gorgeous. 20 | review text: Bought this dress for my 7 year old & it's utterly gorgeous! I bought one size bigger & it fits perfect. The price is really cheap for such a well made dress. Looking forward to purchase more dresses from this seller. 21 | -------------------------------------------------------------------------------- /scripts/amazon/amazon_prompts/prompt_assistant_16.txt: -------------------------------------------------------------------------------- 1 | cause event 1 2 | product category: Men Shoes 3 | product title: PUMA Men's Mell Es Suga Cross-Training Shoe 4 | event time: 2015-03-24 5 | summary text: Great, comfortable shoe 6 | review text: This shoes are soooo comfortable. My husband runs with them each morning and he said he doesn't feel a thing. 7 | 8 | cause event 2 9 | product category: Men Clothing 10 | product title: LEE Men's Stain Resistant Relaxed Fit Flat Front Pant 11 | event time: 2015-03-24 12 | summary text: Fits big 13 | review text: I bought this pants for my husband, usually buy the Dickies brand but this brand is somewhat big. He usually wears a 36 in the 874 Dickies but this 36 is quite big. 14 | 15 | cause event 3 16 | product category: Men Accessories 17 | product title: Tommy Hilfiger Men's 100% Leather Passcase Wallet 18 | event time: 2015-02-13 19 | summary text: Great wallet 20 | review text: Great wallet. Bought this for my husband for Valentines and he loves it. It's a nice colour and it fits all his cards and is not bulky. Will highly recommend. 21 | -------------------------------------------------------------------------------- /scripts/amazon/amazon_prompts/prompt_assistant_2.txt: -------------------------------------------------------------------------------- 1 | cause event 1 2 | product category: Children Clothing 3 | product title: NA 4 | event time: 2014-08-27 5 | summary text: lee jeans. 6 | review text: We have some large gkids these were a little to snug for our big boys. 7 | 8 | cause event 2 9 | product category: Children Shoes 10 | product title: New Balance KJ750 Youth Lace-Up Running Shoe (Little Kid/Big Kid) 11 | event time: 2014-04-08 12 | summary text: NEW BALANCE SHOE. 13 | review text: MY GRANDDAUGHTER IS HARD ON HER SHOES. THESE HAVE AND DO HOLD UP, RECOMMEND THEM. 14 | 15 | cause event 3 16 | product category: Costumes & Accessories 17 | product title: My First Career Gear - Road Crew Toddler Costume 18 | event time: 2012-02-24 19 | summary text: Impressed. 20 | review text: Grand child thinks this has helped her doctoring skills considerably. Its made her happy so as a toy it did its job. -------------------------------------------------------------------------------- /scripts/amazon/amazon_prompts/prompt_assistant_3.txt: -------------------------------------------------------------------------------- 1 | cause event 1 2 | product category: Children Shoes 3 | product title: Timberland Kids' 6" Premium Waterproof Boots for Toddlers 4 | event time: 2014-12-02 5 | summary text: Gift 4 my daughter. 6 | review text: My daughter has been wanting a pair for a couple of years now. The price thru amazon was low compared to pricing in stores. This is one of her Christmas gifts. My daughter loves them and they fit perfectly. Arrived very early. 7 | 8 | cause event 2 9 | product category: Women Clothing 10 | product title: Flip Top Cheer/Megaphone Socks 11 | event time: 2014-09-18 12 | summary text: Five stars. 13 | review text: I love them. purchased these to support my daughter, a cheerleader. 14 | 15 | cause event 3 16 | product category: Women Watches 17 | product title: Geneva Platinum CZ Accented Silicone Link Watch, Large Face 18 | event time: 2014-05-13 19 | summary text: gift for my daughter. 20 | review text: My daughter's favorite color is green. She loves this watch. She had been looking for a big face watch. She has it on today.. -------------------------------------------------------------------------------- /scripts/amazon/amazon_prompts/prompt_assistant_4.txt: -------------------------------------------------------------------------------- 1 | cause event 1: 2 | product category: Men Shoes 3 | product title: La Sportiva Men's Wildcat Trail Running Shoe 4 | event time: 2013-12-04 5 | summary text: The absolute best off-road shoes for hiking or trailrunning. 6 | review text: I do a ton of running on various surfaces. While I mix things up when surfaces are okay, there is only one company that makes shoes that truly hold up when surfaces get dicey: La Sportiva. I've gone through a couple pairs of Raptors (awesome shoes for all-out speed) and am on my third pair of Wildcats. I've also tried a couple of their newer shoes but have been less than impressed. The newer MorphoDynamic soles that they have been experimenting with simply do not live up to the standards set by the Wildcats and Raptors. 7 | 8 | cause event 2: 9 | product category: Men Accessories 10 | product title: Outdoor Research Sombriolet Sun Hat. 11 | event time: 2013-12-04 12 | summary text: A hiking must-have. 13 | review text: This lightweight hat has become one of my favorite hiking accessories. Even when I'm trail-running in nothing but running shorts, this hat comes along. To be frank, it's one ugly hat. But I realized a long time ago that, when hiking, I need an effective hat much more than I need a stylish one. This hat can be folded and packed down into pockets or bags. It can be annoying trying to get the brim perfectly shaped once you've unpacked it, so I don't even bother any more -- it's large enough that it doesn't matter. It's incredibly light, yet stays put even on a fairly gusty day. 14 | 15 | cause event 3: 16 | product category: Men Shoes 17 | product title: Vibram Fivefingers Bikila (41 Mens, Light Grey/Palm/Dark Grey) - M345 18 | event time: 2013-12-03 19 | summary text: Best fitness purchase ever. Literally. 20 | review text: I've spent thousands of dollars on fitness and fitness-related products. I hike, I run, I climb, I lift, I bike... you name it, I do it. I've bought tons of equipment, gear, and nutrition/consumables. But if I had to narrow down to one ""best purchase ever"", it would be, hands down, my first pair of Vibram FiveFinger Bikilas. And not just for running; I use them for lifting (you'll never feel more balanced and connected to the floor when maxing out your dead-lifts), cross-training (they're incredibly responsive and agile), yoga (when the woman forces me to join her), you name it.. 21 | 22 | cause event 4: 23 | product category: Men Uniforms, Work & Safety 24 | product title: Dozen Camo Military Bandanas, Army Camouflage Headwraps, Military Head Scarves Do-rags 100% Cotton Bandanna Covers - 12 Pack 25 | event time: 2011-11-22 26 | summary text: Cheap, Thin, Great for Gym or Hike. 27 | review text: I was a bit worried that these would be of too low quality to last very long. So far, I've washed all of these at least a couple dozen times, and have worn them while at the gym, while hiking, as luggage markers, to wipe up spills in a jam, and various other situations. They're not thick, but they always do the job. They also look pretty cool too, in a I'm-secretly-still-twelve-inside kind of way. Would buy again, maybe in a different pattern. -------------------------------------------------------------------------------- /scripts/amazon/amazon_prompts/prompt_assistant_5.txt: -------------------------------------------------------------------------------- 1 | cause event 1 2 | product category: Novelty & More 3 | product title: Sports Katz Peace Out Socks 4 | event time: 2013-09-24 5 | summary text: Peace Out Socks 6 | review text: We ordered these for soccer for my daughter and they worked out well. They are very cute and have held up. They are a quite thick, which is why I only gave them 4 stars. They will work better for fall season; I would not recommend them for summer soccer, if soccer is your intended use. 7 | 8 | cause event 2 9 | product category: Luggage & Travel Gear 10 | product title: Wildkin Pink Retro Kickstart Messenger Bag 11 | event time: 2013-09-09 12 | summary text: Wildkin Kickstart Messenger Bag 13 | review text: My daughter, going to 4th grade, wanted a messenger type bag this year for school. I searched and searched on Amazon for the perfect bag, but all of the adult bags seemed huge (like they would definately look funny with her carrying them). Thank goodness I found this bag. It is the perfect size for her, and her favorite color! The quality is excellent and the shoulder pad is very padded. It is large enough to carry a 3 ring binder, folders, and probably 2 (not huge) school books. I would not recommend this bag for an older school kid who had to carry around many books and binders, obviously, as this line (Wildkin) is made for children. I would highly recommend this bag for an elementary school student age group, it is wonderfully made, with lots of small pockets on the inside! 14 | 15 | cause event 3 16 | product category: Children Clothing 17 | product title: Levi's Girls' 715 Thick Stitch Bootcut Jean 18 | event time: 2013-09-01 19 | summary text: Thick Stitch Bootcut Jeans 20 | review text: I love Levis and this style is sooo cute! My daughter even likes them! The price was amazing in these, as we found similar Levis at Dillard's for $36. The waist of these Levi in particular run small compared to other Levis of the same size. During the Levis sale on Amazon, we purchased these, the Sweetie Skimmers, and the Scooter. All of the others ran almost half an inch larger in the waist. I worry the larger size would be huge on her, but I think I will try the larger size, as they must be comfortable on her or she will not wear them. 21 | 22 | cause event 4 23 | product category: Children Shoes 24 | product title: New Balance KJ750 Running Running Shoe (Lttle Kid/Big Kid) 25 | event time: 2013-02-02 26 | summary text: Cute and durable! 27 | review text: I ordered these for my 8 year old daughter. She usually destroys shoes within months, but these have held up! -------------------------------------------------------------------------------- /scripts/amazon/amazon_prompts/prompt_assistant_6.txt: -------------------------------------------------------------------------------- 1 | cause event 1 2 | product category: Men Uniforms, Work & Safety 3 | product title: Carhartt Men's Canvas Work Dungaree Pant B151 4 | event time: 2013-11-07 5 | summary text: Great Pants- Carhartt quality with a thin lining 6 | review text: Very comfy, waist of 33 was True to size, flexible in crotch to work comfortable in warehouse. Carhart quality, with a thin, (not very heavy) flannel lining. I have been wearing them in 20 degree weather and they keep me nice and warm, I ended up putting on suspenders as they are a little heavy to just wear with a belt, but this does not take away from the purpose, to keep you warm, and they do the job. 7 | 8 | cause event 2 9 | product category: Men Accessories 10 | product title: Danbury Men's 1 1/2" Basic Bridle Leather Work Belt 11 | event time: 2013-08-30 12 | summary text: Great Pants- Carhartt quality with a thin lining 13 | review text: I carry a firearm for work- this belt does the trick, the stictching seems good. I am normally a 34, but it did not give much leather after you put it through the loop, so I ordered the 36 and it is fine. You cant beat the price. I like one piece of leather which this is. With lined belts they are harder to get the holster through, and unless you spend money for genuine leather lined, you will get a "bonded" lining which is not real. 14 | 15 | cause event 3 16 | product category: Uniforms, Work & Safety 17 | product title: Carhartt Men's Canvas Work Dungaree Pant B151 18 | event time: 2013-07-29 19 | summary text: Comfortable lightweight pants 20 | review text: They came pretty long, so they had to be hemmed, but not a big deal. They are very comfortable, I would have been happy with one more pocket, they have the hammer loop, the usually 4 pockets, plus one thin pocket on right leg. I could have used another on left. But all in all good pants. I use them for security, and walk all day with no chafing !! -------------------------------------------------------------------------------- /scripts/amazon/amazon_prompts/prompt_assistant_7.txt: -------------------------------------------------------------------------------- 1 | cause event 1 2 | product category: Men Jewelry 3 | product title: Men's Square Face Faux Onyx Sterling Silver Ring 4 | event time: 2012-12-30 5 | summary text: Nice Look perfect size 6 | review text: It is the same in the picture with nice look and fit size. I like these stones and like it on my hand. will by more soon thanks. 7 | 8 | cause event 2 9 | product category: Women Jewelry 10 | product title: 925 Sterling Silver Bali Inspired Filigree Black Onyx Gemstone Triangle Fan Dangle Earrings 11 | event time: 2012-12-30 12 | summary text: TOO SMALL 13 | review text: It is too small on the ear, was waiting for something bigger, but same on picture. wish this is helpful. 14 | 15 | cause event 3 16 | product category: Women Clothing 17 | product title: US Fairytailes Mother of the Bride Formal Evening Dress 18 | event time: 2012-11-28 19 | summary text: amazing 20 | review text: Awesome with excellent clothe , nice look, with acceptable price. With smooth touch. As described above. Keep going . I like the product and the store i will purchase more. 21 | 22 | cause event 4 23 | product category: Women Watches 24 | product title: Swatch Ladies Watches LK294G - WW 25 | event time: 2012-11-28 26 | summary text: Nice But scratched front 27 | review text: It is a nice product I bought it as a gift and it is nice on hand BUT Front glasses is scratched which make a not good point for the store. It is too small on the hand , with plastic back front for the watch which gives a cheap point to the product. Like for the price and the shape and the other parts of it. -------------------------------------------------------------------------------- /scripts/amazon/amazon_prompts/prompt_assistant_8.txt: -------------------------------------------------------------------------------- 1 | cause event 1 2 | product category: Women Jewelry 3 | product title: Small Stainless Steel Celtic Triquetra Trinity Dangle Earrings, 1 3/4 inch 4 | event time: 2013-12-24 5 | summary text: Love these earrings 6 | review text: These earrings were exactly what I was expecting -- dangly, well made and looked like the picture. I am always looking for fish hook dangly stainless steel up to sterling silver earrings. These are joining my collection. I often surf Amazon's site for earrings of this type and have acquired quite a number of pairs. I will keep shopping at Amazon for these types of earrings. 7 | 8 | cause event 2 9 | product category: Costumes & Accessories 10 | product title: Sterling Silver Celtic Knot and Created Blue Sapphire Ring 11 | event time: 2013-10-06 12 | summary text: Really nice wig! 13 | review text:The hair on this wig is more close to natural feeling hair rather than the straw-like materials used in most Halloween wigs. It does have to be styled and doesn't quite stand up like the picture. I'm looking forward to wearing it with my Halloween costume: California Costumes Deluxe Hooded Robe Adult Costume, Red/Black that I also purchased from Amazon. 14 | 15 | cause event 3 16 | product category: Women Jewelry 17 | product title: Sterling Silver Circle Link Drop Earrings, 2.16" 18 | event time: 2013-08-25 19 | summary text: These are gorgeous! 20 | review text:I love these earrings. They are just the right length of dangly and are comfortable to wear. I gave them a "4" only because when I took them out of the box, they were all tangled up and it took me about 15 minutes to get them straightened out. I have to store them hanging down and be careful or they will tangle yet again. I would buy another pair. 21 | 22 | cause event 4 23 | product category: Women Jewelry 24 | product title: Sterling Silver Angel Wing Drop Earrings 25 | event time: 2014-06-23 26 | summary text: These are OK 27 | review text: I like these earrings, but personally, they are flat and not all that pretty. There is nothing wrong with them, just that they do not look like what I expected. -------------------------------------------------------------------------------- /scripts/amazon/amazon_prompts/prompt_assistant_9.txt: -------------------------------------------------------------------------------- 1 | cause event 1 2 | product category: Men Clothing 3 | product title: Russell Athletic Men's Dri-Power Fleece Sweatshirt 4 | event time: 2014-05-05 5 | summary text: SWEATSHIRT 6 | review text: I HAVE BOUGHT THIS BRAND SWEATSHIRT BEFORE AND I LIKE THEM VERY MUCH. THEY ARRIVED WHEN EXPECTED, AND FIT WELL AND WERE IN EXCELLENT CONDITION. 7 | 8 | cause event 2 9 | product category: Men Uniforms, Work & Safety 10 | product title: Dickies Men's Big-Tall Long-Sleeve Work Shirt 11 | event time: 2014-03-13 12 | summary text: DICKIE SHIRT 13 | review text:I HAVE BEEN BUYING THESE SHIRTS FOR OVER 30 YEARS. I WORK CONSTRUCTION AND THESE SHIRTS HOLD UP. THE SHIRT FIT WELL AND THE COLOR IS FINE. IT IS WHAT I EXPECTED AND IAM VERY SATISFIED. 14 | 15 | cause event 3 16 | product category: Men Clothing 17 | product title: Fruit of the Loom Men's boxer brief 18 | event time: 2014-03-13 19 | summary text: BOXER BRIEFS 20 | review text:THESE SHORTS ARE VERY COMFORTABLE AND FIT VERY WELL. I WOULD DEFINATELY BUY THESE SHORTS AGAIN. THEY ARIIVED WHEN EXPECTED AND I AM SATISFIED. 21 | 22 | cause event 4 23 | product category: Men Clothing 24 | product title: Dockers Men's 4 Pack Fly Front Full Rise Brief, White 25 | event time: 2014-03-13 26 | summary text: BRIEFS 27 | review text: THESE ARE A GOOD QUALITY WELL CONSTRUCTED BRIEF. I WOULD BUY THESE AGAIN, AND RECOMMEND THEM TO OTHERS. ARRIVED AS EXPECTED AND WERE IN PERFECT CONDITION. -------------------------------------------------------------------------------- /scripts/amazon/amazon_prompts/prompt_user_1.txt: -------------------------------------------------------------------------------- 1 | effect event 2 | product category: Children Shoes 3 | event time: 2014-03-27 4 | -------------------------------------------------------------------------------- /scripts/amazon/amazon_prompts/prompt_user_10.txt: -------------------------------------------------------------------------------- 1 | effect event 2 | product category: Shoe, Jewelry & Watch Accessories 3 | event time: 2014-07-26 -------------------------------------------------------------------------------- /scripts/amazon/amazon_prompts/prompt_user_11.txt: -------------------------------------------------------------------------------- 1 | effect event 2 | product category: Men Clothing 3 | event time: 2015-06-03 -------------------------------------------------------------------------------- /scripts/amazon/amazon_prompts/prompt_user_12.txt: -------------------------------------------------------------------------------- 1 | effect event 2 | product category: Children Shoes 3 | event time: 2014-04-08 -------------------------------------------------------------------------------- /scripts/amazon/amazon_prompts/prompt_user_13.txt: -------------------------------------------------------------------------------- 1 | effect event 2 | product category: Others 3 | event time: 2014-12-06 -------------------------------------------------------------------------------- /scripts/amazon/amazon_prompts/prompt_user_14.txt: -------------------------------------------------------------------------------- 1 | effect event 2 | product category: Women Clothing 3 | event time: 2015-07-13 -------------------------------------------------------------------------------- /scripts/amazon/amazon_prompts/prompt_user_15.txt: -------------------------------------------------------------------------------- 1 | effect event 2 | product category: Children Clothing 3 | event time: 2015-04-27 -------------------------------------------------------------------------------- /scripts/amazon/amazon_prompts/prompt_user_16.txt: -------------------------------------------------------------------------------- 1 | effect event 2 | product category: Men Clothing 3 | event time: 2015-08-21 -------------------------------------------------------------------------------- /scripts/amazon/amazon_prompts/prompt_user_2.txt: -------------------------------------------------------------------------------- 1 | effect event 2 | product category: Children Clothing 3 | event time: 2015-09-01 4 | -------------------------------------------------------------------------------- /scripts/amazon/amazon_prompts/prompt_user_3.txt: -------------------------------------------------------------------------------- 1 | effect event 2 | product category: Men Clothing 3 | event time: 2015-01-28 -------------------------------------------------------------------------------- /scripts/amazon/amazon_prompts/prompt_user_4.txt: -------------------------------------------------------------------------------- 1 | effect event 2 | product category: Men Shoes 3 | event time: 2014-11-19 -------------------------------------------------------------------------------- /scripts/amazon/amazon_prompts/prompt_user_5.txt: -------------------------------------------------------------------------------- 1 | effect event 2 | product category: Luggage & Travel Gear 3 | event time: 2013-10-19 -------------------------------------------------------------------------------- /scripts/amazon/amazon_prompts/prompt_user_6.txt: -------------------------------------------------------------------------------- 1 | effect event 2 | product category: Men Uniforms, Work & Safety 3 | event time: 2015-08-29 -------------------------------------------------------------------------------- /scripts/amazon/amazon_prompts/prompt_user_7.txt: -------------------------------------------------------------------------------- 1 | effect event 2 | product category: Women Jewelry 3 | event time: 2013-02-18 -------------------------------------------------------------------------------- /scripts/amazon/amazon_prompts/prompt_user_8.txt: -------------------------------------------------------------------------------- 1 | effect event 2 | product category: Women Jewelry 3 | event time: 2014-05-18 -------------------------------------------------------------------------------- /scripts/amazon/amazon_prompts/prompt_user_9.txt: -------------------------------------------------------------------------------- 1 | effect event 2 | product category: Men Clothing 3 | event time: 2014-07-08 -------------------------------------------------------------------------------- /scripts/amazon/chatgpt_script.txt: -------------------------------------------------------------------------------- 1 | input: We are using Amazon Review Data, published by Jianmo Ni, UCSD, 2018, which is firstly used in the paper "Justifying recommendations using distantly-labeled reviews and fined-grained aspects, EMNLp 2019. 2 | 3 | The dataset includes millions of reviews (ratings, text, helpfulness votes), product metadata (descriptions, category information, price, brand, and image features) in the range May 1996 - Oct 2018. Do you remember any of the review information if I give you a review id? 4 | 5 | ChatGPT output: As an AI language model, I am not capable of remembering any review information unless I have been programmed to store and retrieve specific pieces of information. However, I can search for and retrieve review information from the dataset if given a review ID. 6 | 7 | input: can you recall any information with reviewer ID A28GK1G2KDXHRP on date 04 17, 2014. 8 | 9 | ChatGPT output: I apologize, but the Amazon Review dataset does not include information about the specific date when a review was posted. It only includes the year and month when the review was posted. Therefore, I cannot provide you with any information about reviews posted by reviewer ID A28GK1G2KDXHRP on April 17, 2014. 10 | 11 | input: Reviewer A0148968UM59JS3Y8D1M bought Hanes Men's Classic Crew on 01 18, 2014, what review did he wrote? 12 | 13 | ChatGPT output: I apologize, but as an AI language model, I do not have access to real-time or updated information about specific Amazon reviews or purchases made by individual reviewers. Additionally, the Amazon Review dataset does not include information about the specific content of a specific review. It only includes metadata such as the reviewer ID, product ID, and overall rating. Therefore, I cannot provide you with the review written by the reviewer A0148968UM59JS3Y8D1M for the Hanes Men's Classic Crew on January 18, 2014 or any other date. 14 | 15 | input: Reviewer A0148968UM59JS3Y8D1M bought Hanes Men's Classic Crew on 01 18, 2014, his review starts with 'they seem very well made..', can you recall the complete review he wrote? 16 | 17 | ChatGPT output: I apologize, but as an AI language model, I do not have access to real-time or updated information about specific Amazon reviews or purchases made by individual reviewers. Additionally, the Amazon Review dataset does not include information about the specific content of a specific review. It only includes metadata such as the reviewer ID, product ID, and overall rating. Therefore, I cannot provide you with the complete review written by the reviewer A0148968UM59JS3Y8D1M for the Hanes Men's Classic Crew on January 18, 2014, or any other date. 18 | 19 | input: on which date did reviewer A0148968UM59JS3Y8D1M buy Hanes Men's 7-Pack ComfortSoft Briefs? 20 | 21 | ChatGPT output: Sorry, as an AI language model, I do not have access to personal information or purchase history of individuals. 22 | -------------------------------------------------------------------------------- /scripts/amazon/check_script.txt: -------------------------------------------------------------------------------- 1 | input: We are using Amazon Review Data, published by Jianmo Ni, UCSD, 2018, which is firstly used in the paper "Justifying recommendations using distantly-labeled reviews and fined-grained aspects, EMNLp 2019. 2 | 3 | The dataset includes millions of reviews (ratings, text, helpfulness votes), product metadata (descriptions, category information, price, brand, and image features) in the range May 1996 - Oct 2018. Do you remember any of the review information if I give you a review id? 4 | 5 | ChatGPT output: As an AI language model, I am not capable of remembering any review information unless I have been programmed to store and retrieve specific pieces of information. However, I can search for and retrieve review information from the dataset if given a review ID. 6 | 7 | input: can you recall any information with reviewer ID A28GK1G2KDXHRP on date 04 17, 2014. 8 | 9 | ChatGPT output: I apologize, but the Amazon Review dataset does not include information about the specific date when a review was posted. It only includes the year and month when the review was posted. Therefore, I cannot provide you with any information about reviews posted by reviewer ID A28GK1G2KDXHRP on April 17, 2014. 10 | 11 | input: Reviewer A0148968UM59JS3Y8D1M bought Hanes Men's Classic Crew on 01 18, 2014, what review did he wrote? 12 | 13 | ChatGPT output: I apologize, but as an AI language model, I do not have access to real-time or updated information about specific Amazon reviews or purchases made by individual reviewers. Additionally, the Amazon Review dataset does not include information about the specific content of a specific review. It only includes metadata such as the reviewer ID, product ID, and overall rating. Therefore, I cannot provide you with the review written by the reviewer A0148968UM59JS3Y8D1M for the Hanes Men's Classic Crew on January 18, 2014 or any other date. 14 | 15 | input: Reviewer A0148968UM59JS3Y8D1M bought Hanes Men's Classic Crew on 01 18, 2014, his review starts with 'they seem very well made..', can you recall the complete review he wrote? 16 | 17 | ChatGPT output: I apologize, but as an AI language model, I do not have access to real-time or updated information about specific Amazon reviews or purchases made by individual reviewers. Additionally, the Amazon Review dataset does not include information about the specific content of a specific review. It only includes metadata such as the reviewer ID, product ID, and overall rating. Therefore, I cannot provide you with the complete review written by the reviewer A0148968UM59JS3Y8D1M for the Hanes Men's Classic Crew on January 18, 2014, or any other date. 18 | 19 | input: on which date did reviewer A0148968UM59JS3Y8D1M buy Hanes Men's 7-Pack ComfortSoft Briefs? 20 | 21 | ChatGPT output: Sorry, as an AI language model, I do not have access to personal information or purchase history of individuals. 22 | 23 | 24 | 25 | A0148968UM59JS3Y8D1M,B000LR8JJ6,"01 18, 2014",1390003200,170,Clothing,shirts,they seem very well made and are very comfortable to wear. i do wear them outside alot in warm weather,Hanes Men's Classic Crew (Pack of 3),Hanes Classics is superior quality underwear with classic styling for discerning consumers. 26 | -------------------------------------------------------------------------------- /scripts/amazon/step_1_extract_data.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import warnings 3 | 4 | import pandas as pd 5 | 6 | warnings.filterwarnings('ignore') 7 | 8 | 9 | def parse(path): 10 | g = gzip.open(path, 'rb') 11 | for l in g: 12 | yield eval(l.decode("utf-8").replace('true', 'True').replace('false', 'False')) 13 | 14 | 15 | def getDF(path): 16 | i = 0 17 | df = {} 18 | for d in parse(path): 19 | if i < 5000000: 20 | i += 1 21 | continue 22 | df[i] = d 23 | i += 1 24 | if i % 100000 == 0: 25 | print(i) 26 | if i > 10000000: 27 | break 28 | 29 | return pd.DataFrame.from_dict(df, orient='index') 30 | 31 | 32 | if __name__ == '__main__': 33 | review_df = pd.read_csv('review_df.csv', header=0) 34 | # meta_df = getDF(meta_dir) 35 | meta_df = pd.read_csv('meta_df.csv', header=0) 36 | meta_cols = ['category', 'description', 'brand', 'title', 'asin'] 37 | meta_df = meta_df[meta_cols] 38 | res = pd.merge(review_df, meta_df[meta_cols], on='asin') 39 | res.to_csv('merge_df_v2.csv', header=True, index=False) 40 | -------------------------------------------------------------------------------- /scripts/amazon/step_2_clean_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from sklearn.preprocessing import LabelEncoder 4 | 5 | 6 | class AmazonDataProcessor: 7 | def __init__(self, 8 | source_dir, 9 | target_clean_dir='clean_v0223.csv', 10 | target_user_filter_dir='filer_user_v0306.csv'): 11 | self.source_dir = source_dir 12 | self.target_clean_dir = target_clean_dir 13 | self.target_user_filter_dir = target_user_filter_dir 14 | 15 | @staticmethod 16 | def retrieve_description_text(x): 17 | text = eval(x)[0] if not pd.isnull(x) else x 18 | return text 19 | 20 | @staticmethod 21 | def retrieve_cate_text(x): 22 | text = eval(x) 23 | text = text[1:] 24 | if text[0] in ['Women', 'Men']: 25 | if len(text) > 1: # ['Men' 'Shoes'] => 'Men Shoes' 26 | res = text[0] + ' ' + text[1] 27 | else: # ['Man'] = > 'Men Others' 28 | res = text[0] + ' General' 29 | elif text[0] in ['Boys', 'Girls', 'Baby', 'Baby Girls', 'Baby Boys']: 30 | if len(text) > 1: # ['Boys' 'Shoes'] => 'Men Shoes' 31 | if text[1] not in ['Baby', 'Baby Girls', 'Baby Boys']: 32 | res = 'Children ' + text[1] 33 | elif len(text) > 2: 34 | res = 'Children ' + text[2] 35 | else: 36 | res = 'Children General' 37 | else: # ['Man'] = > 'Boys Others' 38 | res = 'Children General' 39 | else: 40 | res = text[0] 41 | 42 | return res 43 | 44 | @staticmethod 45 | def cate_map(cate_text, cate_event_map_df, target_col): 46 | res = cate_event_map_df[cate_event_map_df['cate_text'] == cate_text][target_col].to_list()[0] 47 | return res 48 | 49 | def _1_clean_df(self): 50 | source_df = pd.read_csv(self.source_dir, header=0) 51 | print(len(source_df)) 52 | keep_cols = ['reviewerID', 'asin', 'reviewTime', 'unixReviewTime', 'cate_id', 'cate_text', 53 | 'summary_text', 'review_text', 'title_text', 'description_text'] 54 | 55 | source_df['cate_text'] = source_df['category'].apply(self.retrieve_cate_text) 56 | print(len(source_df)) 57 | source_df['description_text'] = source_df['description'].apply(self.retrieve_description_text) 58 | 59 | print(len(source_df)) 60 | 61 | source_df.rename(columns={'summary': 'summary_text', 'reviewText': 'review_text', 62 | 'title': 'title_text'}, inplace=True) 63 | 64 | # pd.DataFrame(source_df['cate_text'].value_counts(normalize=True).reset_index()).to_csv('cate_type.csv', 65 | # header=True, index=False) 66 | 67 | 68 | # encode the cate as the event type 69 | le = LabelEncoder() 70 | source_df['cate_id'] = le.fit_transform(source_df['cate_text']) 71 | 72 | # to slow to do the encoding here, so i comment these lines 73 | # cate_map_df = pd.read_csv('cate_type.csv', header=0) 74 | # source_df['cate_id'] = source_df['cate_text'].apply(lambda x: self.cate_map(x, cate_map_df, target_col='cate_id')) 75 | # source_df['cate_text_clean'] = source_df['cate_text'].apply(lambda x: self.cate_map(x, cate_map_df, target_col='cate_text_clean')) 76 | # source_df.drop(columns=['cate_text'], inplace=True) 77 | source_df[keep_cols].to_csv(self.target_clean_dir, index=False, header=True) 78 | 79 | # 939254400 80 | print(min(source_df['unixReviewTime'])) 81 | print(len(np.unique(source_df['cate_id']))) # 436 82 | print(len(source_df)) 83 | 84 | return 85 | 86 | def _2_filter_users(self, num_users=2500): 87 | merge_df = pd.read_csv(self.target_clean_dir, header=0) 88 | print(len(merge_df)) 89 | 90 | res_1 = pd.DataFrame() 91 | idx = 0 92 | for name, group in merge_df.groupby(['reviewerID']): 93 | group.sort_values('unixReviewTime', inplace=True) 94 | # drop duplicate 95 | group.drop_duplicates(inplace=True, keep='first') 96 | 97 | # drop duplicates that have the same values on time, cate id , cate_text and summary text 98 | group.drop_duplicates(subset=['unixReviewTime', 99 | 'cate_id', 100 | 'cate_text', 101 | 'summary_text', 102 | 'review_text', 103 | 'title_text', 104 | 'description_text'], inplace=True, keep='first') 105 | 106 | if len(group) < 20: 107 | continue 108 | 109 | idx += 1 110 | print(idx) 111 | group.index = list(range(len(group))) 112 | 113 | # set some invalid cell to be blank 114 | for i in range(len(group)): 115 | if 'var aPageStart' in group.loc[i, 'title_text']: 116 | group.loc[i, 'title_text'] = '' 117 | 118 | group['reviewerID'] = len(group) * [name] 119 | res_1 = pd.concat([res_1, pd.DataFrame(group)]) 120 | if idx > num_users: 121 | break 122 | 123 | # remap the cate map 124 | # i do it here because it is much faster than do it for the whole dataset 125 | cate_map_df = pd.read_csv('cate_type.csv', header=0) 126 | res_1['cate_id'] = res_1['cate_text'].apply(lambda x: self.cate_map(x, cate_map_df, target_col='cate_id')) 127 | res_1['cate_text_clean'] = res_1['cate_text'].apply( 128 | lambda x: self.cate_map(x, cate_map_df, target_col='cate_text_clean')) 129 | 130 | res_1.to_csv(self.target_user_filter_dir, header=True, index=False) 131 | 132 | return 133 | 134 | 135 | if __name__ == '__main__': 136 | # source_dir = '../../scripts/amazon/merge_df_v2.csv' 137 | # target_clean_dir = '../../scripts/amazon/clean_v0307.csv' 138 | # target_user_filter_dir = 'filer_user_v0308.csv' 139 | # data_processor = AmazonDataProcessor(source_dir=source_dir, 140 | # target_clean_dir=target_clean_dir, 141 | # target_user_filter_dir=target_user_filter_dir) 142 | # data_processor._2_filter_users() 143 | 144 | target_clean_dir = '../../scripts/amazon/clean_v0307.csv' 145 | source_df = pd.read_csv(target_clean_dir, header=0) 146 | # print(min(source_df['unixReviewTime'])) # 939254400 147 | # print(len(np.unique(source_df['cate_id']))) # 155 148 | # print(len(source_df)) # 3641673 # 10296612 149 | target_id = 'A117Q3W4LPC9VL' 150 | source_df = source_df[source_df['reviewerID']==target_id] 151 | source_df.to_csv('tmp.csv',header=True, index=False) 152 | print(source_df) -------------------------------------------------------------------------------- /scripts/amazon/step_3_make_model_input.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from datetime import datetime 3 | 4 | import numpy as np 5 | import pandas as pd 6 | 7 | 8 | class InputProcessor: 9 | def __init__(self, source_csv_dir, train_end_date, valid_end_date): 10 | self.source_csv_dir = source_csv_dir 11 | self.source_df = pd.read_csv(source_csv_dir, header=0, encoding='latin-1') 12 | self.train_end_date = datetime.strptime(train_end_date, '%Y-%m-%d') 13 | self.valid_end_date = datetime.strptime(valid_end_date, '%Y-%m-%d') 14 | 15 | def data_analysis(self, train_end_date, valid_end_date): 16 | self.source_df['review_time'] = pd.to_datetime(self.source_df['reviewTime']) 17 | 18 | train_df = self.source_df[self.source_df['review_time'] < train_end_date] 19 | print(f'num train data {len(train_df)}') 20 | 21 | valid_df = self.source_df[self.source_df['review_time'] < valid_end_date] 22 | print(f'num valid data {len(valid_df) - len(train_df)}') 23 | 24 | print(f'num test data {len(self.source_df) - len(valid_df)}') 25 | 26 | return 27 | 28 | def validate_user_df(self): 29 | user_list = np.unique(self.source_df['reviewerID']) 30 | print(f'num user {len(user_list)}') 31 | 32 | seq_len_list = [] 33 | 34 | self.source_df['event_time'] = self.source_df['unixReviewTime'] / 86400 35 | 36 | res = pd.DataFrame() 37 | for idx, user in enumerate(user_list): 38 | user_df = self.source_df[self.source_df['reviewerID'] == user] 39 | user_df.sort_values(by=['unixReviewTime'], inplace=True) 40 | user_df.index = np.arange(len(user_df)) 41 | user_df['event_dtime'] = user_df['event_time'].diff() 42 | user_df['review_time'] = pd.to_datetime(user_df['reviewTime']) 43 | seq_len_list.append(len(user_df)) 44 | 45 | for row_idx in range(len(user_df)): 46 | if row_idx == 0: 47 | user_df.loc[0, 'event_dtime'] = 0.0 48 | prev_event_time = user_df.loc[0, 'event_time'] 49 | continue 50 | 51 | # if reviews at the same day, we set it as a random value + 3 hours interval 52 | if user_df.loc[row_idx, 'event_dtime'] == 0: 53 | delta_t = 0.1 + np.random.uniform(-1, 1) * 0.05 54 | user_df.loc[row_idx, 'event_dtime'] = delta_t 55 | # need to update event_time 56 | user_df.loc[row_idx, 'event_time'] = prev_event_time + delta_t 57 | else: 58 | user_df.loc[row_idx, 'event_dtime'] = user_df.loc[row_idx, 'event_time'] - user_df.loc[ 59 | row_idx - 1, 'event_time'] 60 | 61 | prev_event_time = user_df.loc[row_idx, 'event_time'] 62 | 63 | # check 64 | diff_event_time = user_df['event_time'].diff().dropna() 65 | event_dtime = user_df['event_dtime'].values[1:] 66 | assert sum(abs(diff_event_time - event_dtime)) < 0.0001 67 | 68 | res = pd.concat([res, user_df]) 69 | 70 | print(f'avg seq len {np.mean(seq_len_list)}') 71 | print(f'max seq len {np.max(seq_len_list)}') 72 | print(f'min seq len {np.min(seq_len_list)}') 73 | 74 | 75 | min_event_time = np.min(res['event_time']) 76 | res['event_time'] -= min_event_time 77 | 78 | return res 79 | 80 | def retrieve_text(self, input_text, num_tokens): 81 | if pd.isnull(input_text): 82 | return '' 83 | else: 84 | return input_text.replace('\n', ' ').replace('\r', '')[:num_tokens] 85 | 86 | def make_seq(self, total_user_df): 87 | user_list = np.unique(total_user_df['reviewerID']) 88 | res_seqs = {} 89 | for idx, user in enumerate(user_list): 90 | user_df = total_user_df[total_user_df['reviewerID'] == user] 91 | user_df.sort_values(by=['event_time'], inplace=True) 92 | user_df.index = np.arange(len(user_df)) 93 | user_seq = [] 94 | for i in range(len(user_df)): 95 | cate_text = self.retrieve_text(user_df.loc[i, 'cate_text_clean'], 50) 96 | title_text = self.retrieve_text(user_df.loc[i, 'title_text'], 50) 97 | summary_text = self.retrieve_text(user_df.loc[i, 'summary_text'], 50) 98 | review_text = self.retrieve_text(user_df.loc[i, 'review_text'], 1000) 99 | event_date = user_df.loc[i, 'reviewTime'] 100 | event_date = datetime.strptime(event_date, '%m %d, %Y') 101 | event_date = event_date.strftime('%Y-%m-%d') 102 | 103 | event_text = 'product category:' + cate_text + ';product title:' + title_text + ';event time:' \ 104 | + event_date + ';summary text:' + summary_text + ';review text:' + review_text + '.' 105 | 106 | temp_dict = {'event_date': event_date, 107 | 'event_time': user_df.loc[i, 'event_time'], 108 | 'event_dtime': user_df.loc[i, 'event_dtime'], 109 | 'event_type': user_df.loc[i, 'cate_id'], 110 | 'event_text': event_text} 111 | user_seq.append(temp_dict) 112 | 113 | res_seqs[idx] = user_seq 114 | 115 | return res_seqs 116 | 117 | def save_to_pkl(self): 118 | 119 | total_df = self.validate_user_df() 120 | total_seqs = self.make_seq(total_df) 121 | 122 | with open('amazon_v0327.pkl', "wb") as f_out: 123 | pickle.dump( 124 | { 125 | "dim_process": 24, 126 | 'user_seqs': total_seqs 127 | }, f_out 128 | ) 129 | 130 | return 131 | 132 | 133 | if __name__ == '__main__': 134 | source_dir = 'filer_user_v0308.csv' 135 | processor = InputProcessor(source_dir, train_end_date='2015-08-01', valid_end_date='2016-02-01') 136 | processor.data_analysis(train_end_date='2015-08-01', valid_end_date='2016-02-01') 137 | # # 46955, 6597, 12994 '2015-08-01', '2016-02-01' 138 | # # 49302, 5283, 11961 139 | # # 44561, 7844, 14141 140 | # processor.save_to_pkl() 141 | 142 | # entity_map = pd.read_csv('cate_type.csv', header=0) 143 | # res = entity_map[['cate_id', 'cate_text_clean']].loc[:24, :] 144 | # res = dict(res.values) 145 | # 146 | # with open('amazon_type_map.pkl', 'wb') as handle: 147 | # pickle.dump(res, handle) 148 | -------------------------------------------------------------------------------- /scripts/amazon/step_4_query_chatgpt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import openai 3 | import elara 4 | from concurrent.futures import ThreadPoolExecutor 5 | 6 | from utils.general import file_uri_reader_processor 7 | from utils.gpt_utils import generate_gdelt_prompt_amazon 8 | import dictdatabase as DDB 9 | 10 | 11 | def read_text_file(fn): 12 | with open(fn, 'r') as file: 13 | return file.read() 14 | 15 | 16 | class EventQuery: 17 | def __init__(self, api_key, prompt_folder: str, num_prompts: int = 12): 18 | openai.api_key = api_key 19 | self.setup_msgs = [] 20 | 21 | # process prompts 22 | system_msgs = [] 23 | prompt_assistant_msgs = [] 24 | prompt_user_msgs = [] 25 | help_msgs = [] 26 | if not os.path.exists(prompt_folder): 27 | raise RuntimeError(f'Prompt folder is not existed: {prompt_folder}') 28 | for fn in sorted(os.listdir(prompt_folder)): 29 | file_path = os.path.join(prompt_folder, fn) 30 | if not os.path.isfile(file_path): 31 | continue 32 | msg = read_text_file(file_path) 33 | if fn.startswith('system_msg'): 34 | system_msgs.append(msg) 35 | elif fn.startswith('prompt_assistant'): 36 | prompt_assistant_msgs.append(msg) 37 | elif fn.startswith('prompt_user'): 38 | prompt_user_msgs.append(msg) 39 | elif fn.startswith('help_msg'): 40 | help_msgs.append(msg) 41 | 42 | for msg in system_msgs: 43 | self.setup_msgs.append({ 44 | 'role': 'system', 45 | 'content': msg 46 | }) 47 | for msg in help_msgs: 48 | self.setup_msgs.append({ 49 | 'role': 'user', 50 | 'content': msg 51 | }) 52 | for user_msg, assistant_msg in zip(prompt_user_msgs[:num_prompts], prompt_assistant_msgs[:num_prompts]): 53 | self.setup_msgs.append({ 54 | 'role': 'user', 55 | 'content': user_msg 56 | }) 57 | self.setup_msgs.append({ 58 | 'role': 'assistant', 59 | 'content': assistant_msg 60 | }) 61 | 62 | def query(self, msg): 63 | msg_list = self.setup_msgs + [{ 64 | 'role': 'user', 65 | 'content': msg 66 | }] 67 | completions = openai.ChatCompletion.create( 68 | model="gpt-3.5-turbo", 69 | messages=msg_list, 70 | stream=False 71 | ) 72 | print('Usage:', completions['usage']['total_tokens']) 73 | body = completions['choices'][0]['message']['content'] 74 | return body 75 | # for i, choice in enumerate(completions['choices']): 76 | # print('---------------- choice', i) 77 | # print(choice['message']['content']) 78 | 79 | 80 | def gpt_generate_causal_events( 81 | db_base_name, 82 | gpt, 83 | pred_data, 84 | source_data, 85 | inference_type: str = 'type', 86 | top_k: int = 5, 87 | num_threads: int = 16 88 | ): 89 | msg_head = '''Now I give you an effect event, and you give me three to four cause events.\n\n''' 90 | 91 | def _process_one_type_or_time(idx, type_or_date, text): 92 | try: 93 | existed_dict = DDB.at(db_base_name, inference_type, key=str(idx)).read() 94 | if existed_dict is not None and str(type_or_date) in existed_dict: 95 | return 96 | ret_body = gpt.query(msg_head + text) 97 | with DDB.at(db_base_name, inference_type).session() as (sess, obj): 98 | if str(idx) in obj: 99 | obj[str(idx)][str(type_or_date)] = ret_body 100 | else: 101 | obj[str(idx)] = { 102 | str(type_or_date): ret_body 103 | } 104 | sess.write() 105 | print(idx, type_or_date) 106 | except Exception as e: 107 | print('Error of', str(e)) 108 | 109 | executor = ThreadPoolExecutor(max_workers=num_threads) 110 | init_db = DDB.at(db_base_name, inference_type).read() 111 | if init_db is None: 112 | DDB.at(db_base_name, inference_type).create() 113 | 114 | for prediction_seq in pred_data[-10:]: 115 | for one_prediction in prediction_seq: 116 | seq_idx = one_prediction['seq_idx'] 117 | original_idx = one_prediction['original_idx'] 118 | idx = (str(seq_idx), str(original_idx)) 119 | if inference_type == 'type': 120 | msg_content_dict = generate_gdelt_prompt_amazon(one_prediction, source_data, top_k=top_k, pred_type=True) 121 | else: 122 | msg_content_dict = generate_gdelt_prompt_amazon(one_prediction, source_data, top_k=top_k, pred_type=False) 123 | for rel_or_obj, content in msg_content_dict.items(): 124 | # check if the relation of the idx is existed 125 | if init_db is not None and str(idx) in init_db and str(rel_or_obj) in init_db[str(idx)]: 126 | continue 127 | executor.submit(_process_one_type_or_time, str(idx), rel_or_obj, content) 128 | 129 | executor.shutdown(wait=True) 130 | 131 | 132 | if __name__ == '__main__': 133 | # predictions draw from base model 134 | base_model_prediction_fn = '../../logs/tpp_amazon_test.pkl' 135 | # original sequence data 136 | raw_seq_data_fn = '../../data/gdelt/amazon.pkl' 137 | 138 | gpt = EventQuery( 139 | api_key='*', 140 | prompt_folder='amazon_prompts', 141 | num_prompts=8 142 | ) 143 | 144 | pred_data = file_uri_reader_processor(base_model_prediction_fn) 145 | source_data = file_uri_reader_processor(raw_seq_data_fn)['user_seqs'] 146 | 147 | gpt_generate_causal_events( 148 | 'amazon_chatgpt', 149 | gpt, 150 | pred_data, 151 | source_data, 152 | inference_type='time', 153 | top_k=5, 154 | num_threads=6 155 | ) 156 | -------------------------------------------------------------------------------- /scripts/amazon/step_5_make_emb_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | from utils.general import file_uri_reader_processor 3 | from utils.misc import make_samples_for_energy_function_amazon 4 | 5 | if __name__ == '__main__': 6 | # source_data_fn = '../../data/amazon/amazon_v0327.pkl' 7 | source_data_fn = 'amazon_emb.pkl' 8 | # pred_data_fn = '../../logs/ke_gdelt_test.pkl' 9 | pred_data_fn = '../../logs/tpp_test_v0426.pkl' 10 | 11 | source_data = file_uri_reader_processor(source_data_fn)['user_seqs'] 12 | pred_data = file_uri_reader_processor(pred_data_fn) 13 | 14 | make_samples_for_energy_function_amazon( 15 | gpt_db_name='amazon_chatgpt/type', 16 | source_data=source_data, 17 | pred_data=pred_data, 18 | pred_type='type', 19 | topk=5, 20 | # pred_type='object', 21 | # topk=20, 22 | ebm_db_name='anhp_amazon_ebm_dataset_v3', 23 | retro_top_n=5, # make the sequence longer 24 | distance_type='bert' 25 | ) 26 | -------------------------------------------------------------------------------- /scripts/gdelt/gdelt_prompts_simplified/help_msg_1.txt: -------------------------------------------------------------------------------- 1 | I given you an effect event and you give me four or five cause events. Each event consists of an event headline, an event predicate, an event time, subject name and object name of describing the event. 2 | 3 | The event predicates of the effect and cause events are restricted to 20 types, with names (in capital letters) and the descriptions below. 4 | 1. MAKE STATEMENT: decline comment, make pessimistic or optimistic comments. 5 | 2. APPEAL: appeal for material, economic, military, humanitarian aid or cooperation. 6 | 3. EXPRESS INTENT TO COOPERATE: Express intent to engage in material, diplomatic, military, judicial cooperation or provide aid. 7 | 4. CONSULT: make a visit, host a visit, meet at a third location, engage in mediation and negotiation. 8 | 5. ENGAGE IN DIPLOMATIC COOPERATION: praise or endorse, defend verbally, grant diplomatic recognition. Apologize. Forgive. 9 | 6. ENGAGE IN MATERIAL COOPERATION: Cooperate economically, militarily, judicially. 10 | 7. PROVIDE AID: provide economic, military, humanitarian aid. 11 | 8. YIELD: ease admin or political sanctions or military blockade, return and release, allow intl involvement and humanitarian access,declare truce. 12 | 9. INVESTIGATE: investigate crime, corruption, human rights abuses, military actions. 13 | 10. DEMAND: demand any type of cooperation, aid, political or military reforms, rights, easing of sanctions and dissent. 14 | 11. DISAPPROVE: criticize or denounce, accuse of crime, human rights abuses, complain officially and make lawsuit against. 15 | 12. REJECT: reject material, economic, military, judicial operations, requests or plans. 16 | 13. THREATEN: threaten to reduce aid, to boycott, to reduce or break relations, to impose sanctions, to ban political parties, to halt negotiations. 17 | 14. PROTEST: civilian demonstrations carried out as protests, e.g., rally for any type of change, conduct strikes or boycott. 18 | 15. EXHIBIT MILITARY POSTURE: e.g, mobilize or increase any forces. 19 | 16. REDUCE RELATIONS: reduce or break any relations, aids, stop any assistance, halt negotiations, expel inspectors or aid agencies. 20 | 17. COERCE: seize or damage properties, impose administrative sanctions or restrictions, arrest, expel individuals. 21 | 18. ASSAULT: use of unconventional forms of violence, e.g, physical or sexual assault, torture, conducting bombing, assassinate. 22 | 19. FIGHT: uses of conventional force, e.g, impose blocktrade, restrict movement, occupy territory, fight with small arms. 23 | 20. ENGAGE IN UNCONVENTIONAL MASS VIOLENCE. 24 | -------------------------------------------------------------------------------- /scripts/gdelt/gdelt_prompts_simplified/help_msg_2.txt: -------------------------------------------------------------------------------- 1 | Now I give you 16 examples of abductive reasoning consecutively. In each example, the first event is the effect event, the next three to five events are cause events that happen earlier. -------------------------------------------------------------------------------- /scripts/gdelt/gdelt_prompts_simplified/prompt_assistant_1.txt: -------------------------------------------------------------------------------- 1 | cause event 1 2 | event headline: Tougher Russia sanctions urged over Ukraine 'war crimes' 3 | event predicate: Yield 4 | event time: 2022-04-04 5 | subject name: RUSSIA 6 | object name: RUSSIA 7 | 8 | cause event 2 9 | event headline: PM: Ukraine to lose more than a trillion dollars due to war with Russia 10 | event predicate: Make statement 11 | event time: 2022-03-31 12 | subject name: UKRAINE 13 | object name: UKRAINE 14 | 15 | cause event 3 16 | event headline: Russia Mobilizes Reinforcements from Syria and Africa to Ukraine 17 | event predicate: Demand 18 | event time: 2022-03-31 19 | subject name: CENTRAL AFRICAN REPUBLIC 20 | object name: RUSSIA 21 | 22 | cause event 4 23 | event headline: Visa, Mastercard, American Express say they're suspending operations in Russia amid its invasion of Ukraine 24 | event predicate: EXPRESS INTENT TO COOPERATE 25 | event time: 2022-03-06 26 | subject name: COMPANIES 27 | object name: RUSSIAN 28 | -------------------------------------------------------------------------------- /scripts/gdelt/gdelt_prompts_simplified/prompt_assistant_10.txt: -------------------------------------------------------------------------------- 1 | cause event 1 2 | event headline: Taiwan question purely China's internal affair, nothing to do with Japan: Defense Spokesperson 3 | event predicate: Consult 4 | event time: 2022-03-31 5 | subject name: JAPAN 6 | object name: COLONEL 7 | 8 | cause event 2 9 | event headline: Taiwan: China has launched a disinformation campaign in Chinese on the Russian invasion of Ukrain 10 | event predicate: ENGAGE IN DIPLOMATIC COOPERATION 11 | event time: 2022-03-29 12 | subject name: TAIWANESE 13 | object name: MEDIA 14 | 15 | cause event 3 16 | event headline: US, Filipino forces start war drills in region facing Taiwan 17 | event predicate: REJECT 18 | event time: 2022-03-28 19 | subject name: USA 20 | object name: CHINA 21 | 22 | cause event 4 23 | event headline: Taiwan braces itself for the unthinkable as Russia's invasion sets off alarm bells over China 24 | event predicate: Make statement 25 | event time: 2022-03-05 26 | subject name: BEIJING 27 | object name: UKRAINIAN 28 | 29 | cause event 5 30 | event headline: In New Year's speech, Taiwan president warns China against 'military adventurism' 31 | event predicate: Make statement 32 | event time: 2022-01-01 33 | subject name: CHINESE 34 | object name: TAIWAN 35 | -------------------------------------------------------------------------------- /scripts/gdelt/gdelt_prompts_simplified/prompt_assistant_11.txt: -------------------------------------------------------------------------------- 1 | cause event 1 2 | event headline: Qatar - Saudi Arabia lifts most Covid-19 restrictions for incoming travellers 3 | event predicate: YIELD( Ease administrative sanctions) 4 | event time: 2022-03-22 5 | subject name: DOHA 6 | object name: SAUDI ARABIA 7 | 8 | cause event 2 9 | event headline: Take That Shot and Shove It! Some States Continue to Push Vaccine Exemptions 10 | event predicate: Appeal 11 | event time: 2022-03-16 12 | subject name: EMPLOYER 13 | object name: EMPLOYEE 14 | 15 | cause event 3 16 | event headline: France to scrap Covid-19 test requirement for vaccinated travellers 17 | event predicate: EXPRESS INTENT TO COOPERATE 18 | event time: 2022-02-12 19 | subject name: PARIS 20 | object name: TRAVELLER 21 | 22 | case event 4 23 | event headline: French tourism bosses tell Macron to scrap his travel ban on UK holidaymakers 24 | event predicate: Appeal 25 | event time: 2022-01-01 26 | subject name: BELGIUM 27 | object name: PARIS 28 | 29 | case event 5 30 | event headline: Report: COVID-19 causes 71% drop in UK air travel 31 | event predicate: DISAPPROVE(Criticize or denounce) 32 | event time: 2022-01-01 33 | subject name: BRUSSELS 34 | object name: MEDIA 35 | -------------------------------------------------------------------------------- /scripts/gdelt/gdelt_prompts_simplified/prompt_assistant_12.txt: -------------------------------------------------------------------------------- 1 | cause event 1 2 | event headline: White House: Europe Determined To Ban Russian Energy Exports 3 | event predicate: REDUCE RELATIONS 4 | event time: 2022-04-22 5 | subject name: EUROPE 6 | object name: RUSSIAN 7 | 8 | cause event 2 9 | event headline: oil prices rise as europe weighs additional russia sanctions 10 | event predicate: Engage in material cooperation 11 | event time: 2022-04-04 12 | subject name: INVESTOR 13 | object name: YEMEN 14 | 15 | cause event 3 16 | event headline: G7 countries, including Canada, reject Russia's demand to pay for gas in rubles: Germany 17 | event predicate: Reject 18 | event time: 2022-03-28 19 | subject name: UKRAINE 20 | object name: RUSSIA 21 | 22 | cause event 4 23 | event headline: Russia orders troops into rebel-held regions in Ukraine 24 | event predicate: Make statement 25 | event time: 2022-02-22 26 | subject name: RUSSIA 27 | object name: REBEL 28 | -------------------------------------------------------------------------------- /scripts/gdelt/gdelt_prompts_simplified/prompt_assistant_13.txt: -------------------------------------------------------------------------------- 1 | cause event 1 2 | event headline: China BYD to Stop Producing Fossil-Fueled Vehicles on EV Shift 3 | event predicate: ENGAGE IN MATERIAL COOPERATION 4 | event time: 2022-04-04 5 | subject name: BEIJING 6 | object name: CHNBUS 7 | 8 | cause event 2 9 | event headline: byd seal ev looks like a worthy tesla model 3 competitor from china 10 | event predicate: Appeal 11 | event time: 2022-04-02 12 | subject name: LOCAL MEDIA 13 | object name: CHINA 14 | 15 | cause event 3 16 | event headline: January-February Global Power Battery installed TOP5 released: BYD defeats Panasonic! 17 | event predicate: Make statement 18 | event time: 2022-04-02 19 | subject name: CHINA 20 | object name: Panasonic 21 | 22 | cause event 4 23 | event headline: Chinese auto brands gain larger market share in first two months 24 | event predicate: Make statement 25 | event time: 2022-03-27 26 | subject name: CHINESE 27 | object name: MANUFACTURER 28 | -------------------------------------------------------------------------------- /scripts/gdelt/gdelt_prompts_simplified/prompt_assistant_14.txt: -------------------------------------------------------------------------------- 1 | cause event 1 2 | event headline: North Korea fires ballistic missile in extension of testing 3 | event predicate: DISAPPROVE 4 | event time: 2022-03-06 5 | subject name: SOUTH KOREA 6 | object name: NORTH KOREAN 7 | 8 | cause event 2 9 | event headline: North Korea's Kim Jong-un fires 9th missile this year as World War 3 fears escalate 10 | event predicate: Make statement 11 | event time: 2022-03-04 12 | subject name: RADIO FREE ASIA 13 | object name: UNITPARTY MEMBER 14 | 15 | cause event 3 16 | event headline: North Korea's Kim vows to boost agriculture, defenses amid uncertainties 17 | event predicate: EXPRESS INTENT TO COOPERATE 18 | event time: 2022-01-01 19 | subject name: RULING PARTY 20 | object name: UNITED STATES 21 | -------------------------------------------------------------------------------- /scripts/gdelt/gdelt_prompts_simplified/prompt_assistant_15.txt: -------------------------------------------------------------------------------- 1 | cause event 1 2 | event headline: Parliament of Canada supports visa-free travel from Ukraine 3 | event predicate: ENGAGE IN DIPLOMATIC COOPERATION( Praise or endorse) 4 | event time: 2022-03-30 5 | subject name: CANADA 6 | object name: REGIME 7 | 8 | cause event 2 9 | event headline: Manitoba settlement agencies want federal supports extended to Ukrainians arriving in Canada 10 | event predicate: CONSULT(host a visit) 11 | event time: 2022-03-29 12 | subject name: CANADA 13 | object name: MANITOBA 14 | 15 | cause event 3 16 | event headline: Immigration, booming population and global influence: Is this the 'heroic' Canadian dream? 17 | event predicate: EXPRESS INTENT TO COOPERATE 18 | event time: 2022-01-01 19 | subject name: CANADA 20 | object name: IMMIGRANT 21 | -------------------------------------------------------------------------------- /scripts/gdelt/gdelt_prompts_simplified/prompt_assistant_16.txt: -------------------------------------------------------------------------------- 1 | cause event 1 2 | event headline: Visa, Mastercard, American Express say they're suspending operations in Russia amid its invasion of Ukraine 3 | event predicate: EXPRESS INTENT TO COOPERATE 4 | event time: 2022-03-06 5 | subject name: COMPANIES 6 | object name: RUSSIAN 7 | 8 | cause event 2 9 | event headline: Microsoft 'suspends' new sales of products, services in Russia 10 | event predicate: DISAPPROVE 11 | event time: 2022-03-05 12 | subject name: MICROSOFT 13 | object name: RUSSIAN 14 | 15 | cause event 3 16 | event headline: Google suspends online advertising sales in Russia 17 | event predicate: Make statement 18 | event time: 2022-03-04 19 | subject name: GOOGLE 20 | object name: RUSSIAN 21 | 22 | cause event 4 23 | event headline: Biden imposes sanctions against Russian oligarchs, banks for beginning of Ukraine invasion 24 | event predicate: ENGAGE IN DIPLOMATIC COOPERATION 25 | event time: 2022-02-22 26 | subject name: JOE BIDEN 27 | object name: RUSSIA 28 | -------------------------------------------------------------------------------- /scripts/gdelt/gdelt_prompts_simplified/prompt_assistant_2.txt: -------------------------------------------------------------------------------- 1 | cause event 1 2 | event headline: Covid-19's death toll is nearly three times higher than reported, WHO data suggests 3 | event predicate: Engage in material cooperation 4 | event time: 2022-05-05 5 | subject name: UNITED STATES 6 | object name: DIRECTOR GENERAL 7 | 8 | cause event 2 9 | event headline: Shanghai considers more Covid-19 curbs as lockdown enters 4th week 10 | event predicate: FIGHT (Impose blockade, restrict movement) 11 | event time: 2022-04-25 12 | subject name: Shanghai 13 | object name: China 14 | 15 | cause event 3 16 | event headline: Shanghai to tough prevention measures as local COVID infections rise 17 | event predicate: Decline comment 18 | event time: 2022-04-23 19 | subject name: Shanghai 20 | object name: CARRIER 21 | 22 | cause event 4 23 | event headline: Shanghai intensifies lockdown 24 | event predicate: Express intent to cooperate 25 | event time: 2022-03-31 26 | subject name: Shanghai 27 | object name: China 28 | 29 | cause event 5 30 | event headline: Sticking to dynamic zero-COVID approach, China strives for balance of virus containment, economic development--China Economic Net 31 | event predicate: Appeal 32 | event time: 2022-03-26 33 | subject name: Shanghai 34 | object name: Residents 35 | -------------------------------------------------------------------------------- /scripts/gdelt/gdelt_prompts_simplified/prompt_assistant_3.txt: -------------------------------------------------------------------------------- 1 | cause event 1 2 | event headline: Joe Biden Welcomes Visa & Mastercard's Decision To Suspend Operations In Russia 3 | event predicate: ENGAGE IN DIPLOMATIC COOPERATION 4 | event time: 2022-03-06 5 | subject name: RUSSIA 6 | object name: JOE BIDEN 7 | 8 | cause event 2 9 | event headline: Putin says Ukraine's future is in doubt as cease-fires collapse 10 | event predicate: Threaten 11 | event time: 2022-03-06 12 | subject name: RUSSIA 13 | object name: UKRAINE 14 | 15 | cause event 3 16 | event headline: Zelensky Urges US Lawmakers to Ban Russian Oil, Provide Fighter Jets 17 | event predicate: Make statement 18 | event time: 2022-03-05 19 | subject name: COLONEL 20 | object name: RUSSIA 21 | 22 | cause event 4 23 | event headline: Marshall, Moran seek ban on Russian oil imports, urge expansion of domestic production 24 | event predicate: Use unconventional violence 25 | event time: 2022-03-05 26 | subject name: SENATOR 27 | object name: RUSSIA 28 | -------------------------------------------------------------------------------- /scripts/gdelt/gdelt_prompts_simplified/prompt_assistant_4.txt: -------------------------------------------------------------------------------- 1 | cause event 1 2 | event headline: Russia orders troops into rebel-held regions in Ukraine 3 | event predicate: Make statement 4 | event time: 2022-02-22 5 | subject name: RUSSIA 6 | object name: REBEL 7 | 8 | cause event 2 9 | event headline: Russia-Ukraine: Biden briefed by security advisers as Russian commanders receive invasion orders 10 | event predicate: EXPRESS INTENT TO COOPERATE 11 | event time: 2022-02-20 12 | subject name: PRESIDENT 13 | object name: RUSSIA 14 | 15 | cause event 3 16 | event headline: Biden says he's now convinced Putin has decided to invade Ukraine, but leaves door open for diplomacy 17 | event predicate: Consult 18 | event time: 2022-02-18 19 | subject name: PRESIDENT 20 | object name: UNITED STATES 21 | 22 | cause event 4 23 | event headline: russia may stage violent event to justify ukraine invasion,says us 24 | event predicate: Fight 25 | event time: 2022-02-18 26 | subject name: RUSSIA 27 | object name: CIVILIAN 28 | 29 | cause event 5 30 | event headline: Full coverage: Russian troops advance toward Ukraine capital of Kyiv 31 | event predicate: Reject 32 | event time: 2022-02-17 33 | subject name: LAWMAKER 34 | object name: PRESIDENT 35 | -------------------------------------------------------------------------------- /scripts/gdelt/gdelt_prompts_simplified/prompt_assistant_5.txt: -------------------------------------------------------------------------------- 1 | cause event 1 2 | event headline: china adopts russias denazification myth to rationalize invasion of ukraine 3 | event predicate: Appeal 4 | event time: 2022-03-06 5 | subject name: CHINESE 6 | object name: TAIWAN 7 | 8 | cause event 2 9 | event headline: Biden seeks to inflict economic pain on Putin; avoids bigger war 10 | event predicate: Disapprove 11 | event time: 2022-03-06 12 | subject name: UKRAINE 13 | object name: RUSSIA 14 | 15 | cause event 3 16 | event headline: Visa, Mastercard Grant Zelensky's Request, Ban All Operations in Russia 17 | event predicate: Reduce relations 18 | event time: 2022-03-05 19 | subject name: MASTERCARD 20 | object name: RUSSIA 21 | 22 | cause event 4 23 | event headline: Visa, Mastercard Suspend All Operations in Russia Effective Immediately Over Ukraine Crisis 24 | event predicate: Appeal 25 | event time: 2022-03-05 26 | subject name: MASTERCARD 27 | object name: REGULATOR 28 | 29 | cause event 5 30 | event headline: Biden imposes sanctions against Russian oligarchs, banks for beginning of Ukraine invasion 31 | event predicate: ENGAGE IN DIPLOMATIC COOPERATION 32 | event time: 2022-02-22 33 | subject name: JOE BIDEN 34 | object name: RUSSIA 35 | -------------------------------------------------------------------------------- /scripts/gdelt/gdelt_prompts_simplified/prompt_assistant_6.txt: -------------------------------------------------------------------------------- 1 | cause event 1 2 | event headline: Int'l flights to Shanghai rescheduled amid COVID-19 resurgence 3 | event predicate: Appeal 4 | event time: 2022-03-16 5 | subject name: SHANGHAI 6 | object name: CHINA 7 | 8 | cause event 2 9 | event headline: Saudi Arabia lifts most Covid restrictions including quarantine for vaccinated arrivals 10 | cause predicate: EXPRESS INTENT TO COOPERATE 11 | event time: 2022-03-06 12 | subject name: TRAVELLER 13 | object name: KINGDOM 14 | 15 | cause event 3 16 | event headline: Tesla halts work at Shanghai fentityy for two days amid China COVID curbs 17 | event predicate: ENGAGE IN MATERIAL COOPERATION 18 | event time: 2022-03-16 19 | subject name: SHANGHAI 20 | object name: GERMANY 21 | 22 | cause event 4 23 | event headline: China Mandates COVID-Zero Policy in Shenzhen; Locals Struggle to Source Foodng of Ukraine invasion 24 | event predicate: Make statement 25 | event time: 2022-03-05 26 | subject name: CHINA 27 | object name: CHINA 28 | 29 | cause event 5 30 | event headline: China welcomes 2022 with the worst week of COVID-19 dominating the virus 31 | event predicate: Reject 32 | event time: 2022-01-01 33 | subject name: XIAN 34 | object name: AUTHORITIES 35 | -------------------------------------------------------------------------------- /scripts/gdelt/gdelt_prompts_simplified/prompt_assistant_7.txt: -------------------------------------------------------------------------------- 1 | cause event 1 2 | event headline: COVID-hit Beijing returns to work after subdued Labour Day break 3 | event predicate: YIELD 4 | event time: 2022-05-05 5 | subject name: CHINESE 6 | object name: WORKER 7 | 8 | cause event 2 9 | event headline: russias war against ukraine might persistently shift global supply chains 10 | event predicate: Investigate 11 | event time: 2022-03-31 12 | subject name: RUSSIA 13 | object name: CRIMEA 14 | 15 | cause event 3 16 | event headline: Covid-led lockdowns, closure of markets, and unseasonal rains have increased farmers_ dependence on moneylenders 17 | event predicate: PROVIDE AID 18 | event time: 2022-03-22 19 | subject name: MAHARASHTRA 20 | object name: FARMER 21 | 22 | cause event 4 23 | event headline: Air, travel industry expect fast recovery, but face burden over soaring oil prices, Ukraine war 24 | event predicate: Engage in material cooperation 25 | event time: 2022-03-27 26 | subject name: AIRLINE 27 | object name: NAGOYA 28 | 29 | cause event 5 30 | event headline: Tesla halts work at Shanghai fentityy for two days amid China COVID curbs 31 | event predicate: ENGAGE IN MATERIAL COOPERATION 32 | event time: 2022-03-16 33 | subject name: SHANGHAI 34 | object name: GERMANY 35 | -------------------------------------------------------------------------------- /scripts/gdelt/gdelt_prompts_simplified/prompt_assistant_8.txt: -------------------------------------------------------------------------------- 1 | cause event 1 2 | event headline: Australia imposes further sanctions on Russia 3 | event predicate: Make statement 4 | event time: 2022-04-01 5 | subject name: AUSTRALIA 6 | object name: RUSSIA 7 | 8 | cause event 2 9 | event headline: Biden seeks to inflict economic pain on Putin; avoids bigger war 10 | event predicate: Disapprove 11 | event time: 2022-03-05 12 | subject name: UKRAINE 13 | object name: RUSSIA 14 | 15 | cause event 3 16 | event headline: Switzerland backs full EU sanctions against Russia 17 | event predicate: CONSULT 18 | event time: 2022-03-06 19 | subject name: SWISS 20 | object name: GOVERNMENT 21 | 22 | cause event 4 23 | event headline: Canada ratchets up sanctions against Russia, cancels export permits 24 | event predicate: ENGAGE IN MATERIAL COOPERATION 25 | event time: 2022-02-25 26 | subject name: CANADA 27 | object name: RUSSIA 28 | 29 | cause event 5 30 | event headline: Biden imposes sanctions against Russian oligarchs, banks for beginning of Ukraine invasion 31 | event predicate: ENGAGE IN DIPLOMATIC COOPERATION 32 | event time: 2022-02-22 33 | subject name: JOE BIDEN 34 | object name: RUSSIA 35 | -------------------------------------------------------------------------------- /scripts/gdelt/gdelt_prompts_simplified/prompt_assistant_9.txt: -------------------------------------------------------------------------------- 1 | cause event 1 2 | event headline: White House: Europe Determined To Ban Russian Energy Exports 3 | event predicate: REDUCE RELATIONS 4 | event time: 2022-04-22 5 | subject name: EUROPE 6 | object name: RUSSIAN 7 | 8 | cause event 2 9 | event headline: oil prices rise as europe weighs additional russia sanctions 10 | event predicate: Engage in material cooperation 11 | event time: 2022-04-04 12 | subject name: INVESTOR 13 | object name: YEMEN 14 | 15 | cause event 3 16 | event headline: EU Considers Energy Sanctions Against Russia Over War Crimes In Ukraine 17 | event predicate: DISAPPROVE 18 | event time: 2022-04-04 19 | subject name: THE EUROPEAN UNION 20 | object name: CIVILIAN 21 | -------------------------------------------------------------------------------- /scripts/gdelt/gdelt_prompts_simplified/prompt_user_1.txt: -------------------------------------------------------------------------------- 1 | effect event 2 | event predicate: Appeal 3 | event time: 2022-04-19 4 | subject name: UNITED STATES 5 | object name: Russia 6 | -------------------------------------------------------------------------------- /scripts/gdelt/gdelt_prompts_simplified/prompt_user_10.txt: -------------------------------------------------------------------------------- 1 | effect event 2 | event predicate: Consult 3 | event time: 2022-04-02 4 | subject name: WASHINGTON 5 | object name: CHINESE 6 | -------------------------------------------------------------------------------- /scripts/gdelt/gdelt_prompts_simplified/prompt_user_11.txt: -------------------------------------------------------------------------------- 1 | effect event 2 | event predicate: Consult 3 | event time: 2022-03-31 4 | subject name: FRANCE 5 | object name: TRAVELLER 6 | -------------------------------------------------------------------------------- /scripts/gdelt/gdelt_prompts_simplified/prompt_user_12.txt: -------------------------------------------------------------------------------- 1 | effect event 2 | event predicate: Appeal 3 | event time: 2022-05-05 4 | subject name: GERMANY 5 | object name: EUROPE 6 | -------------------------------------------------------------------------------- /scripts/gdelt/gdelt_prompts_simplified/prompt_user_13.txt: -------------------------------------------------------------------------------- 1 | effect event 2 | event predicate: Make statement 3 | event time: 2022-05-05 4 | subject name: GERMAN 5 | object name: CHINA 6 | -------------------------------------------------------------------------------- /scripts/gdelt/gdelt_prompts_simplified/prompt_user_14.txt: -------------------------------------------------------------------------------- 1 | effect event 2 | event predicate: Consult 3 | event time: 2022-05-05 4 | subject name: SOUTH KOREA 5 | object name: UNITED STATES 6 | -------------------------------------------------------------------------------- /scripts/gdelt/gdelt_prompts_simplified/prompt_user_15.txt: -------------------------------------------------------------------------------- 1 | effect event 2 | event predicate: Make statement 3 | event time: 2022-05-07 4 | subject name: CANADA 5 | object name: ATTORNEY 6 | -------------------------------------------------------------------------------- /scripts/gdelt/gdelt_prompts_simplified/prompt_user_16.txt: -------------------------------------------------------------------------------- 1 | effect event 2 | event predicate: CONSULT 3 | event time: 2022-04-22 4 | subject name: RUSSIA 5 | object name: BANK OF AMERICA 6 | -------------------------------------------------------------------------------- /scripts/gdelt/gdelt_prompts_simplified/prompt_user_2.txt: -------------------------------------------------------------------------------- 1 | effect event 2 | event predicate: Investigate 3 | event time: 2022-05-07 4 | subject name: Shanghai 5 | object name: Mayor 6 | -------------------------------------------------------------------------------- /scripts/gdelt/gdelt_prompts_simplified/prompt_user_3.txt: -------------------------------------------------------------------------------- 1 | effect event 2 | event predicate: Use unconventional violence 3 | event time: 2022-03-08 4 | subject name: JOE BIDEN 5 | object name: RUSSIA 6 | -------------------------------------------------------------------------------- /scripts/gdelt/gdelt_prompts_simplified/prompt_user_4.txt: -------------------------------------------------------------------------------- 1 | effect event 2 | event type: ENGAGE IN DIPLOMATIC COOPERATION 3 | event time: 2022-02-22 4 | subject name: JOE BIDEN 5 | object name: RUSSIA -------------------------------------------------------------------------------- /scripts/gdelt/gdelt_prompts_simplified/prompt_user_5.txt: -------------------------------------------------------------------------------- 1 | effect event 2 | event type: CONSULT 3 | event time: 2022-03-06 4 | subject name: CHINESE 5 | object name: RUSSIAN -------------------------------------------------------------------------------- /scripts/gdelt/gdelt_prompts_simplified/prompt_user_6.txt: -------------------------------------------------------------------------------- 1 | effect event 2 | event type: Provide aid 3 | event time: 2022-03-29 4 | subject name: SHANGHAI 5 | object name: RESIDENTS -------------------------------------------------------------------------------- /scripts/gdelt/gdelt_prompts_simplified/prompt_user_7.txt: -------------------------------------------------------------------------------- 1 | effect event 2 | event type: Express intent to cooperate 3 | event time: 2022-06-07 4 | subject name: BANK OF CHINA 5 | object name: GOVERNOR -------------------------------------------------------------------------------- /scripts/gdelt/gdelt_prompts_simplified/prompt_user_8.txt: -------------------------------------------------------------------------------- 1 | effect event 2 | event type: Engage in material cooperation 3 | event time: 2022-04-19 4 | subject name: CENTRAL BANK 5 | object name: RUSSIAN -------------------------------------------------------------------------------- /scripts/gdelt/gdelt_prompts_simplified/prompt_user_9.txt: -------------------------------------------------------------------------------- 1 | effect event 2 | event type: Consult 3 | event time: 2022-05-05 4 | subject name: CARTEL 5 | object name: OPEC -------------------------------------------------------------------------------- /scripts/gdelt/gdelt_prompts_simplified/system_msg_1.txt: -------------------------------------------------------------------------------- 1 | We will work with event data from Global Database of Events, Language, and Tone (GDELT). You are a helpful assistant that does the abductive reasoning over the events that are extracted from GDELT database.I gave you an effect event, then you give me three to five cause events that cause the effect event to happen. Each event consists of an event headline, an event type, an event time, subject name and object name of describing the event. -------------------------------------------------------------------------------- /scripts/gdelt/step_1_news_scraper.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import pandas as pd 4 | 5 | import feapder 6 | from newspaper import Article 7 | 8 | 9 | def parse_html_to_title_and_article(html_doc): 10 | article = Article('') # STRING REQUIRED AS `url` ARGUMENT BUT NOT USED 11 | article.set_html(html_doc) 12 | article.parse() 13 | title = article.title 14 | content = article.text 15 | return title, content 16 | 17 | 18 | class BaseSpider: 19 | """ 20 | Download the news text from given urls in GDELT data 21 | """ 22 | def __init__( 23 | self, 24 | db_storage: str, 25 | gdelt_data_fn: str, 26 | is_continue: bool = True, 27 | is_inverse: bool = False, 28 | thread_count: int = None, 29 | ): 30 | self.db_storage = db_storage 31 | self.gdelt_data_fn = gdelt_data_fn 32 | self.is_continue = is_continue 33 | self.is_inverse = is_inverse 34 | self._gdelt_data = None # load data in 'start_requests' 35 | 36 | @property 37 | def gdelt_data(self): 38 | if self._gdelt_data is None: 39 | self._gdelt_data = pd.read_csv(self.gdelt_data_fn) 40 | return self._gdelt_data 41 | 42 | def _is_a_error_request(self, title, article): 43 | skip_titles = [ 44 | 'Just a moment', 45 | 'Page Unavailable', 46 | 'Page Not Found', 47 | 'Access denied', 48 | 'Attention Required!', 49 | ] 50 | 51 | skip_articles = [ 52 | '请检查您的互联网连接是否正常', 53 | '无法访问此网站', 54 | 'The item that you have requested was not found', 55 | ] 56 | 57 | if title is None or title.strip() == '': 58 | return True 59 | if article is None or article.strip() == '': 60 | return True 61 | 62 | for key_title in skip_titles: 63 | if key_title in title: 64 | return True 65 | 66 | for key_article in skip_articles: 67 | if key_article in article: 68 | return True 69 | return False 70 | 71 | def _save_news(self, gdelt_series, title, article): 72 | path = self._construct_news_path(gdelt_series) 73 | os.makedirs(path, exist_ok=True) 74 | with open(os.path.join(path, f'article_{title}.txt'), 'w') as file: 75 | file.write(article) 76 | return True 77 | 78 | def _construct_news_path(self, gdelt_series): 79 | return os.path.join(self.db_storage, str(gdelt_series['DATEADDED']), str(gdelt_series['GLOBALEVENTID'])) 80 | 81 | def get_scraped_news_count(self): 82 | if not os.path.exists(self.db_storage): 83 | return 0 84 | count = 0 85 | for ts_id in os.listdir(self.db_storage): 86 | ts_id_path = os.path.join(self.db_storage, ts_id) 87 | if os.path.isdir(ts_id_path): 88 | for event_id in os.listdir(ts_id_path): 89 | event_id_path = os.path.join(ts_id_path, event_id) 90 | 91 | if os.path.isdir(event_id_path) and len(os.listdir(event_id_path)) > 0: 92 | count += 1 93 | 94 | return count 95 | 96 | def get_total_news(self): 97 | return len(self.gdelt_data) 98 | 99 | 100 | class NewsRender(feapder.AirSpider, BaseSpider): 101 | __custom_setting__ = dict( 102 | # SPIDER_THREAD_COUNT=32, 103 | SPIDER_MAX_RETRY_TIMES=1, 104 | REQUEST_TIMEOUT=300, 105 | 106 | # LOG_IS_WRITE_TO_FILE=True, 107 | LOG_LEVEL='ERROR', 108 | PRINT_EXCEPTION_DETAILS=False, 109 | 110 | WEBDRIVER=dict( 111 | pool_size=16, # 浏览器的数量 112 | load_images=False, # 是否加载图片 113 | # user_agent=None, # 字符串 或 无参函数,返回值为user_agent 114 | user_agent="User-Agent': 'Mozilla/5.0 (iPad; U; CPU OS 3_2_1 like Mac OS X; en-us) AppleWebKit/531.21.10 (KHTML, like Gecko) Mobile/7B405", 115 | proxy=None, # xxx.xxx.xxx.xxx:xxxx 或 无参函数,返回值为代理地址 116 | headless=False, # 是否为无头浏览器 117 | driver_type="CHROME", # CHROME、PHANTOMJS、FIREFOX 118 | timeout=120, # 请求超时时间 119 | window_size=(1024, 800), # 窗口大小 120 | executable_path=None, # 浏览器路径,默认为默认路径 121 | # render_time=5, # 渲染时长,即打开网页等待指定时间后再获取源码 122 | custom_argument=[ 123 | "--ignore-certificate-errors", 124 | "--disable-blink-features=AutomationControlled", 125 | ], # 自定义浏览器渲染参数 126 | xhr_url_regexes=[ 127 | "/ad", 128 | ], # 拦截 http://www.spidertools.cn/spidertools/ad 接口 129 | auto_install_driver=True, 130 | ) 131 | ) 132 | 133 | def __init__( 134 | self, 135 | db_storage: str, 136 | gdelt_data_fn: str, 137 | is_continue: bool = True, 138 | is_inverse: bool = False, 139 | thread_count: int = None, 140 | ): 141 | feapder.AirSpider.__init__(self, thread_count) 142 | BaseSpider.__init__( 143 | self, 144 | db_storage=db_storage, 145 | gdelt_data_fn=gdelt_data_fn, 146 | is_continue=is_continue, 147 | is_inverse=is_inverse, 148 | thread_count=thread_count 149 | ) 150 | 151 | def start_requests(self): 152 | if self.is_inverse: 153 | it = range(len(self.gdelt_data) - 1, -1, -1) 154 | else: 155 | it = range(410000, len(self.gdelt_data)) 156 | for idx in it: 157 | gdelt_series = self.gdelt_data.iloc[idx] 158 | path = self._construct_news_path(gdelt_series) 159 | if self.is_continue: 160 | # check if already fetched 161 | if os.path.exists(path) and len(os.listdir(path)) > 0: 162 | continue 163 | 164 | url = gdelt_series['SOURCEURL'] 165 | yield feapder.Request( 166 | url, 167 | random_user_agent=False, 168 | stream=False, 169 | render_time=5, 170 | render=True, 171 | verify=False, 172 | timeout=(30, 120), 173 | allow_redirects=True, 174 | extra_info={ 175 | 'DATEADDED': gdelt_series['DATEADDED'], 176 | 'GLOBALEVENTID': gdelt_series['GLOBALEVENTID'], 177 | } 178 | ) 179 | 180 | def parse(self, request, response): 181 | # browser: WebDriver = response.browser 182 | url = request.url 183 | title, article = parse_html_to_title_and_article(response.text) 184 | 185 | # check status 186 | if self._is_a_error_request(title, article): 187 | raise RuntimeError(f'Scrape fail: {url}') 188 | 189 | # --------------------------------------------- 190 | gdelt_series = request.extra_info 191 | # save news 192 | self._save_news(gdelt_series, title, article) 193 | 194 | 195 | if __name__ == "__main__": 196 | gdelt_data_fn = '../../data/gdelt/events_raw_clean.csv' 197 | render = NewsRender( 198 | db_storage='../../data/gdelt/news', 199 | gdelt_data_fn=gdelt_data_fn, 200 | thread_count=24, 201 | # is_inverse=True 202 | ) 203 | 204 | # statistics 205 | print('Scraped news count', render.get_scraped_news_count()) 206 | print('Total news count:', render.get_total_news()) 207 | 208 | # start spider 209 | render.start() 210 | -------------------------------------------------------------------------------- /scripts/gdelt/step_3_5_analyze_gdelt_entity.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import pickle 3 | 4 | import numpy as np 5 | import pandas as pd 6 | 7 | 8 | 9 | def get_event_given_entity_id(source_data, entity_id, id_pos=1): 10 | events = filter(lambda x: x[id_pos] == entity_id, source_data) 11 | return events 12 | 13 | 14 | 15 | def main(): 16 | # load all events 17 | data_dir = '../..data/gdelt/gdelt.pkl' 18 | with open(data_dir, 'rb') as f: 19 | data = pickle.load(f) 20 | 21 | data = data['data'] 22 | 23 | entity_map_csv_dir = '../../scripts/entity_map.csv' 24 | entity_map = pd.read_csv(entity_map_csv_dir, header=0) 25 | entity_map.rename(columns={'index': 'obj_id'}, inplace=True) 26 | 27 | # an overall statistics for the object 28 | total_obj = [x[1] for x in data] 29 | counter = collections.Counter(total_obj) 30 | top_20_obj = dict(counter.most_common(20)) 31 | 32 | top_20_obj_list = top_20_obj.keys() 33 | obj_pct = np.array(list(top_20_obj.values())) / len(total_obj) 34 | 35 | obj_stat = pd.DataFrame(columns=['obj_id', 'obj_pct']) 36 | obj_stat['obj_id'] = top_20_obj_list 37 | obj_stat['obj_pct'] = obj_pct 38 | obj_stat = pd.merge(obj_stat, entity_map, on='obj_id', how='left') 39 | obj_stat = obj_stat[['obj_id', 'obj_pct', 'actor_code']] 40 | 41 | # a detailed look at every event 42 | total_sub = [x[0] for x in data] 43 | obj_spec = {} 44 | for sub_id in total_sub: 45 | # top 20 sub given the obj 46 | all_events_given_obj_id = get_event_given_entity_id(data, obj_id) 47 | all_suj_given_obj_id = [x[0] for x in all_events_given_obj_id] 48 | 49 | top_20_sub_given_obj_id = collections.Counter(all_suj_given_obj_id).most_common(20) 50 | top_20_sub_given_obj_id = list(top_20_sub_given_obj_id.values()) 51 | 52 | # top 20 obj given obj's top sub 53 | top_20_obj_one_step = [] 54 | for sub_id in top_20_sub_given_obj_id: 55 | all_events_given_sub_id = get_event_given_entity_id(data, sub_id, id_pos=0) 56 | all_obj_given_sub_id = [x[1] for x in all_events_given_sub_id] 57 | top_20_obj_given_sub_id = collections.Counter(all_obj_given_sub_id).most_common(5) 58 | top_20_obj_one_step.extend(list(top_20_obj_given_sub_id.values())) 59 | 60 | 61 | res = {'obj_stat': obj_stat, 62 | 'obj_spec': obj_spec} 63 | 64 | with open('gdelt_obj_stat_v0327.pkl', "wb") as f_out: 65 | pickle.dump(res, f_out) 66 | 67 | return 68 | 69 | 70 | if __name__ == '__main__': 71 | main() -------------------------------------------------------------------------------- /scripts/gdelt/step_4_query_chatgpt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import openai 3 | import elara 4 | from concurrent.futures import ThreadPoolExecutor 5 | 6 | from utils.general import file_uri_reader_processor 7 | from utils.gpt_utils import generate_gdelt_prompt_v2, generate_gdelt_comb_prompt 8 | import dictdatabase as DDB 9 | 10 | 11 | def read_text_file(fn): 12 | with open(fn, 'r') as file: 13 | return file.read() 14 | 15 | 16 | class EventQuery: 17 | def __init__(self, api_key, prompt_folder: str, num_prompts: int = 12): 18 | openai.api_key = api_key 19 | self.setup_msgs = [] 20 | 21 | # process prompts 22 | system_msgs = [] 23 | prompt_assistant_msgs = [] 24 | prompt_user_msgs = [] 25 | help_msgs = [] 26 | if not os.path.exists(prompt_folder): 27 | raise RuntimeError(f'Prompt folder is not existed: {prompt_folder}') 28 | for fn in sorted(os.listdir(prompt_folder)): 29 | file_path = os.path.join(prompt_folder, fn) 30 | if not os.path.isfile(file_path): 31 | continue 32 | msg = read_text_file(file_path) 33 | if fn.startswith('system_msg'): 34 | system_msgs.append(msg) 35 | elif fn.startswith('prompt_assistant'): 36 | prompt_assistant_msgs.append(msg) 37 | elif fn.startswith('prompt_user'): 38 | prompt_user_msgs.append(msg) 39 | elif fn.startswith('help_msg'): 40 | help_msgs.append(msg) 41 | 42 | for msg in system_msgs: 43 | self.setup_msgs.append({ 44 | 'role': 'system', 45 | 'content': msg 46 | }) 47 | for msg in help_msgs: 48 | self.setup_msgs.append({ 49 | 'role': 'user', 50 | 'content': msg 51 | }) 52 | for user_msg, assistant_msg in zip(prompt_user_msgs[:num_prompts], prompt_assistant_msgs[:num_prompts]): 53 | self.setup_msgs.append({ 54 | 'role': 'user', 55 | 'content': user_msg 56 | }) 57 | self.setup_msgs.append({ 58 | 'role': 'assistant', 59 | 'content': assistant_msg 60 | }) 61 | 62 | def query(self, msg): 63 | msg_list = self.setup_msgs + [{ 64 | 'role': 'user', 65 | 'content': msg 66 | }] 67 | completions = openai.ChatCompletion.create( 68 | model="gpt-3.5-turbo", 69 | messages=msg_list, 70 | stream=False 71 | ) 72 | print('Usage:', completions['usage']['total_tokens']) 73 | body = completions['choices'][0]['message']['content'] 74 | return body 75 | # for i, choice in enumerate(completions['choices']): 76 | # print('---------------- choice', i) 77 | # print(choice['message']['content']) 78 | 79 | 80 | def gpt_generate_causal_events( 81 | db_base_name, 82 | gpt, 83 | pred_data, 84 | source_data, 85 | inference_type: str = 'relation', 86 | top_k: int = 5, 87 | num_threads: int = 16 88 | ): 89 | msg_head = '''Now I give you an effect event, and you give me three to five cause events.\n\n''' 90 | 91 | def _process_one_relation_or_object(idx, rel_or_obj, text): 92 | try: 93 | existed_dict = DDB.at(db_base_name, inference_type, key=str(idx)).read() 94 | if existed_dict is not None and str(rel_or_obj) in existed_dict: 95 | return 96 | ret_body = gpt.query(msg_head + text) 97 | with DDB.at(db_base_name, inference_type).session() as (sess, obj): 98 | if str(idx) in obj: 99 | obj[str(idx)][str(rel_or_obj)] = ret_body 100 | else: 101 | obj[str(idx)] = { 102 | str(rel_or_obj): ret_body 103 | } 104 | sess.write() 105 | print(idx, rel_or_obj) 106 | except Exception as e: 107 | print('Error of', str(e)) 108 | 109 | executor = ThreadPoolExecutor(max_workers=num_threads) 110 | init_db = DDB.at(db_base_name, inference_type).read() 111 | if init_db is None: 112 | DDB.at(db_base_name, inference_type).create() 113 | 114 | for one_prediction in pred_data: 115 | idx = one_prediction['original_idx'] 116 | if inference_type == 'relation': 117 | msg_content_dict = generate_gdelt_prompt_v2(one_prediction, source_data, top_k=top_k, pred_relation=True) 118 | else: 119 | msg_content_dict = generate_gdelt_prompt_v2(one_prediction, source_data, top_k=top_k, pred_relation=False) 120 | for rel_or_obj, content in msg_content_dict.items(): 121 | # check if the relation of the idx is existed 122 | if str(idx) in init_db and str(rel_or_obj) in init_db[str(idx)]: 123 | continue 124 | executor.submit(_process_one_relation_or_object, str(idx), rel_or_obj, content) 125 | 126 | executor.shutdown(wait=True) 127 | 128 | 129 | def gpt_generate_comb_causal_events( 130 | db_base_name, 131 | gpt, 132 | pred_data, 133 | source_data, 134 | top_k: int = 100, 135 | num_threads: int = 6 136 | ): 137 | msg_head = '''Now I give you an effect event, and you give me three to five cause events.\n\n''' 138 | 139 | def _process_one_relation_or_object(idx, rel_obj, text): 140 | try: 141 | existed_dict = DDB.at(db_base_name, str(idx), key=str(rel_obj)).read() 142 | if existed_dict is not None: 143 | return 144 | ret_body = gpt.query(msg_head + text) 145 | with DDB.at(db_base_name, str(idx)).session() as (sess, ddb): 146 | ddb[str(rel_obj)] = ret_body 147 | sess.write() 148 | print(idx, rel_obj) 149 | except Exception as e: 150 | print('Error of', str(e)) 151 | 152 | executor = ThreadPoolExecutor(max_workers=num_threads) 153 | 154 | for one_prediction in pred_data: 155 | idx = one_prediction['original_idx'] 156 | 157 | idx_db = DDB.at(db_base_name, str(idx)).read() 158 | if idx_db is None: 159 | DDB.at(db_base_name, str(idx)).create() 160 | idx_db = {} 161 | 162 | msg_content_dict = generate_gdelt_comb_prompt(one_prediction, source_data, top_k=top_k, filter_hit=True) 163 | 164 | for rel_obj, content in msg_content_dict.items(): 165 | # check if the relation of the idx is existed 166 | if str(idx) in idx_db and str(rel_obj) in idx_db: 167 | continue 168 | executor.submit(_process_one_relation_or_object, str(idx), rel_obj, content) 169 | 170 | executor.shutdown(wait=True) 171 | 172 | 173 | if __name__ == '__main__': 174 | # predictions draw from base model 175 | base_model_prediction_fn = '../../logs/ke_anhp_gdelt_test.pkl' 176 | # original sequence data 177 | raw_seq_data_fn = '../../data/gdelt/gdelt.pkl' 178 | 179 | gpt = EventQuery( 180 | api_key='*', 181 | prompt_folder='gdelt_prompts_simplified', 182 | num_prompts=10 183 | ) 184 | 185 | pred_data = file_uri_reader_processor(base_model_prediction_fn) 186 | source_data = file_uri_reader_processor(raw_seq_data_fn)['data'] 187 | 188 | gpt_generate_causal_events( 189 | 'gdelt_chatgpt', 190 | gpt, 191 | pred_data, 192 | source_data, 193 | inference_type='relation', 194 | top_k=5, 195 | # inference_type='object', 196 | # top_k=20, 197 | num_threads=6 198 | ) 199 | -------------------------------------------------------------------------------- /scripts/gdelt/step_5_make_emb_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | from utils.general import file_uri_reader_processor, ProcessPool 3 | from utils.misc import make_comb_samples_for_energy_function 4 | 5 | if __name__ == '__main__': 6 | 7 | # source_data_fn = 'gdelt_emb.pkl' 8 | source_data_fn = '../../data/gdelt/gdelt.pkl' 9 | # pred_data_fn = '../../logs/ke_gdelt_test.pkl' 10 | pred_data_fn = '../../logs/ke_anhp_gdelt_test.pkl' 11 | 12 | 13 | source_data = file_uri_reader_processor(source_data_fn)['data'] 14 | pred_data = file_uri_reader_processor(pred_data_fn) 15 | 16 | make_samples_for_energy_function( 17 | gpt_db_name='gdelt_chatgpt/relation', 18 | source_data=source_data, 19 | pred_data=pred_data, 20 | pred_type='relation', 21 | topk=5, 22 | # pred_type='object', 23 | # topk=20, 24 | ebm_db_name='ke_anhp_gdelt_bert_ebm_dataset', 25 | retro_top_n=2, 26 | distance_type='bert' 27 | ) -------------------------------------------------------------------------------- /scripts/gdelt/step_6_recompute_distance.py: -------------------------------------------------------------------------------- 1 | from utils.general import file_uri_reader_processor, file_uri_writer_processor 2 | 3 | 4 | def encode_event_text(source_data, method='bert'): 5 | if method == 'bert': 6 | from sentence_transformers import SentenceTransformer 7 | model = SentenceTransformer('all-MiniLM-L6-v2') 8 | else: 9 | from transformers import AutoTokenizer 10 | tokenizer = AutoTokenizer.from_pretrained("gpt2") 11 | tokenizer.pad_token = tokenizer.eos_token 12 | 13 | source_data_ = [] 14 | for event in source_data: 15 | if method == 'bert': 16 | event_encode = model.encode(event[-1]) 17 | else: 18 | event_encode = tokenizer.encode(event[-1], padding="max_length", max_length=60, truncation=True) 19 | print(len(event_encode)) 20 | event_tuple = (event[0], event[1], event[2], event[3], event[4], event_encode) 21 | 22 | source_data_.append(event_tuple) 23 | 24 | print('encoding finished') 25 | return source_data_ 26 | 27 | 28 | if __name__ == '__main__': 29 | source_data_fn = '../../data/gdelt/gdelt.pkl' 30 | 31 | source_data = file_uri_reader_processor(source_data_fn)['data'] 32 | 33 | source_data_ = encode_event_text(source_data, method='sparse') 34 | 35 | source_data_to_save = {'num_entity': 2279, 36 | 'num_rel': 20, 37 | 'data': source_data_} 38 | 39 | file_uri_writer_processor(source_data_to_save, 'gdelt_sparse.pkl') -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .log_utils import LogWriter 2 | from .misc import py_assert, create_folder, make_config_string, save_yaml_config, load_yaml_config, load_pickle 3 | 4 | __all__ = ['LogWriter', 5 | 'py_assert', 6 | 'create_folder', 7 | 'make_config_string', 8 | 'save_yaml_config', 9 | 'load_yaml_config', 10 | 'load_pickle'] 11 | -------------------------------------------------------------------------------- /utils/bert_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sentence_transformers import SentenceTransformer 3 | from sentence_transformers.util import cos_sim 4 | 5 | 6 | class SentenceEncoder: 7 | def __init__(self): 8 | self.model = SentenceTransformer('all-MiniLM-L6-v2') 9 | 10 | def encode_event(self, event_text): 11 | emb = self.model.encode(event_text) 12 | print(np.shape(emb)) 13 | return emb 14 | 15 | 16 | 17 | if __name__ == '__main__': 18 | encoder = SentenceEncoder() 19 | emb_1 = encoder.encode_event('subject name:CN; object:USA, relation: fight') 20 | emb_2 = encoder.encode_event('subject name:US Gov; object name: US, relation: make statement') 21 | print(cos_sim(emb_1, emb_2)) -------------------------------------------------------------------------------- /utils/general.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import os 4 | import pickle 5 | from typing import List 6 | 7 | import pandas as pd 8 | 9 | 10 | def setup_seed(seed=9899): 11 | import random 12 | import torch 13 | torch.manual_seed(seed) 14 | torch.cuda.manual_seed_all(seed) 15 | np.random.seed(seed) 16 | random.seed(seed) 17 | torch.backends.cudnn.benchmark = False 18 | torch.backends.cudnn.deterministic = True 19 | os.environ["PYTHONHASHSEED"] = str(seed) 20 | 21 | 22 | def file_uri_writer_processor(data, path: str, **kwargs): 23 | dirname = os.path.dirname(path) 24 | if dirname: 25 | os.makedirs(dirname, exist_ok=True) 26 | if path.endswith('pkl') or path.endswith('pickle'): 27 | pickle.dump(data, open(path, 'wb')) 28 | else: 29 | # treat file as csv file 30 | data.to_csv(path, index=False) 31 | return path 32 | 33 | 34 | def file_uri_reader_processor(uri, columns=None, **kwargs): 35 | if uri.endswith('csv'): 36 | data = pd.read_csv(uri, sep='\n', delimiter=',', usecols=columns) 37 | elif uri.endswith('pkl') or uri.endswith('pickle'): 38 | with open(uri, 'rb') as file: 39 | try: 40 | data = pickle.load(file, encoding='latin-1') 41 | except: 42 | data = pickle.load(file) 43 | else: 44 | # treat file as csv file 45 | data = pd.read_csv(uri, sep='\n', delimiter=',', usecols=columns) 46 | return data 47 | 48 | 49 | def list_of_dict_to_dict(list_of_dicts): 50 | if not list_of_dicts: 51 | raise ValueError("The list of dicts is empty") 52 | 53 | dict_of_lists = {key: np.array([d[key] for d in list_of_dicts]) for key in list_of_dicts[0]} 54 | 55 | return dict_of_lists 56 | 57 | 58 | def get_value_by_key(key, list_dict): 59 | return [x[key] for x in list_dict] 60 | 61 | 62 | def dill_serialized_execute_func(f, *input_tuple): 63 | """Load the function that has been serialized by dill, and subsequently execute it. 64 | 65 | Args: 66 | f: function 67 | The function that has been serialized by dill. 68 | *input_tuple: 69 | 70 | Returns: 71 | The output of 'f'. 72 | """ 73 | # noinspection PyBroadException 74 | try: 75 | import dill 76 | f = dill.loads(f) 77 | return f(*input_tuple) 78 | except Exception: 79 | return None 80 | 81 | 82 | class ProcessPool(object): 83 | def __init__(self, num_processes: int = None, interval_sec: int = 0): 84 | """ A pool of processing. 85 | 86 | Args: 87 | num_processes: int 88 | Number of concurrently executing processes. 89 | interval_sec: int 90 | Interval seconds between tasks. (Invalid when use async_pool) 91 | """ 92 | self.num_processes = num_processes if num_processes is not None else os.cpu_count() 93 | self.interval_sec = interval_sec 94 | 95 | def map(self, *args, **kwargs): 96 | return self.run(*args, **kwargs) 97 | 98 | def run(self, target, dynamic_param_list: List[tuple], static_param: tuple = None, **kwargs): 99 | """ Run 'target' function. 100 | 101 | Args: 102 | target: function 103 | Function like func(*static_param, *dynamic_param) 104 | dynamic_param_list: List[tuple] 105 | static_param: tuple 106 | 107 | Returns: 108 | List of result in every target function. (drop the output which is None) 109 | """ 110 | static_param = static_param if static_param else tuple() 111 | if dynamic_param_list is None or len(dynamic_param_list) == 0: 112 | return None 113 | 114 | if self.num_processes > 0: 115 | res_list = self._run_async_pool( 116 | target, 117 | dynamic_param_list=dynamic_param_list, 118 | static_param=static_param, 119 | chunksize=kwargs.get('chunksize') 120 | ) 121 | else: 122 | res_list = self._run_sync(target, dynamic_param_list=dynamic_param_list, static_param=static_param) 123 | 124 | return res_list 125 | 126 | # --------------------------- Private Functions ---------------------------------- 127 | def _run_sync(self, target, dynamic_param_list: List[tuple], static_param: tuple): 128 | """ Run 'target' in main process for each task. 129 | 130 | Args: 131 | target: 132 | dynamic_param_list: 133 | static_param: 134 | 135 | Returns: 136 | 137 | """ 138 | result_list = list() 139 | for dp in dynamic_param_list: 140 | out = target(*static_param, *dp) 141 | if out is not None: 142 | result_list.append(out) 143 | return result_list 144 | 145 | def _run_async_pool(self, target, dynamic_param_list: List[tuple], static_param: tuple, chunksize: int): 146 | """ Run 'target' in multiprocessing using pool. 147 | 148 | Args: 149 | target: 150 | dynamic_param_list: 151 | static_param: tuple 152 | chunksize: int 153 | If None, then set it to num_tasks // num_processes 154 | 155 | Returns: 156 | List of result 157 | """ 158 | import multiprocessing as mp 159 | import dill 160 | 161 | target = dill.dumps(target) 162 | param_list = [(target, *static_param, *dp) for dp in dynamic_param_list] 163 | if chunksize is None: 164 | chunksize = max(len(param_list) // self.num_processes, 1) 165 | with mp.Pool(self.num_processes) as pool: 166 | out_list = pool.starmap( 167 | func=dill_serialized_execute_func, 168 | iterable=param_list, 169 | chunksize=chunksize 170 | ) 171 | res_list = [] 172 | for res in out_list: 173 | if res is not None: 174 | res_list.append(res) 175 | return res_list 176 | 177 | def __enter__(self): 178 | return self 179 | 180 | def __exit__(self, exc_type, exc_val, exc_tb): 181 | pass 182 | -------------------------------------------------------------------------------- /utils/gpt_utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from datetime import datetime 3 | from datetime import timedelta 4 | 5 | import numpy as np 6 | 7 | 8 | def get_entity_map(pkl_dir): 9 | with open(pkl_dir, 'rb') as f: 10 | entity_map = pickle.load(f) 11 | return entity_map 12 | 13 | 14 | try: 15 | AMAZON_TYPE_MAP = get_entity_map('../../data/gdelt/amazon_type_map.pkl') 16 | except: 17 | print('please use amazon type map pkl!') 18 | pass 19 | 20 | ENTITY_MAP = get_entity_map('../../data/gdelt/entity_map.pkl') 21 | 22 | EVENT_TYPE_MAP = {0: 'MAKE STATEMENT', 23 | 1: 'APPEAL', 24 | 2: 'EXPRESS INTENT TO COOPERATE', 25 | 3: 'CONSULT', 26 | 4: 'ENGAGE IN DIPLOMATIC COOPERATION', 27 | 5: 'ENGAGE IN MATERIAL COOPERATION', 28 | 6: 'PROVIDE AID', 29 | 7: 'YIELD', 30 | 8: 'INVESTIGATE', 31 | 9: 'DEMAND', 32 | 10: 'DISAPPROVE', 33 | 11: 'REJECT', 34 | 12: 'THREATEN', 35 | 13: 'PROTEST', 36 | 14: 'EXHIBIT MILITARY POSTURE', 37 | 15: 'REDUCE RELATIONS', 38 | 16: 'COERCE', 39 | 17: 'ASSAULT', 40 | 18: 'FIGHT', 41 | 19: 'ENGAGE IN UNCONVENTIONAL MASS VIOLENCE'} 42 | 43 | 44 | def generate_gdelt_prompt_from_pred_relation(model_output, event_seq, top_k: int = 5): 45 | """Generate the prompt that describes the effect event. 46 | """ 47 | # find the top-k relations 48 | pred_rel_list = model_output['pred_relation'] 49 | 50 | topk_rels = np.argsort(pred_rel_list)[-top_k:] 51 | original_idx = model_output['original_idx'] 52 | original_event = event_seq[original_idx] 53 | label_rel = original_event[2] 54 | 55 | all_rels = set(topk_rels.tolist() + [label_rel]) 56 | parts = original_event[-1].split(';') 57 | prompt_dict = { 58 | rel: 'effect event\n' + 59 | 'event type:' + EVENT_TYPE_MAP[rel] + '\n' + 60 | parts[-3].strip() + '\n' + 61 | parts[-2].strip() + '\n' + 62 | parts[-1].strip() + '\n' 63 | for rel in all_rels 64 | } 65 | return prompt_dict 66 | 67 | 68 | def generate_gdelt_prompt_v2(model_output, event_seq, top_k: int = 5, pred_relation: bool = True): 69 | """Generate the prompt that describes the effect event. 70 | """ 71 | # find the top-k relations 72 | if pred_relation: 73 | pred_list = model_output['pred_relation'] 74 | else: 75 | pred_list = model_output['pred_object'] 76 | 77 | if len(pred_list.shape) > 1: 78 | pred_list = pred_list[0, :] 79 | 80 | top_k = np.argsort(pred_list)[-top_k:] 81 | original_idx = model_output['original_idx'] 82 | original_event = event_seq[original_idx] 83 | 84 | if pred_relation: 85 | label = original_event[2] 86 | else: 87 | label = original_event[1] # object 88 | 89 | all_preds = set(top_k.tolist() + [label]) 90 | parts = original_event[-1].split(';') 91 | 92 | if pred_relation: 93 | prompt_dict = { 94 | rel: 'effect event\n' + 95 | 'event type:' + EVENT_TYPE_MAP[rel] + '\n' + 96 | parts[-3].strip() + '\n' + 97 | parts[-2].strip() + '\n' + 98 | parts[-1].strip() + '\n' 99 | for rel in all_preds 100 | } 101 | else: 102 | prompt_dict = { 103 | obj: 'effect event\n' + 104 | parts[-4].strip() + '\n' + 105 | parts[-3].strip() + '\n' + 106 | parts[-2].strip() + '\n' + 107 | 'object name:' + ENTITY_MAP[obj] + '\n' 108 | for obj in all_preds 109 | } 110 | return prompt_dict 111 | 112 | 113 | def generate_gdelt_comb_prompt(model_output, event_seq, top_k: int = 5, filter_hit=True): 114 | # find the top-k relations 115 | pred_rel_obj_dict = { 116 | '{0}_{1}'.format(np.sum(item['pred_relation']), np.sum(item['pred_object'])): np.sum(item['pred_score']) 117 | for item in model_output['pred_rel_obj']} 118 | pred_rel_obj_topk_list = sorted(pred_rel_obj_dict.items(), key=lambda item: item[1], reverse=True)[:top_k] 119 | pred_rel_obj_topk = [pair.split('_') for pair, _ in pred_rel_obj_topk_list] 120 | 121 | original_idx = model_output['original_idx'] 122 | original_event = event_seq[original_idx] 123 | 124 | is_hit = '{0}_{1}'.format(original_event[2], original_event[1]) in pred_rel_obj_dict 125 | 126 | if not is_hit: 127 | if filter_hit: 128 | return {} 129 | else: 130 | pred_rel_obj_topk.append([str(original_event[2]), str(original_event[1])]) 131 | 132 | parts = original_event[-1].split(';') 133 | prompt_dict = { 134 | rel + '_' + obj: 'effect event\n' + 135 | 'event type:' + EVENT_TYPE_MAP[int(rel)] + '\n' + 136 | parts[-3].strip() + '\n' + 137 | parts[-2].strip() + '\n' + 138 | 'object name:' + ENTITY_MAP[int(obj)] + '\n' 139 | for rel, obj in pred_rel_obj_topk 140 | } 141 | return prompt_dict 142 | 143 | 144 | def get_dtime_top_n(seq_dtimes, num_bins: int = 50, top_k: int = 5): 145 | hist, bin_edges = np.histogram(seq_dtimes, bins=num_bins) 146 | top_k_index = np.argsort(hist)[-top_k:] 147 | return bin_edges[top_k_index] 148 | 149 | 150 | def generate_gdelt_prompt_amazon(model_output, event_seq, top_k: int = 5, pred_type: bool = True): 151 | """Generate the prompt that describes the effect event. 152 | """ 153 | seq_idx = model_output['seq_idx'] 154 | original_idx = model_output['original_idx'] 155 | if pred_type: 156 | # find the top-k relations 157 | pred_list = model_output['pred_type_score'] 158 | topk_pred = np.argsort(pred_list)[-top_k:] 159 | label_type = event_seq[seq_idx][original_idx]['event_type'] 160 | all_types = set(topk_pred.tolist() + [label_type]) 161 | else: 162 | # seq_dtimes = [x['event_dtime'] for x in event_seq[seq_idx][1:original_idx]] 163 | # have to add mbr here 164 | # all_dtimes = get_dtime_top_n(seq_dtimes, top_k=top_k).tolist() 165 | # all_dtimes.append(model_output['pred_dtime']) 166 | all_dtimes = model_output['pred_dtime'] 167 | prev_date = event_seq[seq_idx][original_idx - 1]['event_date'] 168 | prev_date = datetime.strptime(prev_date, "%Y-%m-%d") 169 | all_date = [prev_date + timedelta(days=int(dtime)) for dtime in all_dtimes] 170 | # revert to string 171 | all_date = [date.strftime("%Y-%m-%d") for date in all_date] 172 | label_date = event_seq[seq_idx][original_idx]['event_date'] 173 | all_date = set(all_date + [label_date]) 174 | 175 | parts = event_seq[seq_idx][original_idx]['event_text'].split(';') 176 | 177 | if pred_type: 178 | prompt_dict = { 179 | type: 'effect event\n' + 180 | 'product category:' + AMAZON_TYPE_MAP[type] + '\n' + 181 | parts[-3].strip() + '\n' # event time 182 | for type in all_types 183 | } 184 | else: 185 | prompt_dict = { 186 | date: 'effect event\n' + 187 | parts[0].strip() + '\n' + 188 | 'event time:' + date + '\n' # event time 189 | for date in all_date 190 | } 191 | return prompt_dict 192 | -------------------------------------------------------------------------------- /utils/log_utils.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | 3 | 4 | class LogWriter(object): 5 | 6 | def __init__(self, path, args): 7 | if '' in args: 8 | del args[''] 9 | self.path = path 10 | self.args = args 11 | with open(self.path, 'w') as f: 12 | f.write("Training Log\n") 13 | f.write("Specifications\n") 14 | for argname in self.args: 15 | f.write("{} : {}\n".format(argname, self.args[argname])) 16 | f.write("Checkpoints:\n") 17 | 18 | def checkpoint(self, to_write): 19 | with open(self.path, 'a') as f: 20 | f.write(to_write + '\n') 21 | 22 | def initBest(self): 23 | self.current_best = { 24 | 'loglik': numpy.finfo(float).min, 25 | 'distance': numpy.finfo(float).max, 26 | 'loss': numpy.finfo(float).max, 27 | 'rmse': numpy.finfo(float).max, 28 | 'acc': numpy.finfo(float).min, 29 | } 30 | self.episode_best = 'NeverUpdated' 31 | 32 | def updateBest(self, key, value, episode): 33 | updated = False 34 | if key == 'loglik' or key == 'acc': 35 | if value > self.current_best[key]: 36 | updated = True 37 | self.current_best[key] = value 38 | self.episode_best = episode 39 | elif key == 'distance' or key == 'loss' or key == 'rmse': 40 | if value < self.current_best[key]: 41 | updated = True 42 | self.current_best[key] = value 43 | self.episode_best = episode 44 | else: 45 | raise Exception("unknown key {}".format(key)) 46 | return updated 47 | -------------------------------------------------------------------------------- /utils/misc_amazon.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from Levenshtein import distance 3 | 4 | from utils.general import ProcessPool 5 | 6 | 7 | def parse_causal_event_from_text_amazon(llm_text_output): 8 | import re 9 | 10 | product_category = re.compile(r"product category:[\s]*(.+)[\s]*").findall(llm_text_output) 11 | product_title = re.compile(r"product title:[\s]*(.+)[\s]*").findall(llm_text_output) 12 | event_times = re.compile(r"event time:[\s]*(.+)[\s]*").findall(llm_text_output) 13 | summary_text = re.compile(r"summary text:[\s]*(.+)[\s]*").findall(llm_text_output) 14 | review_text = re.compile(r"review text:[\s]*(.+)[\s]*").findall(llm_text_output) 15 | 16 | sample_list = [] 17 | for category, title, time, summary, review in zip(product_category, product_title, event_times, summary_text, 18 | review_text): 19 | one = ';'.join([ 20 | 'product category:' + category, 21 | 'product title:' + title, 22 | 'event time:' + time, 23 | 'summary text:' + summary, 24 | 'review text:' + review, 25 | ]) 26 | sample_list.append(one) 27 | 28 | return sample_list 29 | 30 | 31 | def parse_events(model_output): 32 | res = dict() 33 | for event_type, prompt_output_str in model_output.items(): 34 | res[event_type] = parse_causal_event_from_text_amazon(prompt_output_str) 35 | return res 36 | 37 | 38 | def make_event_dict(event_list): 39 | return {'product category': event_list[0], 40 | 'event time': event_list[2], 41 | 'seq_index': event_list[3], 42 | 'original_index': event_list[4]} 43 | 44 | 45 | def retrieve_event_from_source_data_amazon(target_event, source_data, seq_index, original_index, top_n=1): 46 | # search over the previous events 47 | # fix to search over the last 10k events 48 | # source_data_ = source_data[int(original_index) - context_window:int(original_index)] 49 | target_seq = source_data[seq_index] 50 | source_data_ = target_seq[:original_index] 51 | # scores = [levenshteinDistance(target_event, event[4]) for event in source_data_] 52 | scores = [distance(target_event, event[4]) for event in source_data_] 53 | sort_index = np.argsort(np.array(scores)) 54 | retro_events = [] 55 | # to do list selection, we need to convert source data to np.array, which may be slow 56 | # so i do the list append here. 57 | select_index = sort_index[:top_n] 58 | for i in select_index: 59 | retro_events.append(make_event_dict(source_data_[i])) 60 | return retro_events 61 | 62 | 63 | def get_event(source_data, seq_idx, original_idx): 64 | seq = source_data[seq_idx] 65 | event = (seq['event_type'][original_idx], seq['event_time'][original_idx], seq_idx, original_idx) 66 | return event 67 | 68 | 69 | def is_event_existed(target_event, event_list): 70 | for event in event_list: 71 | s = 0 72 | for k, v in target_event.items(): 73 | s += abs(target_event[k] - event[k]) 74 | if s < 0.001: 75 | return True 76 | return False 77 | 78 | 79 | def make_complete_event_sequence(real_event, retro_event_dict, pred_name='rel'): 80 | """ Add the last event to the samples """ 81 | 82 | event_type, event_time, seq_idx, original_idx = real_event 83 | 84 | real_event_seq_sample = None 85 | noise_event_seq_sample = [] 86 | noise_event_target = [] 87 | real_event_target = label_rel 88 | for pred_target, causal_evts in retro_event_dict.items(): 89 | if find_label(pred_target, real_event, pred_name): 90 | real_event_seq_sample = causal_evts 91 | # append the last event 92 | real_event_seq_sample.append(make_event_dict(real_event)) 93 | else: 94 | noise_seq_ = causal_evts 95 | noise_seq_.append(make_event_dict(event_time, seq_idx, original_idx)) 96 | noise_event_target.append(int(pred_target)) 97 | noise_event_seq_sample.append(noise_seq_) 98 | 99 | return real_event_seq_sample, noise_event_seq_sample, real_event_target, noise_event_target 100 | 101 | 102 | def make_samples_for_energy_function_amazon(db, source_data, retro_top_n=2): 103 | import dictdatabase as DDB 104 | 105 | ebm_db_name = 'ebm_dataset_amazon' 106 | 107 | def _process(s_idx, o_idx, p_e_d): 108 | retro_event_dict = {} 109 | for pred_type, prompt_event_list in p_e_d.items(): 110 | retro_event_list = [ 111 | retrieve_event_from_source_data_amazon( 112 | target_event=prompt_event, 113 | source_data=source_data, 114 | seq_index=s_idx, 115 | original_index=o_idx, 116 | top_n=retro_top_n) 117 | for prompt_event in prompt_event_list 118 | ] 119 | 120 | # do a flatten 121 | retro_event_list = [item for sublist in retro_event_list for item in sublist] 122 | 123 | # sort by time and drop duplicate 124 | retro_event_list_ = [] 125 | for i, event in enumerate(retro_event_list): 126 | if i == 0: 127 | retro_event_list_.append(event) 128 | continue 129 | else: 130 | if not is_event_existed(event, retro_event_list_): 131 | retro_event_list_.append(event) 132 | 133 | retro_event_list_ = sorted(retro_event_list_, key=lambda x: x['event_time']) 134 | # save 135 | retro_event_dict[pred_type] = retro_event_list_ 136 | 137 | real_event_ = get_event(source_data, seq_idx=seq_index, original_idx=original_index) 138 | real_event_sample, noise_event_sample, real_type, noise_type = make_complete_event_sequence(real_event_, 139 | retro_event_dict) 140 | 141 | is_label_in_top_5 = list(p_e_d.keys()) == 5 142 | res_i = ( 143 | o_idx, 144 | real_event_sample, 145 | noise_event_sample, 146 | is_label_in_top_5, 147 | real_type, 148 | noise_type 149 | ) 150 | with DDB.at(ebm_db_name).session() as (sess, ebm_db): 151 | ebm_db[(str(s_idx), str(o_idx))] = res_i 152 | sess.write() 153 | print('--- ok', s_idx, o_idx) 154 | # return res_i 155 | 156 | if not DDB.at(ebm_db_name).exists(): 157 | DDB.at(ebm_db_name).create() 158 | 159 | existed_ebm_db_dict = DDB.at(ebm_db_name).read() 160 | 161 | dp_list = [] 162 | for k, v in db.db.items(): 163 | if k in existed_ebm_db_dict and len(existed_ebm_db_dict[k][2]) >= 5: 164 | continue 165 | if len(v) == 0: 166 | continue 167 | 168 | seq_index, original_index = k 169 | 170 | # parse the prompt 171 | res_i_prompt_event_dict = parse_events(v) 172 | if len(res_i_prompt_event_dict) == 0: 173 | continue 174 | 175 | dp_list.append( 176 | (seq_index, original_index, res_i_prompt_event_dict) 177 | ) 178 | 179 | with ProcessPool() as pool: 180 | pool.run( 181 | target=_process, 182 | dynamic_param_list=dp_list 183 | ) 184 | # file_uri_writer_processor(res_list, 'ebm_dataset_v0326.pkl') 185 | -------------------------------------------------------------------------------- /utils/persist_utils.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | class PersistDB: 5 | def __init__( 6 | self, 7 | storage_uri: str = '.ddb_storage', 8 | use_compression: bool = True, 9 | ): 10 | import dictdatabase 11 | self._db = dictdatabase 12 | 13 | self._db.config.storage_directory = storage_uri 14 | self._db.config.use_compression = use_compression 15 | 16 | @property 17 | def db(self): 18 | return self._db 19 | -------------------------------------------------------------------------------- /utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def count_torch_model_params(model): 5 | """Count the number of params of the model. 6 | 7 | Args: 8 | model (torch.nn.Module): a torch model. 9 | 10 | Returns: 11 | int: total num of the parameters. 12 | """ 13 | return sum(p.numel() for p in model.parameters()) 14 | 15 | 16 | def sample_gumbel(shape, eps=1e-10): 17 | """ 18 | NOTE: Stolen from https://github.com/pytorch/pytorch/pull/3341/commits/327fcfed4c44c62b208f750058d14d4dc1b9a9d3 19 | Sample from Gumbel(0, 1) 20 | based on 21 | https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb , 22 | (MIT license) 23 | """ 24 | U = torch.rand(shape).float() 25 | return - torch.log(eps - torch.log(U + eps)) 26 | 27 | 28 | def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10): 29 | """ 30 | NOTE: Stolen from https://github.com/pytorch/pytorch/pull/3341/commits/327fcfed4c44c62b208f750058d14d4dc1b9a9d3 31 | Sample from the Gumbel-Softmax distribution and optionally discretize. 32 | Args: 33 | logits: [batch_size, n_class] unnormalized log-probs 34 | tau: non-negative scalar temperature 35 | hard: if True, take argmax, but differentiate w.r.t. soft sample y 36 | Returns: 37 | [batch_size, n_class] sample from the Gumbel-Softmax distribution. 38 | If hard=True, then the returned sample will be one-hot, otherwise it will 39 | be a probability distribution that sums to 1 across classes 40 | Constraints: 41 | - this implementation only works on batch_size x num_features tensor for now 42 | based on 43 | https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb , 44 | (MIT license) 45 | """ 46 | y_soft = gumbel_softmax_sample(logits, tau=tau, eps=eps) 47 | if hard: 48 | shape = logits.size() 49 | _, k = y_soft.data.max(-1) 50 | # this bit is based on 51 | # https://discuss.pytorch.org/t/stop-gradients-for-st-gumbel-softmax/530/5 52 | y_hard = torch.zeros(*shape) 53 | if y_soft.is_cuda: 54 | y_hard = y_hard.cuda() 55 | y_hard = y_hard.zero_().scatter_(-1, k.view(shape[:-1] + (1,)), 1.0) 56 | # this cool bit of code achieves two things: 57 | # - makes the output value exactly one-hot (since we add then 58 | # subtract y_soft value) 59 | # - makes the gradient equal to y_soft gradient (since we strip 60 | # all other gradients) 61 | y = Variable(y_hard - y_soft.data) + y_soft 62 | else: 63 | y = y_soft 64 | return y 65 | --------------------------------------------------------------------------------