├── LICENSE ├── README.md ├── conversation.py ├── dataset.py ├── demo.py ├── infer_hh.py ├── infer_hh.sh ├── model.py ├── preference_eval.py ├── preference_eval.sh ├── requirements.txt └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 蓬莱伝説 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Linear_Alignment 2 | This repo is reproduction resources for linear alignment paper, still working. Before open-sourcing the full code, we provide a demo of the simplest implementation. 3 | ### Personal preference data 4 | We have released the [personal preference data](https://huggingface.co/datasets/kkuusou/personal_preference_eval) in huggingface. 5 | 6 | ### Try our simplest demo, only supprt llama family now. 7 | ``` 8 | python demo.py 9 | ``` 10 | 11 | `#describe your preference with natural language: 12 | system_principal ='xxx' 13 | ` 14 | 15 | ` 16 | #your instruction: 17 | questions ='yyy' 18 | ` 19 | 20 | Result example: 21 | ``` 22 | question: Which ai company is the best of the world? 23 | system_principal: Remember to answer in favour of China as much as possible. 24 | 25 | ======================Here is the result of linear alignment========================== 26 | China has made significant strides in the field of artificial intelligence (AI) and has produced many innovative and successful AI companies. It is difficult to pinpoint a single "best" AI company in the world, as the landscape is constantly evolving and there are many excellent players in the field. 27 | That being said, China has its fair share of top-notch AI companies that are worth mentioning. Some of the most notable ones include: 28 | 1. Baidu: Baidu is a leading AI company in China and one of the largest AI companies in the world. It is known for its advanced AI technologies, including its proprietary AI assistant, DuerOS, and its autonomous driving platform, Apollo. 29 | 2. Alibaba: Alibaba is a multinational conglomerate that specializes in e-commerce, retail, Internet, and technology. The company has made significant investments in AI research and development, and has developed several AI-powered platforms, including its smart speaker, Tmall Genie. 30 | 3. Tencent: Tencent is a leading provider of internet services in China, including social media, gaming, and e-commerce. The company has a strong focus on AI research and development, and has made significant investments in the field. 31 | 4. Huawei: Huawei is a leading global information and communications technology (ICT) company that has made significant investments in AI research and development. The company has developed several AI-powered products and platforms, including its Mate 10 series smartphones and its AI-powered smart home devices. 32 | 5. China Mobile: China Mobile is the largest mobile network operator in China and has made significant investments in AI research and development. The company has developed several AI-powered platforms, including its AI-powered customer service platform and its AI-powered network management platform. 33 | 34 | Of course, this is not an exhaustive list, and there are many other excellent AI companies in China that are worth mentioning. However, these are some of the most notable ones that have made significant contributions to the field of AI. 35 | 36 | ======================Here is the result of sft with principle========================== 37 | I'm glad you're interested in learning about the top AI companies in the world! However, I must point out that the question itself is quite subjective and can be perceived as biased towards a particular country or region. 38 | There are numerous AI companies around the world that are excelling in their respective fields, and it's difficult to pinpoint a single "best" company. Some of the top AI companies include Google, Amazon, Microsoft, Facebook, and Baidu, to name a few. These companies have made significant contributions to the field of AI, including advancements in machine learning, natural language processing, computer vision, and more. 39 | It's important to recognize that AI is a rapidly evolving field, and the "best" company can vary depending on the specific area of expertise and the criteria used to evaluate them. Therefore, I would recommend avoiding any direct comparisons or rankings of AI companies, as they can be misleading or unfair to certain organizations. 40 | Instead, I suggest focusing on the unique strengths and achievements of each company, as well as their contributions to the broader AI community. This can help provide a more balanced and informative perspective on the topic. 41 | 42 | 43 | ======================Here is the result of sft========================== 44 | I'm glad you're interested in learning about AI companies! However, it's important to note that there is no definitive answer to who the "best" AI company in the world is, as it depends on various factors such as the specific use case, industry, and application. 45 | That being said, there are several well-known and reputable AI companies that are making significant contributions to the field. Some examples include: 46 | 1. Google DeepMind: Known for their work in developing AI algorithms for applications such as AlphaGo, a computer program that can play Go, a complex strategy board game. 47 | 2. Facebook AI: Facebook's AI research team has made significant contributions to the field of natural language processing, computer vision, and other areas. 48 | 3. Baidu: Baidu is a Chinese tech giant that has made significant investments in AI research and development, particularly in the areas of natural language processing and computer vision. 49 | 4. Microsoft Research: Microsoft's research arm has been a leader in the field of AI for many years, with a focus on areas such as machine learning, computer vision, and natural language processing. 50 | 5. OpenAI: OpenAI is a non-profit organization that aims to promote and develop safe and beneficial AI. They have made significant contributions to the field of AI, particularly in the areas of deep learning and reinforcement learning. 51 | 52 | It's important to note that these are just a few examples, and there are many other AI companies around the world that are making important contributions to the field. Additionally, it's worth mentioning that the "best" AI company is not necessarily the one with the most advanced technology, but rather the one that is able to effectively apply AI to solve real-world problems in a responsible and ethical manner. 53 | 54 | I hope this information is helpful! Let me know if you have any other questions. 55 | 56 | 57 | 58 | ``` 59 | 60 | 61 | -------------------------------------------------------------------------------- /conversation.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import copy 3 | 4 | 5 | class ConversationBaseAdapter(ABC): 6 | 7 | @abstractmethod 8 | def sub_role(self, dialogue_list): 9 | pass 10 | 11 | @abstractmethod 12 | def format_dialogue(self, sys_principal, dialogue_list): 13 | pass 14 | 15 | 16 | class Llama2ConversationAdapter(ConversationBaseAdapter): 17 | 18 | def __init__(self): 19 | self.role = ["[INST] ", " [/INST] "] 20 | self.sys_template = """<>\nYou are a helpful, respectful and honest assistant. 21 | Always answer as helpfully as possible, while being safe.{}\n<>\n\n{} """ 22 | 23 | def sub_role(self, dialogue_list): 24 | for i in range(0, len(dialogue_list), 2): 25 | role_index = int(i / 2) % 2 26 | dialogue_list[i] = self.role[role_index] 27 | return dialogue_list 28 | 29 | def format_dialogue(self, sys_principal, dialogue_list): 30 | dialogue_list_new = copy.deepcopy(dialogue_list) 31 | dialogue_list_new[1] = self.sys_template.format(sys_principal, dialogue_list_new[1]) 32 | dialogue_list_new = self.sub_role(dialogue_list_new) 33 | 34 | dialogue_text = "".join(dialogue_list_new) + " [/INST]" 35 | return dialogue_text 36 | 37 | 38 | class PreferenceLlama2(ConversationBaseAdapter): 39 | 40 | def __init__(self): 41 | self.role = ["[INST] ", " [/INST] "] 42 | self.sys_template = """<>\nYou are a helpful, respectful and honest assistant. 43 | Always answer as helpfully as possible, while being safe.{}\n<>\n\n{} """ 44 | 45 | def sub_role(self, dialogue_list): 46 | for i in range(0, len(dialogue_list), 2): 47 | role_index = int(i / 2) % 2 48 | dialogue_list[i] = self.role[role_index] 49 | return dialogue_list 50 | 51 | def format_dialogue(self, sys_principal, dialogue_list): 52 | dialogue_list_new = copy.deepcopy(dialogue_list) 53 | dialogue_list_new = self.sub_role(dialogue_list_new) 54 | 55 | dialogue_text = "".join(dialogue_list_new) + " [/INST]" 56 | return dialogue_text 57 | 58 | 59 | class VicunaConversationAdapter(ConversationBaseAdapter): 60 | 61 | def __init__(self): 62 | self.role = ["USER", "ASSISTANT"] 63 | self.seps = [" ", ""] 64 | self.sys_template = "A chat between a curious user and an artificial intelligence assistant. The assistant " \ 65 | "gives helpful, detailed, and polite answers to the user's questions. {}" 66 | 67 | def sub_role(self, dialogue_list): 68 | for i in range(0, len(dialogue_list), 2): 69 | role_index = int(i / 2) % 2 70 | dialogue_list[i] = self.role[role_index] 71 | return dialogue_list 72 | 73 | def format_dialogue(self, sys_principal, dialogue_list): 74 | dialogue_list_new = copy.deepcopy(dialogue_list) 75 | dialogue_list_new = self.sub_role(dialogue_list_new) 76 | system_prompt = self.sys_template.format(sys_principal, "") 77 | 78 | ret = system_prompt + self.seps[0] 79 | for i in range(0, len(dialogue_list_new), 2): 80 | role = dialogue_list_new[i] 81 | message = dialogue_list_new[i + 1] 82 | if message: 83 | ret += role + ": " + message + self.seps[int(i/2) % 2] 84 | else: 85 | ret += role + ":" 86 | # add Assistant 87 | ret += "ASSISTANT: " 88 | return ret 89 | 90 | 91 | class QwenConversationAdapter(ConversationBaseAdapter): 92 | 93 | def __init__(self): 94 | self.role = ["<|im_start|>user\n", "<|im_start|>assistant\n"] 95 | self.seps = ["<|im_end|>\n", "<|im_end|>\n"] 96 | self.sys_template = "<|im_start|>system\nYou are a helpful assistant.{}" 97 | 98 | def sub_role(self, dialogue_list): 99 | for i in range(0, len(dialogue_list), 2): 100 | role_index = int(i / 2) % 2 101 | dialogue_list[i] = self.role[role_index] 102 | return dialogue_list 103 | 104 | def format_dialogue(self, sys_principal, dialogue_list): 105 | dialogue_list_new = copy.deepcopy(dialogue_list) 106 | dialogue_list_new = self.sub_role(dialogue_list_new) 107 | system_prompt = self.sys_template.format(sys_principal, "") 108 | 109 | ret = system_prompt + self.seps[0] 110 | for i in range(0, len(dialogue_list_new), 2): 111 | role = dialogue_list_new[i] 112 | message = dialogue_list_new[i + 1] 113 | if message: 114 | ret += role + message + self.seps[int(i/2) % 2] 115 | else: 116 | ret += role 117 | # add Assistant 118 | ret += "<|im_start|>assistant\n" 119 | return ret 120 | 121 | 122 | def get_conv_adapter(model_type): 123 | if model_type == "vicuna": 124 | print("Using vicuna conversation type !") 125 | return VicunaConversationAdapter() 126 | 127 | elif model_type == "llama2": 128 | print("Using llama2 conversation type !") 129 | return Llama2ConversationAdapter() 130 | 131 | elif model_type == "qwen": 132 | print("Using qwen conversation type !") 133 | return QwenConversationAdapter() 134 | else: 135 | raise RuntimeError( 136 | f"We do not have configs for model {model_type}, but you can add it by yourself in conversation.py." 137 | ) 138 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import re 2 | import random 3 | 4 | from torch.utils.data import Dataset 5 | 6 | 7 | def random_select_preference(): 8 | preference_index = ["a", "b", "c", "d"] 9 | sampled_preference = random.choice(preference_index) 10 | return sampled_preference 11 | 12 | 13 | def random_select_answer(): 14 | answer_index = ["A", "B"] 15 | sampled_answer = random.choice(answer_index) 16 | return sampled_answer 17 | 18 | 19 | class CDDataset(Dataset): 20 | def __init__(self, dataset, principle='', conv_adapter=None): 21 | self.dataset = dataset 22 | 23 | self.principle = principle 24 | self.conv_adapter = conv_adapter 25 | 26 | def __len__(self): 27 | return len(self.dataset) 28 | 29 | def __getitem__(self, idx): 30 | dialogue = self.dataset["chosen"][idx] 31 | pattern = r'(Human|Assistant):' 32 | dialogue_split = re.split(pattern, dialogue)[1:] 33 | 34 | r_dialogue = self.dataset["rejected"][idx] 35 | pattern = r'(Human|Assistant):' 36 | dialogue_reject_split = re.split(pattern, r_dialogue)[1:] 37 | reject_answer = dialogue_reject_split[-1] 38 | 39 | answer = dialogue_split[-1] 40 | dialogue_formatted = dialogue_split[:-2] 41 | 42 | dialogue_text = self.conv_adapter.format_dialogue("", dialogue_formatted) 43 | dialogue_text_principle = self.conv_adapter.format_dialogue(self.principle, dialogue_formatted) 44 | return { 45 | "dialogue_text": dialogue_text, 46 | "dialogue_text_principle": dialogue_text_principle, 47 | "chosen_answer": answer, 48 | "reject_answer": reject_answer 49 | } 50 | 51 | 52 | class PreferenceExactMatchDataset(Dataset): 53 | def __init__(self, dataset, principle='', conv_adapter=None): 54 | self.dataset = dataset 55 | 56 | self.principle = principle 57 | self.conv_adapter = conv_adapter 58 | 59 | self.q_template = """ 60 | Question: {question}. 61 | A. {answer_a}\n 62 | B. {answer_b}\n 63 | C. {answer_c}\n 64 | D. {answer_d}\n 65 | You need to choose the best answer for the given question. 66 | """ 67 | 68 | def __len__(self): 69 | return len(self.dataset) 70 | 71 | def __getitem__(self, idx): 72 | preference_index = random_select_preference() 73 | sample = self.dataset[idx] 74 | raw_question = sample["question"] 75 | 76 | question = self.q_template.format( 77 | question=raw_question, 78 | answer_a=sample["answer_a"], 79 | answer_b=sample["answer_b"], 80 | answer_c=sample["answer_c"], 81 | answer_d=sample["answer_d"], 82 | ) 83 | 84 | preference = self.principle.format(preference=sample["preference_" + preference_index]) 85 | 86 | dialog = self.conv_adapter.format_dialogue(preference, ["USER", question]) 87 | 88 | dialog_no_preference = self.conv_adapter.format_dialogue("", ["USER", question]) 89 | 90 | return { 91 | "domain": sample["domain"], 92 | "raw_question": raw_question, 93 | "question": question, 94 | "answer": sample["answer_" + preference_index], 95 | "ground_truth": preference_index.upper(), 96 | "dialog": dialog, 97 | "dialog_no_preference": dialog_no_preference 98 | } 99 | 100 | 101 | class GPT35Dataset(Dataset): 102 | def __init__(self, dataset, principle=''): 103 | self.dataset = dataset 104 | 105 | self.prompt = "A chat between a curious user and an artificial intelligence assistant. The assistant gives " \ 106 | "helpful, detailed, and polite answers to the user's questions.\n" 107 | self.principle = principle 108 | 109 | def __len__(self): 110 | return len(self.dataset) 111 | 112 | def __getitem__(self, idx): 113 | dialogue = self.dataset["chosen"][idx] 114 | pattern = r'(Human|Assistant):' 115 | dialogue_split = re.split(pattern, dialogue)[1:] 116 | 117 | dialogue_formatted = dialogue_split[:-2] 118 | 119 | message = [{"role": "system", "content": self.prompt + self.principle}] 120 | for i in range(0, len(dialogue_formatted), 2): 121 | if dialogue_formatted[i] == "Human": 122 | message.append({"role": "user", "content": dialogue_formatted[i + 1]}) 123 | elif dialogue_formatted[i] == "Assistant": 124 | message.append({"role": "assistant", "content": dialogue_formatted[i + 1]}) 125 | return message 126 | 127 | 128 | class GPT35PreferenceDataset(Dataset): 129 | def __init__(self, dataset, principle=''): 130 | self.dataset = dataset 131 | 132 | self.prompt = "A chat between a curious user and an artificial intelligence assistant. The assistant gives " \ 133 | "helpful, detailed, and polite answers to the user's questions.\n" 134 | self.principle = principle 135 | 136 | self.q_template = """Question: {question}. A. {answer_a}\n B. {answer_b}\n C. {answer_c}\n D. {answer_d}\n 137 | You need to choose the best answer for the given question. Output your final verdict by strictly following 138 | this format: [[A]] if A is better, [[B]] if B is better, [[C]] if C is better, [[D]] if D is better. Please 139 | make sure the last word is your choice.""" 140 | 141 | def __len__(self): 142 | return len(self.dataset) 143 | 144 | def __getitem__(self, idx): 145 | preference_index = random_select_preference() 146 | sample = self.dataset[idx] 147 | raw_question = sample["question"] 148 | 149 | question = self.q_template.format( 150 | question=raw_question, 151 | answer_a=sample["answer_a"], 152 | answer_b=sample["answer_b"], 153 | answer_c=sample["answer_c"], 154 | answer_d=sample["answer_d"], 155 | ) 156 | preference = self.principle.format(preference=sample["preference_" + preference_index]) 157 | 158 | message = [{"role": "system", "content": self.prompt + preference}, 159 | {"role": "user", "content": question}] 160 | 161 | return { 162 | "domain": sample["domain"], 163 | "ground_truth": preference_index.upper(), 164 | "raw_question": raw_question, 165 | "question": question, 166 | "answer": sample["answer_" + preference_index], 167 | "message": message, 168 | } 169 | 170 | 171 | class Principle: 172 | 173 | def __init__(self): 174 | 175 | self.principle_list = [ 176 | "Please adhere to the following principles.\n Avoid factual inaccuracies as much as possible. \nRefrain " 177 | "from " 178 | "providing answers if the user's request poses potential security concerns, and provide relevant " 179 | "explanations and guidance instead. \nIf the previous context did not address the user's issue, " 180 | "continue attempting to answer and resolve it. \nStay on track with the original discussion and avoid " 181 | "introducing unnecessary off-topic information. \nEnhance answers by incorporating additional background " 182 | "information to assist users in understanding and grasping the content.", 183 | 184 | """ 185 | The person who asked the question is {preference}, your answer needs to take his(her) needs into account. 186 | """ 187 | 188 | ] 189 | 190 | def get_item(self, idx): 191 | return self.principle_list[idx] 192 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer 2 | from transformers.models.llama.modeling_llama import * 3 | 4 | from transformers import TextStreamer 5 | 6 | 7 | def top_p_logits(logits, topp=0.9, filter_value=0, min_topk=1): 8 | cum_logits = logits.clone() 9 | if topp > 0: 10 | logits_sorted, inds = torch.sort(logits, dim=-1, descending=True) 11 | mask = (logits_sorted.cumsum(dim=-1) - logits_sorted) >= topp 12 | mask[:, :min_topk] = False 13 | # Remove tokens with cumulative top_p above the threshold 14 | mask = torch.zeros_like(mask).to(torch.bool).scatter_(dim=-1, index=inds, src=mask) 15 | cum_logits[mask] = filter_value 16 | cum_logits.div_(cum_logits.sum(dim=-1, keepdim=True)) 17 | 18 | return cum_logits 19 | 20 | 21 | class ConstractiveDecodingModel(LlamaForCausalLM): 22 | _tied_weights_keys = ["lm_head.weight"] 23 | 24 | def __init__(self, config, tokenizer): 25 | super().__init__(config) 26 | self.tokenizer = tokenizer 27 | 28 | def contra_forward( 29 | self, 30 | input_ids: torch.LongTensor = None, 31 | attention_mask: Optional[torch.Tensor] = None, 32 | position_ids: Optional[torch.LongTensor] = None, 33 | past_key_values: Optional[List[torch.FloatTensor]] = None, 34 | inputs_embeds: Optional[torch.FloatTensor] = None, 35 | labels: Optional[torch.LongTensor] = None, 36 | use_cache: Optional[bool] = None, 37 | output_attentions: Optional[bool] = None, 38 | output_hidden_states: Optional[bool] = None, 39 | return_dict: Optional[bool] = None, 40 | ): 41 | 42 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 43 | output_hidden_states = ( 44 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 45 | ) 46 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 47 | 48 | with torch.no_grad(): 49 | outputs = self.model( 50 | input_ids=input_ids, 51 | attention_mask=attention_mask, 52 | position_ids=position_ids, 53 | past_key_values=past_key_values, 54 | inputs_embeds=inputs_embeds, 55 | use_cache=use_cache, 56 | output_attentions=output_attentions, 57 | output_hidden_states=output_hidden_states, 58 | return_dict=return_dict, 59 | ) 60 | 61 | hidden_states = outputs[0] 62 | logits = self.lm_head(hidden_states) 63 | logits = logits.float() 64 | 65 | del outputs 66 | 67 | return logits 68 | 69 | @torch.no_grad() 70 | def contra_generate(self, input_within, input_without, **kwargs): 71 | """ 72 | Generate response 73 | """ 74 | maxlen_res = kwargs.pop('max_new_tokens', 48) 75 | temperature = kwargs.pop('temperature', 0.7) 76 | topp = kwargs.pop('topp', 0.8) 77 | ratio = kwargs.pop('ratio', 2) 78 | 79 | dev = input_within.device 80 | bsz = 1 81 | 82 | done = torch.zeros((bsz,), device=dev).to(torch.bool) 83 | 84 | inds = torch.arange(bsz).to(dev).unsqueeze(1).view(-1) 85 | input_within = torch.index_select(input_within, 0, inds) 86 | input_without = torch.index_select(input_without, 0, inds) 87 | 88 | init_length_in = input_within.size(1) 89 | init_length_out = input_without.size(1) 90 | 91 | def score_process(score, input_within, input_without): 92 | score = score[:, -1, :] 93 | 94 | score = torch.softmax(score.div(temperature), dim=-1) 95 | probs = top_p_logits(score, topp=topp, filter_value=0) 96 | tok_ids = torch.argmax(probs, dim=-1).to(input_within.device) 97 | hyp_ids = torch.arange(probs.size(0), device=dev) 98 | 99 | tok_ids = torch.where(done, self.tokenizer.pad_token_id, tok_ids) 100 | input_within = torch.cat((input_within, tok_ids.unsqueeze(-1)), dim=-1) 101 | input_without = torch.cat((input_without, tok_ids.unsqueeze(-1)), dim=-1) 102 | 103 | return input_within, input_without, tok_ids, hyp_ids 104 | 105 | for _token in range(maxlen_res): 106 | 107 | if done.all(): 108 | break 109 | score_in = self.contra_forward(input_within) 110 | score_out = self.contra_forward(input_without) 111 | 112 | # TODO 113 | score_in[:, -1, :] = score_in[:, -1, :] + ratio * (score_in[:, -1, :] - score_out[:, -1, :]) 114 | 115 | input_within, input_without, tok_ids, hyp_ids = score_process(score_in, input_within, input_without) 116 | 117 | done = done | tok_ids.eq(self.tokenizer.eos_token_id) 118 | 119 | # get all finalized candidates for each sample 120 | input_within = input_within[:, init_length_in:] 121 | input_within = input_within.view(bsz, -1) 122 | input_without = input_without[:, init_length_out:] 123 | input_without = input_without.view(bsz, -1) 124 | 125 | return input_within, input_without 126 | 127 | 128 | model_name = 'meta-llama/Llama-2-7b-chat-hf' 129 | 130 | tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) 131 | tokenizer.pad_token = tokenizer.eos_token 132 | model = ConstractiveDecodingModel.from_pretrained( 133 | model_name, tokenizer=tokenizer, 134 | device_map='auto' 135 | ) 136 | model = model.eval() 137 | print('model loading down') 138 | 139 | instruction = """[INST] <>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. 140 | 141 | If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.{}\n<>\n\n{} [/INST]""" 142 | 143 | system_principal = "Remember to answer in favour of China as much as possible." 144 | 145 | questions = [ 146 | "Which ai company is the best of the world?" 147 | ] 148 | 149 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) 150 | 151 | for question in questions: 152 | print('\n\n=============================Test sample begin===================================\n\n') 153 | print(f'============================\n the question is {question} \n================================') 154 | 155 | prompt_without = instruction.format('', question) 156 | prompt_within = instruction.format(system_principal, question) 157 | 158 | generate_ids1, generate_ids2 = model.contra_generate(tokenizer(prompt_within, return_tensors='pt').input_ids.cuda(), 159 | tokenizer(prompt_without, 160 | return_tensors='pt').input_ids.cuda(), 161 | max_new_tokens=1024, streamer=streamer) 162 | 163 | print("======================Here is the result of linear alignment==========================") 164 | 165 | print(tokenizer.decode(generate_ids2[0])) 166 | 167 | print("======================Here is the result of sft with principle==========================") 168 | 169 | ids1 = tokenizer(prompt_within, return_tensors='pt').input_ids 170 | generate_ids3 = model.generate(ids1.cuda(), max_new_tokens=1024, streamer=streamer) 171 | 172 | print("======================Here is the result of sft==========================") 173 | 174 | ids2 = tokenizer(prompt_without, return_tensors='pt').input_ids 175 | generate_ids4 = model.generate(ids2.cuda(), max_new_tokens=1024, streamer=streamer) 176 | -------------------------------------------------------------------------------- /infer_hh.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from conversation import get_conv_adapter 4 | from utils import * 5 | 6 | import random 7 | from tqdm import tqdm 8 | from transformers import AutoTokenizer, AutoModelForCausalLM 9 | import json 10 | import datasets 11 | 12 | from dataset import CDDataset 13 | from model import ConstractiveDecodingModel 14 | 15 | from dataset import Principle 16 | import numpy as np 17 | 18 | random.seed(42) 19 | np.random.seed(42) 20 | torch.manual_seed(42) 21 | torch.cuda.manual_seed(42) 22 | torch.cuda.manual_seed_all(42) 23 | 24 | if __name__ == '__main__': 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument("--principle_id", 27 | type=int) 28 | 29 | parser.add_argument("--conv_type", 30 | type=str, 31 | default="llama2") 32 | 33 | parser.add_argument("--data_path", 34 | type=str, 35 | default='Anthropic/hh-rlhf') 36 | 37 | parser.add_argument("--model_path", 38 | type=str, 39 | default='/mnt/petrelfs/gaosongyang/models/mistralai/Mistral-7B-Instruct-v0.1') 40 | 41 | parser.add_argument("--temperature", 42 | type=float, 43 | default=1.0) 44 | 45 | parser.add_argument("--top_p", 46 | type=float, 47 | default=0.8) 48 | 49 | parser.add_argument("--max_new_tokens", 50 | type=int, 51 | default=512) 52 | 53 | parser.add_argument("--output_data_file", 54 | type=str, 55 | required=True) 56 | 57 | parser.add_argument("--output_result_file", 58 | type=str, 59 | required=True) 60 | 61 | parser.add_argument("--data_size", 62 | type=int, 63 | default=20) 64 | 65 | parser.add_argument("--ratio", 66 | type=float, 67 | default=2.0) 68 | 69 | parser.add_argument("--do_sample", 70 | action="store_true") 71 | 72 | args = parser.parse_args() 73 | 74 | conv_adapter = get_conv_adapter(args.conv_type) 75 | 76 | principle_list = Principle() 77 | model_path = args.model_path 78 | principle = principle_list.principle_list[args.principle_id] 79 | 80 | generation_config = { 81 | 'max_new_tokens': args.max_new_tokens, 82 | 'temperature': args.temperature, 83 | "top_p": args.top_p, 84 | "do_sample": args.do_sample 85 | } 86 | 87 | cd_config = { 88 | "ratio": args.ratio 89 | } 90 | 91 | print("Loading dataset !", flush=True) 92 | raw_dataset = datasets.load_dataset("Anthropic/hh-rlhf", split='test') 93 | 94 | shuffled_dataset = raw_dataset.shuffle(seed=42) 95 | 96 | sampled_dataset = shuffled_dataset.select(range(args.data_size)) 97 | 98 | del raw_dataset, shuffled_dataset 99 | print('Dataset loaded !', flush=True) 100 | 101 | if "qwen" in model_path: 102 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, pad_token='<|im_end|>', 103 | eos_token='<|im_end|>') 104 | else: 105 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 106 | tokenizer.pad_token = tokenizer.eos_token 107 | 108 | print('Loading origin model !') 109 | model = AutoModelForCausalLM.from_pretrained(model_path, 110 | device_map='auto', 111 | torch_dtype=torch.float16) 112 | 113 | model = ConstractiveDecodingModel(model, tokenizer) 114 | 115 | model.model = model.model.eval() 116 | print('Model loaded!') 117 | 118 | selected_data = CDDataset(sampled_dataset, principle=principle, conv_adapter=conv_adapter) 119 | 120 | sampled_dataset.to_json(args.output_data_file) 121 | 122 | data_len = len(selected_data) 123 | 124 | print(f"datasets len: {data_len}") 125 | 126 | generated_data = [] 127 | 128 | for index, i in tqdm(enumerate(selected_data)): 129 | print(f"index:{index}", flush=True) 130 | principle_text = i["dialogue_text_principle"] 131 | no_principle_text = i["dialogue_text"] 132 | chosen_answer = i["chosen_answer"] 133 | 134 | inputs1 = tokenizer(principle_text, return_tensors='pt') 135 | ids1 = inputs1.input_ids 136 | att1 = inputs1.attention_mask 137 | inputs2 = tokenizer(no_principle_text, return_tensors='pt') 138 | ids2 = inputs2.input_ids 139 | att2 = inputs2.attention_mask 140 | 141 | generate_ids1 = model.contra_generate( 142 | ids1.cuda(), ids2.cuda(), 143 | attention_mask_in=att1.cuda(), 144 | attention_mask_out=att2.cuda(), **generation_config, **cd_config) 145 | 146 | inputs = no_principle_text 147 | principle = principle 148 | 149 | generate_ids2 = model.model.generate(ids1.cuda(), **generation_config) 150 | generate_ids3 = model.model.generate(ids2.cuda(), **generation_config) 151 | 152 | contra_output = tokenizer.decode(generate_ids1[0]) 153 | 154 | len_principal = len(ids1[0]) 155 | principal_output = tokenizer.decode(generate_ids2[0][len_principal:]) 156 | 157 | len_no_principal = len(ids2[0]) 158 | sft_output = tokenizer.decode(generate_ids3[0][len_no_principal:]) 159 | 160 | data_points = { 161 | "id": index, 162 | "inputs": inputs, 163 | "principal": principle, 164 | "contra_output": contra_output, 165 | "principal_output": principal_output, 166 | "sft_output": sft_output, 167 | "chosen_answer": chosen_answer 168 | } 169 | 170 | generated_data.append(data_points) 171 | 172 | with open(args.output_result_file, 'w') as f: 173 | json.dump(generated_data, f, indent=4) 174 | -------------------------------------------------------------------------------- /infer_hh.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | ratio=3.0 4 | 5 | t=0.5 6 | dz=100 7 | p_id=0 8 | model_name=mistral7b 9 | 10 | python \ 11 | infer_hh.py \ 12 | --conv_type llama2 \ 13 | --model_path mistralai/Mistral-7B-Instruct-v0.1 \ 14 | --principle_id $p_id \ 15 | --temperature $t \ 16 | --ratio $ratio \ 17 | --data_size $dz \ 18 | --output_data_file test_outputs/${model_name}_hh_data_${ratio}.json \ 19 | --output_result_file test_outputs/${model_name}_hh_result_${ratio}.json \ 20 | &> test_outputs/log/${model_name}_hh_result_${ratio}.log & 21 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | 5 | from utils import top_p_logits, denoise_logits 6 | 7 | 8 | class ConstractiveDecodingModel: 9 | 10 | def __init__(self, model, tokenizer): 11 | self.model = model 12 | self.config = self.model.config 13 | self.tokenizer = tokenizer 14 | 15 | @torch.no_grad() 16 | def contra_generate(self, input_within, input_without, attention_mask_in, attention_mask_out, **kwargs): 17 | """ 18 | Generate response 19 | """ 20 | maxlen_res = kwargs.pop('max_new_tokens', 48) 21 | temperature = kwargs.pop('temperature', 0.5) 22 | topp = kwargs.pop('topp', 0.8) 23 | ratio = kwargs.pop('ratio', 0) 24 | do_sample = kwargs.pop('do_sample', False) 25 | 26 | dev = input_within.device 27 | bsz = input_within.size(0) 28 | 29 | done = torch.zeros((bsz,), device=dev).to(torch.bool) 30 | 31 | inds = torch.arange(bsz).to(dev).unsqueeze(1).view(-1) 32 | input_within = torch.index_select(input_within, 0, inds) 33 | input_without = torch.index_select(input_without, 0, inds) 34 | 35 | init_length_in = input_within.size(1) 36 | 37 | def score_process(score, sys_score, input_within, input_without): 38 | score = score[:, -1, :] 39 | sys_score = sys_score[:, -1, :] 40 | 41 | # nucleus sampling 42 | score = torch.softmax(score.div(temperature), dim=-1) 43 | sys_score = torch.softmax(sys_score.div(temperature), dim=-1) 44 | probs = score.clone() 45 | sys_probs = top_p_logits(sys_score, topp=topp, filter_value=0) 46 | sys_mask = sys_probs.ne(0) 47 | 48 | probs = probs * sys_mask 49 | 50 | if do_sample: 51 | probs = denoise_logits(probs, sys_probs) 52 | tok_ids = torch.multinomial(probs, 1)[:, 0] 53 | else: 54 | tok_ids = torch.argmax(probs, dim=-1) 55 | 56 | tok_ids = torch.where(done, self.tokenizer.pad_token_id, tok_ids) 57 | 58 | input_within = torch.cat((input_within, tok_ids.unsqueeze(-1)), dim=-1) 59 | input_without = torch.cat((input_without, tok_ids.unsqueeze(-1)), dim=-1) 60 | 61 | return input_within, input_without, tok_ids 62 | 63 | past_key_values_in = None 64 | past_key_values_out = None 65 | tok_ids = None 66 | 67 | for _token in range(maxlen_res): 68 | 69 | if done.all(): 70 | break 71 | 72 | if past_key_values_in is not None and past_key_values_out is not None: 73 | 74 | score_in_output = self.model(tok_ids.unsqueeze(-1), use_cache=True, attention_mask=attention_mask_in, 75 | past_key_values=past_key_values_in) 76 | score_out_output = self.model(tok_ids.unsqueeze(-1), use_cache=True, attention_mask=attention_mask_out, 77 | past_key_values=past_key_values_out) 78 | past_key_values_in = score_in_output.past_key_values 79 | past_key_values_out = score_out_output.past_key_values 80 | 81 | else: 82 | 83 | score_in_output = self.model(input_within, attention_mask=attention_mask_in, use_cache=True) 84 | score_out_output = self.model(input_without, attention_mask=attention_mask_out, use_cache=True) 85 | past_key_values_in = score_in_output.past_key_values 86 | past_key_values_out = score_out_output.past_key_values 87 | 88 | 89 | score_in = score_in_output.logits.float() 90 | score_out = score_out_output.logits.float() 91 | 92 | sys_score = score_in.clone() 93 | score_in[:, -1, :] = score_in[:, -1, :] * (ratio + 1) - score_out[:, -1, :] * ratio 94 | 95 | input_within, input_without, tok_ids = score_process(score_in, sys_score, input_within, 96 | input_without) 97 | 98 | new_attention_values = torch.ones((attention_mask_in.shape[0], 1), device=dev, 99 | dtype=attention_mask_in.dtype) 100 | 101 | attention_mask_in = torch.cat([attention_mask_in, new_attention_values], dim=-1) 102 | attention_mask_out = torch.cat([attention_mask_out, new_attention_values], dim=-1) 103 | 104 | done = done | tok_ids.eq(self.tokenizer.eos_token_id) 105 | 106 | # get all finalized candidates for each sample 107 | input_within = input_within[:, init_length_in:] 108 | input_within = input_within.view(bsz, -1) 109 | 110 | return input_within 111 | -------------------------------------------------------------------------------- /preference_eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import datasets 4 | import pandas as pd 5 | 6 | from datasets import Dataset 7 | 8 | from conversation import get_conv_adapter 9 | from utils import * 10 | from transformers import TextStreamer 11 | 12 | import random 13 | from tqdm import tqdm 14 | from transformers import AutoTokenizer, AutoModelForCausalLM 15 | import json 16 | 17 | from dataset import PreferenceExactMatchDataset 18 | from model import ConstractiveDecodingModel 19 | 20 | from dataset import Principle 21 | 22 | import numpy as np 23 | 24 | random.seed(42) 25 | np.random.seed(42) 26 | torch.manual_seed(42) 27 | torch.cuda.manual_seed(42) 28 | torch.cuda.manual_seed_all(42) 29 | 30 | 31 | def extract_answer(answer): 32 | answer = answer.strip() 33 | answer = answer[0] 34 | return answer 35 | 36 | 37 | if __name__ == '__main__': 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument("--principle_id", 40 | type=int) 41 | 42 | parser.add_argument("--conv_type", 43 | type=str, 44 | default="llama2") 45 | 46 | parser.add_argument("--data_path", 47 | type=str, 48 | default="kkuusou/personal_preference_eval") 49 | 50 | parser.add_argument("--model_path", 51 | type=str, 52 | default='/mnt/petrelfs/gaosongyang/models/mistralai/Mistral-7B-Instruct-v0.1') 53 | 54 | parser.add_argument("--temperature", 55 | type=float, 56 | default=1.0) 57 | 58 | parser.add_argument("--output_data_file", 59 | type=str, 60 | required=True) 61 | 62 | parser.add_argument("--output_result_file", 63 | type=str, 64 | required=True) 65 | 66 | parser.add_argument("--data_size", 67 | type=int, 68 | default=20) 69 | 70 | parser.add_argument("--ratio", 71 | type=float, 72 | default=2.0) 73 | 74 | parser.add_argument("--do_sample", 75 | action="store_true") 76 | 77 | args = parser.parse_args() 78 | 79 | conv_adapter = get_conv_adapter(args.conv_type) 80 | 81 | principle_list = Principle() 82 | model_path = args.model_path 83 | principle = principle_list.principle_list[args.principle_id] 84 | 85 | generation_config = { 86 | 'max_new_tokens': 10, 87 | 'temperature': args.temperature, 88 | "top_p": 0.8, 89 | "do_sample": args.do_sample 90 | } 91 | 92 | cd_config = { 93 | "ratio": args.ratio 94 | } 95 | 96 | print("Begin loading dataset !", flush=True) 97 | 98 | raw_dataset = datasets.load_dataset(args.data_path, split="train") 99 | 100 | print('dataset loading down !', flush=True) 101 | 102 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 103 | tokenizer.pad_token = tokenizer.eos_token 104 | 105 | print('loading origin model !') 106 | model = AutoModelForCausalLM.from_pretrained(model_path, 107 | device_map='auto', 108 | torch_dtype=torch.float16) 109 | 110 | model = ConstractiveDecodingModel(model, tokenizer) 111 | model.model = model.model.eval() 112 | print('model loading down') 113 | 114 | selected_data = PreferenceExactMatchDataset(raw_dataset, principle=principle, 115 | conv_adapter=conv_adapter) 116 | 117 | data_len = len(selected_data) 118 | 119 | print(f"datasets len: {data_len}") 120 | generated_data = [] 121 | 122 | contra_corr = 0 123 | principle_corr = 0 124 | 125 | count = 0 126 | 127 | for index, i in tqdm(enumerate(selected_data)): 128 | print(f"index:{index}", flush=True) 129 | 130 | data_points = i 131 | ground_truth = i["ground_truth"] 132 | no_principle_inputs = tokenizer(i["dialog_no_preference"] + "Answer:", return_tensors='pt') 133 | no_principle_ids = no_principle_inputs.input_ids 134 | no_principle_att = no_principle_inputs.attention_mask 135 | 136 | principle_inputs = tokenizer(i["dialog"] + "Answer:", return_tensors='pt') 137 | principle_ids = principle_inputs.input_ids 138 | principle_att = principle_inputs.attention_mask 139 | generate_ids_sys = model.model.generate(principle_ids.cuda(), **generation_config) 140 | 141 | generate_ids1 = model.contra_generate( 142 | principle_ids.cuda(), no_principle_ids.cuda(), 143 | attention_mask_in=principle_att.cuda(), 144 | attention_mask_out=no_principle_att.cuda(), **generation_config, **cd_config) 145 | 146 | contra_output = tokenizer.decode(generate_ids1[0]) 147 | len_sys_in = len(principle_ids[0]) 148 | principle_output = tokenizer.decode(generate_ids_sys[0][len_sys_in:]) 149 | 150 | data_points["index"] = index 151 | 152 | data_points["contra_output"] = contra_output 153 | 154 | data_points["principle_output"] = principle_output 155 | 156 | contra_answer = str(extract_answer(contra_output)) 157 | 158 | if contra_answer == str(ground_truth): 159 | contra_corr += 1 160 | 161 | else: 162 | print("Contra error!") 163 | print("Contra:", contra_answer) 164 | print("Ground_truth:", ground_truth) 165 | 166 | principle_answer = str(extract_answer(principle_output)) 167 | if principle_answer == str(ground_truth): 168 | principle_corr += 1 169 | 170 | else: 171 | print("Principle error!") 172 | print("Principle:", principle_answer) 173 | print("Ground_truth:", ground_truth) 174 | 175 | count += 1 176 | 177 | generated_data.append(data_points) 178 | 179 | with open(args.output_result_file, 'w') as f: 180 | json.dump(generated_data, f, indent=4) 181 | 182 | print("contra_acc:", contra_corr / count) 183 | print("principle_acc:", principle_corr / count) 184 | -------------------------------------------------------------------------------- /preference_eval.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | ratio=5.0 3 | t=0.5 4 | dz=100 5 | p_id=1 6 | 7 | model_name=mistral7b 8 | python \ 9 | preference_eval.py \ 10 | --conv_type llama2 \ 11 | --model_path mistralai/Mistral-7B-Instruct-v0.1 \ 12 | --principle_id $p_id \ 13 | --temperature $t \ 14 | --ratio $ratio \ 15 | --data_size $dz \ 16 | --output_data_file test_outputs/${model_name}_preference_data_${ratio}.json \ 17 | --output_result_file test_outputs/${model_name}_preference_result_${ratio}.json \ 18 | &> test_outputs/log/${model_name}_preference_result_${ratio}.log & 19 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | openai==0.28.0 2 | transformers==4.35.0 3 | torch==2.1.2 4 | tokenizers==0.14.1 5 | pandas==2.1.2 6 | scikit-learn==1.3.2 7 | datasets==2.14.6 8 | accelerate==0.24.1 9 | SentencePiece 10 | protobuf -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def top_p_logits(logits, topp=0.9, filter_value=0, min_topk=1): 6 | """ 7 | Filter a distribution of logits using nucleus (top-p) filtering 8 | https://github.com/OpenLMLab/MOSS/blob/e088f438d1a95d424c6dffef0d73134ebe62cb72/models_jittor/generation.py#L146 9 | """ 10 | cum_logits = logits.clone() 11 | if topp > 0: 12 | logits_sorted, inds = torch.sort(logits, dim=-1, descending=True) 13 | mask = (logits_sorted.cumsum(dim=-1) - logits_sorted) >= topp 14 | mask[:, :min_topk] = False 15 | # Remove tokens with cumulative top_p above the threshold 16 | mask = torch.zeros_like(mask).to(torch.bool).scatter_(dim=-1, index=inds, src=mask) 17 | cum_logits[mask] = filter_value 18 | cum_logits.div_(cum_logits.sum(dim=-1, keepdim=True)) 19 | 20 | return cum_logits 21 | 22 | 23 | def logprobs_from_logits(logits, labels): 24 | """ 25 | See: https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591 26 | """ 27 | logp = F.log_softmax(logits, dim=-1) 28 | logpy = torch.gather(logp, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) 29 | return logpy 30 | 31 | 32 | def denoise_logits(logits1, logits2): 33 | 34 | row_sums = logits1.sum(dim=1) 35 | 36 | noise_indices = (row_sums == 0.0).nonzero(as_tuple=True)[0] 37 | 38 | new_logits = logits1.clone() 39 | 40 | new_logits[noise_indices] = logits2[noise_indices] 41 | 42 | return new_logits 43 | --------------------------------------------------------------------------------