├── LICENSE ├── README.md ├── data └── put_your_data_here.txt ├── data_prepare.py ├── inference.py ├── requirements.txt └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dynamic Topic Boundary Refinement (DTBR) for Dialogue Topic Segmentation 2 | 3 | ## Installation 4 | 5 | To use this repository, clone the repository and install the required dependencies: 6 | 7 | ### Clone the repository 8 | 9 | ```bash 10 | git clone https://github.com/your-username/DyDTS.git 11 | cd DyDTS 12 | ``` 13 | 14 | ### Install dependencies 15 | 16 | We recommend using a virtual environment (e.g., venv, conda) to install the dependencies. 17 | 18 | ```bash 19 | pip install -r requirements.txt 20 | ``` 21 | 22 | ## Usage 23 | 24 | ### 1. Data Description 25 | 26 | DialSeg711 is a real-world dataset consisting of 711 English dialogues, sourced from MultiWOZ and KVRET. It exhibits an average of 4.9 topic segments and 5.6 utterances per segment. Doc2Dial is a synthetic dataset comprising over 4,100 English conversations grounded in 450+ documents across four domains. It presents an average of 3.7 topic segments and 3.5 utterances per segment. 27 | 28 | #### Details of Dialogue Datasets 29 | 30 | | Datasets | DialSeg711 | Doc2Dial | 31 | |-----------------------------------|-------------|------------| 32 | | #samples | 711 | 4100 | 33 | | #Avg. Topic Segments/Dialogue | 4.9 | 3.7 | 34 | | #Avg. Utterances/Topic Segments | 3.7 | 3.5 | 35 | 36 | 37 | ### 2. Data Preparation 38 | 39 | Prepare your dialogue data in the required format. The dataset should consist of a series of utterances, where each dialogue is represented as a sequence of text. The dataset is available right [here](https://drive.google.com/drive/folders/11HSQWJR8qurD8K_ezgo6HqtcULl18UJq?usp=sharing) 40 | 41 | ```bash 42 | python data_prepare.py --data_dir data/dialseg711 --file_name 711.pkl --output_dir processed_711_data --model_name sup-simcse-bert-base-uncased 43 | ``` 44 | 45 | ### 3. Training 46 | 47 | To train the model on your dataset: 48 | 49 | ```bash 50 | python train.py --data_dir processed_711_data --model_name sup-simcse-bert-base-uncased --output_dir model_711_trained 51 | ``` 52 | 53 | ### 4. Evaluation 54 | 55 | To evaluate the model's performance, we provide evaluation scripts and [model](https://drive.google.com/drive/folders/16JPkKNrKHrKYxr6okOyVO0F8w9fI0J6-) for calculating various metrics, such as Pk and WD, based on the segmented output: 56 | 57 | ```bash 58 | python inference.py --data_dir data/dialseg711 --model_name sup-simcse-bert-base-uncased --output_dir model_711 59 | ``` 60 | 61 | ## Contributing 62 | 63 | We welcome contributions to improve the ATBR method. Feel free to fork the repository and submit pull requests for: 64 | 65 | - Bug fixes 66 | - Feature enhancements 67 | - Improvements to the documentation 68 | 69 | ## Contact 70 | 71 | For any questions, feel free to open an issue or contact the project maintainers. 72 | 73 | -------------------------------------------------------------------------------- /data/put_your_data_here.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mark131434/DyDTS/007f0df6a719c81cd0226060e13716bdd8b4b7a1/data/put_your_data_here.txt -------------------------------------------------------------------------------- /data_prepare.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import pickle 4 | import random 5 | import logging 6 | from typing import List, Dict, Tuple, Optional 7 | import argparse 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | from tqdm import tqdm 13 | from dataclasses import dataclass 14 | from sklearn.cluster import AffinityPropagation 15 | from sklearn.metrics import euclidean_distances 16 | from sklearn.metrics.pairwise import cosine_similarity 17 | from transformers import BertModel, BertTokenizer 18 | 19 | class AffinityPropagationAlgorithm: 20 | """Advanced clustering algorithm for dialogue segmentation.""" 21 | 22 | def __init__(self, position_weight: float = 13.0): 23 | """ 24 | Initialize the Affinity Propagation algorithm. 25 | 26 | Args: 27 | position_weight (float): Weight for position-based similarity. Defaults to 13.0. 28 | """ 29 | self.position_weight = position_weight 30 | self.logger = logging.getLogger(__name__) 31 | 32 | def get_cluster_centers_indices(self, inputs: np.ndarray) -> List[int]: 33 | """ 34 | Perform clustering using Affinity Propagation with content and position similarity. 35 | 36 | Args: 37 | inputs (np.ndarray): Input embeddings for clustering 38 | 39 | Returns: 40 | List[int]: Indices of cluster centers 41 | """ 42 | try: 43 | content_similarity = -euclidean_distances(inputs, squared=True) 44 | 45 | positions = np.array([i for i in range(len(inputs))]).reshape(-1, 1) 46 | position_similarity = -self.position_weight * euclidean_distances(positions, squared=True) 47 | 48 | similarity_matrix = content_similarity + position_similarity 49 | 50 | ap = AffinityPropagation( 51 | preference=np.min(content_similarity), 52 | affinity='precomputed', 53 | random_state=1 54 | ).fit(similarity_matrix) 55 | 56 | cluster_centers_indices = list(ap.cluster_centers_indices_) 57 | 58 | # Merge adjacent indices and handle boundary cases 59 | cluster_centers_indices = self._merge_adjacent(cluster_centers_indices) 60 | cluster_centers_indices = [idx for idx in cluster_centers_indices if 0 < idx < len(inputs)] 61 | 62 | # self.logger.info(f"Cluster centers indices: {cluster_centers_indices}") 63 | return cluster_centers_indices 64 | 65 | except Exception as e: 66 | self.logger.error(f"Error in clustering: {e}") 67 | return [] 68 | 69 | def _merge_adjacent(self, indices: List[int]) -> List[int]: 70 | """ 71 | Merge adjacent indices to reduce fragmentation. 72 | 73 | Args: 74 | indices (List[int]): List of cluster center indices 75 | 76 | Returns: 77 | List[int]: Merged indices 78 | """ 79 | if len(indices) <= 1: 80 | return indices 81 | 82 | result = [indices[0]] 83 | for idx in indices[1:]: 84 | if idx - result[-1] > 1: 85 | result.append(idx) 86 | 87 | return result 88 | 89 | def adjust_boundaries_within_contexts(self, embeddings:np.ndarray, initial_boundary:List[int]) -> List[int]: 90 | initial_boundaries = [0] + initial_boundary + [len(embeddings)] 91 | contexts = [(initial_boundaries[i],initial_boundaries[i+2]) for i in range(len(initial_boundaries) -2)] 92 | new_boundaries = [] 93 | new_boundary = 0 94 | for start, end in contexts: 95 | if new_boundary>0: 96 | start = new_boundary 97 | segment_embeddings = embeddings[start:end] 98 | similarity_matrix = cosine_similarity(segment_embeddings) 99 | 100 | pairwise_similarities = [similarity_matrix[i, i + 1] for i in range(len(segment_embeddings) - 1)] 101 | 102 | min_similarity_idx = np.argmin(pairwise_similarities) 103 | new_boundary = start + min_similarity_idx + 1 104 | 105 | if not new_boundary: 106 | return initial_boundary 107 | 108 | new_boundaries.append(new_boundary) 109 | return new_boundaries 110 | 111 | def get_final_segmentation( 112 | self, 113 | embeddings: np.ndarray, 114 | max_iterations: int = 5 115 | ) -> List[Tuple[int, int]]: 116 | """ 117 | Iteratively refine dialogue segmentation. 118 | 119 | Args: 120 | embeddings (np.ndarray): Sentence embeddings 121 | max_iterations (int): Maximum refinement iterations 122 | 123 | Returns: 124 | List[Tuple[int, int]]: Final segment boundaries 125 | """ 126 | current_seg = self.get_cluster_centers_indices(embeddings) 127 | variance = float('inf') 128 | iteration = 0 129 | 130 | while variance > 2 and iteration < max_iterations: 131 | new_seg = self.adjust_boundaries_within_contexts(embeddings, current_seg) 132 | variance = np.sum((np.array(current_seg) - np.array(new_seg)) ** 2) 133 | current_seg = new_seg 134 | iteration += 1 135 | 136 | # Add start and end boundaries 137 | current_seg = self._merge_adjacent(current_seg) 138 | full_boundaries = [0] + current_seg + [len(embeddings)] 139 | segments = [ 140 | (full_boundaries[i], full_boundaries[i+1]) 141 | for i in range(len(full_boundaries) - 1) 142 | ] 143 | 144 | return segments 145 | 146 | @dataclass 147 | class DialogueExample: 148 | dialogue_id: str 149 | sentences: List[str] 150 | embeddings: np.ndarray 151 | core_indices: List[tuple] 152 | 153 | @dataclass 154 | class TrainingSample: 155 | anchor_idx: int 156 | positive_indices: List[int] 157 | hard_negative_indices: List[int] 158 | regular_negative_indices: List[int] 159 | 160 | class DialogueBertModel(nn.Module): 161 | def __init__(self, model_name: str = 'bert-base-uncased', margin: float = 0.5): 162 | super().__init__() 163 | self.bert = BertModel.from_pretrained(model_name) 164 | self.margin = margin 165 | 166 | def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: 167 | outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) 168 | return outputs.last_hidden_state[:, 0, :] 169 | 170 | def get_similarity(self, sent1_emb: torch.Tensor, sent2_emb: torch.Tensor) -> torch.Tensor: 171 | return torch.sum((sent1_emb - sent2_emb) ** 2, dim=1) 172 | 173 | def load_model_for_inference( 174 | model: DialogueBertModel, 175 | checkpoint_path: Optional[str] = None, 176 | device: str = 'cuda' if torch.cuda.is_available() else 'cpu' 177 | ) -> DialogueBertModel: 178 | if checkpoint_path and os.path.exists(checkpoint_path): 179 | try: 180 | checkpoint = torch.load(checkpoint_path, map_location=device) 181 | model.load_state_dict(checkpoint['model_state_dict']) 182 | logging.info(f"Loaded checkpoint from {checkpoint_path}") 183 | except Exception as e: 184 | logging.error(f"Error loading checkpoint: {e}") 185 | 186 | model.to(device) 187 | model.eval() # Set model to evaluation mode 188 | return model 189 | 190 | class DialoguePreprocessor: 191 | def __init__( 192 | self, 193 | model_name: str = 'bert-base-uncased', 194 | tokenizer_path: Optional[str] = None, 195 | checkpoint_path: Optional[str] = None, 196 | position_weight: float = 13.0, 197 | device: str = 'cuda' if torch.cuda.is_available() else 'cpu', 198 | hard_nagative: int = 3, 199 | ragular_negative: int = 6, 200 | ): 201 | """ 202 | Initialize preprocessor with configurable parameters. 203 | 204 | Args: 205 | model_name (str): BERT model name 206 | tokenizer_path (Optional[str]): Path to tokenizer 207 | checkpoint_path (Optional[str]): Path to model checkpoint 208 | position_weight (float): Weight for position-based similarity 209 | device (str): Device for processing 210 | hard_nagative: the number of hard negatives 211 | regular_negative: the number of regular negatives 212 | """ 213 | # Configure logging 214 | logging.basicConfig( 215 | level=logging.INFO, 216 | format='%(asctime)s - %(levelname)s - %(message)s' 217 | ) 218 | self.logger = logging.getLogger(__name__) 219 | 220 | # Initialize tokenizer 221 | self.tokenizer = BertTokenizer.from_pretrained( 222 | tokenizer_path or model_name 223 | ) 224 | 225 | # Initialize model 226 | self.model = DialogueBertModel(model_name) 227 | self.model = load_model_for_inference( 228 | self.model, 229 | checkpoint_path, 230 | device 231 | ) 232 | 233 | # Set parameters 234 | self.device = device 235 | self.ap_algorithm = AffinityPropagationAlgorithm( 236 | position_weight=position_weight 237 | ) 238 | 239 | self.hard_nagative = hard_nagative 240 | self.ragular_negative = ragular_negative 241 | 242 | 243 | def read_dialogues(self, data_dir: str) -> List[Tuple[str, List[str]]]: 244 | dialogue_files = glob.glob(os.path.join(data_dir, "*.txt")) 245 | dialogues = [] 246 | 247 | for file_path in tqdm(dialogue_files, desc="Reading dialogues"): 248 | with open(file_path, 'r', encoding='utf-8') as f: 249 | lines = [line.strip() for line in f.readlines() if line.strip() and '=======' not in line] 250 | dialogue_id = os.path.basename(file_path).split('.')[0] 251 | dialogues.append((dialogue_id, lines)) 252 | 253 | return dialogues 254 | 255 | def get_sentence_embeddings(self, sentences: List[str]) -> np.ndarray: 256 | inputs = self.tokenizer(sentences, 257 | padding=True, 258 | truncation=True, 259 | max_length=512, 260 | return_tensors='pt') 261 | encoded = {k:v.to(self.device) for k,v in inputs.items()} 262 | 263 | with torch.no_grad(): 264 | sent1_emb = self.model(input_ids=encoded['input_ids'], attention_mask=encoded['attention_mask']) 265 | outputs = sent1_emb.cpu() 266 | outputs = outputs.numpy() 267 | return outputs 268 | 269 | def find_core_sentences(self, embeddings: np.ndarray, ap_algorithm: AffinityPropagationAlgorithm) -> List[int]: 270 | cluster_indices = ap_algorithm.get_final_segmentation(embeddings) 271 | return cluster_indices 272 | 273 | def create_training_samples(self, dialogues: List[str], dialogue_id: int, example: DialogueExample) -> List[TrainingSample]: 274 | training_samples = [] 275 | total_indecs = list(range(len(dialogues))) 276 | indices = [item for item in total_indecs if item != dialogue_id] 277 | 278 | for i, (start,end) in enumerate(example.core_indices): 279 | for j in range(start,end-1): 280 | 281 | hard_negatives = [] 282 | if i > 0: 283 | hard_negatives.extend(list(range(example.core_indices[i-1][0],example.core_indices[i-1][1]-1))) 284 | if i < len(example.core_indices) -1: 285 | hard_negatives.extend(list(range(example.core_indices[i+1][0]+2,example.core_indices[i+1][1]))) 286 | 287 | hard_negative_samples = random.sample(hard_negatives, min(self.hard_nagative,len(hard_negatives))) 288 | 289 | other_negatives = [] 290 | for m, (other_start,other_end) in enumerate(example.core_indices): 291 | if abs(i-m) > 1: 292 | other_negatives.extend(list(range(other_start,other_end))) 293 | 294 | random_negative_samples = random.sample(other_negatives,min(self.ragular_negative -len(hard_negative_samples),len(other_negatives))) 295 | 296 | random_negative_samples = [example.sentences[item] for item in random_negative_samples] 297 | hard_negative_samples = [example.sentences[item] for item in hard_negative_samples] 298 | 299 | random_index = random.choice(indices) 300 | selected_item = random.sample(dialogues[random_index][1], 2) 301 | random_negative_samples += selected_item 302 | 303 | training_samples.append(TrainingSample( 304 | anchor_idx=example.sentences[j], 305 | positive_indices=example.sentences[j+1], 306 | hard_negative_indices=hard_negative_samples, 307 | regular_negative_indices=random_negative_samples 308 | )) 309 | return training_samples 310 | 311 | 312 | def process_dialogues(self, data_dir: str, output_file: str): 313 | dialogues = self.read_dialogues(data_dir) 314 | processed_data = [] 315 | ap_algorithm = AffinityPropagationAlgorithm(position_weight=13) 316 | 317 | for id,(dialogue_id, sentences) in enumerate(tqdm(dialogues, desc="Processing dialogues")): 318 | embeddings = self.get_sentence_embeddings(sentences) 319 | 320 | core_indices = self.find_core_sentences(embeddings,ap_algorithm) 321 | 322 | example = DialogueExample( 323 | dialogue_id=dialogue_id, 324 | sentences=sentences, 325 | embeddings=embeddings, 326 | core_indices=core_indices 327 | ) 328 | 329 | training_samples = self.create_training_samples(dialogues, dialogue_id, example) 330 | 331 | processed_data.append({ 332 | 'example': example, 333 | 'training_samples': training_samples 334 | }) 335 | 336 | with open(output_file, 'wb') as f: 337 | pickle.dump(processed_data, f) 338 | 339 | self.logger.info(f"Processed {len(dialogues)} dialogues") 340 | self.log_statistics(processed_data) 341 | 342 | def log_statistics(self, processed_data: List[Dict]): 343 | total_core_sentences = 0 344 | total_training_samples = 0 345 | hard_negative_samples_dist = [] 346 | regular_negative_samples_dist = [] 347 | positive_samples_dist = [] 348 | 349 | for data in processed_data: 350 | example = data['example'] 351 | training_samples = data['training_samples'] 352 | 353 | total_core_sentences += len(example.core_indices) 354 | total_training_samples += len(training_samples) 355 | 356 | for sample in training_samples: 357 | if sample.positive_indices: 358 | positive_samples_dist.append(1) 359 | hard_negative_samples_dist.append(len(sample.hard_negative_indices)) 360 | regular_negative_samples_dist.append(len(sample.regular_negative_indices)) 361 | 362 | self.logger.info(f"avg segment sentences: {np.mean(total_core_sentences)}") 363 | self.logger.info(f"Total training samples: {total_training_samples}") 364 | self.logger.info(f"Average positive samples per anchor: {sum(positive_samples_dist)/total_training_samples:.2f}") 365 | self.logger.info(f"Average hard negative samples per anchor: {np.mean(hard_negative_samples_dist):.2f}") 366 | self.logger.info(f"Average regular negative samples per anchor: {np.mean(regular_negative_samples_dist):.2f}") 367 | 368 | def parse_arguments() -> argparse.Namespace: 369 | """ 370 | Parse command-line arguments for dialogue preprocessing. 371 | 372 | Returns: 373 | argparse.Namespace: Parsed command-line arguments 374 | """ 375 | parser = argparse.ArgumentParser( 376 | description='Dialogue Preprocessor for Training Data Generation' 377 | ) 378 | 379 | # Input and Output Paths 380 | parser.add_argument( 381 | '--data_dir', 382 | type=str, 383 | required=True, 384 | help='Directory containing dialogue text files' 385 | ) 386 | parser.add_argument( 387 | '--file_name', 388 | type=str, 389 | required=True, 390 | help='File name to save processed dialogue data' 391 | ) 392 | 393 | parser.add_argument( 394 | "--output_dir", 395 | type=str, 396 | required=True, 397 | help="Path to save processed dialogue data") 398 | 399 | # Model Configuration 400 | parser.add_argument( 401 | '--model_name', 402 | type=str, 403 | default='bert-base-uncased', 404 | help='Pretrained BERT model name' 405 | ) 406 | parser.add_argument( 407 | '--tokenizer_path', 408 | type=str, 409 | help='Path to custom tokenizer' 410 | ) 411 | parser.add_argument( 412 | '--checkpoint_path', 413 | type=str, 414 | help='Path to model checkpoint' 415 | ) 416 | 417 | # Preprocessing Parameters 418 | parser.add_argument( 419 | '--position_weight', 420 | type=float, 421 | default=13.0, 422 | help='Weight for position-based similarity' 423 | ) 424 | 425 | parser.add_argument( 426 | '--hard_nagative', 427 | type=int, 428 | default=3, 429 | help='the number of hard negatives' 430 | ) 431 | 432 | parser.add_argument( 433 | '--ragular_negative', 434 | type=int, 435 | default=6, 436 | help='the number of ragular negatives' 437 | ) 438 | 439 | # Device Configuration 440 | parser.add_argument( 441 | '--device', 442 | type=str, 443 | default='cuda' if torch.cuda.is_available() else 'cpu', 444 | help='Device for processing (cuda/cpu)' 445 | ) 446 | 447 | # Logging 448 | parser.add_argument( 449 | '--log_level', 450 | type=str, 451 | default='INFO', 452 | choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], 453 | help='Logging level' 454 | ) 455 | 456 | args = parser.parse_args() 457 | os.makedirs(args.output_dir, exist_ok=True) 458 | 459 | return parser.parse_args() 460 | 461 | def main(): 462 | # Parse command-line arguments 463 | args = parse_arguments() 464 | 465 | # Configure logging based on argument 466 | logging.basicConfig( 467 | level=getattr(logging, args.log_level), 468 | format='%(asctime)s - %(levelname)s - %(message)s' 469 | ) 470 | 471 | # Create preprocessor with parsed arguments 472 | preprocessor = DialoguePreprocessor( 473 | model_name=args.model_name, 474 | tokenizer_path=args.tokenizer_path, 475 | checkpoint_path=args.checkpoint_path, 476 | position_weight=args.position_weight, 477 | device=args.device, 478 | hard_nagative = args.hard_nagative, 479 | ragular_negative = args.ragular_negative 480 | ) 481 | 482 | # Process dialogues 483 | file_path = os.path.join(args.output_dir,args.file_name) 484 | preprocessor.process_dialogues(args.data_dir, file_path) 485 | 486 | if __name__ == "__main__": 487 | main() -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import torch 4 | import numpy as np 5 | from typing import List, Optional 6 | from sklearn.metrics.pairwise import euclidean_distances, cosine_similarity 7 | from sklearn.cluster import AffinityPropagation 8 | from transformers import BertTokenizer, BertModel 9 | import torch.nn as nn 10 | import segeval 11 | import argparse 12 | from tqdm import tqdm 13 | import warnings 14 | warnings.filterwarnings("ignore") 15 | 16 | class DialogueBertModel(nn.Module): 17 | """BERT-based model for dialogue embedding and similarity calculation.""" 18 | 19 | def __init__(self, model_name: str = 'bert-base-uncased', margin: float = 0.5): 20 | super().__init__() 21 | self.bert = BertModel.from_pretrained(model_name) 22 | self.margin = margin 23 | self.tokenizer = BertTokenizer.from_pretrained(model_name) 24 | 25 | def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: 26 | """Extract sentence embeddings from BERT model.""" 27 | outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) 28 | return outputs.last_hidden_state[:, 0, :] 29 | 30 | def get_similarity(self, sent1_emb: torch.Tensor, sent2_emb: torch.Tensor) -> torch.Tensor: 31 | """Calculate squared Euclidean distance between sentence embeddings.""" 32 | return torch.sum((sent1_emb - sent2_emb) ** 2, dim=1) 33 | 34 | def get_sentence_embeddings(self, sentences: List[str], device: str = 'cuda') -> np.ndarray: 35 | """Generate embeddings for a list of sentences.""" 36 | inputs = self.tokenizer(sentences, 37 | padding=True, 38 | truncation=True, 39 | max_length=512, 40 | return_tensors='pt') 41 | encoded = {k: v.to(device) for k, v in inputs.items()} 42 | 43 | with torch.no_grad(): 44 | sent_emb = self(input_ids=encoded['input_ids'], attention_mask=encoded['attention_mask']) 45 | outputs = sent_emb.cpu().numpy() 46 | return outputs 47 | 48 | class AffinityPropagationAlgorithm: 49 | """Advanced clustering algorithm for dialogue segmentation.""" 50 | 51 | def __init__(self, position_weight: float = 13.0): 52 | """ 53 | Initialize the Affinity Propagation algorithm. 54 | 55 | Args: 56 | position_weight (float): Weight for position-based similarity. Defaults to 13.0. 57 | """ 58 | self.position_weight = position_weight 59 | self.logger = logging.getLogger(__name__) 60 | 61 | def get_cluster_centers_indices(self, inputs: np.ndarray) -> List[int]: 62 | """ 63 | Perform clustering using Affinity Propagation with content and position similarity. 64 | 65 | Args: 66 | inputs (np.ndarray): Input embeddings for clustering 67 | 68 | Returns: 69 | List[int]: Indices of cluster centers 70 | """ 71 | try: 72 | content_similarity = -euclidean_distances(inputs, squared=True) 73 | 74 | positions = np.array([i for i in range(len(inputs))]).reshape(-1, 1) 75 | position_similarity = -self.position_weight * euclidean_distances(positions, squared=True) 76 | 77 | similarity_matrix = content_similarity + position_similarity 78 | 79 | ap = AffinityPropagation( 80 | preference=np.min(content_similarity), 81 | affinity='precomputed', 82 | random_state=1 83 | ).fit(similarity_matrix) 84 | 85 | cluster_centers_indices = list(ap.cluster_centers_indices_) 86 | 87 | # Merge adjacent indices and handle boundary cases 88 | cluster_centers_indices = self._merge_adjacent(cluster_centers_indices) 89 | cluster_centers_indices = [idx for idx in cluster_centers_indices if 0 < idx < len(inputs)] 90 | 91 | # self.logger.info(f"Cluster centers indices: {cluster_centers_indices}") 92 | return cluster_centers_indices 93 | 94 | except Exception as e: 95 | self.logger.error(f"Error in clustering: {e}") 96 | return [] 97 | 98 | def _merge_adjacent(self, indices: List[int]) -> List[int]: 99 | """ 100 | Merge adjacent indices to reduce fragmentation. 101 | 102 | Args: 103 | indices (List[int]): List of cluster center indices 104 | 105 | Returns: 106 | List[int]: Merged indices 107 | """ 108 | if len(indices) <= 1: 109 | return indices 110 | 111 | result = [indices[0]] 112 | for idx in indices[1:]: 113 | if idx - result[-1] > 1: 114 | result.append(idx) 115 | 116 | return result 117 | 118 | def load_model_for_inference( 119 | model: DialogueBertModel, 120 | checkpoint_path: Optional[str] = None, 121 | device: str = 'cuda' if torch.cuda.is_available() else 'cpu' 122 | ) -> DialogueBertModel: 123 | """ 124 | Load model checkpoint and prepare for inference. 125 | 126 | Args: 127 | model (DialogueBertModel): Model to load 128 | checkpoint_path (str, optional): Path to checkpoint file 129 | device (str): Device to load model on 130 | 131 | Returns: 132 | DialogueBertModel: Model prepared for inference 133 | """ 134 | if checkpoint_path and os.path.exists(checkpoint_path): 135 | 136 | checkpoint = torch.load(checkpoint_path, map_location=device) 137 | model.load_state_dict(checkpoint['model_state_dict']) 138 | logging.info(f"Loaded checkpoint from {checkpoint_path}") 139 | 140 | model.to(device) 141 | model.eval() # Set model to evaluation mode 142 | return model 143 | 144 | def adjust_boundaries_within_contexts(embeddings: np.ndarray, initial_boundaries: List[int]) -> List[int]: 145 | """ 146 | Adjust segment boundaries based on local similarity. 147 | 148 | Args: 149 | embeddings (np.ndarray): Sentence embeddings 150 | initial_boundaries (List[int]): Initial segment boundaries 151 | 152 | Returns: 153 | List[int]: Refined segment boundaries 154 | """ 155 | initial_boundaries = [0] + initial_boundaries + [len(embeddings)] 156 | contexts = [(initial_boundaries[i], initial_boundaries[i+2]) for i in range(len(initial_boundaries) - 2)] 157 | 158 | new_boundaries = [] 159 | new_boundary = 0 160 | 161 | for start, end in contexts: 162 | if new_boundary > 0: 163 | start = new_boundary 164 | 165 | segment_embeddings = embeddings[start:end] 166 | similarity_matrix = cosine_similarity(segment_embeddings) 167 | 168 | # Calculate pairwise similarities between adjacent sentences 169 | pairwise_similarities = [similarity_matrix[i, i + 1] for i in range(len(segment_embeddings) - 1)] 170 | 171 | # Find the position with lowest similarity as new boundary 172 | min_similarity_idx = np.argmin(pairwise_similarities) 173 | 174 | new_boundary = start + min_similarity_idx + 1 175 | new_boundaries.append(new_boundary) 176 | 177 | return new_boundaries 178 | 179 | def calculate_variance(old_center: List[int], new_center: List[int]) -> float: 180 | """ 181 | Calculate variance between old and new cluster centers. 182 | 183 | Args: 184 | old_center (List[int]): Previous cluster centers 185 | new_center (List[int]): New cluster centers 186 | 187 | Returns: 188 | float: Squared distance between centers 189 | """ 190 | old_center = np.array(old_center) 191 | new_center = np.array(new_center) 192 | squared_distances = np.sum((old_center - new_center) ** 2, axis=-1) 193 | return squared_distances 194 | 195 | def convert_to_binary_segments(contents: List[str], seg_points: List[int]) -> List[int]: 196 | """ 197 | Convert segmentation points to segment lengths. 198 | 199 | Args: 200 | contents (List[str]): Original text contents 201 | seg_points (List[int]): Segmentation points 202 | 203 | Returns: 204 | List[int]: Segment lengths 205 | """ 206 | results_p = [] 207 | seg_p_labels = [0]*(len(contents)+1) 208 | for i in seg_points: 209 | seg_p_labels[i] = 1 210 | 211 | tmp = 0 212 | for fake in seg_p_labels: 213 | if fake == 1: 214 | tmp+=1 215 | results_p.append(tmp) 216 | tmp = 0 217 | else: 218 | tmp += 1 219 | results_p.append(tmp) 220 | results_p[0] = results_p[0] -1 221 | return results_p 222 | 223 | def main_process_plus(contents: List[str], model: DialogueBertModel, position_weight: float) -> List[int]: 224 | """ 225 | Main process for dialogue segmentation with iterative refinement. 226 | 227 | Args: 228 | contents (List[str]): Dialogue contents 229 | model (DialogueBertModel): Embedding model 230 | 231 | Returns: 232 | List[int]: Final segmentation points 233 | """ 234 | embeddings = model.get_sentence_embeddings(contents) 235 | 236 | # Initialize affinity propagation algorithm 237 | ap_algorithm = AffinityPropagationAlgorithm(position_weight) 238 | init_seg = ap_algorithm.get_cluster_centers_indices(embeddings) 239 | 240 | variance = 100 241 | iteration = 0 242 | 243 | while variance > 2: 244 | new_seg = adjust_boundaries_within_contexts(embeddings, init_seg) 245 | variance = calculate_variance(init_seg, new_seg) 246 | 247 | init_seg = new_seg 248 | iteration += 1 249 | 250 | if iteration > 5: 251 | break 252 | 253 | return init_seg 254 | 255 | def evaluate_segmentation(input_path: str, model: DialogueBertModel, position_weight: float): 256 | """ 257 | Evaluate dialogue segmentation on a dataset. 258 | 259 | Args: 260 | input_path (str): Path to input documents 261 | model (DialogueBertModel): Embedding model 262 | """ 263 | input_files = [f for f in os.listdir(input_path) if os.path.isfile(os.path.join(input_path, f))] 264 | 265 | window_diff_scores = [] 266 | pk_scores = [] 267 | total_turns = 0 268 | error_count = 0 269 | 270 | for file_name in tqdm(input_files,desc="Reading dialogues"): 271 | try: 272 | file_path = os.path.join(input_path, file_name) 273 | index = [] 274 | contents = [] 275 | tmp = 0 276 | 277 | with open(file_path, 'r') as file: 278 | for line in file: 279 | if '================' in line.strip(): 280 | index.append(tmp) 281 | tmp = 0 282 | else: 283 | tmp += 1 284 | contents.append(line) 285 | index.append(tmp) 286 | 287 | # Perform segmentation 288 | seg_predicted = main_process_plus(contents, model, position_weight) 289 | seg_reference = index 290 | 291 | # Convert to segment lengths 292 | seg_p = convert_to_binary_segments(contents, seg_predicted) 293 | 294 | # print(f"Predicted Segments: {seg_p}") 295 | # print(f"Reference Segments: {seg_reference}") 296 | 297 | total_turns += len(seg_reference) 298 | 299 | # Evaluate segmentation 300 | window_diff_scores.append(segeval.window_diff(seg_p, seg_reference)) 301 | pk_scores.append(segeval.pk(seg_p, seg_reference)) 302 | 303 | except Exception as e: 304 | print(f"Error processing file {file_name}: {e}") 305 | error_count += 1 306 | 307 | # Calculate average scores 308 | avg_window_diff = np.mean(window_diff_scores) 309 | avg_pk = np.mean(pk_scores) 310 | 311 | print(f"Average Window Diff Score: {avg_window_diff}") 312 | print(f"Average Pk Score: {avg_pk}") 313 | print(f"Total Errors: {error_count}") 314 | 315 | def parse_args(): 316 | parser = argparse.ArgumentParser() 317 | 318 | parser.add_argument("--model_name", type=str, required=True, 319 | help="Directory containing the pkl data files") 320 | parser.add_argument("--data_path", type=str, required=True, 321 | help="Path to dialogue data") 322 | parser.add_argument("--checkpoint_path", type=str, required=True, 323 | help="Path to checkpoint for inference") 324 | parser.add_argument("--position_weight", type=float, default=0.01, 325 | help="position weight of AP algorithm") 326 | 327 | 328 | args = parser.parse_args() 329 | 330 | return args 331 | def main(): 332 | """ 333 | Main execution function for dialogue segmentation. 334 | """ 335 | # Configure logging 336 | logging.basicConfig(level=logging.INFO) 337 | # Load model 338 | args = parse_args() 339 | model = DialogueBertModel(model_name=args.model_name) 340 | model = load_model_for_inference( 341 | model, 342 | args.checkpoint_path 343 | ) 344 | # Set input path 345 | input_path = args.data_path 346 | 347 | position_weight = args.position_weight 348 | 349 | # Evaluate segmentation 350 | evaluate_segmentation(input_path, model, position_weight) 351 | 352 | if __name__ == "__main__": 353 | main() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.4.1 2 | tensorflow==2.17.0 3 | transformers==4.44.2 4 | numpy==1.26.4 5 | tqdm==4.66.5 6 | segeval==2.0.11 7 | scikit-learn==1.5.2 8 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.utils.data import Dataset, DataLoader, RandomSampler 6 | from transformers import BertModel, BertTokenizer, AdamW, get_linear_schedule_with_warmup 7 | import numpy as np 8 | from typing import List, Dict, Optional 9 | import logging 10 | from tqdm import tqdm 11 | import argparse 12 | from dataclasses import dataclass 13 | import pickle 14 | import glob 15 | from torch.cuda.amp import autocast, GradScaler 16 | 17 | @dataclass 18 | class DialogueExample: 19 | dialogue_id: str 20 | sentences: List[str] 21 | embeddings: np.ndarray 22 | core_indices: List[int] 23 | 24 | @dataclass 25 | class TrainingSample: 26 | anchor_idx: int 27 | positive_indices: List[int] 28 | hard_negative_indices: List[int] 29 | regular_negative_indices: List[int] 30 | 31 | def load_pkl_data(data_dir: str) -> List[Dict]: 32 | processed_data = [] 33 | pkl_files = glob.glob(os.path.join(data_dir, "*.pkl")) 34 | 35 | if not pkl_files: 36 | raise ValueError(f"No pkl files found in {data_dir}") 37 | 38 | for pkl_file in tqdm(pkl_files, desc="Loading pkl files"): 39 | try: 40 | with open(pkl_file, 'rb') as f: 41 | data = pickle.load(f) 42 | if isinstance(data, list): 43 | processed_data.extend(data) 44 | else: 45 | processed_data.append(data) 46 | except Exception as e: 47 | logging.error(f"Error loading {pkl_file}: {str(e)}") 48 | continue 49 | 50 | logging.info(f"Loaded {len(processed_data)} samples from {len(pkl_files)} files") 51 | return processed_data 52 | 53 | class DialogueDataset(Dataset): 54 | def __init__(self, processed_data: List[Dict], tokenizer: BertTokenizer, max_length: int = 512): 55 | self.samples = [] 56 | self.tokenizer = tokenizer 57 | self.max_length = max_length 58 | 59 | for data in processed_data: 60 | example = data['example'] 61 | training_samples = data['training_samples'] 62 | 63 | for sample in training_samples: 64 | anchor_sent = sample.anchor_idx 65 | 66 | pos_sent = sample.positive_indices 67 | 68 | for neg_sent in sample.hard_negative_indices: 69 | self.samples.append({ 70 | 'anchor': anchor_sent, 71 | 'positive': pos_sent, 72 | 'negative': neg_sent, 73 | 'neg_type': 'hard' 74 | }) 75 | 76 | for neg_sent in sample.regular_negative_indices: 77 | self.samples.append({ 78 | 'anchor': anchor_sent, 79 | 'positive': pos_sent, 80 | 'negative': neg_sent, 81 | 'neg_type': 'regular' 82 | }) 83 | 84 | def __len__(self): 85 | return len(self.samples) 86 | 87 | def __getitem__(self, idx): 88 | sample = self.samples[idx] 89 | 90 | anchor_encoding = self.tokenizer( 91 | sample['anchor'], 92 | padding='max_length', 93 | truncation=True, 94 | max_length=self.max_length, 95 | return_tensors='pt' 96 | ) 97 | 98 | positive_encoding = self.tokenizer( 99 | sample['positive'], 100 | padding='max_length', 101 | truncation=True, 102 | max_length=self.max_length, 103 | return_tensors='pt' 104 | ) 105 | 106 | negative_encoding = self.tokenizer( 107 | sample['negative'], 108 | padding='max_length', 109 | truncation=True, 110 | max_length=self.max_length, 111 | return_tensors='pt' 112 | ) 113 | 114 | return { 115 | 'anchor_input_ids': anchor_encoding['input_ids'].squeeze(0), 116 | 'anchor_attention_mask': anchor_encoding['attention_mask'].squeeze(0), 117 | 'positive_input_ids': positive_encoding['input_ids'].squeeze(0), 118 | 'positive_attention_mask': positive_encoding['attention_mask'].squeeze(0), 119 | 'negative_input_ids': negative_encoding['input_ids'].squeeze(0), 120 | 'negative_attention_mask': negative_encoding['attention_mask'].squeeze(0), 121 | 'neg_type': 1.0 if sample['neg_type'] == 'hard' else 0.5 # hard negative have high value 122 | } 123 | 124 | @staticmethod 125 | def collect_fn(batch): 126 | return { 127 | 'anchor_input_ids': torch.stack([item['anchor_input_ids'] for item in batch]), 128 | 'anchor_attention_mask': torch.stack([item['anchor_attention_mask'] for item in batch]), 129 | 'positive_input_ids': torch.stack([item['positive_input_ids'] for item in batch]), 130 | 'positive_attention_mask': torch.stack([item['positive_attention_mask'] for item in batch]), 131 | 'negative_input_ids': torch.stack([item['negative_input_ids'] for item in batch]), 132 | 'negative_attention_mask': torch.stack([item['negative_attention_mask'] for item in batch]), 133 | 'neg_type': torch.tensor([item['neg_type'] for item in batch]) 134 | } 135 | 136 | class DialogueBertModel(nn.Module): 137 | def __init__(self, model_name: str = 'bert-base-uncased', margin: float = 0.5): 138 | super().__init__() 139 | self.bert = BertModel.from_pretrained(model_name) 140 | self.margin = margin 141 | 142 | def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: 143 | outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) 144 | return outputs.last_hidden_state[:, 0, :] 145 | 146 | def get_similarity(self, sent1_emb: torch.Tensor, sent2_emb: torch.Tensor, sigma=1.0) -> torch.Tensor: 147 | return torch.cosine_similarity(sent1_emb,sent2_emb) 148 | 149 | class Trainer: 150 | def __init__(self, 151 | model: DialogueBertModel, 152 | tokenizer: BertTokenizer, 153 | args): 154 | """ 155 | init trainer 156 | 157 | Args: 158 | model: BERT MODEL 159 | tokenizer: BERT TOKENIZER 160 | args: PARAMETERS 161 | """ 162 | self.args = args 163 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 164 | self.model = model.to(self.device) 165 | self.tokenizer = tokenizer 166 | self.scaler = GradScaler(enabled=(not args.no_amp)) 167 | 168 | logging.basicConfig( 169 | level=logging.INFO, 170 | format='%(asctime)s - %(levelname)s - %(message)s' 171 | ) 172 | self.logger = logging.getLogger(__name__) 173 | 174 | def setup_training(self, train_dataset: Dataset): 175 | 176 | train_sampler = RandomSampler(train_dataset) 177 | 178 | self.train_dataloader = DataLoader( 179 | train_dataset, 180 | sampler=train_sampler, 181 | batch_size=self.args.batch_size, 182 | collate_fn=DialogueDataset.collect_fn 183 | ) 184 | 185 | self.optimizer = AdamW( 186 | self.model.parameters(), 187 | lr=self.args.learning_rate, 188 | eps=self.args.adam_epsilon 189 | ) 190 | 191 | total_steps = len(self.train_dataloader) * self.args.num_epochs 192 | num_warmup_steps = int(total_steps * self.args.warmup_ratio) 193 | 194 | self.scheduler = get_linear_schedule_with_warmup( 195 | self.optimizer, 196 | num_warmup_steps=num_warmup_steps, 197 | num_training_steps=total_steps 198 | ) 199 | 200 | return total_steps 201 | 202 | def save_checkpoint(self, path: str, epoch: int, global_step: int, loss: float): 203 | """save checkpoint""" 204 | os.makedirs(os.path.dirname(path), exist_ok=True) 205 | torch.save({ 206 | 'epoch': epoch + 1, 207 | 'global_step': global_step, 208 | 'model_state_dict': self.model.state_dict(), 209 | 'optimizer_state_dict': self.optimizer.state_dict(), 210 | 'scheduler_state_dict': self.scheduler.state_dict(), 211 | 'loss': loss, 212 | }, path) 213 | self.logger.info(f"Saved checkpoint to {path}") 214 | 215 | def train(self, train_dataset: Dataset): 216 | """train model""" 217 | total_steps = self.setup_training(train_dataset) 218 | global_step = 0 219 | best_loss = float('inf') 220 | 221 | 222 | if self.args.resume and os.path.exists(self.args.resume): 223 | checkpoint = torch.load(self.args.resume, map_location=self.device) 224 | self.model.load_state_dict(checkpoint['model_state_dict']) 225 | self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 226 | self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 227 | global_step = checkpoint['global_step'] 228 | print(f"Resumed from checkpoint: {self.args.resume}") 229 | self.logger.info(f"Resumed from checkpoint: {self.args.resume}") 230 | 231 | for epoch in range(self.args.num_epochs): 232 | self.model.train() 233 | epoch_loss = 0 234 | 235 | for step, batch in enumerate(tqdm(self.train_dataloader, 236 | desc=f"Training Epoch {epoch + 1}")): 237 | with autocast(enabled=(not self.args.no_amp)): 238 | batch = {k: v.to(self.device) for k, v in batch.items()} 239 | 240 | 241 | anchor_emb = self.model(batch['anchor_input_ids'], 242 | batch['anchor_attention_mask']) 243 | positive_emb = self.model(batch['positive_input_ids'], 244 | batch['positive_attention_mask']) 245 | negative_emb = self.model(batch['negative_input_ids'], 246 | batch['negative_attention_mask']) 247 | 248 | 249 | pos_sim = self.model.get_similarity(anchor_emb, positive_emb) 250 | neg_sim = self.model.get_similarity(anchor_emb, negative_emb) 251 | 252 | diff = self.model.margin - pos_sim + neg_sim 253 | clamped_diff = torch.clamp(diff, min=0.0) 254 | loss = torch.mean(batch['neg_type'] * clamped_diff) 255 | 256 | 257 | self.scaler.scale(loss).backward() 258 | 259 | self.scaler.unscale_(self.optimizer) 260 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm) 261 | 262 | self.scaler.step(self.optimizer) 263 | self.scaler.update() 264 | self.scheduler.step() 265 | self.optimizer.zero_grad() 266 | 267 | epoch_loss += loss.item() 268 | global_step += 1 269 | 270 | 271 | if self.args.save_steps > 0 and global_step % self.args.save_steps == 0: 272 | self.save_checkpoint( 273 | os.path.join(self.args.output_dir, f'checkpoint-{global_step}'), 274 | epoch, 275 | global_step, 276 | loss.item() 277 | ) 278 | 279 | 280 | avg_loss = epoch_loss / len(self.train_dataloader) 281 | self.logger.info(f"Epoch {epoch + 1}/{self.args.num_epochs}, " 282 | f"Average Loss: {avg_loss:.4f}") 283 | 284 | 285 | if avg_loss < best_loss: 286 | best_loss = avg_loss 287 | self.save_checkpoint( 288 | os.path.join(self.args.output_dir, 'best_model'), 289 | epoch, 290 | global_step, 291 | avg_loss 292 | ) 293 | 294 | def parse_args(): 295 | parser = argparse.ArgumentParser() 296 | 297 | 298 | parser.add_argument("--data_dir", type=str, required=True, 299 | help="Directory containing the pkl data files") 300 | parser.add_argument("--max_length", type=int, default=512, 301 | help="Maximum sequence length for BERT input") 302 | 303 | 304 | parser.add_argument("--model_name", type=str, default="sup-simcse-bert-base-uncased", 305 | help="Pretrained BERT model name") 306 | parser.add_argument("--margin", type=float, default=0.5, 307 | help="Margin for triplet loss") 308 | 309 | 310 | parser.add_argument("--batch_size", type=int, default=32, 311 | help="Training batch size") 312 | parser.add_argument("--learning_rate", type=float, default=2e-5, 313 | help="Initial learning rate") 314 | parser.add_argument("--adam_epsilon", type=float, default=1e-8, 315 | help="Epsilon for Adam optimizer") 316 | parser.add_argument("--warmup_ratio", type=float, default=0.1, 317 | help="Ratio of warmup steps") 318 | parser.add_argument("--num_epochs", type=int, default=1, 319 | help="Number of training epochs") 320 | parser.add_argument("--max_grad_norm", type=float, default=1.0, 321 | help="Maximum gradient norm for clipping") 322 | 323 | parser.add_argument("--output_dir", type=str, required=True, 324 | help="Directory to save model checkpoints") 325 | parser.add_argument("--save_steps", type=int, default=1000, 326 | help="Save checkpoint every X updates steps (0 to disable)") 327 | 328 | 329 | parser.add_argument("--resume", type=str, default=None, 330 | help="Path to checkpoint for resuming training") 331 | 332 | 333 | parser.add_argument("--seed", type=int, default=42, 334 | help="Random seed for initialization") 335 | parser.add_argument("--no_amp", action="store_true", 336 | help="Disable automatic mixed precision training") 337 | 338 | args = parser.parse_args() 339 | 340 | 341 | os.makedirs(args.output_dir, exist_ok=True) 342 | 343 | return args 344 | 345 | def main(): 346 | 347 | args = parse_args() 348 | 349 | 350 | torch.manual_seed(args.seed) 351 | np.random.seed(args.seed) 352 | 353 | 354 | tokenizer = BertTokenizer.from_pretrained(args.model_name) 355 | model = DialogueBertModel(model_name=args.model_name, margin=args.margin) 356 | 357 | 358 | processed_data = load_pkl_data(args.data_dir) 359 | train_dataset = DialogueDataset( 360 | processed_data=processed_data, 361 | tokenizer=tokenizer, 362 | max_length=args.max_length 363 | ) 364 | 365 | 366 | trainer = Trainer(model, tokenizer,args) 367 | 368 | 369 | trainer.train(train_dataset) 370 | 371 | if __name__ == "__main__": 372 | main() --------------------------------------------------------------------------------