├── LICENSE ├── README.md ├── dataset ├── eval_dataset.json └── train_dataset.zip ├── pics ├── gig.jpg └── gig.pdf └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 ChenhanYuan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

4 | 5 |

6 |
7 | 8 |

9 | 🤗 Hugging Face   |    📑 Paper 10 |
11 | 12 | | Models | TimeLlama-7b | ChatTimeLlama-7b | TimeLlama-13b | ChatTimeLlama-13b | 13 | |---------------------|:------------:|:----------------:|:-------------:|------------------:| 14 | | Huggingface Repo |🤗|🤗|🤗|🤗| 15 | 16 | ## News and Updates 17 | 18 | * 2023.9.30 🔥 The TimeLlama series models are available on huggingface. 19 |
20 | 21 | ## Introduction 22 | 23 | This repository contains the code and dataset for our work on explainable temporal reasoning. Temporal reasoning involves predicting future events based on understanding the temporal relationships between events described in the text. Explainability is critical for building trust in AI systems that make temporal predictions. 24 | 25 | In this work, we introduce the first multi-source dataset for explainable temporal reasoning, called **ExpTime**. The dataset contains 26k examples derived from temporal knowledge graph datasets. Each example includes a context with multiple events, a future event to predict, and an explanation for the prediction in the form of temporal reasoning over the events. 26 | 27 | To generate the dataset, we propose a novel knowledge-graph-instructed-generation strategy. The dataset supports the comprehensive evaluation of large language models on complex temporal reasoning, future event prediction, and explainability. 28 | 29 | Based on ExpTime, we develop **TimeLlaMA**, a series of LLM models fine-tuned for explainable temporal reasoning. TimeLlaMA builds on the foundation LLM LLaMA-2 and utilizes instruction tuning to follow prompts for making explanations. 30 | 31 | The code in this repo allows training TimeLlaMA models on ExpTime and evaluating their temporal reasoning and explanation abilities. We open-source the code, dataset, and models to provide a basis for future work on explainable AI. 32 |

33 | 34 | ## Knowledge Graph-Instructed Generation (GIG) Strategy 35 | 36 |

37 | 38 |

