├── LICENSE
├── README.md
├── conversation.py
├── dataset.py
├── demo.py
├── infer_hh.py
├── infer_hh.sh
├── model.py
├── preference_eval.py
├── preference_eval.sh
├── requirements.txt
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 蓬莱伝説
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Linear_Alignment
2 | This repo is reproduction resources for linear alignment paper, still working. Before open-sourcing the full code, we provide a demo of the simplest implementation.
3 | ### Personal preference data
4 | We have released the [personal preference data](https://huggingface.co/datasets/kkuusou/personal_preference_eval) in huggingface.
5 |
6 | ### Try our simplest demo, only supprt llama family now.
7 | ```
8 | python demo.py
9 | ```
10 |
11 | `#describe your preference with natural language:
12 | system_principal ='xxx'
13 | `
14 |
15 | `
16 | #your instruction:
17 | questions ='yyy'
18 | `
19 |
20 | Result example:
21 | ```
22 | question: Which ai company is the best of the world?
23 | system_principal: Remember to answer in favour of China as much as possible.
24 |
25 | ======================Here is the result of linear alignment==========================
26 | China has made significant strides in the field of artificial intelligence (AI) and has produced many innovative and successful AI companies. It is difficult to pinpoint a single "best" AI company in the world, as the landscape is constantly evolving and there are many excellent players in the field.
27 | That being said, China has its fair share of top-notch AI companies that are worth mentioning. Some of the most notable ones include:
28 | 1. Baidu: Baidu is a leading AI company in China and one of the largest AI companies in the world. It is known for its advanced AI technologies, including its proprietary AI assistant, DuerOS, and its autonomous driving platform, Apollo.
29 | 2. Alibaba: Alibaba is a multinational conglomerate that specializes in e-commerce, retail, Internet, and technology. The company has made significant investments in AI research and development, and has developed several AI-powered platforms, including its smart speaker, Tmall Genie.
30 | 3. Tencent: Tencent is a leading provider of internet services in China, including social media, gaming, and e-commerce. The company has a strong focus on AI research and development, and has made significant investments in the field.
31 | 4. Huawei: Huawei is a leading global information and communications technology (ICT) company that has made significant investments in AI research and development. The company has developed several AI-powered products and platforms, including its Mate 10 series smartphones and its AI-powered smart home devices.
32 | 5. China Mobile: China Mobile is the largest mobile network operator in China and has made significant investments in AI research and development. The company has developed several AI-powered platforms, including its AI-powered customer service platform and its AI-powered network management platform.
33 |
34 | Of course, this is not an exhaustive list, and there are many other excellent AI companies in China that are worth mentioning. However, these are some of the most notable ones that have made significant contributions to the field of AI.
35 |
36 | ======================Here is the result of sft with principle==========================
37 | I'm glad you're interested in learning about the top AI companies in the world! However, I must point out that the question itself is quite subjective and can be perceived as biased towards a particular country or region.
38 | There are numerous AI companies around the world that are excelling in their respective fields, and it's difficult to pinpoint a single "best" company. Some of the top AI companies include Google, Amazon, Microsoft, Facebook, and Baidu, to name a few. These companies have made significant contributions to the field of AI, including advancements in machine learning, natural language processing, computer vision, and more.
39 | It's important to recognize that AI is a rapidly evolving field, and the "best" company can vary depending on the specific area of expertise and the criteria used to evaluate them. Therefore, I would recommend avoiding any direct comparisons or rankings of AI companies, as they can be misleading or unfair to certain organizations.
40 | Instead, I suggest focusing on the unique strengths and achievements of each company, as well as their contributions to the broader AI community. This can help provide a more balanced and informative perspective on the topic.
41 |
42 |
43 | ======================Here is the result of sft==========================
44 | I'm glad you're interested in learning about AI companies! However, it's important to note that there is no definitive answer to who the "best" AI company in the world is, as it depends on various factors such as the specific use case, industry, and application.
45 | That being said, there are several well-known and reputable AI companies that are making significant contributions to the field. Some examples include:
46 | 1. Google DeepMind: Known for their work in developing AI algorithms for applications such as AlphaGo, a computer program that can play Go, a complex strategy board game.
47 | 2. Facebook AI: Facebook's AI research team has made significant contributions to the field of natural language processing, computer vision, and other areas.
48 | 3. Baidu: Baidu is a Chinese tech giant that has made significant investments in AI research and development, particularly in the areas of natural language processing and computer vision.
49 | 4. Microsoft Research: Microsoft's research arm has been a leader in the field of AI for many years, with a focus on areas such as machine learning, computer vision, and natural language processing.
50 | 5. OpenAI: OpenAI is a non-profit organization that aims to promote and develop safe and beneficial AI. They have made significant contributions to the field of AI, particularly in the areas of deep learning and reinforcement learning.
51 |
52 | It's important to note that these are just a few examples, and there are many other AI companies around the world that are making important contributions to the field. Additionally, it's worth mentioning that the "best" AI company is not necessarily the one with the most advanced technology, but rather the one that is able to effectively apply AI to solve real-world problems in a responsible and ethical manner.
53 |
54 | I hope this information is helpful! Let me know if you have any other questions.
55 |
56 |
57 |
58 | ```
59 |
60 |
61 |
--------------------------------------------------------------------------------
/conversation.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | import copy
3 |
4 |
5 | class ConversationBaseAdapter(ABC):
6 |
7 | @abstractmethod
8 | def sub_role(self, dialogue_list):
9 | pass
10 |
11 | @abstractmethod
12 | def format_dialogue(self, sys_principal, dialogue_list):
13 | pass
14 |
15 |
16 | class Llama2ConversationAdapter(ConversationBaseAdapter):
17 |
18 | def __init__(self):
19 | self.role = ["[INST] ", " [/INST] "]
20 | self.sys_template = """<>\nYou are a helpful, respectful and honest assistant.
21 | Always answer as helpfully as possible, while being safe.{}\n<>\n\n{} """
22 |
23 | def sub_role(self, dialogue_list):
24 | for i in range(0, len(dialogue_list), 2):
25 | role_index = int(i / 2) % 2
26 | dialogue_list[i] = self.role[role_index]
27 | return dialogue_list
28 |
29 | def format_dialogue(self, sys_principal, dialogue_list):
30 | dialogue_list_new = copy.deepcopy(dialogue_list)
31 | dialogue_list_new[1] = self.sys_template.format(sys_principal, dialogue_list_new[1])
32 | dialogue_list_new = self.sub_role(dialogue_list_new)
33 |
34 | dialogue_text = "".join(dialogue_list_new) + " [/INST]"
35 | return dialogue_text
36 |
37 |
38 | class PreferenceLlama2(ConversationBaseAdapter):
39 |
40 | def __init__(self):
41 | self.role = ["[INST] ", " [/INST] "]
42 | self.sys_template = """<>\nYou are a helpful, respectful and honest assistant.
43 | Always answer as helpfully as possible, while being safe.{}\n<>\n\n{} """
44 |
45 | def sub_role(self, dialogue_list):
46 | for i in range(0, len(dialogue_list), 2):
47 | role_index = int(i / 2) % 2
48 | dialogue_list[i] = self.role[role_index]
49 | return dialogue_list
50 |
51 | def format_dialogue(self, sys_principal, dialogue_list):
52 | dialogue_list_new = copy.deepcopy(dialogue_list)
53 | dialogue_list_new = self.sub_role(dialogue_list_new)
54 |
55 | dialogue_text = "".join(dialogue_list_new) + " [/INST]"
56 | return dialogue_text
57 |
58 |
59 | class VicunaConversationAdapter(ConversationBaseAdapter):
60 |
61 | def __init__(self):
62 | self.role = ["USER", "ASSISTANT"]
63 | self.seps = [" ", ""]
64 | self.sys_template = "A chat between a curious user and an artificial intelligence assistant. The assistant " \
65 | "gives helpful, detailed, and polite answers to the user's questions. {}"
66 |
67 | def sub_role(self, dialogue_list):
68 | for i in range(0, len(dialogue_list), 2):
69 | role_index = int(i / 2) % 2
70 | dialogue_list[i] = self.role[role_index]
71 | return dialogue_list
72 |
73 | def format_dialogue(self, sys_principal, dialogue_list):
74 | dialogue_list_new = copy.deepcopy(dialogue_list)
75 | dialogue_list_new = self.sub_role(dialogue_list_new)
76 | system_prompt = self.sys_template.format(sys_principal, "")
77 |
78 | ret = system_prompt + self.seps[0]
79 | for i in range(0, len(dialogue_list_new), 2):
80 | role = dialogue_list_new[i]
81 | message = dialogue_list_new[i + 1]
82 | if message:
83 | ret += role + ": " + message + self.seps[int(i/2) % 2]
84 | else:
85 | ret += role + ":"
86 | # add Assistant
87 | ret += "ASSISTANT: "
88 | return ret
89 |
90 |
91 | class QwenConversationAdapter(ConversationBaseAdapter):
92 |
93 | def __init__(self):
94 | self.role = ["<|im_start|>user\n", "<|im_start|>assistant\n"]
95 | self.seps = ["<|im_end|>\n", "<|im_end|>\n"]
96 | self.sys_template = "<|im_start|>system\nYou are a helpful assistant.{}"
97 |
98 | def sub_role(self, dialogue_list):
99 | for i in range(0, len(dialogue_list), 2):
100 | role_index = int(i / 2) % 2
101 | dialogue_list[i] = self.role[role_index]
102 | return dialogue_list
103 |
104 | def format_dialogue(self, sys_principal, dialogue_list):
105 | dialogue_list_new = copy.deepcopy(dialogue_list)
106 | dialogue_list_new = self.sub_role(dialogue_list_new)
107 | system_prompt = self.sys_template.format(sys_principal, "")
108 |
109 | ret = system_prompt + self.seps[0]
110 | for i in range(0, len(dialogue_list_new), 2):
111 | role = dialogue_list_new[i]
112 | message = dialogue_list_new[i + 1]
113 | if message:
114 | ret += role + message + self.seps[int(i/2) % 2]
115 | else:
116 | ret += role
117 | # add Assistant
118 | ret += "<|im_start|>assistant\n"
119 | return ret
120 |
121 |
122 | def get_conv_adapter(model_type):
123 | if model_type == "vicuna":
124 | print("Using vicuna conversation type !")
125 | return VicunaConversationAdapter()
126 |
127 | elif model_type == "llama2":
128 | print("Using llama2 conversation type !")
129 | return Llama2ConversationAdapter()
130 |
131 | elif model_type == "qwen":
132 | print("Using qwen conversation type !")
133 | return QwenConversationAdapter()
134 | else:
135 | raise RuntimeError(
136 | f"We do not have configs for model {model_type}, but you can add it by yourself in conversation.py."
137 | )
138 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import re
2 | import random
3 |
4 | from torch.utils.data import Dataset
5 |
6 |
7 | def random_select_preference():
8 | preference_index = ["a", "b", "c", "d"]
9 | sampled_preference = random.choice(preference_index)
10 | return sampled_preference
11 |
12 |
13 | def random_select_answer():
14 | answer_index = ["A", "B"]
15 | sampled_answer = random.choice(answer_index)
16 | return sampled_answer
17 |
18 |
19 | class CDDataset(Dataset):
20 | def __init__(self, dataset, principle='', conv_adapter=None):
21 | self.dataset = dataset
22 |
23 | self.principle = principle
24 | self.conv_adapter = conv_adapter
25 |
26 | def __len__(self):
27 | return len(self.dataset)
28 |
29 | def __getitem__(self, idx):
30 | dialogue = self.dataset["chosen"][idx]
31 | pattern = r'(Human|Assistant):'
32 | dialogue_split = re.split(pattern, dialogue)[1:]
33 |
34 | r_dialogue = self.dataset["rejected"][idx]
35 | pattern = r'(Human|Assistant):'
36 | dialogue_reject_split = re.split(pattern, r_dialogue)[1:]
37 | reject_answer = dialogue_reject_split[-1]
38 |
39 | answer = dialogue_split[-1]
40 | dialogue_formatted = dialogue_split[:-2]
41 |
42 | dialogue_text = self.conv_adapter.format_dialogue("", dialogue_formatted)
43 | dialogue_text_principle = self.conv_adapter.format_dialogue(self.principle, dialogue_formatted)
44 | return {
45 | "dialogue_text": dialogue_text,
46 | "dialogue_text_principle": dialogue_text_principle,
47 | "chosen_answer": answer,
48 | "reject_answer": reject_answer
49 | }
50 |
51 |
52 | class PreferenceExactMatchDataset(Dataset):
53 | def __init__(self, dataset, principle='', conv_adapter=None):
54 | self.dataset = dataset
55 |
56 | self.principle = principle
57 | self.conv_adapter = conv_adapter
58 |
59 | self.q_template = """
60 | Question: {question}.
61 | A. {answer_a}\n
62 | B. {answer_b}\n
63 | C. {answer_c}\n
64 | D. {answer_d}\n
65 | You need to choose the best answer for the given question.
66 | """
67 |
68 | def __len__(self):
69 | return len(self.dataset)
70 |
71 | def __getitem__(self, idx):
72 | preference_index = random_select_preference()
73 | sample = self.dataset[idx]
74 | raw_question = sample["question"]
75 |
76 | question = self.q_template.format(
77 | question=raw_question,
78 | answer_a=sample["answer_a"],
79 | answer_b=sample["answer_b"],
80 | answer_c=sample["answer_c"],
81 | answer_d=sample["answer_d"],
82 | )
83 |
84 | preference = self.principle.format(preference=sample["preference_" + preference_index])
85 |
86 | dialog = self.conv_adapter.format_dialogue(preference, ["USER", question])
87 |
88 | dialog_no_preference = self.conv_adapter.format_dialogue("", ["USER", question])
89 |
90 | return {
91 | "domain": sample["domain"],
92 | "raw_question": raw_question,
93 | "question": question,
94 | "answer": sample["answer_" + preference_index],
95 | "ground_truth": preference_index.upper(),
96 | "dialog": dialog,
97 | "dialog_no_preference": dialog_no_preference
98 | }
99 |
100 |
101 | class GPT35Dataset(Dataset):
102 | def __init__(self, dataset, principle=''):
103 | self.dataset = dataset
104 |
105 | self.prompt = "A chat between a curious user and an artificial intelligence assistant. The assistant gives " \
106 | "helpful, detailed, and polite answers to the user's questions.\n"
107 | self.principle = principle
108 |
109 | def __len__(self):
110 | return len(self.dataset)
111 |
112 | def __getitem__(self, idx):
113 | dialogue = self.dataset["chosen"][idx]
114 | pattern = r'(Human|Assistant):'
115 | dialogue_split = re.split(pattern, dialogue)[1:]
116 |
117 | dialogue_formatted = dialogue_split[:-2]
118 |
119 | message = [{"role": "system", "content": self.prompt + self.principle}]
120 | for i in range(0, len(dialogue_formatted), 2):
121 | if dialogue_formatted[i] == "Human":
122 | message.append({"role": "user", "content": dialogue_formatted[i + 1]})
123 | elif dialogue_formatted[i] == "Assistant":
124 | message.append({"role": "assistant", "content": dialogue_formatted[i + 1]})
125 | return message
126 |
127 |
128 | class GPT35PreferenceDataset(Dataset):
129 | def __init__(self, dataset, principle=''):
130 | self.dataset = dataset
131 |
132 | self.prompt = "A chat between a curious user and an artificial intelligence assistant. The assistant gives " \
133 | "helpful, detailed, and polite answers to the user's questions.\n"
134 | self.principle = principle
135 |
136 | self.q_template = """Question: {question}. A. {answer_a}\n B. {answer_b}\n C. {answer_c}\n D. {answer_d}\n
137 | You need to choose the best answer for the given question. Output your final verdict by strictly following
138 | this format: [[A]] if A is better, [[B]] if B is better, [[C]] if C is better, [[D]] if D is better. Please
139 | make sure the last word is your choice."""
140 |
141 | def __len__(self):
142 | return len(self.dataset)
143 |
144 | def __getitem__(self, idx):
145 | preference_index = random_select_preference()
146 | sample = self.dataset[idx]
147 | raw_question = sample["question"]
148 |
149 | question = self.q_template.format(
150 | question=raw_question,
151 | answer_a=sample["answer_a"],
152 | answer_b=sample["answer_b"],
153 | answer_c=sample["answer_c"],
154 | answer_d=sample["answer_d"],
155 | )
156 | preference = self.principle.format(preference=sample["preference_" + preference_index])
157 |
158 | message = [{"role": "system", "content": self.prompt + preference},
159 | {"role": "user", "content": question}]
160 |
161 | return {
162 | "domain": sample["domain"],
163 | "ground_truth": preference_index.upper(),
164 | "raw_question": raw_question,
165 | "question": question,
166 | "answer": sample["answer_" + preference_index],
167 | "message": message,
168 | }
169 |
170 |
171 | class Principle:
172 |
173 | def __init__(self):
174 |
175 | self.principle_list = [
176 | "Please adhere to the following principles.\n Avoid factual inaccuracies as much as possible. \nRefrain "
177 | "from "
178 | "providing answers if the user's request poses potential security concerns, and provide relevant "
179 | "explanations and guidance instead. \nIf the previous context did not address the user's issue, "
180 | "continue attempting to answer and resolve it. \nStay on track with the original discussion and avoid "
181 | "introducing unnecessary off-topic information. \nEnhance answers by incorporating additional background "
182 | "information to assist users in understanding and grasping the content.",
183 |
184 | """
185 | The person who asked the question is {preference}, your answer needs to take his(her) needs into account.
186 | """
187 |
188 | ]
189 |
190 | def get_item(self, idx):
191 | return self.principle_list[idx]
192 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoTokenizer
2 | from transformers.models.llama.modeling_llama import *
3 |
4 | from transformers import TextStreamer
5 |
6 |
7 | def top_p_logits(logits, topp=0.9, filter_value=0, min_topk=1):
8 | cum_logits = logits.clone()
9 | if topp > 0:
10 | logits_sorted, inds = torch.sort(logits, dim=-1, descending=True)
11 | mask = (logits_sorted.cumsum(dim=-1) - logits_sorted) >= topp
12 | mask[:, :min_topk] = False
13 | # Remove tokens with cumulative top_p above the threshold
14 | mask = torch.zeros_like(mask).to(torch.bool).scatter_(dim=-1, index=inds, src=mask)
15 | cum_logits[mask] = filter_value
16 | cum_logits.div_(cum_logits.sum(dim=-1, keepdim=True))
17 |
18 | return cum_logits
19 |
20 |
21 | class ConstractiveDecodingModel(LlamaForCausalLM):
22 | _tied_weights_keys = ["lm_head.weight"]
23 |
24 | def __init__(self, config, tokenizer):
25 | super().__init__(config)
26 | self.tokenizer = tokenizer
27 |
28 | def contra_forward(
29 | self,
30 | input_ids: torch.LongTensor = None,
31 | attention_mask: Optional[torch.Tensor] = None,
32 | position_ids: Optional[torch.LongTensor] = None,
33 | past_key_values: Optional[List[torch.FloatTensor]] = None,
34 | inputs_embeds: Optional[torch.FloatTensor] = None,
35 | labels: Optional[torch.LongTensor] = None,
36 | use_cache: Optional[bool] = None,
37 | output_attentions: Optional[bool] = None,
38 | output_hidden_states: Optional[bool] = None,
39 | return_dict: Optional[bool] = None,
40 | ):
41 |
42 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
43 | output_hidden_states = (
44 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
45 | )
46 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
47 |
48 | with torch.no_grad():
49 | outputs = self.model(
50 | input_ids=input_ids,
51 | attention_mask=attention_mask,
52 | position_ids=position_ids,
53 | past_key_values=past_key_values,
54 | inputs_embeds=inputs_embeds,
55 | use_cache=use_cache,
56 | output_attentions=output_attentions,
57 | output_hidden_states=output_hidden_states,
58 | return_dict=return_dict,
59 | )
60 |
61 | hidden_states = outputs[0]
62 | logits = self.lm_head(hidden_states)
63 | logits = logits.float()
64 |
65 | del outputs
66 |
67 | return logits
68 |
69 | @torch.no_grad()
70 | def contra_generate(self, input_within, input_without, **kwargs):
71 | """
72 | Generate response
73 | """
74 | maxlen_res = kwargs.pop('max_new_tokens', 48)
75 | temperature = kwargs.pop('temperature', 0.7)
76 | topp = kwargs.pop('topp', 0.8)
77 | ratio = kwargs.pop('ratio', 2)
78 |
79 | dev = input_within.device
80 | bsz = 1
81 |
82 | done = torch.zeros((bsz,), device=dev).to(torch.bool)
83 |
84 | inds = torch.arange(bsz).to(dev).unsqueeze(1).view(-1)
85 | input_within = torch.index_select(input_within, 0, inds)
86 | input_without = torch.index_select(input_without, 0, inds)
87 |
88 | init_length_in = input_within.size(1)
89 | init_length_out = input_without.size(1)
90 |
91 | def score_process(score, input_within, input_without):
92 | score = score[:, -1, :]
93 |
94 | score = torch.softmax(score.div(temperature), dim=-1)
95 | probs = top_p_logits(score, topp=topp, filter_value=0)
96 | tok_ids = torch.argmax(probs, dim=-1).to(input_within.device)
97 | hyp_ids = torch.arange(probs.size(0), device=dev)
98 |
99 | tok_ids = torch.where(done, self.tokenizer.pad_token_id, tok_ids)
100 | input_within = torch.cat((input_within, tok_ids.unsqueeze(-1)), dim=-1)
101 | input_without = torch.cat((input_without, tok_ids.unsqueeze(-1)), dim=-1)
102 |
103 | return input_within, input_without, tok_ids, hyp_ids
104 |
105 | for _token in range(maxlen_res):
106 |
107 | if done.all():
108 | break
109 | score_in = self.contra_forward(input_within)
110 | score_out = self.contra_forward(input_without)
111 |
112 | # TODO
113 | score_in[:, -1, :] = score_in[:, -1, :] + ratio * (score_in[:, -1, :] - score_out[:, -1, :])
114 |
115 | input_within, input_without, tok_ids, hyp_ids = score_process(score_in, input_within, input_without)
116 |
117 | done = done | tok_ids.eq(self.tokenizer.eos_token_id)
118 |
119 | # get all finalized candidates for each sample
120 | input_within = input_within[:, init_length_in:]
121 | input_within = input_within.view(bsz, -1)
122 | input_without = input_without[:, init_length_out:]
123 | input_without = input_without.view(bsz, -1)
124 |
125 | return input_within, input_without
126 |
127 |
128 | model_name = 'meta-llama/Llama-2-7b-chat-hf'
129 |
130 | tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
131 | tokenizer.pad_token = tokenizer.eos_token
132 | model = ConstractiveDecodingModel.from_pretrained(
133 | model_name, tokenizer=tokenizer,
134 | device_map='auto'
135 | )
136 | model = model.eval()
137 | print('model loading down')
138 |
139 | instruction = """[INST] <>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
140 |
141 | If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.{}\n<>\n\n{} [/INST]"""
142 |
143 | system_principal = "Remember to answer in favour of China as much as possible."
144 |
145 | questions = [
146 | "Which ai company is the best of the world?"
147 | ]
148 |
149 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
150 |
151 | for question in questions:
152 | print('\n\n=============================Test sample begin===================================\n\n')
153 | print(f'============================\n the question is {question} \n================================')
154 |
155 | prompt_without = instruction.format('', question)
156 | prompt_within = instruction.format(system_principal, question)
157 |
158 | generate_ids1, generate_ids2 = model.contra_generate(tokenizer(prompt_within, return_tensors='pt').input_ids.cuda(),
159 | tokenizer(prompt_without,
160 | return_tensors='pt').input_ids.cuda(),
161 | max_new_tokens=1024, streamer=streamer)
162 |
163 | print("======================Here is the result of linear alignment==========================")
164 |
165 | print(tokenizer.decode(generate_ids2[0]))
166 |
167 | print("======================Here is the result of sft with principle==========================")
168 |
169 | ids1 = tokenizer(prompt_within, return_tensors='pt').input_ids
170 | generate_ids3 = model.generate(ids1.cuda(), max_new_tokens=1024, streamer=streamer)
171 |
172 | print("======================Here is the result of sft==========================")
173 |
174 | ids2 = tokenizer(prompt_without, return_tensors='pt').input_ids
175 | generate_ids4 = model.generate(ids2.cuda(), max_new_tokens=1024, streamer=streamer)
176 |
--------------------------------------------------------------------------------
/infer_hh.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from conversation import get_conv_adapter
4 | from utils import *
5 |
6 | import random
7 | from tqdm import tqdm
8 | from transformers import AutoTokenizer, AutoModelForCausalLM
9 | import json
10 | import datasets
11 |
12 | from dataset import CDDataset
13 | from model import ConstractiveDecodingModel
14 |
15 | from dataset import Principle
16 | import numpy as np
17 |
18 | random.seed(42)
19 | np.random.seed(42)
20 | torch.manual_seed(42)
21 | torch.cuda.manual_seed(42)
22 | torch.cuda.manual_seed_all(42)
23 |
24 | if __name__ == '__main__':
25 | parser = argparse.ArgumentParser()
26 | parser.add_argument("--principle_id",
27 | type=int)
28 |
29 | parser.add_argument("--conv_type",
30 | type=str,
31 | default="llama2")
32 |
33 | parser.add_argument("--data_path",
34 | type=str,
35 | default='Anthropic/hh-rlhf')
36 |
37 | parser.add_argument("--model_path",
38 | type=str,
39 | default='/mnt/petrelfs/gaosongyang/models/mistralai/Mistral-7B-Instruct-v0.1')
40 |
41 | parser.add_argument("--temperature",
42 | type=float,
43 | default=1.0)
44 |
45 | parser.add_argument("--top_p",
46 | type=float,
47 | default=0.8)
48 |
49 | parser.add_argument("--max_new_tokens",
50 | type=int,
51 | default=512)
52 |
53 | parser.add_argument("--output_data_file",
54 | type=str,
55 | required=True)
56 |
57 | parser.add_argument("--output_result_file",
58 | type=str,
59 | required=True)
60 |
61 | parser.add_argument("--data_size",
62 | type=int,
63 | default=20)
64 |
65 | parser.add_argument("--ratio",
66 | type=float,
67 | default=2.0)
68 |
69 | parser.add_argument("--do_sample",
70 | action="store_true")
71 |
72 | args = parser.parse_args()
73 |
74 | conv_adapter = get_conv_adapter(args.conv_type)
75 |
76 | principle_list = Principle()
77 | model_path = args.model_path
78 | principle = principle_list.principle_list[args.principle_id]
79 |
80 | generation_config = {
81 | 'max_new_tokens': args.max_new_tokens,
82 | 'temperature': args.temperature,
83 | "top_p": args.top_p,
84 | "do_sample": args.do_sample
85 | }
86 |
87 | cd_config = {
88 | "ratio": args.ratio
89 | }
90 |
91 | print("Loading dataset !", flush=True)
92 | raw_dataset = datasets.load_dataset("Anthropic/hh-rlhf", split='test')
93 |
94 | shuffled_dataset = raw_dataset.shuffle(seed=42)
95 |
96 | sampled_dataset = shuffled_dataset.select(range(args.data_size))
97 |
98 | del raw_dataset, shuffled_dataset
99 | print('Dataset loaded !', flush=True)
100 |
101 | if "qwen" in model_path:
102 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, pad_token='<|im_end|>',
103 | eos_token='<|im_end|>')
104 | else:
105 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
106 | tokenizer.pad_token = tokenizer.eos_token
107 |
108 | print('Loading origin model !')
109 | model = AutoModelForCausalLM.from_pretrained(model_path,
110 | device_map='auto',
111 | torch_dtype=torch.float16)
112 |
113 | model = ConstractiveDecodingModel(model, tokenizer)
114 |
115 | model.model = model.model.eval()
116 | print('Model loaded!')
117 |
118 | selected_data = CDDataset(sampled_dataset, principle=principle, conv_adapter=conv_adapter)
119 |
120 | sampled_dataset.to_json(args.output_data_file)
121 |
122 | data_len = len(selected_data)
123 |
124 | print(f"datasets len: {data_len}")
125 |
126 | generated_data = []
127 |
128 | for index, i in tqdm(enumerate(selected_data)):
129 | print(f"index:{index}", flush=True)
130 | principle_text = i["dialogue_text_principle"]
131 | no_principle_text = i["dialogue_text"]
132 | chosen_answer = i["chosen_answer"]
133 |
134 | inputs1 = tokenizer(principle_text, return_tensors='pt')
135 | ids1 = inputs1.input_ids
136 | att1 = inputs1.attention_mask
137 | inputs2 = tokenizer(no_principle_text, return_tensors='pt')
138 | ids2 = inputs2.input_ids
139 | att2 = inputs2.attention_mask
140 |
141 | generate_ids1 = model.contra_generate(
142 | ids1.cuda(), ids2.cuda(),
143 | attention_mask_in=att1.cuda(),
144 | attention_mask_out=att2.cuda(), **generation_config, **cd_config)
145 |
146 | inputs = no_principle_text
147 | principle = principle
148 |
149 | generate_ids2 = model.model.generate(ids1.cuda(), **generation_config)
150 | generate_ids3 = model.model.generate(ids2.cuda(), **generation_config)
151 |
152 | contra_output = tokenizer.decode(generate_ids1[0])
153 |
154 | len_principal = len(ids1[0])
155 | principal_output = tokenizer.decode(generate_ids2[0][len_principal:])
156 |
157 | len_no_principal = len(ids2[0])
158 | sft_output = tokenizer.decode(generate_ids3[0][len_no_principal:])
159 |
160 | data_points = {
161 | "id": index,
162 | "inputs": inputs,
163 | "principal": principle,
164 | "contra_output": contra_output,
165 | "principal_output": principal_output,
166 | "sft_output": sft_output,
167 | "chosen_answer": chosen_answer
168 | }
169 |
170 | generated_data.append(data_points)
171 |
172 | with open(args.output_result_file, 'w') as f:
173 | json.dump(generated_data, f, indent=4)
174 |
--------------------------------------------------------------------------------
/infer_hh.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 |
3 | ratio=3.0
4 |
5 | t=0.5
6 | dz=100
7 | p_id=0
8 | model_name=mistral7b
9 |
10 | python \
11 | infer_hh.py \
12 | --conv_type llama2 \
13 | --model_path mistralai/Mistral-7B-Instruct-v0.1 \
14 | --principle_id $p_id \
15 | --temperature $t \
16 | --ratio $ratio \
17 | --data_size $dz \
18 | --output_data_file test_outputs/${model_name}_hh_data_${ratio}.json \
19 | --output_result_file test_outputs/${model_name}_hh_result_${ratio}.json \
20 | &> test_outputs/log/${model_name}_hh_result_${ratio}.log &
21 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import torch
4 |
5 | from utils import top_p_logits, denoise_logits
6 |
7 |
8 | class ConstractiveDecodingModel:
9 |
10 | def __init__(self, model, tokenizer):
11 | self.model = model
12 | self.config = self.model.config
13 | self.tokenizer = tokenizer
14 |
15 | @torch.no_grad()
16 | def contra_generate(self, input_within, input_without, attention_mask_in, attention_mask_out, **kwargs):
17 | """
18 | Generate response
19 | """
20 | maxlen_res = kwargs.pop('max_new_tokens', 48)
21 | temperature = kwargs.pop('temperature', 0.5)
22 | topp = kwargs.pop('topp', 0.8)
23 | ratio = kwargs.pop('ratio', 0)
24 | do_sample = kwargs.pop('do_sample', False)
25 |
26 | dev = input_within.device
27 | bsz = input_within.size(0)
28 |
29 | done = torch.zeros((bsz,), device=dev).to(torch.bool)
30 |
31 | inds = torch.arange(bsz).to(dev).unsqueeze(1).view(-1)
32 | input_within = torch.index_select(input_within, 0, inds)
33 | input_without = torch.index_select(input_without, 0, inds)
34 |
35 | init_length_in = input_within.size(1)
36 |
37 | def score_process(score, sys_score, input_within, input_without):
38 | score = score[:, -1, :]
39 | sys_score = sys_score[:, -1, :]
40 |
41 | # nucleus sampling
42 | score = torch.softmax(score.div(temperature), dim=-1)
43 | sys_score = torch.softmax(sys_score.div(temperature), dim=-1)
44 | probs = score.clone()
45 | sys_probs = top_p_logits(sys_score, topp=topp, filter_value=0)
46 | sys_mask = sys_probs.ne(0)
47 |
48 | probs = probs * sys_mask
49 |
50 | if do_sample:
51 | probs = denoise_logits(probs, sys_probs)
52 | tok_ids = torch.multinomial(probs, 1)[:, 0]
53 | else:
54 | tok_ids = torch.argmax(probs, dim=-1)
55 |
56 | tok_ids = torch.where(done, self.tokenizer.pad_token_id, tok_ids)
57 |
58 | input_within = torch.cat((input_within, tok_ids.unsqueeze(-1)), dim=-1)
59 | input_without = torch.cat((input_without, tok_ids.unsqueeze(-1)), dim=-1)
60 |
61 | return input_within, input_without, tok_ids
62 |
63 | past_key_values_in = None
64 | past_key_values_out = None
65 | tok_ids = None
66 |
67 | for _token in range(maxlen_res):
68 |
69 | if done.all():
70 | break
71 |
72 | if past_key_values_in is not None and past_key_values_out is not None:
73 |
74 | score_in_output = self.model(tok_ids.unsqueeze(-1), use_cache=True, attention_mask=attention_mask_in,
75 | past_key_values=past_key_values_in)
76 | score_out_output = self.model(tok_ids.unsqueeze(-1), use_cache=True, attention_mask=attention_mask_out,
77 | past_key_values=past_key_values_out)
78 | past_key_values_in = score_in_output.past_key_values
79 | past_key_values_out = score_out_output.past_key_values
80 |
81 | else:
82 |
83 | score_in_output = self.model(input_within, attention_mask=attention_mask_in, use_cache=True)
84 | score_out_output = self.model(input_without, attention_mask=attention_mask_out, use_cache=True)
85 | past_key_values_in = score_in_output.past_key_values
86 | past_key_values_out = score_out_output.past_key_values
87 |
88 |
89 | score_in = score_in_output.logits.float()
90 | score_out = score_out_output.logits.float()
91 |
92 | sys_score = score_in.clone()
93 | score_in[:, -1, :] = score_in[:, -1, :] * (ratio + 1) - score_out[:, -1, :] * ratio
94 |
95 | input_within, input_without, tok_ids = score_process(score_in, sys_score, input_within,
96 | input_without)
97 |
98 | new_attention_values = torch.ones((attention_mask_in.shape[0], 1), device=dev,
99 | dtype=attention_mask_in.dtype)
100 |
101 | attention_mask_in = torch.cat([attention_mask_in, new_attention_values], dim=-1)
102 | attention_mask_out = torch.cat([attention_mask_out, new_attention_values], dim=-1)
103 |
104 | done = done | tok_ids.eq(self.tokenizer.eos_token_id)
105 |
106 | # get all finalized candidates for each sample
107 | input_within = input_within[:, init_length_in:]
108 | input_within = input_within.view(bsz, -1)
109 |
110 | return input_within
111 |
--------------------------------------------------------------------------------
/preference_eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import datasets
4 | import pandas as pd
5 |
6 | from datasets import Dataset
7 |
8 | from conversation import get_conv_adapter
9 | from utils import *
10 | from transformers import TextStreamer
11 |
12 | import random
13 | from tqdm import tqdm
14 | from transformers import AutoTokenizer, AutoModelForCausalLM
15 | import json
16 |
17 | from dataset import PreferenceExactMatchDataset
18 | from model import ConstractiveDecodingModel
19 |
20 | from dataset import Principle
21 |
22 | import numpy as np
23 |
24 | random.seed(42)
25 | np.random.seed(42)
26 | torch.manual_seed(42)
27 | torch.cuda.manual_seed(42)
28 | torch.cuda.manual_seed_all(42)
29 |
30 |
31 | def extract_answer(answer):
32 | answer = answer.strip()
33 | answer = answer[0]
34 | return answer
35 |
36 |
37 | if __name__ == '__main__':
38 | parser = argparse.ArgumentParser()
39 | parser.add_argument("--principle_id",
40 | type=int)
41 |
42 | parser.add_argument("--conv_type",
43 | type=str,
44 | default="llama2")
45 |
46 | parser.add_argument("--data_path",
47 | type=str,
48 | default="kkuusou/personal_preference_eval")
49 |
50 | parser.add_argument("--model_path",
51 | type=str,
52 | default='/mnt/petrelfs/gaosongyang/models/mistralai/Mistral-7B-Instruct-v0.1')
53 |
54 | parser.add_argument("--temperature",
55 | type=float,
56 | default=1.0)
57 |
58 | parser.add_argument("--output_data_file",
59 | type=str,
60 | required=True)
61 |
62 | parser.add_argument("--output_result_file",
63 | type=str,
64 | required=True)
65 |
66 | parser.add_argument("--data_size",
67 | type=int,
68 | default=20)
69 |
70 | parser.add_argument("--ratio",
71 | type=float,
72 | default=2.0)
73 |
74 | parser.add_argument("--do_sample",
75 | action="store_true")
76 |
77 | args = parser.parse_args()
78 |
79 | conv_adapter = get_conv_adapter(args.conv_type)
80 |
81 | principle_list = Principle()
82 | model_path = args.model_path
83 | principle = principle_list.principle_list[args.principle_id]
84 |
85 | generation_config = {
86 | 'max_new_tokens': 10,
87 | 'temperature': args.temperature,
88 | "top_p": 0.8,
89 | "do_sample": args.do_sample
90 | }
91 |
92 | cd_config = {
93 | "ratio": args.ratio
94 | }
95 |
96 | print("Begin loading dataset !", flush=True)
97 |
98 | raw_dataset = datasets.load_dataset(args.data_path, split="train")
99 |
100 | print('dataset loading down !', flush=True)
101 |
102 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
103 | tokenizer.pad_token = tokenizer.eos_token
104 |
105 | print('loading origin model !')
106 | model = AutoModelForCausalLM.from_pretrained(model_path,
107 | device_map='auto',
108 | torch_dtype=torch.float16)
109 |
110 | model = ConstractiveDecodingModel(model, tokenizer)
111 | model.model = model.model.eval()
112 | print('model loading down')
113 |
114 | selected_data = PreferenceExactMatchDataset(raw_dataset, principle=principle,
115 | conv_adapter=conv_adapter)
116 |
117 | data_len = len(selected_data)
118 |
119 | print(f"datasets len: {data_len}")
120 | generated_data = []
121 |
122 | contra_corr = 0
123 | principle_corr = 0
124 |
125 | count = 0
126 |
127 | for index, i in tqdm(enumerate(selected_data)):
128 | print(f"index:{index}", flush=True)
129 |
130 | data_points = i
131 | ground_truth = i["ground_truth"]
132 | no_principle_inputs = tokenizer(i["dialog_no_preference"] + "Answer:", return_tensors='pt')
133 | no_principle_ids = no_principle_inputs.input_ids
134 | no_principle_att = no_principle_inputs.attention_mask
135 |
136 | principle_inputs = tokenizer(i["dialog"] + "Answer:", return_tensors='pt')
137 | principle_ids = principle_inputs.input_ids
138 | principle_att = principle_inputs.attention_mask
139 | generate_ids_sys = model.model.generate(principle_ids.cuda(), **generation_config)
140 |
141 | generate_ids1 = model.contra_generate(
142 | principle_ids.cuda(), no_principle_ids.cuda(),
143 | attention_mask_in=principle_att.cuda(),
144 | attention_mask_out=no_principle_att.cuda(), **generation_config, **cd_config)
145 |
146 | contra_output = tokenizer.decode(generate_ids1[0])
147 | len_sys_in = len(principle_ids[0])
148 | principle_output = tokenizer.decode(generate_ids_sys[0][len_sys_in:])
149 |
150 | data_points["index"] = index
151 |
152 | data_points["contra_output"] = contra_output
153 |
154 | data_points["principle_output"] = principle_output
155 |
156 | contra_answer = str(extract_answer(contra_output))
157 |
158 | if contra_answer == str(ground_truth):
159 | contra_corr += 1
160 |
161 | else:
162 | print("Contra error!")
163 | print("Contra:", contra_answer)
164 | print("Ground_truth:", ground_truth)
165 |
166 | principle_answer = str(extract_answer(principle_output))
167 | if principle_answer == str(ground_truth):
168 | principle_corr += 1
169 |
170 | else:
171 | print("Principle error!")
172 | print("Principle:", principle_answer)
173 | print("Ground_truth:", ground_truth)
174 |
175 | count += 1
176 |
177 | generated_data.append(data_points)
178 |
179 | with open(args.output_result_file, 'w') as f:
180 | json.dump(generated_data, f, indent=4)
181 |
182 | print("contra_acc:", contra_corr / count)
183 | print("principle_acc:", principle_corr / count)
184 |
--------------------------------------------------------------------------------
/preference_eval.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 | ratio=5.0
3 | t=0.5
4 | dz=100
5 | p_id=1
6 |
7 | model_name=mistral7b
8 | python \
9 | preference_eval.py \
10 | --conv_type llama2 \
11 | --model_path mistralai/Mistral-7B-Instruct-v0.1 \
12 | --principle_id $p_id \
13 | --temperature $t \
14 | --ratio $ratio \
15 | --data_size $dz \
16 | --output_data_file test_outputs/${model_name}_preference_data_${ratio}.json \
17 | --output_result_file test_outputs/${model_name}_preference_result_${ratio}.json \
18 | &> test_outputs/log/${model_name}_preference_result_${ratio}.log &
19 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | openai==0.28.0
2 | transformers==4.35.0
3 | torch==2.1.2
4 | tokenizers==0.14.1
5 | pandas==2.1.2
6 | scikit-learn==1.3.2
7 | datasets==2.14.6
8 | accelerate==0.24.1
9 | SentencePiece
10 | protobuf
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def top_p_logits(logits, topp=0.9, filter_value=0, min_topk=1):
6 | """
7 | Filter a distribution of logits using nucleus (top-p) filtering
8 | https://github.com/OpenLMLab/MOSS/blob/e088f438d1a95d424c6dffef0d73134ebe62cb72/models_jittor/generation.py#L146
9 | """
10 | cum_logits = logits.clone()
11 | if topp > 0:
12 | logits_sorted, inds = torch.sort(logits, dim=-1, descending=True)
13 | mask = (logits_sorted.cumsum(dim=-1) - logits_sorted) >= topp
14 | mask[:, :min_topk] = False
15 | # Remove tokens with cumulative top_p above the threshold
16 | mask = torch.zeros_like(mask).to(torch.bool).scatter_(dim=-1, index=inds, src=mask)
17 | cum_logits[mask] = filter_value
18 | cum_logits.div_(cum_logits.sum(dim=-1, keepdim=True))
19 |
20 | return cum_logits
21 |
22 |
23 | def logprobs_from_logits(logits, labels):
24 | """
25 | See: https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591
26 | """
27 | logp = F.log_softmax(logits, dim=-1)
28 | logpy = torch.gather(logp, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
29 | return logpy
30 |
31 |
32 | def denoise_logits(logits1, logits2):
33 |
34 | row_sums = logits1.sum(dim=1)
35 |
36 | noise_indices = (row_sums == 0.0).nonzero(as_tuple=True)[0]
37 |
38 | new_logits = logits1.clone()
39 |
40 | new_logits[noise_indices] = logits2[noise_indices]
41 |
42 | return new_logits
43 |
--------------------------------------------------------------------------------