39 | Recent work has shown promise in using large language models (LLMs) like ChatGPT to automatically generate datasets by prompting the model to produce answers. However, directly prompting LLMs to generate temporal reasoning explanations results in low-quality and incoherent outputs. 40 | 41 | To address this challenge, we propose a novel framework called **Temporal Knowledge Graph-instructed Generation (GIG)** to produce more accurate and coherent reasoning explanations. The key idea is to leverage temporal knowledge graphs (TKGs), which have been effectively utilized for explainable event forecasting. Our approach first applies explainable TKG reasoning models to generate reasoning paths for a given query about a future event. We then convert these paths into natural language explanations using a two-level prompting technique. Next, we identify relevant context from the TKG and reasoning paths to construct a coherent context document. Finally, we convert the original query into a question to produce a complete training instance. 42 | 43 | In this way, our GIG framework overcomes the limitations of directly prompting LLMs by leveraging structured knowledge in TKGs to generate higher-quality temporal reasoning explanations. We believe this approach could enable more effective use of LLMs guided by knowledge graphs for automated dataset creation. 44 | 45 | *Stay tuned for the code release of our GIG framework* 46 |
47 | 48 | ## Dataset 49 | We release the first-of-its-kind Explainable Temporal Event Forecasting (ExpTime) dataset, which aims to assess and enhance the complex temporal reasoning capabilities of large language models (LLMs). The dataset has the following format: 50 | ```json 51 | [ 52 | { 53 | "instruction": "Given the following document, is there a potential that......", 54 | "input": "In the context of Egypt, on April 19, 2014......", 55 | "output": "Yes. Based on the information provided by the document......", 56 | "label": "pos" 57 | } 58 | ] 59 | ``` 60 | In each sample of the ExpTime dataset, the instruction provides the query if an event will happen in the future. The input provides a context document about past events information, and the output is the prediction along with explanations. The label "pos", "neg", and "unsure" denotes if the answer should be "yes", "no", or "unsure", respectively. The dataset can be found in the dataset folder, where "train_dataset.json" is the training set and "eval_dataset.json" is the human-annotated golden testing set. 61 | 62 | ## 🤗 Inference 63 | To use the TimeLlama series for the inference, all you need to do is write the following codes. 64 | ```python 65 | from transformers import LlamaConfig, LlamaTokenizer, LlamaForCausalLM 66 | # Model names: "chrisyuan45/TimeLlama-7b-chat", "chrisyuan45/TimeLlama-13b-chat" 67 | model = LlamaForCausalLM.from_pretrained( 68 | model_name, 69 | return_dict=True, 70 | load_in_8bit=quantization, 71 | device_map="auto", 72 | low_cpu_mem_usage=True) 73 | tokenizer = LlamaTokenizer.from_pretrained(model_name) 74 | ``` 75 | However, if you prefer no coding, we also prepare a chat script for you (of course!). All you need to do is: 76 | ```python 77 | python ChatwithModel.py --model_name chrisyuan45/TimeLlama-7b-chat 78 | ``` 79 | 80 | ## Finetune 81 | 82 | We provide our finetuning code in "train.py". To run the finetuning code, you need to have access to meta-llama models. Click here to submit the request form [Llama2 Access](https://ai.meta.com/resources/models-and-libraries/llama-downloads/). Then, please generate your own access token on huggingface and replace line 23 in train.py accordingly. 83 | 84 | To run the finetuning code on multi-GPUs, please consider adopting the following script: 85 | ``` 86 | #!/bin/bash 87 | 88 | torchrun --nproc_per_node 8 --master_port 14545 train.py\ 89 | --data_path dataset/train_dataset.json \ 90 | --output_dir model13chat \ 91 | --num_train_epochs 70 \ 92 | --per_device_train_batch_size 1 \ 93 | --per_device_eval_batch_size 1 \ 94 | --tf32 True \ 95 | --bf16 True \ 96 | --gradient_accumulation_steps 4 \ 97 | --weight_decay 0.01 \ 98 | --warmup_ratio 0.05 \ 99 | --lr_scheduler_type "cosine" \ 100 | --logging_steps 1 \ 101 | --gradient_checkpointing True \ 102 | --disable_tqdm False \ 103 | --learning_rate 5e-5 \ 104 | --fsdp "full_shard offload auto_wrap" \ 105 | --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \ 106 | ``` 107 | ## Performance 108 | 109 | We evaluate the most popular LLMs on our golden human-annotated ExpTime evaluation dataset besides the TimeLlama series. The prediction is evaluated by the F1 score and explanation correctness is evaluated via BLEU, ROURGE, and BertScore. We also included human evaluation results in our paper. The brief evaluation results are shown here: 110 | 111 | | Models | Pos F1 | Neg F1 | Neu F1 | Overall F1 | BLEU | ROUGE | BertScore | 112 | |---------------------|:------:|:------:|:------:|:----------:|:----:|:-----:|:---------:| 113 | | Flan T5 | 39.9 | 40.5 | 31.5 | 38.0 | 15.2 | 26.0 | 76.9 | 114 | | BART | 34.9 | 16.2 | 19.8 | 25.3 | 8.9 | 19.7 | 74.9 | 115 | | MPT-7B | 55.4 | 37.5 | 18.7 | 40.3 | 10.7 | 27.2 | 80.1 | 116 | | Falcon-7B | 51.7 | 27.8 | 21.5 | 36.5 | 19.8 | 29.3 | 79.9 | 117 | | Vicuna-7B | 60.4 | 28.1 | 22.6 | 40.4 | 23.5 | 37.2 | 83.3 | 118 | | ChatGPT | 54.7 | 30.5 | 39.8 | 43.5 | 31.1 | 37.1 | 83.7 | 119 | | Llama2-7B-chat | 62.7 | 19.8 | 22.0 | 39.1 | 26.8 | 38.4 | 83.8 | 120 | | Llama2-13B-chat | 52.5 | 31.5 | 31.8 | 40.7 | 25.5 | 36.6 | 83.4 | 121 | | TimeLlama2-7B | 93.7 | 75.3 | 70.5 | 81.5 | 59.9 | 56.5 | 90.2 | 122 | | TimeLlama2-13B | 97.2 | 81.7 | 77.5 | 87.3 | 44.6 | 54.9 | 89.4 | 123 | | TimeLlama2-7B-chat | 95.2 | 76.1 | 71.2 | 83.1 | 61.9 | 57.7 | 90.4 | 124 | | TimeLlama2-13B-chat | 97.9 | 83.4 | 78.5 | 88.4 | 46.3 | 56.3 | 89.7 | 125 | 126 | ## Citation 127 | 128 | Consider citing our paper if you find the repo useful ;) 129 | 130 | ``` 131 | @article{yuan2023back, 132 | title={Back to the Future: Towards Explainable Temporal Reasoning with Large Language Models}, 133 | author={Yuan, Chenhan and Xie, Qianqian and Huang, Jimin and Ananiadou, Sophia}, 134 | journal={arXiv preprint arXiv:2310.01074}, 135 | year={2023} 136 | } 137 | ``` 138 | -------------------------------------------------------------------------------- /dataset/train_dataset.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenhan97/TimeLlama/40527d9d55ed3311e052b7e92268bb7124a67cda/dataset/train_dataset.zip -------------------------------------------------------------------------------- /pics/gig.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenhan97/TimeLlama/40527d9d55ed3311e052b7e92268bb7124a67cda/pics/gig.jpg -------------------------------------------------------------------------------- /pics/gig.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenhan97/TimeLlama/40527d9d55ed3311e052b7e92268bb7124a67cda/pics/gig.pdf -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | import json 3 | import pathlib 4 | from typing import Dict, Optional, Sequence 5 | 6 | import numpy as np 7 | import torch 8 | import copy 9 | from torch.utils.data import Dataset 10 | import transformers 11 | from transformers import Trainer 12 | from transformers.trainer_pt_utils import LabelSmoother 13 | import warnings 14 | #from fastchat.train.llama_flash_attn_monkey_patch import ( 15 | # #replace_llama_attn_with_flash_attn, 16 | #) 17 | 18 | #replace_llama_attn_with_flash_attn() 19 | 20 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index 21 | 22 | #add your own token here 23 | access_token = "" 24 | 25 | @dataclass 26 | class ModelArguments: 27 | model_name_or_path: Optional[str] = field(default="meta-llama/Llama-2-13b-chat-hf") 28 | flash_attn: bool = False 29 | 30 | 31 | @dataclass 32 | class DataArguments: 33 | data_path: str = field( 34 | default="dataset/train_dataset.json", metadata={"help": "Path to the training data."} 35 | ) 36 | lazy_preprocess: bool = False 37 | 38 | 39 | @dataclass 40 | class TrainingArguments(transformers.TrainingArguments): 41 | cache_dir: Optional[str] = field(default=None) 42 | optim: str = field(default="adamw_torch") 43 | output_dir: str = field(default="model") 44 | evaluation_strategy: str = field(default="epoch") 45 | #save_strategy: str = field(default="epoch") 46 | save_strategy: str = field(default="steps") 47 | save_steps: int = field(default=1400) 48 | save_total_limit: int = field(default=3) 49 | #deepspeed: str = field(default="deepspeed_config.json") 50 | learning_rate: float = field(default=1e-4) 51 | weight_decay: float = field(default=0.01) 52 | warmup_ratio: float = field(default=1e-3) 53 | lr_scheduler_type: str = field(default="linear") 54 | gradient_accumulation_steps: int = field(default=1) 55 | #load_best_model_at_end: bool = field(default=True) 56 | # model_max_length: int = field( 57 | # default=512, 58 | # metadata={ 59 | # "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." 60 | # }, 61 | # ) 62 | 63 | 64 | local_rank = None 65 | 66 | 67 | def rank0_print(*args): 68 | if local_rank == 0: 69 | print(*args) 70 | 71 | 72 | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): 73 | """Collects the state dict and dump to disk.""" 74 | state_dict = trainer.model.state_dict() 75 | if trainer.args.should_save: 76 | cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} 77 | del state_dict 78 | trainer._save(output_dir, state_dict=cpu_state_dict) # noqa 79 | 80 | 81 | PROMPT_DICT = { 82 | "prompt_input": ( 83 | "Below is an instruction that describes a task, paired with an input that provides further context. " 84 | "Write a response that appropriately completes the request.\n\n" 85 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" 86 | ), 87 | "prompt_no_input": ( 88 | "Below is an instruction that describes a task. " 89 | "Write a response that appropriately completes the request.\n\n" 90 | "### Instruction:\n{instruction}\n\n### Response:" 91 | ), 92 | } 93 | 94 | 95 | class InstructionDataset(Dataset): 96 | def __init__(self, data_path, tokenizer, partition="train", max_words=1024): 97 | self.ann = json.load(open(data_path)) 98 | if partition == "train": 99 | self.ann = self.ann[:] 100 | else: 101 | self.ann = self.ann[:100] 102 | 103 | self.max_words = max_words 104 | # tokenizer = Tokenizer(model_path=model_path + "./tokenizer.model") 105 | self.tokenizer = tokenizer 106 | # self.tokenizer1 = tokenizer 107 | 108 | def __len__(self): 109 | return len(self.ann) 110 | 111 | def __getitem__(self, index): 112 | ann = self.ann[index] 113 | if ann.get("input", "") == "": 114 | prompt = PROMPT_DICT["prompt_no_input"].format_map(ann) 115 | else: 116 | prompt = PROMPT_DICT["prompt_input"].format_map(ann) 117 | example = prompt + ann["output"] 118 | prompt = torch.tensor( 119 | self.tokenizer.encode(prompt), dtype=torch.int64 120 | ) 121 | example = self.tokenizer.encode(example) 122 | example.append(self.tokenizer.eos_token_id) 123 | 124 | example = torch.tensor(example, dtype=torch.int64) 125 | padding = self.max_words - example.shape[0] 126 | if padding > 0: 127 | example = torch.cat((example, torch.zeros(padding, dtype=torch.int64) - 1)) 128 | elif padding < 0: 129 | example = example[: self.max_words] 130 | labels = copy.deepcopy(example) 131 | labels[: len(prompt)] = -1 132 | example_mask = example.ge(0) 133 | label_mask = labels.ge(0) 134 | example[~example_mask] = 0 135 | labels[~label_mask] = 0 136 | example_mask = example_mask.float() 137 | label_mask = label_mask.float() 138 | 139 | return { 140 | "input_ids": example, 141 | "labels": labels, 142 | "attention_mask": example_mask, 143 | } 144 | 145 | 146 | def make_supervised_data_module( 147 | tokenizer: transformers.PreTrainedTokenizer, data_args 148 | ) -> Dict: 149 | """Make dataset and collator for supervised fine-tuning.""" 150 | dataset_cls = ( 151 | InstructionDataset 152 | ) 153 | rank0_print("Loading data...") 154 | train_dataset = dataset_cls(data_args.data_path, tokenizer=tokenizer) 155 | eval_dataset = dataset_cls(data_args.data_path, tokenizer=tokenizer, partition="eval") 156 | #eval_dataset = dataset_cls("dataset/eval_dataset.json", tokenizer=tokenizer) 157 | return dict(train_dataset=train_dataset, eval_dataset=eval_dataset) 158 | 159 | def trainer_save_model_safe(trainer: transformers.Trainer): 160 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 161 | from torch.distributed.fsdp import StateDictType, FullStateDictConfig 162 | 163 | save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) 164 | with FSDP.state_dict_type( 165 | trainer.model, StateDictType.FULL_STATE_DICT, save_policy 166 | ): 167 | trainer.save_model() 168 | 169 | def train(): 170 | global local_rank 171 | 172 | parser = transformers.HfArgumentParser( 173 | (ModelArguments, DataArguments, TrainingArguments) 174 | ) 175 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 176 | local_rank = training_args.local_rank 177 | model = transformers.LlamaForCausalLM.from_pretrained( 178 | model_args.model_name_or_path, 179 | cache_dir=training_args.cache_dir, 180 | token=access_token 181 | ) 182 | model.config.use_cache = False 183 | tokenizer = transformers.LlamaTokenizer.from_pretrained( 184 | model_args.model_name_or_path, 185 | token=access_token 186 | ) 187 | tokenizer.add_special_tokens( 188 | { 189 | 190 | "pad_token": "", 191 | } 192 | ) 193 | 194 | data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) 195 | trainer = Trainer( 196 | model=model, tokenizer=tokenizer, args=training_args, **data_module 197 | ) 198 | 199 | 200 | #if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): 201 | # trainer.train(resume_from_checkpoint=False) 202 | #else: 203 | trainer.train() 204 | model.config.use_cache = True 205 | trainer.save_state() 206 | trainer_save_model_safe(trainer) 207 | #safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir) 208 | 209 | 210 | if __name__ == "__main__": 211 | train() 212 | --------------------------------------------------------------------------------