├── data_indices ├── data_index_sciq.npy ├── data_index_tqa.npy ├── exemplar_idx_tqa.npy ├── data_index_nq_open.npy ├── exemplar_idx_sciq.npy ├── data_index_triviaqa.npy ├── exemplar_idx_nq_open.npy └── exemplar_idx_triviaqa.npy ├── __pycache__ ├── cache_utils.cpython-38.pyc ├── llm_layers.cpython-38.pyc ├── train_utils.cpython-38.pyc └── sinkhorn_knopp.cpython-38.pyc ├── train.sh ├── gen.sh ├── gt.sh ├── README.md ├── sinkhorn_knopp.py ├── cache_utils.py ├── train_utils.py ├── llm_layers.py ├── requirements.txt ├── LICENSE ├── tsv.yml └── tsv_main.py /data_indices/data_index_sciq.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/tsv/HEAD/data_indices/data_index_sciq.npy -------------------------------------------------------------------------------- /data_indices/data_index_tqa.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/tsv/HEAD/data_indices/data_index_tqa.npy -------------------------------------------------------------------------------- /data_indices/exemplar_idx_tqa.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/tsv/HEAD/data_indices/exemplar_idx_tqa.npy -------------------------------------------------------------------------------- /data_indices/data_index_nq_open.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/tsv/HEAD/data_indices/data_index_nq_open.npy -------------------------------------------------------------------------------- /data_indices/exemplar_idx_sciq.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/tsv/HEAD/data_indices/exemplar_idx_sciq.npy -------------------------------------------------------------------------------- /__pycache__/cache_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/tsv/HEAD/__pycache__/cache_utils.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/llm_layers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/tsv/HEAD/__pycache__/llm_layers.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/train_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/tsv/HEAD/__pycache__/train_utils.cpython-38.pyc -------------------------------------------------------------------------------- /data_indices/data_index_triviaqa.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/tsv/HEAD/data_indices/data_index_triviaqa.npy -------------------------------------------------------------------------------- /data_indices/exemplar_idx_nq_open.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/tsv/HEAD/data_indices/exemplar_idx_nq_open.npy -------------------------------------------------------------------------------- /data_indices/exemplar_idx_triviaqa.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/tsv/HEAD/data_indices/exemplar_idx_triviaqa.npy -------------------------------------------------------------------------------- /__pycache__/sinkhorn_knopp.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/tsv/HEAD/__pycache__/sinkhorn_knopp.cpython-38.pyc -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python tsv_main.py --model_name llama3.1-8B --dataset_name tqa --most_likely 1 > train.log 2>&1 & 2 | 3 | -------------------------------------------------------------------------------- /gen.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python tsv_main.py --model_name llama3.1-8B --dataset_name tqa --most_likely 1 --num_gene 1 --gene 1 > gen.log 2>&1 & 2 | 3 | -------------------------------------------------------------------------------- /gt.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python tsv_main.py --model_name llama3.1-8B --dataset_name tqa --most_likely 1 --num_gene 1 --generate_gt 1 > gt.log 2>&1 & 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TSV 2 | 3 | 4 | Source code for ICML 2025 paper [Steer LLM Latents for Hallucination Detection](https://arxiv.org/abs/2503.01917) by Seongheon Park, Xuefeng Du, Min-Hsuan Yeh, Haobo Wang, and Yixuan Li 5 | 6 | --- 7 | 8 | ## Requirements 9 | 10 | ```bash 11 | conda env create -f tsv.yml 12 | ``` 13 | --- 14 | 15 | ## LLM response generation 16 | 17 | Generate responses for each question to construct an unlabeled QA dataset in the wild. 18 | 19 | ```bash 20 | bash gen.sh 21 | ``` 22 | 23 | --- 24 | 25 | ## GT generation 26 | 27 | Generate [BLEURT](https://arxiv.org/abs/2004.04696) score for each QA pair 28 | 29 | 30 | ```bash 31 | bash gt.sh 32 | ``` 33 | 34 | --- 35 | 36 | ## Train TSV 37 | 38 | Train TSV for hallucination detection. 39 | 40 | ```bash 41 | bash train.sh 42 | ``` 43 | 44 | --- 45 | 46 | ## Citation 47 | 48 | ``` 49 | @inproceedings{ 50 | park2025steer, 51 | title={Steer {LLM} Latents for Hallucination Detection}, 52 | author={Seongheon Park and Xuefeng Du and Min-Hsuan Yeh and Haobo Wang and Yixuan Li}, 53 | booktitle={Forty-second International Conference on Machine Learning}, 54 | year={2025} 55 | } 56 | ``` 57 | 58 | --- 59 | 60 | ## Acknowledgement 61 | 62 | We gratefully acknowledge [HaloScope](https://arxiv.org/abs/2409.17504), [ITI](https://arxiv.org/abs/2306.03341), and [ICV](https://arxiv.org/abs/2311.06668) for their inspiring ideas and open-source contributions, which served as valuable foundations for this work. 63 | -------------------------------------------------------------------------------- /sinkhorn_knopp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | 5 | # https://github.com/facebookresearch/swav/blob/main/main_swav.py 6 | 7 | def shoot_infs(inp_tensor): 8 | """Replaces inf by maximum of tensor""" 9 | mask_inf = torch.isinf(inp_tensor) 10 | ind_inf = torch.nonzero(mask_inf) 11 | if len(ind_inf) > 0: 12 | for ind in ind_inf: 13 | if len(ind) == 2: 14 | inp_tensor[ind[0], ind[1]] = 0 15 | elif len(ind) == 1: 16 | inp_tensor[ind[0]] = 0 17 | m = torch.max(inp_tensor) 18 | for ind in ind_inf: 19 | if len(ind) == 2: 20 | inp_tensor[ind[0], ind[1]] = m 21 | elif len(ind) == 1: 22 | inp_tensor[ind[0]] = m 23 | return inp_tensor 24 | 25 | 26 | class SinkhornKnopp_imb(torch.nn.Module): 27 | def __init__(self, args, cls_dist): 28 | super().__init__() 29 | self.num_iters = args.num_iters_sk 30 | self.epsilon = args.epsilon_sk 31 | self.temperature = args.cos_temp 32 | self.cls_dist = cls_dist 33 | 34 | @torch.no_grad() 35 | def iterate(self, Q): 36 | 37 | Q = shoot_infs(Q) 38 | sum_Q = torch.sum(Q) 39 | Q /= sum_Q 40 | 41 | B = Q.shape[1] 42 | K = Q.shape[0] 43 | 44 | for it in range(self.num_iters): 45 | 46 | sum_of_rows = torch.sum(Q, dim=1, keepdim=True) 47 | Q /= sum_of_rows 48 | Q = shoot_infs(Q) 49 | Q *= self.cls_dist 50 | 51 | # normalize each column: total weight per sample must be 1/B 52 | Q /= torch.sum(Q, dim=0, keepdim=True) 53 | Q /= B 54 | 55 | Q *= B # the colomns must sum to 1 so that Q is an assignment 56 | 57 | return Q.t() 58 | 59 | @torch.no_grad() 60 | def forward(self, embeddings, centroids): 61 | 62 | 63 | last_token_rep = F.normalize(embeddings, p=2, dim=-1) 64 | centroids = F.normalize(centroids, p=2, dim=-1) 65 | 66 | # Compute cosine similarity (which is equivalent to the dot product for normalized vectors) 67 | similarities = torch.matmul(last_token_rep, centroids.T) 68 | 69 | # Apply the temperature scaling factor (similar to dividing by τ in the equation) 70 | similarities = similarities / self.temperature 71 | 72 | # Convert similarities to probability distributions using softmax 73 | pt = F.softmax(similarities, dim=-1) 74 | 75 | # Compute the OT loss as the cross-entropy between pseudo-labels and pt 76 | pt = torch.log(pt + 1e-8) 77 | 78 | # Divide by temperature (epsilon) to scale the distance 79 | q = pt / (self.epsilon) 80 | 81 | # Apply exponential to form soft assignment weights 82 | q = torch.exp(q).t() 83 | 84 | return self.iterate(q) -------------------------------------------------------------------------------- /cache_utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import importlib.metadata 3 | import json 4 | import os 5 | from dataclasses import dataclass 6 | from typing import Any, Dict, List, Optional, Tuple, Union 7 | 8 | import torch 9 | from packaging import version 10 | 11 | # from .configuration_utils import PretrainedConfig 12 | # from .utils import ( 13 | # is_hqq_available, 14 | # is_optimum_quanto_available, 15 | # is_torchdynamo_compiling, 16 | # logging, 17 | # ) 18 | # from .utils.deprecation import deprecate_kwarg 19 | 20 | 21 | # if is_hqq_available(): 22 | # from hqq.core.quantize import Quantizer as HQQQuantizer 23 | 24 | # logger = logging.get_logger(__name__) 25 | 26 | 27 | class Cache(torch.nn.Module): 28 | """ 29 | Base, abstract class for all caches. The actual data structure is specific to each subclass. 30 | """ 31 | 32 | def __init__(self): 33 | super().__init__() 34 | 35 | def update( 36 | self, 37 | key_states: torch.Tensor, 38 | value_states: torch.Tensor, 39 | layer_idx: int, 40 | cache_kwargs: Optional[Dict[str, Any]] = None, 41 | ) -> Tuple[torch.Tensor, torch.Tensor]: 42 | """ 43 | Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. 44 | 45 | Parameters: 46 | key_states (`torch.Tensor`): 47 | The new key states to cache. 48 | value_states (`torch.Tensor`): 49 | The new value states to cache. 50 | layer_idx (`int`): 51 | The index of the layer to cache the states for. 52 | cache_kwargs (`Dict[str, Any]`, `optional`): 53 | Additional arguments for the cache subclass. These are specific to each subclass and allow new types of 54 | cache to be created. 55 | 56 | Return: 57 | A tuple containing the updated key and value states. 58 | """ 59 | raise NotImplementedError("Make sure to implement `update` in a subclass.") 60 | 61 | def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: 62 | """Returns the sequence length of the cached states. A layer index can be optionally passed.""" 63 | # TODO: deprecate this function in favor of `cache_position` 64 | raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") 65 | 66 | # Deprecate in favor of max-cache-shape because we want to be specifc by what we mean with "max_length" 67 | # Prev some cache objects didn't have "max_length" (SlidingWindowCache or SinkCache) because the cache object technically handles 68 | # infinite amount of tokens. In the codebase what we really need to check is the max capacity of certain cache instances, so 69 | # we change naming to be more explicit 70 | def get_max_length(self) -> Optional[int]: 71 | # logger.warning_once( 72 | # "`get_max_cache()` is deprecated for all Cache classes. Use `get_max_cache_shape()` instead. " 73 | # "Calling `get_max_cache()` will raise error from v4.48" 74 | # ) 75 | return self.get_max_cache_shape() 76 | 77 | def get_max_cache_shape(self) -> Optional[int]: 78 | """Returns the maximum sequence length (i.e. max capacity) of the cache object""" 79 | raise NotImplementedError("Make sure to implement `get_max_cache_shape` in a subclass.") 80 | 81 | def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: 82 | """Given the sequence length of the new inputs, returns the usable length of the cache.""" 83 | # Cache without size limit -> all cache is usable 84 | # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache 85 | # length, we will need to evict part of the cache (and thus not all cache is usable) 86 | max_length = self.get_max_cache_shape() 87 | previous_seq_length = self.get_seq_length(layer_idx) 88 | if max_length is not None and previous_seq_length + new_seq_length > max_length: 89 | return max_length - new_seq_length 90 | return previous_seq_length 91 | 92 | def reorder_cache(self, beam_idx: torch.LongTensor): 93 | """Reorders the cache for beam search, given the selected beam indices.""" 94 | for layer_idx in range(len(self.key_cache)): 95 | if self.key_cache[layer_idx] != []: 96 | device = self.key_cache[layer_idx].device 97 | self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) 98 | if self.value_cache[layer_idx] != []: 99 | device = self.value_cache[layer_idx].device 100 | self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) 101 | 102 | @property 103 | def seen_tokens(self): 104 | # logger.warning_once( 105 | # "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` " 106 | # "model input instead." 107 | # ) 108 | if hasattr(self, "_seen_tokens"): 109 | return self._seen_tokens 110 | else: 111 | return None 112 | 113 | 114 | -------------------------------------------------------------------------------- /train_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from tqdm import tqdm 4 | from torch.cuda.amp import autocast 5 | import torch.nn.functional as F 6 | 7 | 8 | 9 | def collate_fn(prompts, labels): 10 | 11 | # Find the maximum sequence length in the batch 12 | max_seq_len = max(prompt.size(1) for prompt in prompts) 13 | 14 | # Initialize a tensor to hold the batched prompts 15 | batch_size = len(prompts) 16 | dtype = prompts[0].dtype 17 | device = prompts[0].device # Assuming all prompts are on the same device 18 | prompts_padded = torch.zeros(batch_size, 1, max_seq_len, dtype=dtype) 19 | 20 | # Pad each prompt to the maximum sequence length 21 | for i, prompt in enumerate(prompts): 22 | seq_len = prompt.size(1) 23 | prompts_padded[i, :, :seq_len] = prompt 24 | 25 | # Stack labels into a tensor 26 | labels = torch.tensor(labels, dtype=torch.long, device=device) 27 | 28 | return prompts_padded, labels 29 | 30 | 31 | def get_last_non_padded_token_rep(hidden_states, attention_mask): 32 | """ 33 | Get the last non-padded token's representation for each sequence in the batch. 34 | """ 35 | # Find the length of each sequence by summing the attention mask (1 for real tokens, 0 for padding) 36 | lengths = attention_mask.squeeze().sum(dim=1).long() 37 | 38 | # Index the last non-padded token for each sequence 39 | batch_size, max_seq_len, hidden_size = hidden_states.size() 40 | last_token_reps = torch.stack([hidden_states[i, lengths[i]-1, :] for i in range(batch_size)]) 41 | 42 | return last_token_reps 43 | 44 | 45 | def get_ex_data(model, prompts, labels, batch_size, centroids, sinkhorn, num_selected_data, cls_dist, args): 46 | 47 | all_embeddings = [] 48 | all_labels = [] 49 | num_samples = len(prompts) 50 | 51 | with torch.no_grad(): 52 | with autocast(dtype=torch.float16): 53 | for batch_start in tqdm(range(0, num_samples, batch_size)): 54 | batch_prompts = prompts[batch_start: batch_start + batch_size] 55 | batch_labels = labels[batch_start: batch_start + batch_size] 56 | batch_prompts, batch_labels = collate_fn(batch_prompts,batch_labels) 57 | attention_mask = (batch_prompts != 0).half() 58 | batch_prompts = batch_prompts.cuda() 59 | batch_labels = batch_labels.cuda() 60 | attention_mask = attention_mask.to(batch_prompts.device) 61 | all_labels.append(batch_labels.cpu().numpy()) 62 | 63 | output = model(batch_prompts.squeeze(), attention_mask=attention_mask.squeeze(), output_hidden_states=True) 64 | hidden_states = output.hidden_states 65 | 66 | hidden_states = torch.stack(hidden_states, dim=0).squeeze() 67 | last_layer_hidden_state = hidden_states[-1] 68 | 69 | last_token_rep = get_last_non_padded_token_rep(last_layer_hidden_state, attention_mask.squeeze()) 70 | all_embeddings.append(last_token_rep) 71 | 72 | all_embeddings = F.normalize(torch.concat(all_embeddings),p=2,dim=-1) 73 | 74 | pseudo_label = sinkhorn(all_embeddings, centroids) 75 | 76 | selected_indices = compute_entropy(all_embeddings, centroids, pseudo_label, num_selected_data, cls_dist, args) 77 | 78 | selected_labels_soft = pseudo_label[selected_indices] 79 | 80 | 81 | return selected_indices, selected_labels_soft 82 | 83 | 84 | def compute_ot_loss_cos(last_token_rep, centroids, pseudo_label, batch_size, args): 85 | 86 | last_token_rep = F.normalize(last_token_rep, p=2, dim=-1) 87 | 88 | centroids = F.normalize(centroids, p=2, dim=-1) 89 | 90 | similarities = torch.matmul(last_token_rep, centroids.T) 91 | 92 | similarities = similarities / args.cos_temp 93 | 94 | pt = F.softmax(similarities, dim=-1) 95 | 96 | ot_loss = -torch.sum(pseudo_label * torch.log(pt + 1e-8)) / pseudo_label.shape[0] 97 | 98 | return ot_loss, similarities 99 | 100 | 101 | def compute_entropy(last_token_rep, centroids, pseudo_label, k, cls_dist, args): 102 | 103 | 104 | last_token_rep = F.normalize(last_token_rep, p=2, dim=-1) 105 | 106 | centroids = F.normalize(centroids, p=2, dim=-1) 107 | 108 | similarities = torch.matmul(last_token_rep, centroids.T) 109 | 110 | similarities = similarities / args.cos_temp 111 | 112 | pt = F.softmax(similarities, dim=-1) 113 | 114 | ce = - (pseudo_label * torch.log(pt + 1e-8)) 115 | 116 | pseudo_label_hard = torch.argmax(pt,dim=1) 117 | 118 | # * Added for preventing severe cases 119 | # Class-wise data selection: Select pseudo-labeled unlabeled data in proportion to the class distribution of the exemplar set. 120 | 121 | cls0_num = k*cls_dist[0] 122 | cls1_num = k*cls_dist[1] 123 | 124 | cls_0_indices = (pseudo_label_hard == 0).nonzero(as_tuple=True)[0] 125 | cls_1_indices = (pseudo_label_hard == 1).nonzero(as_tuple=True)[0] 126 | 127 | ce = torch.sum(ce, dim=1) 128 | 129 | ce_class_0 = ce[cls_0_indices] 130 | ce_class_1 = ce[cls_1_indices] 131 | 132 | if len(ce_class_0) < cls0_num or len(ce_class_1) < cls1_num: # Fallback to top-k across all classes 133 | 134 | _, top_k_indices = torch.topk(ce, k, largest=False, sorted=True) 135 | 136 | else: 137 | 138 | top_0_indices = cls_0_indices[torch.topk(ce_class_0, int(cls0_num), largest=False, sorted=True).indices] 139 | top_1_indices = cls_1_indices[torch.topk(ce_class_1, int(cls1_num), largest=False, sorted=True).indices] 140 | top_k_indices = torch.cat((top_0_indices, top_1_indices)) 141 | 142 | return top_k_indices 143 | 144 | 145 | def update_centroids_ema(centroids, last_token_rep, pseudo_label, args): 146 | 147 | last_token_rep_norm = F.normalize(last_token_rep, p=2, dim=1) 148 | 149 | centroids= F.normalize(centroids, p=2, dim=1) 150 | 151 | weighted_sum = torch.matmul(pseudo_label.T, last_token_rep_norm) 152 | 153 | # Normalize the weighted sums to get the new centroids 154 | pseudo_label_sum = pseudo_label.sum(dim=0).unsqueeze(1) + 1e-8 155 | new_centroids_batch = weighted_sum / pseudo_label_sum 156 | 157 | # EMA update for centroids 158 | updated_centroids = F.normalize(args.ema_decay * centroids + (1 - args.ema_decay) * new_centroids_batch, p=2, dim=1) 159 | 160 | return updated_centroids 161 | 162 | def update_centroids_ema_hard(centroids, last_token_rep, pseudo_label, args): 163 | 164 | last_token_rep_norm = F.normalize(last_token_rep, p=2, dim=1) 165 | 166 | centroids = F.normalize(centroids, p=2, dim=1) 167 | 168 | max_indices = torch.argmax(pseudo_label, dim=1) 169 | 170 | discrete_labels = torch.zeros_like(pseudo_label) 171 | 172 | discrete_labels[torch.arange(pseudo_label.size(0)), max_indices] = 1 173 | 174 | weighted_sum = torch.matmul(discrete_labels.T.float(), last_token_rep_norm) 175 | 176 | pseudo_label_sum = discrete_labels.sum(dim=0).unsqueeze(1) + 1e-8 177 | 178 | new_centroids_batch = weighted_sum / pseudo_label_sum 179 | 180 | # EMA update for centroids 181 | updated_centroids = F.normalize(args.ema_decay * centroids + (1 - args.ema_decay) * new_centroids_batch, p=2, dim=-1) 182 | 183 | return updated_centroids -------------------------------------------------------------------------------- /llm_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | from torch import nn 4 | from transformers import PreTrainedModel 5 | from torch import Tensor 6 | import numpy as np 7 | from typing import Optional, Tuple 8 | from cache_utils import Cache 9 | from transformers.activations import ACT2FN 10 | 11 | class LlamaDecoderLayerWrapper(nn.Module): 12 | def __init__(self, llama_decoder_layer, tsv_layer, model_name='llama3.1-8B'): 13 | super().__init__() 14 | self.llama_decoder_layer = llama_decoder_layer 15 | self.tsv_layer = tsv_layer # Instance of ICVLayer 16 | self.model_name = model_name 17 | 18 | def forward( 19 | self, 20 | hidden_states: torch.Tensor, 21 | attention_mask: Optional[torch.Tensor] = None, 22 | position_ids: Optional[torch.LongTensor] = None, 23 | past_key_value: Optional[Cache] = None, 24 | output_attentions: Optional[bool] = False, 25 | use_cache: Optional[bool] = False, 26 | cache_position: Optional[torch.LongTensor] = None, 27 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 28 | **kwargs, 29 | )-> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: 30 | # Save original residual state 31 | residual = hidden_states 32 | 33 | # Forward pass through the input layer norm 34 | hidden_states = self.llama_decoder_layer.input_layernorm(hidden_states) 35 | 36 | 37 | if self.model_name == 'qwen2.5-7B': 38 | hidden_states, self_attn_weights, present_key_value = self.llama_decoder_layer.self_attn( 39 | hidden_states=hidden_states, 40 | attention_mask=attention_mask, 41 | position_ids=position_ids, 42 | past_key_value=past_key_value, 43 | output_attentions=output_attentions, 44 | use_cache=use_cache, 45 | cache_position=cache_position, 46 | **kwargs, 47 | ) 48 | 49 | else: 50 | hidden_states, self_attn_weights, present_key_value = self.llama_decoder_layer.self_attn( 51 | hidden_states=hidden_states, 52 | attention_mask=attention_mask, 53 | position_ids=position_ids, 54 | past_key_value=past_key_value, 55 | output_attentions=output_attentions, 56 | use_cache=use_cache, 57 | cache_position=cache_position, 58 | position_embeddings=position_embeddings, 59 | **kwargs, 60 | ) 61 | 62 | # Add residual + steering vector after self-attention 63 | hidden_states = residual.to(hidden_states.device) + hidden_states 64 | 65 | 66 | # Save residual state for the MLP 67 | residual = hidden_states 68 | 69 | # Forward pass through the post-attention layer norm and MLP 70 | hidden_states = self.llama_decoder_layer.post_attention_layernorm(hidden_states) 71 | hidden_states = self.llama_decoder_layer.mlp(hidden_states) 72 | 73 | # Add residual + steering vector after MLP 74 | hidden_states = residual + hidden_states 75 | hidden_states = self.tsv_layer(hidden_states) # Add steering vector 76 | 77 | # Return the outputs 78 | outputs = (hidden_states,) 79 | if output_attentions: 80 | outputs += (self_attn_weights,) 81 | if use_cache: 82 | outputs += (present_key_value,) 83 | 84 | return outputs 85 | 86 | class TSVLayer(nn.Module): 87 | 88 | def __init__(self, tsv, lam): 89 | super(TSVLayer, self).__init__() 90 | self.tsv = tsv 91 | self.lam = lam 92 | 93 | def forward(self, x): 94 | if self.tsv is not None: 95 | 96 | x = x.half() 97 | y = self.lam[0] * self.tsv.repeat(1,x.shape[1],1) 98 | y = y.to(x.device) 99 | x = x.half() + y 100 | 101 | return x.half() 102 | 103 | else: 104 | 105 | return x.half() 106 | 107 | 108 | def get_nested_attr(obj, attr_path): 109 | attrs = attr_path.split(".") 110 | for attr in attrs: 111 | obj = getattr(obj, attr) 112 | return obj 113 | 114 | 115 | def set_nested_attr(obj, attr_path, value): 116 | attrs = attr_path.split(".") 117 | parent = get_nested_attr(obj, ".".join(attrs[:-1])) 118 | setattr(parent, attrs[-1], value) 119 | 120 | 121 | def find_longest_modulelist(model, path=""): 122 | """ 123 | Recursively find the longest nn.ModuleList in a PyTorch model. 124 | Args: 125 | model: PyTorch model. 126 | path: Current path in the model (used for recursion). 127 | Returns: 128 | Tuple with path and length of the longest nn.ModuleList found. 129 | """ 130 | longest_path = path 131 | longest_len = 0 132 | 133 | for name, child in model.named_children(): 134 | if isinstance(child, nn.ModuleList) and len(child) > longest_len: 135 | longest_len = len(child) 136 | longest_path = f"{path}.{name}" if path else name 137 | 138 | # Recursively check the child's children 139 | child_path, child_len = find_longest_modulelist(child, f"{path}.{name}" if path else name) 140 | if child_len > longest_len: 141 | longest_len = child_len 142 | longest_path = child_path 143 | 144 | return longest_path, longest_len 145 | 146 | 147 | def find_module(block, keywords): 148 | """ 149 | Try to find a module in a transformer block. 150 | Args: 151 | block: Transformer block (nn.Module). 152 | keywords: List of possible module names (str). 153 | Returns: 154 | The found module if found, else None. 155 | """ 156 | 157 | for name, module in block.named_modules(): 158 | if any(keyword in name for keyword in keywords): 159 | return module 160 | submodule_names = [name for name, _ in block.named_modules()] 161 | raise ValueError(f"Could not find keywords {keywords} in: {submodule_names}") 162 | 163 | 164 | def get_embedding_layer(model: PreTrainedModel): 165 | 166 | keywords = ["emb", "wte"] 167 | return find_module(model, keywords) 168 | 169 | 170 | def get_lm_head(model: PreTrainedModel): 171 | keywords = ["lm_head", "embed_out"] 172 | return find_module(model, keywords) 173 | 174 | 175 | def get_lm_pipeline(model: PreTrainedModel): 176 | model_class = model.__class__.__name__ 177 | 178 | if model_class == "LlamaForCausalLM": 179 | return nn.Sequential(model.model.norm, model.lm_head) 180 | elif model_class == "RWForCausalLM": 181 | return nn.Sequential(model.transformer.ln_f, model.lm_head) 182 | elif model_class == "GPTNeoForCausalLM": 183 | return nn.Sequential(model.transformer.ln_f, model.lm_head) 184 | elif model_class == "GPTNeoXForCausalLM": 185 | return nn.Sequential(model.gpt_neox.final_layer_norm, model.embed_out) 186 | 187 | # TODO: make the default case more robust 188 | return get_lm_head(model) 189 | 190 | 191 | def get_layers_path(model: PreTrainedModel): 192 | longest_path, longest_len = find_longest_modulelist(model) 193 | return longest_path 194 | 195 | 196 | def get_layers(model: PreTrainedModel): 197 | longest_path = get_layers_path(model) 198 | return get_nested_attr(model, longest_path) 199 | 200 | def get_mlp_layers(model: PreTrainedModel): 201 | layers = get_layers(model) 202 | mlp_keywords = ["mlp", "feedforward", "ffn"] 203 | mlp_layers = [find_module(layer, mlp_keywords) for layer in layers] 204 | return mlp_layers 205 | 206 | def add_tsv_layers(model: PreTrainedModel, tsv: Tensor, alpha: list, args): 207 | layers = get_layers(model) 208 | mlp_keywords = ["mlp", "feedforward", "ffn"] 209 | attn_keywords = ["self_attn"] 210 | 211 | assert len(tsv) == len(layers) 212 | if args.component == 'mlp': 213 | for i, layer in enumerate(layers): 214 | if i == args.str_layer: 215 | original_mlp = find_module(layer, mlp_keywords) 216 | layer.mlp = nn.Sequential(original_mlp, TSVLayer(tsv[i], alpha)) 217 | 218 | elif args.component == 'attn': 219 | for i, layer in enumerate(layers): 220 | if i == args.str_layer: 221 | original_attn = find_module(layer, attn_keywords) 222 | layer.self_attn = nn.Sequential(original_attn, TSVLayer(tsv[i], alpha)) 223 | 224 | elif args.component == 'res': 225 | 226 | for i, layer in enumerate(layers): 227 | if i == args.str_layer: 228 | decoder_layer = layers[i] 229 | layers[i] = LlamaDecoderLayerWrapper(decoder_layer, TSVLayer(tsv[i], alpha), args.model_name) 230 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | # created-by: conda 25.1.1 5 | _libgcc_mutex=0.1=main 6 | _openmp_mutex=5.1=1_gnu 7 | absl-py=1.4.0=pypi_0 8 | accelerate=0.21.0=pypi_0 9 | aiohttp=3.8.3=pypi_0 10 | aiosignal=1.3.1=pypi_0 11 | annotated-types=0.7.0=pypi_0 12 | anyio=3.5.0=py38h06a4308_0 13 | argon2-cffi=21.3.0=pyhd3eb1b0_0 14 | argon2-cffi-bindings=21.2.0=py38h7f8727e_0 15 | array-record=0.2.0=pypi_0 16 | asttokens=2.0.5=pyhd3eb1b0_0 17 | astunparse=1.6.3=pypi_0 18 | async-timeout=4.0.2=pypi_0 19 | attrs=22.1.0=py38h06a4308_0 20 | babel=2.9.1=pyhd3eb1b0_0 21 | backcall=0.2.0=pyhd3eb1b0_0 22 | baukit=0.0.1=pypi_0 23 | beautifulsoup4=4.11.1=py38h06a4308_0 24 | bert-score=0.3.13=pypi_0 25 | bitsandbytes=0.45.4=pypi_0 26 | blas=1.0=mkl 27 | bleach=4.1.0=pyhd3eb1b0_0 28 | blessed=1.20.0=pypi_0 29 | bleurt=0.0.2=pypi_0 30 | bleurt-pytorch=0.0.1=pypi_0 31 | brotlipy=0.7.0=py38h27cfd23_1003 32 | bzip2=1.0.8=h7b6447c_0 33 | ca-certificates=2024.11.26=h06a4308_0 34 | cachetools=5.3.0=pypi_0 35 | catalogue=2.0.10=py38h06a4308_0 36 | certifi=2024.8.30=py38h06a4308_0 37 | cffi=1.15.1=py38h74dc2b5_0 38 | charset-normalizer=2.1.1=pypi_0 39 | click=8.1.3=pypi_0 40 | cloudpathlib=0.16.0=py38h06a4308_1 41 | colorama=0.4.6=py38h06a4308_0 42 | commonmark=0.9.1=pyhd3eb1b0_0 43 | compressed-tensors=0.9.2=pypi_0 44 | confection=0.1.4=py38h2f386ee_0 45 | contourpy=1.0.7=pypi_0 46 | cryptography=38.0.1=py38h9ce1e76_0 47 | cuda=11.6.2=0 48 | cuda-cccl=11.6.55=hf6102b2_0 49 | cuda-command-line-tools=11.6.2=0 50 | cuda-compiler=11.6.2=0 51 | cuda-cudart=11.8.89=0 52 | cuda-cudart-dev=11.6.55=h42ad0f4_0 53 | cuda-cuobjdump=11.6.124=h2eeebcb_0 54 | cuda-cupti=11.8.87=0 55 | cuda-cuxxfilt=11.6.124=hecbf4f6_0 56 | cuda-driver-dev=11.6.55=0 57 | cuda-gdb=12.5.39=0 58 | cuda-libraries=11.8.0=0 59 | cuda-libraries-dev=11.6.2=0 60 | cuda-memcheck=11.8.86=0 61 | cuda-nsight=12.5.39=0 62 | cuda-nsight-compute=11.8.0=0 63 | cuda-nvcc=11.6.124=hbba6d2d_0 64 | cuda-nvdisasm=12.5.39=0 65 | cuda-nvml-dev=11.6.55=haa9ef22_0 66 | cuda-nvprof=12.5.39=0 67 | cuda-nvprune=11.6.124=he22ec0a_0 68 | cuda-nvrtc=11.8.89=0 69 | cuda-nvrtc-dev=11.6.124=h249d397_0 70 | cuda-nvtx=11.8.86=0 71 | cuda-nvvp=12.5.39=0 72 | cuda-runtime=11.8.0=0 73 | cuda-samples=11.6.101=h8efea70_0 74 | cuda-sanitizer-api=12.5.39=0 75 | cuda-toolkit=11.6.2=0 76 | cuda-tools=11.6.2=0 77 | cuda-version=12.5=3 78 | cuda-visual-tools=11.6.2=0 79 | cycler=0.11.0=pypi_0 80 | cymem=2.0.6=py38h295c915_0 81 | cython-blis=0.7.9=py38h7deecbd_0 82 | dacite=1.8.1=pypi_0 83 | dataclasses=0.8=pyh6d0b6a4_7 84 | datasets=2.17.1=pypi_0 85 | debugpy=1.5.1=py38h295c915_0 86 | decorator=5.1.1=pyhd3eb1b0_0 87 | defusedxml=0.7.1=pyhd3eb1b0_0 88 | dill=0.3.6=pypi_0 89 | distro=1.9.0=pypi_0 90 | dm-tree=0.1.8=pypi_0 91 | easydict=1.13=pypi_0 92 | einops=0.3.2=pypi_0 93 | entrypoints=0.4=py38h06a4308_0 94 | et-xmlfile=1.1.0=pypi_0 95 | etils=1.3.0=pypi_0 96 | evaluate=0.3.0=pypi_0 97 | exceptiongroup=1.2.2=pypi_0 98 | executing=0.8.3=pyhd3eb1b0_0 99 | fairscale=0.4.13=pypi_0 100 | fancy-einsum=0.0.3=pypi_0 101 | ffmpeg=4.3=hf484d3e_0 102 | filelock=3.8.0=pypi_0 103 | fire=0.5.0=pypi_0 104 | flash-attn=2.6.1=pypi_0 105 | flatbuffers=2.0.7=pypi_0 106 | fonttools=4.39.4=pypi_0 107 | freetype=2.12.1=h4a9f257_0 108 | frozenlist=1.3.3=pypi_0 109 | fsspec=2023.10.0=pypi_0 110 | future=0.18.3=py38h06a4308_0 111 | gast=0.4.0=pypi_0 112 | gds-tools=1.4.0.31=0 113 | geotorch=0.3.0=pypi_0 114 | giflib=5.2.1=h7b6447c_0 115 | gin-config=0.5.0=pypi_0 116 | gmp=6.2.1=h295c915_3 117 | gmpy2=2.1.2=py38heeb90bb_0 118 | gnutls=3.6.15=he1e5248_0 119 | google-auth=2.18.0=pypi_0 120 | google-auth-oauthlib=1.0.0=pypi_0 121 | google-pasta=0.2.0=pypi_0 122 | googleapis-common-protos=1.59.0=pypi_0 123 | gpustat=1.1.1=pypi_0 124 | grpcio=1.54.0=pypi_0 125 | h11=0.14.0=pypi_0 126 | h5py=3.7.0=pypi_0 127 | hickle=5.0.2=pypi_0 128 | httpcore=1.0.7=pypi_0 129 | httpx=0.28.1=pypi_0 130 | huggingface-hub=0.25.0=pypi_0 131 | icu=58.2=he6710b0_3 132 | idna=3.4=py38h06a4308_0 133 | importlib-metadata=4.11.3=py38h06a4308_0 134 | importlib_resources=5.2.0=pyhd3eb1b0_1 135 | iniconfig=2.0.0=pypi_0 136 | intel-openmp=2021.4.0=h06a4308_3561 137 | ipdb=0.13.9=pypi_0 138 | ipykernel=6.15.2=py38h06a4308_0 139 | ipython=8.6.0=py38h06a4308_0 140 | ipython_genutils=0.2.0=pyhd3eb1b0_1 141 | jax=0.4.9=pypi_0 142 | jedi=0.18.1=py38h06a4308_1 143 | jinja2=3.1.2=py38h06a4308_0 144 | jiter=0.8.2=pypi_0 145 | joblib=1.2.0=pypi_0 146 | jpeg=9e=h7f8727e_0 147 | json5=0.9.6=pyhd3eb1b0_0 148 | jsonschema=4.16.0=py38h06a4308_0 149 | jupyter_client=7.4.7=py38h06a4308_0 150 | jupyter_core=4.11.2=py38h06a4308_0 151 | jupyter_server=1.18.1=py38h06a4308_0 152 | jupyterlab=3.4.4=py38h06a4308_0 153 | jupyterlab_pygments=0.1.2=py_0 154 | jupyterlab_server=2.15.2=py38h06a4308_0 155 | keras=2.7.0=pypi_0 156 | keras-preprocessing=1.1.2=pypi_0 157 | keyboard=0.13.5=pypi_0 158 | kiwisolver=1.4.4=pypi_0 159 | lame=3.100=h7b6447c_0 160 | langcodes=3.3.0=pyhd3eb1b0_0 161 | lcms2=2.12=h3be6417_0 162 | ld_impl_linux-64=2.38=h1181459_1 163 | lerc=3.0=h295c915_0 164 | libclang=16.0.0=pypi_0 165 | libcublas=11.11.3.6=0 166 | libcublas-dev=11.11.3.6=0 167 | libcufft=10.9.0.58=0 168 | libcufft-dev=10.9.0.58=0 169 | libcufile=1.4.0.31=0 170 | libcufile-dev=1.4.0.31=0 171 | libcurand=10.3.0.86=0 172 | libcurand-dev=10.3.0.86=0 173 | libcusolver=11.4.1.48=0 174 | libcusolver-dev=11.4.1.48=0 175 | libcusparse=11.7.5.86=0 176 | libcusparse-dev=11.7.5.86=0 177 | libdeflate=1.8=h7f8727e_5 178 | libffi=3.4.2=h295c915_4 179 | libgcc-ng=11.2.0=h1234567_1 180 | libgomp=11.2.0=h1234567_1 181 | libiconv=1.16=h7f8727e_2 182 | libidn2=2.3.2=h7f8727e_0 183 | libjpeg-turbo=2.0.0=h9bf148f_0 184 | libnpp=11.8.0.86=0 185 | libnpp-dev=11.8.0.86=0 186 | libnvjpeg=11.9.0.86=0 187 | libnvjpeg-dev=11.9.0.86=0 188 | libpng=1.6.37=hbc83047_0 189 | libsodium=1.0.18=h7b6447c_0 190 | libstdcxx-ng=11.2.0=h1234567_1 191 | libtasn1=4.16.0=h27cfd23_0 192 | libtiff=4.4.0=hecacb30_2 193 | libunistring=0.9.10=h27cfd23_0 194 | libwebp=1.2.4=h11a3e52_0 195 | libwebp-base=1.2.4=h5eee18b_0 196 | libxml2=2.9.14=h74e7548_0 197 | libxslt=1.1.35=h4e12654_0 198 | llmcompressor=0.4.1=pypi_0 199 | llvm-openmp=14.0.6=h9e868ea_0 200 | llvmlite=0.39.1=pypi_0 201 | loguru=0.7.3=pypi_0 202 | loralib=0.1.2=pypi_0 203 | lxml=4.9.1=py38h1edc446_0 204 | lz4-c=1.9.3=h295c915_1 205 | markdown=3.4.3=pypi_0 206 | markupsafe=2.1.1=py38h7f8727e_0 207 | matplotlib=3.7.1=pypi_0 208 | matplotlib-inline=0.1.6=py38h06a4308_0 209 | mesh-tensorflow=0.1.21=pypi_0 210 | mistune=0.8.4=py38h7b6447c_1000 211 | mkl=2021.4.0=h06a4308_640 212 | mkl-service=2.4.0=py38h7f8727e_0 213 | mkl_fft=1.3.1=py38hd3c417c_0 214 | mkl_random=1.2.2=py38h51133e4_0 215 | ml-dtypes=0.1.0=pypi_0 216 | mpc=1.1.0=h10f8cd9_1 217 | mpfr=4.0.2=hb69a4c5_1 218 | mpmath=1.3.0=py38h06a4308_0 219 | multidict=6.0.3=pypi_0 220 | multiprocess=0.70.14=pypi_0 221 | murmurhash=1.0.7=py38h295c915_0 222 | nbclassic=0.4.8=py38h06a4308_0 223 | nbclient=0.5.13=py38h06a4308_0 224 | nbconvert=6.5.4=py38h06a4308_0 225 | nbformat=5.5.0=py38h06a4308_0 226 | ncurses=6.3=h5eee18b_3 227 | nest-asyncio=1.5.5=py38h06a4308_0 228 | nettle=3.7.3=hbbd107a_1 229 | networkx=3.1=py38h06a4308_0 230 | ninja=1.11.1.2=pypi_0 231 | nltk=3.8.1=pypi_0 232 | notebook=6.5.2=py38h06a4308_0 233 | notebook-shim=0.2.2=py38h06a4308_0 234 | nsight-compute=2022.3.0.22=0 235 | numba=0.56.4=pypi_0 236 | numpy=1.22.0=pypi_0 237 | nvidia-ml-py=12.560.30=pypi_0 238 | oauthlib=3.2.2=pypi_0 239 | openai=1.59.3=pypi_0 240 | openh264=2.1.1=h4ff587b_0 241 | openpyxl=3.0.10=pypi_0 242 | openssl=1.1.1w=h7f8727e_0 243 | opt-einsum=3.3.0=pypi_0 244 | packaging=21.3=pyhd3eb1b0_0 245 | pandas=1.3.2=pypi_0 246 | pandas-stubs=1.5.1.221024=pypi_0 247 | pandocfilters=1.5.0=pyhd3eb1b0_0 248 | parallelformers=1.2.7=pypi_0 249 | parso=0.8.3=pyhd3eb1b0_0 250 | peft=0.13.2=pypi_0 251 | pexpect=4.8.0=pyhd3eb1b0_3 252 | pickleshare=0.7.5=pyhd3eb1b0_1003 253 | pillow=9.2.0=py38hace64e9_1 254 | pip=22.2.2=py38h06a4308_0 255 | pkgutil-resolve-name=1.3.10=py38h06a4308_0 256 | plotly=5.14.1=pypi_0 257 | pluggy=1.5.0=pypi_0 258 | portalocker=2.7.0=pypi_0 259 | preshed=3.0.6=py38h295c915_0 260 | prometheus_client=0.14.1=py38h06a4308_0 261 | promise=2.3=pypi_0 262 | prompt-toolkit=3.0.20=pyhd3eb1b0_0 263 | protobuf=3.19.6=pypi_0 264 | psutil=5.9.0=py38h5eee18b_0 265 | ptyprocess=0.7.0=pyhd3eb1b0_2 266 | pure_eval=0.2.2=pyhd3eb1b0_0 267 | pyarrow=17.0.0=pypi_0 268 | pyarrow-hotfix=0.6=pypi_0 269 | pyasn1=0.5.0=pypi_0 270 | pyasn1-modules=0.3.0=pypi_0 271 | pycparser=2.21=pyhd3eb1b0_0 272 | pydantic=2.10.6=pypi_0 273 | pydantic-core=2.27.2=pypi_0 274 | pygments=2.11.2=pyhd3eb1b0_0 275 | pynndescent=0.5.8=pypi_0 276 | pynvml=11.5.3=pypi_0 277 | pyopenssl=22.0.0=pyhd3eb1b0_0 278 | pyparsing=3.0.9=py38h06a4308_0 279 | pyrsistent=0.18.0=py38heee7806_0 280 | pysocks=1.7.1=py38h06a4308_0 281 | pytest=8.3.4=pypi_0 282 | python=3.8.15=h3fd9d12_0 283 | python-dateutil=2.8.2=pyhd3eb1b0_0 284 | python-fastjsonschema=2.16.2=py38h06a4308_0 285 | pytorch=2.3.1=py3.8_cuda11.8_cudnn8.7.0_0 286 | pytorch-cuda=11.8=h7e8668a_5 287 | pytorch-mutex=1.0=cuda 288 | pytz=2022.1=py38h06a4308_0 289 | pyyaml=6.0=pypi_0 290 | pyzmq=23.2.0=py38h6a678d5_0 291 | readline=8.2=h5eee18b_0 292 | regex=2022.10.31=pypi_0 293 | requests=2.28.1=py38h06a4308_0 294 | requests-oauthlib=1.3.1=pypi_0 295 | responses=0.18.0=pypi_0 296 | rich=12.5.1=py38h06a4308_0 297 | rouge-score=0.1.2=pypi_0 298 | rsa=4.9=pypi_0 299 | sacrebleu=2.3.1=pypi_0 300 | sacremoses=0.0.53=pypi_0 301 | safetensors=0.4.5=pypi_0 302 | scikit-learn=1.2.2=pypi_0 303 | scipy=1.10.1=pypi_0 304 | seaborn=0.12.2=pypi_0 305 | send2trash=1.8.0=pyhd3eb1b0_1 306 | sentencepiece=0.2.0=pypi_0 307 | setuptools=65.5.0=py38h06a4308_0 308 | shellingham=1.5.0=py38h06a4308_0 309 | six=1.16.0=pyhd3eb1b0_1 310 | smart_open=5.2.1=py38h06a4308_0 311 | sniffio=1.2.0=py38h06a4308_1 312 | soupsieve=2.3.2.post1=py38h06a4308_0 313 | spacy=3.7.2=py38h3c18c91_0 314 | spacy-legacy=3.0.12=py38h06a4308_0 315 | spacy-loggers=1.0.4=py38h06a4308_0 316 | sqlite=3.39.3=h5082296_0 317 | srsly=2.4.8=py38h6a678d5_1 318 | stack_data=0.2.0=pyhd3eb1b0_0 319 | sympy=1.13.2=py38h06a4308_0 320 | t5=0.7.1=pypi_0 321 | tabulate=0.9.0=pypi_0 322 | tenacity=8.2.2=pypi_0 323 | tensorboard=2.12.3=pypi_0 324 | tensorboard-data-server=0.7.0=pypi_0 325 | tensorflow=2.7.4=pypi_0 326 | tensorflow-datasets=4.9.0=pypi_0 327 | tensorflow-estimator=2.7.0=pypi_0 328 | tensorflow-hub=0.13.0=pypi_0 329 | tensorflow-io-gcs-filesystem=0.32.0=pypi_0 330 | tensorflow-metadata=1.13.0=pypi_0 331 | tensorflow-text=2.7.3=pypi_0 332 | termcolor=2.2.0=pypi_0 333 | terminado=0.13.1=py38h06a4308_0 334 | tf-slim=1.1.0=pypi_0 335 | tfds-nightly=4.9.0.dev202304110044=pypi_0 336 | thinc=8.2.2=py38h3c18c91_0 337 | threadpoolctl=3.1.0=pypi_0 338 | tiktoken=0.7.0=pypi_0 339 | tinycss2=1.2.1=py38h06a4308_0 340 | tk=8.6.12=h1ccaba5_0 341 | tokenizers=0.19.1=pypi_0 342 | toml=0.10.2=pypi_0 343 | tomli=2.2.1=pypi_0 344 | torchaudio=2.3.1=py38_cu118 345 | torchtriton=2.3.1=py38 346 | torchvision=0.18.1=py38_cu118 347 | tornado=6.2=py38h5eee18b_0 348 | tqdm=4.64.1=pypi_0 349 | traitlets=5.1.1=pyhd3eb1b0_0 350 | transformers=4.43.1=pypi_0 351 | truthfulqa=0.0.1=pypi_0 352 | tsne-torch=1.0.1=pypi_0 353 | typer=0.9.0=py38h06a4308_0 354 | types-pytz=2022.6.0.1=pypi_0 355 | typing-extensions=4.12.2=pypi_0 356 | tzdata=2023.3=pypi_0 357 | umap-learn=0.5.3=pypi_0 358 | urllib3=1.26.12=py38h06a4308_0 359 | wasabi=0.9.1=py38h06a4308_0 360 | wcwidth=0.2.5=pyhd3eb1b0_0 361 | weasel=0.3.4=py38h06a4308_0 362 | webencodings=0.5.1=py38_1 363 | websocket-client=0.58.0=py38h06a4308_4 364 | werkzeug=2.3.4=pypi_0 365 | wheel=0.37.1=pyhd3eb1b0_0 366 | wrapt=1.14.1=pypi_0 367 | xxhash=3.1.0=pypi_0 368 | xz=5.2.6=h5eee18b_0 369 | yaml=0.2.5=h7b6447c_0 370 | yarl=1.8.2=pypi_0 371 | zeromq=4.3.4=h2531618_0 372 | zipp=3.8.0=py38h06a4308_0 373 | zlib=1.2.13=h5eee18b_0 374 | zstd=1.5.2=ha4553b6_0 375 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /tsv.yml: -------------------------------------------------------------------------------- 1 | name: tsv 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - defaults 6 | - https://repo.anaconda.com/pkgs/main 7 | - https://repo.anaconda.com/pkgs/r 8 | dependencies: 9 | - _libgcc_mutex=0.1=main 10 | - _openmp_mutex=5.1=1_gnu 11 | - anyio=3.5.0=py38h06a4308_0 12 | - argon2-cffi=21.3.0=pyhd3eb1b0_0 13 | - argon2-cffi-bindings=21.2.0=py38h7f8727e_0 14 | - asttokens=2.0.5=pyhd3eb1b0_0 15 | - attrs=22.1.0=py38h06a4308_0 16 | - babel=2.9.1=pyhd3eb1b0_0 17 | - backcall=0.2.0=pyhd3eb1b0_0 18 | - beautifulsoup4=4.11.1=py38h06a4308_0 19 | - blas=1.0=mkl 20 | - bleach=4.1.0=pyhd3eb1b0_0 21 | - brotlipy=0.7.0=py38h27cfd23_1003 22 | - bzip2=1.0.8=h7b6447c_0 23 | - ca-certificates=2024.11.26=h06a4308_0 24 | - catalogue=2.0.10=py38h06a4308_0 25 | - certifi=2024.8.30=py38h06a4308_0 26 | - cffi=1.15.1=py38h74dc2b5_0 27 | - cloudpathlib=0.16.0=py38h06a4308_1 28 | - colorama=0.4.6=py38h06a4308_0 29 | - commonmark=0.9.1=pyhd3eb1b0_0 30 | - confection=0.1.4=py38h2f386ee_0 31 | - cryptography=38.0.1=py38h9ce1e76_0 32 | - cuda=11.6.2=0 33 | - cuda-cccl=11.6.55=hf6102b2_0 34 | - cuda-command-line-tools=11.6.2=0 35 | - cuda-compiler=11.6.2=0 36 | - cuda-cudart=11.8.89=0 37 | - cuda-cudart-dev=11.6.55=h42ad0f4_0 38 | - cuda-cuobjdump=11.6.124=h2eeebcb_0 39 | - cuda-cupti=11.8.87=0 40 | - cuda-cuxxfilt=11.6.124=hecbf4f6_0 41 | - cuda-driver-dev=11.6.55=0 42 | - cuda-gdb=12.5.39=0 43 | - cuda-libraries=11.8.0=0 44 | - cuda-libraries-dev=11.6.2=0 45 | - cuda-memcheck=11.8.86=0 46 | - cuda-nsight=12.5.39=0 47 | - cuda-nsight-compute=11.8.0=0 48 | - cuda-nvcc=11.6.124=hbba6d2d_0 49 | - cuda-nvdisasm=12.5.39=0 50 | - cuda-nvml-dev=11.6.55=haa9ef22_0 51 | - cuda-nvprof=12.5.39=0 52 | - cuda-nvprune=11.6.124=he22ec0a_0 53 | - cuda-nvrtc=11.8.89=0 54 | - cuda-nvrtc-dev=11.6.124=h249d397_0 55 | - cuda-nvtx=11.8.86=0 56 | - cuda-nvvp=12.5.39=0 57 | - cuda-runtime=11.8.0=0 58 | - cuda-samples=11.6.101=h8efea70_0 59 | - cuda-sanitizer-api=12.5.39=0 60 | - cuda-toolkit=11.6.2=0 61 | - cuda-tools=11.6.2=0 62 | - cuda-version=12.5=3 63 | - cuda-visual-tools=11.6.2=0 64 | - cymem=2.0.6=py38h295c915_0 65 | - cython-blis=0.7.9=py38h7deecbd_0 66 | - dataclasses=0.8=pyh6d0b6a4_7 67 | - debugpy=1.5.1=py38h295c915_0 68 | - decorator=5.1.1=pyhd3eb1b0_0 69 | - defusedxml=0.7.1=pyhd3eb1b0_0 70 | - entrypoints=0.4=py38h06a4308_0 71 | - executing=0.8.3=pyhd3eb1b0_0 72 | - ffmpeg=4.3=hf484d3e_0 73 | - freetype=2.12.1=h4a9f257_0 74 | - future=0.18.3=py38h06a4308_0 75 | - gds-tools=1.4.0.31=0 76 | - giflib=5.2.1=h7b6447c_0 77 | - gmp=6.2.1=h295c915_3 78 | - gmpy2=2.1.2=py38heeb90bb_0 79 | - gnutls=3.6.15=he1e5248_0 80 | - icu=58.2=he6710b0_3 81 | - idna=3.4=py38h06a4308_0 82 | - importlib-metadata=4.11.3=py38h06a4308_0 83 | - importlib_resources=5.2.0=pyhd3eb1b0_1 84 | - intel-openmp=2021.4.0=h06a4308_3561 85 | - ipykernel=6.15.2=py38h06a4308_0 86 | - ipython=8.6.0=py38h06a4308_0 87 | - ipython_genutils=0.2.0=pyhd3eb1b0_1 88 | - jedi=0.18.1=py38h06a4308_1 89 | - jinja2=3.1.2=py38h06a4308_0 90 | - jpeg=9e=h7f8727e_0 91 | - json5=0.9.6=pyhd3eb1b0_0 92 | - jsonschema=4.16.0=py38h06a4308_0 93 | - jupyter_client=7.4.7=py38h06a4308_0 94 | - jupyter_core=4.11.2=py38h06a4308_0 95 | - jupyter_server=1.18.1=py38h06a4308_0 96 | - jupyterlab=3.4.4=py38h06a4308_0 97 | - jupyterlab_pygments=0.1.2=py_0 98 | - jupyterlab_server=2.15.2=py38h06a4308_0 99 | - lame=3.100=h7b6447c_0 100 | - langcodes=3.3.0=pyhd3eb1b0_0 101 | - lcms2=2.12=h3be6417_0 102 | - ld_impl_linux-64=2.38=h1181459_1 103 | - lerc=3.0=h295c915_0 104 | - libcublas=11.11.3.6=0 105 | - libcublas-dev=11.11.3.6=0 106 | - libcufft=10.9.0.58=0 107 | - libcufft-dev=10.9.0.58=0 108 | - libcufile=1.4.0.31=0 109 | - libcufile-dev=1.4.0.31=0 110 | - libcurand=10.3.0.86=0 111 | - libcurand-dev=10.3.0.86=0 112 | - libcusolver=11.4.1.48=0 113 | - libcusolver-dev=11.4.1.48=0 114 | - libcusparse=11.7.5.86=0 115 | - libcusparse-dev=11.7.5.86=0 116 | - libdeflate=1.8=h7f8727e_5 117 | - libffi=3.4.2=h295c915_4 118 | - libgcc-ng=11.2.0=h1234567_1 119 | - libgomp=11.2.0=h1234567_1 120 | - libiconv=1.16=h7f8727e_2 121 | - libidn2=2.3.2=h7f8727e_0 122 | - libjpeg-turbo=2.0.0=h9bf148f_0 123 | - libnpp=11.8.0.86=0 124 | - libnpp-dev=11.8.0.86=0 125 | - libnvjpeg=11.9.0.86=0 126 | - libnvjpeg-dev=11.9.0.86=0 127 | - libpng=1.6.37=hbc83047_0 128 | - libsodium=1.0.18=h7b6447c_0 129 | - libstdcxx-ng=11.2.0=h1234567_1 130 | - libtasn1=4.16.0=h27cfd23_0 131 | - libtiff=4.4.0=hecacb30_2 132 | - libunistring=0.9.10=h27cfd23_0 133 | - libwebp=1.2.4=h11a3e52_0 134 | - libwebp-base=1.2.4=h5eee18b_0 135 | - libxml2=2.9.14=h74e7548_0 136 | - libxslt=1.1.35=h4e12654_0 137 | - llvm-openmp=14.0.6=h9e868ea_0 138 | - lxml=4.9.1=py38h1edc446_0 139 | - lz4-c=1.9.3=h295c915_1 140 | - markupsafe=2.1.1=py38h7f8727e_0 141 | - matplotlib-inline=0.1.6=py38h06a4308_0 142 | - mistune=0.8.4=py38h7b6447c_1000 143 | - mkl=2021.4.0=h06a4308_640 144 | - mkl-service=2.4.0=py38h7f8727e_0 145 | - mkl_fft=1.3.1=py38hd3c417c_0 146 | - mkl_random=1.2.2=py38h51133e4_0 147 | - mpc=1.1.0=h10f8cd9_1 148 | - mpfr=4.0.2=hb69a4c5_1 149 | - mpmath=1.3.0=py38h06a4308_0 150 | - murmurhash=1.0.7=py38h295c915_0 151 | - nbclassic=0.4.8=py38h06a4308_0 152 | - nbclient=0.5.13=py38h06a4308_0 153 | - nbconvert=6.5.4=py38h06a4308_0 154 | - nbformat=5.5.0=py38h06a4308_0 155 | - ncurses=6.3=h5eee18b_3 156 | - nest-asyncio=1.5.5=py38h06a4308_0 157 | - nettle=3.7.3=hbbd107a_1 158 | - networkx=3.1=py38h06a4308_0 159 | - notebook=6.5.2=py38h06a4308_0 160 | - notebook-shim=0.2.2=py38h06a4308_0 161 | - nsight-compute=2022.3.0.22=0 162 | - openh264=2.1.1=h4ff587b_0 163 | - openssl=1.1.1w=h7f8727e_0 164 | - packaging=21.3=pyhd3eb1b0_0 165 | - pandocfilters=1.5.0=pyhd3eb1b0_0 166 | - parso=0.8.3=pyhd3eb1b0_0 167 | - pexpect=4.8.0=pyhd3eb1b0_3 168 | - pickleshare=0.7.5=pyhd3eb1b0_1003 169 | - pillow=9.2.0=py38hace64e9_1 170 | - pip=22.2.2=py38h06a4308_0 171 | - pkgutil-resolve-name=1.3.10=py38h06a4308_0 172 | - preshed=3.0.6=py38h295c915_0 173 | - prometheus_client=0.14.1=py38h06a4308_0 174 | - prompt-toolkit=3.0.20=pyhd3eb1b0_0 175 | - psutil=5.9.0=py38h5eee18b_0 176 | - ptyprocess=0.7.0=pyhd3eb1b0_2 177 | - pure_eval=0.2.2=pyhd3eb1b0_0 178 | - pycparser=2.21=pyhd3eb1b0_0 179 | - pygments=2.11.2=pyhd3eb1b0_0 180 | - pyopenssl=22.0.0=pyhd3eb1b0_0 181 | - pyparsing=3.0.9=py38h06a4308_0 182 | - pyrsistent=0.18.0=py38heee7806_0 183 | - pysocks=1.7.1=py38h06a4308_0 184 | - python=3.8.15=h3fd9d12_0 185 | - python-dateutil=2.8.2=pyhd3eb1b0_0 186 | - python-fastjsonschema=2.16.2=py38h06a4308_0 187 | - pytorch=2.3.1=py3.8_cuda11.8_cudnn8.7.0_0 188 | - pytorch-cuda=11.8=h7e8668a_5 189 | - pytorch-mutex=1.0=cuda 190 | - pytz=2022.1=py38h06a4308_0 191 | - pyzmq=23.2.0=py38h6a678d5_0 192 | - readline=8.2=h5eee18b_0 193 | - requests=2.28.1=py38h06a4308_0 194 | - rich=12.5.1=py38h06a4308_0 195 | - send2trash=1.8.0=pyhd3eb1b0_1 196 | - setuptools=65.5.0=py38h06a4308_0 197 | - shellingham=1.5.0=py38h06a4308_0 198 | - six=1.16.0=pyhd3eb1b0_1 199 | - smart_open=5.2.1=py38h06a4308_0 200 | - sniffio=1.2.0=py38h06a4308_1 201 | - soupsieve=2.3.2.post1=py38h06a4308_0 202 | - spacy=3.7.2=py38h3c18c91_0 203 | - spacy-legacy=3.0.12=py38h06a4308_0 204 | - spacy-loggers=1.0.4=py38h06a4308_0 205 | - sqlite=3.39.3=h5082296_0 206 | - srsly=2.4.8=py38h6a678d5_1 207 | - stack_data=0.2.0=pyhd3eb1b0_0 208 | - sympy=1.13.2=py38h06a4308_0 209 | - terminado=0.13.1=py38h06a4308_0 210 | - thinc=8.2.2=py38h3c18c91_0 211 | - tinycss2=1.2.1=py38h06a4308_0 212 | - tk=8.6.12=h1ccaba5_0 213 | - torchaudio=2.3.1=py38_cu118 214 | - torchtriton=2.3.1=py38 215 | - torchvision=0.18.1=py38_cu118 216 | - tornado=6.2=py38h5eee18b_0 217 | - traitlets=5.1.1=pyhd3eb1b0_0 218 | - typer=0.9.0=py38h06a4308_0 219 | - urllib3=1.26.12=py38h06a4308_0 220 | - wasabi=0.9.1=py38h06a4308_0 221 | - wcwidth=0.2.5=pyhd3eb1b0_0 222 | - weasel=0.3.4=py38h06a4308_0 223 | - webencodings=0.5.1=py38_1 224 | - websocket-client=0.58.0=py38h06a4308_4 225 | - wheel=0.37.1=pyhd3eb1b0_0 226 | - xz=5.2.6=h5eee18b_0 227 | - yaml=0.2.5=h7b6447c_0 228 | - zeromq=4.3.4=h2531618_0 229 | - zipp=3.8.0=py38h06a4308_0 230 | - zlib=1.2.13=h5eee18b_0 231 | - zstd=1.5.2=ha4553b6_0 232 | - pip: 233 | - absl-py==1.4.0 234 | - accelerate==0.21.0 235 | - aiohttp==3.8.3 236 | - aiosignal==1.3.1 237 | - annotated-types==0.7.0 238 | - array-record==0.2.0 239 | - astunparse==1.6.3 240 | - async-timeout==4.0.2 241 | - baukit==0.0.1 242 | - bert-score==0.3.13 243 | - bitsandbytes==0.45.4 244 | - blessed==1.20.0 245 | - bleurt==0.0.2 246 | - bleurt-pytorch==0.0.1 247 | - cachetools==5.3.0 248 | - charset-normalizer==2.1.1 249 | - click==8.1.3 250 | - compressed-tensors==0.9.2 251 | - contourpy==1.0.7 252 | - cycler==0.11.0 253 | - dacite==1.8.1 254 | - datasets==2.17.1 255 | - dill==0.3.6 256 | - distro==1.9.0 257 | - dm-tree==0.1.8 258 | - easydict==1.13 259 | - einops==0.3.2 260 | - et-xmlfile==1.1.0 261 | - etils==1.3.0 262 | - evaluate==0.3.0 263 | - exceptiongroup==1.2.2 264 | - fairscale==0.4.13 265 | - fancy-einsum==0.0.3 266 | - filelock==3.8.0 267 | - fire==0.5.0 268 | - flash-attn==2.6.1 269 | - flatbuffers==2.0.7 270 | - fonttools==4.39.4 271 | - frozenlist==1.3.3 272 | - fsspec==2023.10.0 273 | - gast==0.4.0 274 | - geotorch==0.3.0 275 | - gin-config==0.5.0 276 | - google-auth==2.18.0 277 | - google-auth-oauthlib==1.0.0 278 | - google-pasta==0.2.0 279 | - googleapis-common-protos==1.59.0 280 | - gpustat==1.1.1 281 | - grpcio==1.54.0 282 | - h11==0.14.0 283 | - h5py==3.7.0 284 | - hickle==5.0.2 285 | - httpcore==1.0.7 286 | - httpx==0.28.1 287 | - huggingface-hub==0.25.0 288 | - iniconfig==2.0.0 289 | - ipdb==0.13.9 290 | - jax==0.4.9 291 | - jiter==0.8.2 292 | - joblib==1.2.0 293 | - keras==2.7.0 294 | - keras-preprocessing==1.1.2 295 | - keyboard==0.13.5 296 | - kiwisolver==1.4.4 297 | - libclang==16.0.0 298 | - llmcompressor==0.4.1 299 | - llvmlite==0.39.1 300 | - loguru==0.7.3 301 | - loralib==0.1.2 302 | - markdown==3.4.3 303 | - matplotlib==3.7.1 304 | - mesh-tensorflow==0.1.21 305 | - ml-dtypes==0.1.0 306 | - multidict==6.0.3 307 | - multiprocess==0.70.14 308 | - ninja==1.11.1.2 309 | - nltk==3.8.1 310 | - numba==0.56.4 311 | - numpy==1.22.0 312 | - nvidia-ml-py==12.560.30 313 | - oauthlib==3.2.2 314 | - openai==1.59.3 315 | - openpyxl==3.0.10 316 | - opt-einsum==3.3.0 317 | - pandas==1.3.2 318 | - pandas-stubs==1.5.1.221024 319 | - parallelformers==1.2.7 320 | - peft==0.13.2 321 | - plotly==5.14.1 322 | - pluggy==1.5.0 323 | - portalocker==2.7.0 324 | - promise==2.3 325 | - protobuf==3.19.6 326 | - pyarrow==17.0.0 327 | - pyarrow-hotfix==0.6 328 | - pyasn1==0.5.0 329 | - pyasn1-modules==0.3.0 330 | - pydantic==2.10.6 331 | - pydantic-core==2.27.2 332 | - pynndescent==0.5.8 333 | - pynvml==11.5.3 334 | - pytest==8.3.4 335 | - pyyaml==6.0 336 | - regex==2022.10.31 337 | - requests-oauthlib==1.3.1 338 | - responses==0.18.0 339 | - rouge-score==0.1.2 340 | - rsa==4.9 341 | - sacrebleu==2.3.1 342 | - sacremoses==0.0.53 343 | - safetensors==0.4.5 344 | - scikit-learn==1.2.2 345 | - scipy==1.10.1 346 | - seaborn==0.12.2 347 | - sentencepiece==0.2.0 348 | - t5==0.7.1 349 | - tabulate==0.9.0 350 | - tenacity==8.2.2 351 | - tensorboard==2.12.3 352 | - tensorboard-data-server==0.7.0 353 | - tensorflow==2.7.4 354 | - tensorflow-datasets==4.9.0 355 | - tensorflow-estimator==2.7.0 356 | - tensorflow-hub==0.13.0 357 | - tensorflow-io-gcs-filesystem==0.32.0 358 | - tensorflow-metadata==1.13.0 359 | - tensorflow-text==2.7.3 360 | - termcolor==2.2.0 361 | - tf-slim==1.1.0 362 | - tfds-nightly==4.9.0.dev202304110044 363 | - threadpoolctl==3.1.0 364 | - tiktoken==0.7.0 365 | - tokenizers==0.19.1 366 | - toml==0.10.2 367 | - tomli==2.2.1 368 | - tqdm==4.64.1 369 | - transformers==4.43.1 370 | - truthfulqa==0.0.1 371 | - tsne-torch==1.0.1 372 | - types-pytz==2022.6.0.1 373 | - typing-extensions==4.12.2 374 | - tzdata==2023.3 375 | - umap-learn==0.5.3 376 | - werkzeug==2.3.4 377 | - wrapt==1.14.1 378 | - xxhash==3.1.0 379 | - yarl==1.8.2 380 | 381 | -------------------------------------------------------------------------------- /tsv_main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | from datasets import load_dataset 5 | from tqdm import tqdm 6 | import numpy as np 7 | import argparse 8 | from train_utils import get_last_non_padded_token_rep, compute_ot_loss_cos, update_centroids_ema, update_centroids_ema_hard, get_ex_data, collate_fn 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llm_layers import add_tsv_layers 11 | from sklearn.metrics import roc_auc_score 12 | from torch.cuda.amp import autocast, GradScaler 13 | import torch.nn.functional as F 14 | from sinkhorn_knopp import SinkhornKnopp_imb 15 | import logging 16 | 17 | 18 | def seed_everything(seed: int): 19 | import random, os 20 | import numpy as np 21 | import torch 22 | 23 | random.seed(seed) 24 | os.environ['PYTHONHASHSEED'] = str(seed) 25 | np.random.seed(seed) 26 | torch.manual_seed(seed) 27 | torch.cuda.manual_seed(seed) 28 | torch.backends.cudnn.deterministic = True 29 | torch.backends.cudnn.benchmark = True 30 | 31 | 32 | def train_model(model, optimizer, device, prompts, labels, args): 33 | 34 | layer_number = -1 35 | dir_name = f"TSV_{args.model_name}_{args.dataset_name}/exemplar_num_{args.num_exemplars}_num_selected_data_{args.num_selected_data}/{args.component}/{args.str_layer}/{args.lam}" 36 | log_dir = f"/{dir_name}/" 37 | log_file = os.path.join(log_dir, f"log.txt") 38 | os.makedirs(dir_name,exist_ok=True) 39 | 40 | logging.basicConfig( 41 | filename=log_file, 42 | filemode="w", 43 | level=logging.INFO, 44 | format="%(asctime)s - %(levelname)s - %(message)s",) 45 | 46 | logging.info("Starting training") 47 | logging.info(f"Training parameters: few_shot_size={args.num_exemplars}, num_selected_data={args.num_selected_data}, component={args.component}, str_layer={args.str_layer}") 48 | 49 | test_prompts, train_prompts, exemplar_prompts = prompts[0], prompts[1], prompts[2] 50 | test_labels, train_labels, exemplar_labels = labels[0], labels[1], labels[2] 51 | batch_size = args.batch_size 52 | num_samples = len(train_prompts) 53 | 54 | losses = [] 55 | best_test_auroc = -1 56 | 57 | scaler = GradScaler() 58 | 59 | num_exemplars = args.num_exemplars 60 | 61 | # Initialize Sinkhorn algorithm 62 | args.num_iters_sk = 3 63 | args.epsilon_sk = 0.05 64 | 65 | ex_hallu = (num_exemplars-exemplar_labels[:num_exemplars].sum())/num_exemplars 66 | ex_true = (exemplar_labels[:num_exemplars].sum())/num_exemplars 67 | cls_dist = torch.tensor([ex_hallu,ex_true]).float().cuda() 68 | cls_dist = cls_dist.view(-1, 1) 69 | sinkhorn = SinkhornKnopp_imb(args, cls_dist) 70 | 71 | # Initialize Centroids 72 | centroids = torch.randn((2, model.config.hidden_size)).half().cuda() 73 | centroids = F.normalize(centroids, p=2, dim=1) 74 | 75 | exemplar_prompts_, exemplar_labels_ = exemplar_prompts, exemplar_labels 76 | exemplar_prompts, exemplar_labels = collate_fn(exemplar_prompts, exemplar_labels) 77 | 78 | num_epochs = args.init_num_epochs 79 | 80 | for epoch in range(num_epochs): 81 | running_loss = 0.0 82 | total = 0 83 | all_labels = [] 84 | num_samples = num_exemplars 85 | 86 | # Process data in batches 87 | for batch_start in tqdm(range(0, num_samples, batch_size), desc=f"Epoch {epoch+1}/{num_epochs} Batches", leave=False): 88 | 89 | batch_prompts = exemplar_prompts[batch_start: batch_start + batch_size] 90 | batch_labels = exemplar_labels[batch_start: batch_start + batch_size] 91 | 92 | # Create attention masks (1 for real tokens, 0 for padding) 93 | attention_mask = (batch_prompts != 0).half() 94 | 95 | batch_prompts = batch_prompts.to(device) 96 | batch_labels = batch_labels.to(device) 97 | attention_mask = attention_mask.to(batch_prompts.device) 98 | 99 | # Forward pass 100 | with autocast(dtype=torch.float16): 101 | 102 | output = model(batch_prompts.squeeze(), attention_mask=attention_mask.squeeze(), output_hidden_states=True) 103 | 104 | hidden_states = output.hidden_states 105 | 106 | hidden_states = torch.stack(hidden_states, dim=0).squeeze() 107 | 108 | last_layer_hidden_state = hidden_states[layer_number] # Shape: [batch_size, max_seq_len, hidden_size] 109 | 110 | # Use attention mask to ignore padding tokens, and get the last non-padded token's representation 111 | last_token_rep = get_last_non_padded_token_rep(last_layer_hidden_state, attention_mask.squeeze()) 112 | 113 | batch_labels_oh = torch.nn.functional.one_hot(batch_labels, num_classes=-1) 114 | 115 | ot_loss, similarities = compute_ot_loss_cos(last_token_rep, centroids, batch_labels_oh, batch_size, args) 116 | 117 | loss = ot_loss 118 | 119 | total += batch_labels.size(0) 120 | 121 | with torch.no_grad(): 122 | centroids = update_centroids_ema_hard(centroids, last_token_rep, batch_labels_oh, args) 123 | 124 | # loss.backward() 125 | scaler.scale(loss).backward() 126 | scaler.step(optimizer) 127 | scaler.update() 128 | optimizer.zero_grad() 129 | 130 | running_loss += loss.item() * batch_labels.size(0) 131 | 132 | # Epoch summary 133 | epoch_loss = running_loss / total 134 | 135 | if (epoch+1) % 1 == 0: 136 | 137 | test_labels_ = test_labels 138 | test_predictions, test_labels_combined = test_model(model, centroids, test_prompts, test_labels_, device, batch_size, layer_number) 139 | 140 | test_auroc = roc_auc_score(test_labels_combined.cpu().numpy(), test_predictions.cpu().numpy()) 141 | 142 | print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}") 143 | logging.info(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}") 144 | losses.append(epoch_loss) 145 | 146 | # AUROC Calculation using sklearn 147 | test_predictions = test_predictions.cpu().numpy() 148 | test_labels_combined = test_labels_combined.cpu().numpy() 149 | 150 | if test_auroc > best_test_auroc: 151 | best_test_auroc = test_auroc 152 | best_test_epoch = epoch 153 | print(f"Best test AUROC: {best_test_auroc:.4f}, at epoch: {best_test_epoch }") 154 | logging.info(f"Best test AUROC: {best_test_auroc:.4f}, at epoch: {best_test_epoch }") 155 | 156 | logging.info( 157 | f"Epoch [{epoch+1}/{num_epochs}], " 158 | f"Train Loss: {epoch_loss:.4f}, ") 159 | 160 | logging.info(f"Test AUROC: {test_auroc:.4f}") 161 | print(f"Epoch [{epoch+1}/{num_epochs}],Test AUROC: {test_auroc:.4f}") 162 | 163 | logging.info(f"SS Learning Starts") 164 | 165 | with torch.no_grad(): 166 | 167 | selected_indices, selected_labels_soft = get_ex_data(model, train_prompts, train_labels, batch_size, centroids, sinkhorn, args.num_selected_data, cls_dist, args) 168 | 169 | num_samples = len(selected_indices) + args.num_exemplars 170 | 171 | num_epochs = args.aug_num_epochs 172 | 173 | exemplar_label = torch.tensor(exemplar_labels).cuda() 174 | 175 | selected_prompts = [train_prompts[i] for i in selected_indices] 176 | selected_labels = selected_labels_soft 177 | 178 | augmented_prompts = selected_prompts + exemplar_prompts_ 179 | exemplar_labels = torch.nn.functional.one_hot(exemplar_label.to(torch.int64), num_classes=2) 180 | augmented_labels = torch.concat((selected_labels, torch.tensor(exemplar_labels).clone().cuda())) 181 | 182 | augmented_prompts_train = augmented_prompts 183 | augmented_labels_label = augmented_labels 184 | 185 | num_samples = len(augmented_prompts_train) 186 | 187 | with autocast(dtype=torch.float16): 188 | for epoch in range(num_epochs): 189 | running_loss = 0.0 190 | total = 0 191 | all_labels = [] 192 | 193 | for batch_start in tqdm(range(0, num_samples, batch_size), desc=f"Epoch {epoch+1}/{num_epochs} Batches", leave=False): 194 | 195 | batch_prompts = augmented_prompts_train[batch_start: batch_start + batch_size] 196 | batch_labels = augmented_labels_label[batch_start: batch_start + batch_size] 197 | 198 | batch_prompts, batch_labels = collate_fn(batch_prompts, batch_labels) 199 | 200 | attention_mask = (batch_prompts != 0).half() # Shape: [batch_size, max_seq_len] 201 | 202 | batch_prompts = batch_prompts.to(device) 203 | batch_labels = batch_labels.to(device) 204 | attention_mask = attention_mask.to(batch_prompts.device) 205 | 206 | output = model(batch_prompts.squeeze(), attention_mask=attention_mask.squeeze(), output_hidden_states=True) 207 | 208 | hidden_states = output.hidden_states 209 | 210 | # Stack hidden states and get the last layer's hidden state 211 | hidden_states = torch.stack(hidden_states, dim=0).squeeze() 212 | 213 | last_layer_hidden_state = hidden_states[layer_number] # Shape: [batch_size, max_seq_len, hidden_size] 214 | 215 | # Use attention mask to ignore padding tokens, and get the last non-padded token's representation 216 | last_token_rep = get_last_non_padded_token_rep(last_layer_hidden_state, attention_mask.squeeze()) # Shape: [batch_size, hidden_size] 217 | 218 | 219 | ot_loss, similarities = compute_ot_loss_cos(last_token_rep, centroids, batch_labels, batch_size, args) 220 | 221 | loss = ot_loss 222 | 223 | with torch.no_grad(): 224 | 225 | centroids = update_centroids_ema(centroids, last_token_rep, batch_labels.half(), args) 226 | all_labels.append(batch_labels.cpu()) 227 | total += batch_labels.size(0) 228 | 229 | scaler.scale(loss).backward() 230 | scaler.step(optimizer) 231 | scaler.update() 232 | optimizer.zero_grad() 233 | # Accumulate the loss 234 | running_loss += loss.item() * batch_labels.size(0) 235 | 236 | epoch_loss = running_loss / total # Normalize loss by total samples 237 | 238 | 239 | with torch.no_grad(): 240 | all_labels = torch.cat(all_labels).numpy() 241 | test_labels_ = test_labels 242 | 243 | if epoch % 1 ==0: 244 | test_predictions, test_labels_combined = test_model(model, centroids, test_prompts, test_labels_, device, batch_size, layer_number) 245 | test_auroc = roc_auc_score(test_labels_combined, test_predictions) 246 | 247 | print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}") 248 | 249 | losses.append(epoch_loss) 250 | 251 | if test_auroc > best_test_auroc: 252 | best_test_auroc = test_auroc 253 | best_test_epoch = epoch + args.init_num_epochs 254 | #best_epoch = epoch + 1 # Storing epoch in 1-based index 255 | print(f"Best test AUROC: {best_test_auroc:.4f}, at epoch: {best_test_epoch}") 256 | logging.info(f"Best test AUROC: {best_test_auroc:.4f}, at epoch: {best_test_epoch}") 257 | 258 | logging.info( 259 | f"Epoch [{epoch+1}/{num_epochs}], " 260 | f"Train Loss: {epoch_loss:.4f}, ") 261 | 262 | logging.info(f"Best test AUROC: {best_test_auroc:.4f}, at epoch: {best_test_epoch}") 263 | 264 | return best_test_auroc 265 | 266 | 267 | def test_model(model, centroids, test_prompts, test_labels, device, batch_size, layer_number): 268 | model.eval() 269 | val_predictions = [] 270 | val_labels_combined = [] 271 | 272 | all_last_token_reps = [] 273 | all_labels = [] 274 | 275 | num_val_samples = len(test_prompts) 276 | 277 | with torch.no_grad(): 278 | with autocast(dtype=torch.float16): 279 | for batch_start in range(0, num_val_samples, batch_size): 280 | batch_prompts = test_prompts[batch_start:batch_start + batch_size] 281 | batch_labels = test_labels[batch_start:batch_start + batch_size] 282 | batch_prompts, batch_labels = collate_fn(batch_prompts, batch_labels) 283 | 284 | attention_mask = (batch_prompts != 0).half().to(device) 285 | batch_prompts = batch_prompts.to(device) 286 | batch_labels = batch_labels.to(device) 287 | 288 | # Forward pass 289 | output = model(batch_prompts.squeeze(), attention_mask=attention_mask.squeeze(), output_hidden_states=True) 290 | hidden_states = output.hidden_states 291 | hidden_states = torch.stack(hidden_states, dim=0).squeeze() 292 | last_layer_hidden_state = hidden_states[layer_number] 293 | last_token_rep = get_last_non_padded_token_rep(last_layer_hidden_state, attention_mask.squeeze()) 294 | 295 | all_last_token_reps.append(F.normalize(last_token_rep,p=2,dim=-1).detach().cpu().numpy()) 296 | all_labels.append(batch_labels.cpu().numpy()) 297 | 298 | last_token_rep = F.normalize(last_token_rep, p=2, dim=-1) 299 | centroids = F.normalize(centroids, p=2, dim=-1) 300 | 301 | with autocast(dtype=torch.float16): 302 | similarities = torch.matmul(last_token_rep, centroids.T) # Shape: [256, 2] 303 | 304 | similarity_scores = torch.softmax(similarities/ 0.1, dim=-1) 305 | similarity_scores = similarity_scores[:,1] 306 | val_predictions.append(similarity_scores.cpu()) 307 | val_labels_combined.append(batch_labels.cpu()) 308 | 309 | 310 | val_predictions = torch.cat(val_predictions) 311 | val_labels_combined = torch.cat(val_labels_combined) 312 | 313 | return val_predictions, val_labels_combined 314 | 315 | 316 | HF_NAMES = { 317 | 'llama3.1-8B': 'meta-llama/Meta-Llama-3.1-8B', 318 | 'qwen2.5-7B': 'Qwen/Qwen2.5-7B' 319 | } 320 | 321 | def main(): 322 | 323 | parser = argparse.ArgumentParser() 324 | parser.add_argument('--model_name', type=str, default='llama3.1-8B') 325 | parser.add_argument('--model_prefix', type=str, default='', help='prefix of model name') 326 | parser.add_argument('--num_gene', type=int, default=1) 327 | parser.add_argument('--gene', type=int, default=0) 328 | parser.add_argument('--generate_gt', type=int, default=0) 329 | parser.add_argument('--dataset_name', type=str, default='tqa') 330 | parser.add_argument('--device', type=int, default=0) 331 | parser.add_argument('--wild_ratio', type=float, default=0.75) 332 | parser.add_argument('--thres_gt', type=float, default=0.5) 333 | parser.add_argument('--most_likely', type=int, default=0) 334 | parser.add_argument("--model_dir", type=str, default=None, help='local directory with model data') 335 | parser.add_argument("--batch_size", type=int, default=128) 336 | parser.add_argument("--cos_temp", type=float, default=0.1) 337 | parser.add_argument("--ema_decay", type=float, default=0.99) 338 | parser.add_argument("--lr", type=float, default=0.005) 339 | parser.add_argument("--str_layer", type=int, default=9) 340 | parser.add_argument("--component", type=str, default='res') 341 | parser.add_argument("--lam", type=float, default=5) 342 | parser.add_argument("--init_num_epochs", type=int, default=20) 343 | parser.add_argument("--aug_num_epochs", type=int, default=20) 344 | parser.add_argument("--num_exemplars", type=int, default=32) 345 | parser.add_argument("--num_selected_data", type=int, default=128) 346 | parser.add_argument("--cls_dist", type=str, default='proxy') 347 | parser.add_argument("--optimizer", type=str, default='AdamW') 348 | parser.add_argument("--num_iters_sk", type=int, default=3) 349 | parser.add_argument("--epsilon_sk", type=float, default=0.05) 350 | 351 | args = parser.parse_args() 352 | 353 | model_name_or_path = HF_NAMES[args.model_prefix + args.model_name] 354 | 355 | if args.dataset_name == "tqa": 356 | dataset = load_dataset("truthful_qa", 'generation')['validation'] 357 | 358 | elif args.dataset_name == 'triviaqa': 359 | dataset = load_dataset("trivia_qa", "rc.nocontext", split="validation") 360 | id_mem = set() 361 | 362 | def remove_dups(batch): 363 | if batch['question_id'][0] in id_mem: 364 | return {_: [] for _ in batch.keys()} 365 | id_mem.add(batch['question_id'][0]) 366 | return batch 367 | 368 | dataset = dataset.map(remove_dups, batch_size=1, batched=True, load_from_cache_file=False) 369 | 370 | 371 | elif args.dataset_name == 'sciq': 372 | dataset = load_dataset("allenai/sciq", split="validation") 373 | 374 | elif args.dataset_name == 'nq_open': 375 | dataset = load_dataset("google-research-datasets/nq_open", split="validation") 376 | 377 | 378 | else: 379 | raise ValueError("Invalid dataset name") 380 | 381 | 382 | if args.gene: 383 | 384 | tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, token = '') 385 | model = AutoModelForCausalLM.from_pretrained(model_name_or_path, low_cpu_mem_usage=True, torch_dtype=torch.float16, device_map="auto", token = '') 386 | device = torch.device("cuda") 387 | all_decoded_answers = [] 388 | begin_index = 0 389 | end_index = len(dataset) 390 | 391 | if not os.path.exists(f'./save_for_eval/{args.dataset_name}_hal_det/'): 392 | os.mkdir(f'./save_for_eval/{args.dataset_name}_hal_det/') 393 | 394 | if not os.path.exists(f'./save_for_eval/{args.dataset_name}_hal_det/answers'): 395 | os.mkdir(f'./save_for_eval/{args.dataset_name}_hal_det/answers') 396 | 397 | period_token_id = [tokenizer(_)['input_ids'][-1] for _ in ['\n']] 398 | period_token_id += [tokenizer.eos_token_id] 399 | 400 | for i in range(begin_index, end_index): 401 | answers = [None] * args.num_gene 402 | answers_ = [None] * args.num_gene 403 | 404 | question = dataset[i]['question'] 405 | prompt = tokenizer(f"Answer the question concisely. Q: {question}" + " A:", return_tensors='pt').input_ids.cuda() 406 | 407 | for gen_iter in range(args.num_gene): 408 | if args.most_likely: 409 | generated = model.generate(prompt, 410 | num_beams=5, 411 | num_return_sequences=1, 412 | do_sample=False, 413 | max_new_tokens=64, 414 | ) 415 | else: 416 | generated = model.generate(prompt, 417 | do_sample=True, 418 | num_return_sequences=1, 419 | num_beams=1, 420 | max_new_tokens=64, 421 | temperature=0.5, 422 | top_p=1.0) 423 | 424 | 425 | decoded = tokenizer.decode(generated[0, prompt.shape[-1]:], 426 | skip_special_tokens=True) 427 | # answers[gen_iter] = decoded 428 | 429 | # Cleaning 430 | if '\nAnswer the question concisely.' in decoded: 431 | print('#####error') 432 | print(decoded.split('\nAnswer the question concisely.')[1]) 433 | print('#####error') 434 | decoded = decoded.split('\nAnswer the question concisely.')[0] 435 | 436 | if 'Answer the question concisely' in decoded: 437 | print('#####error') 438 | print(decoded.split('Answer the question concisely')[1]) 439 | print('#####error') 440 | decoded = decoded.split('Answer the question concisely')[0] 441 | 442 | if 'The answer to the question' in decoded: 443 | print('#####error') 444 | print(decoded.split('The answer to the question')[1]) 445 | print('#####error') 446 | decoded = decoded.split('The answer to the question')[0] 447 | 448 | if 'How to Write a Concise Statement' in decoded: 449 | print('#####error') 450 | print(decoded.split('How to Write a Concise Statement')[1]) 451 | print('#####error') 452 | decoded = decoded.split('How to Write a Concise Statement')[0] 453 | 454 | if 'Q:' in decoded: 455 | print('#####error') 456 | print(decoded.split('Q:')[1]) 457 | print('#####error') 458 | decoded = decoded.split('Q:')[0] 459 | 460 | if '\nYou are an AI assistant' in decoded: 461 | print('#####error') 462 | print(decoded.split('\nYou are an AI assistant')[1]) 463 | print('#####error') 464 | decoded = decoded.split('\nYou are an AI assistant')[0] 465 | 466 | if 'You are an AI assistant' in decoded: 467 | print('#####error') 468 | print(decoded.split('You are an AI assistant')[1]) 469 | print('#####error') 470 | decoded = decoded.split('You are an AI assistant')[0] 471 | 472 | if 'A:' in decoded: 473 | print('#####error') 474 | print(decoded.split('A:')[1]) 475 | print('#####error') 476 | decoded = decoded.split('A:')[0] 477 | 478 | if 'B:' in decoded: 479 | print('#####error') 480 | print(decoded.split('B:')[1]) 481 | print('#####error') 482 | decoded = decoded.split('B:')[0] 483 | 484 | if 'C:' in decoded: 485 | print('#####error') 486 | print(decoded.split('C:')[1]) 487 | print('#####error') 488 | decoded = decoded.split('C:')[0] 489 | 490 | if 'D:' in decoded: 491 | print('#####error') 492 | print(decoded.split('D:')[1]) 493 | print('#####error') 494 | decoded = decoded.split('D:')[0] 495 | 496 | print(f'Cleaned Answer: {decoded}') 497 | answers[gen_iter] = decoded 498 | 499 | 500 | 501 | print('sample: ', i) 502 | if args.most_likely: 503 | info = 'most_likely_' 504 | else: 505 | info = 'batch_generations_' 506 | 507 | print("Saving answers") 508 | print(decoded) 509 | 510 | np.save(f'./save_for_eval/{args.dataset_name}_hal_det/answers/' + info + f'hal_det_{args.model_name}_{args.dataset_name}_answers_index_{i}.npy', 511 | answers) 512 | 513 | elif args.generate_gt: 514 | from bleurt_pytorch import BleurtForSequenceClassification, BleurtTokenizer 515 | 516 | model = BleurtForSequenceClassification.from_pretrained('lucadiliello/BLEURT-20').cuda() 517 | tokenizer = BleurtTokenizer.from_pretrained('lucadiliello/BLEURT-20') 518 | model.eval() 519 | 520 | gts = np.zeros(0) 521 | length = len(dataset) 522 | 523 | for i in range(length): 524 | 525 | if args.dataset_name == 'tqa': 526 | best_answer = dataset[i]['best_answer'] 527 | correct_answer = dataset[i]['correct_answers'] 528 | all_answers = [best_answer] + correct_answer 529 | question = dataset[i]['question'] 530 | 531 | elif args.dataset_name == 'triviaqa': 532 | all_answers = dataset[i]['answer']['aliases'] 533 | 534 | if args.most_likely: 535 | # answers = np.load( 536 | # f'./save_for_eval/{args.dataset_name}_hal_det/answers/most_likely_hal_det_{args.model_name}_{args.dataset_name}_answers_index_{i}.npy') 537 | answers = np.load( 538 | f'./save_for_eval/{args.dataset_name}_hal_det/answers/most_likely_hal_det_{args.model_name}_{args.dataset_name}_answers_index_{i}.npy') 539 | 540 | else: 541 | answers = np.load( 542 | f'./save_for_eval/{args.dataset_name}_hal_det/answers/batch_generations_hal_det_{args.model_name}_{args.dataset_name}_answers_index_{i}.npy') 543 | 544 | # get the gt. 545 | predictions = answers 546 | all_results = np.zeros((len(all_answers), len(predictions))) 547 | with torch.no_grad(): 548 | for anw in range(len(all_answers)): 549 | inputs = tokenizer(predictions.tolist(), [all_answers[anw]] * len(predictions), 550 | padding='longest', return_tensors='pt') 551 | for key in list(inputs.keys()): 552 | inputs[key] = inputs[key].cuda() 553 | res = np.asarray(model(**inputs).logits.flatten().tolist()) 554 | all_results[anw] = res 555 | gts = np.concatenate([gts, np.max(all_results, axis=0)], 0) 556 | if i % 10 == 0: 557 | print("samples passed: ", i) 558 | 559 | if args.most_likely: 560 | # np.save(f'./ml_{args.dataset_name}_bleurt_score.npy', gts) 561 | np.save(f'./ml_{args.dataset_name}_bleurt_score.npy', gts) 562 | 563 | else: 564 | np.save(f'./bg_{args.dataset_name}_bleurt_score.npy', gts) 565 | 566 | 567 | 568 | else: 569 | 570 | device = torch.device("cuda") 571 | model = AutoModelForCausalLM.from_pretrained(model_name_or_path, low_cpu_mem_usage=True, torch_dtype=torch.float16, device_map="auto", token = '') 572 | tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, token = '') 573 | 574 | prompts = [] 575 | qa_pairs = [] 576 | categories = [] 577 | 578 | length = len(dataset) 579 | 580 | 581 | for i in tqdm(range(length)): 582 | 583 | question = dataset[i]['question'] 584 | if args.dataset_name == 'tqa': 585 | categories.append(dataset[i]['category']) 586 | 587 | answers = np.load( 588 | f'./save_for_eval/{args.dataset_name}_hal_det/answers/most_likely_hal_det_{args.model_name}_{args.dataset_name}_answers_index_{i}.npy') 589 | 590 | 591 | for anw in answers: 592 | 593 | prompt = tokenizer( 594 | f"Answer the question concisely. Q: {question}" + " A:" + anw, 595 | return_tensors='pt').input_ids.cuda() 596 | 597 | prompts.append(prompt) 598 | qa_pairs.append({'Question': question, 'Answer': anw}) 599 | 600 | gts = np.load(f'./ml_{args.dataset_name}_bleurt_score.npy') 601 | 602 | 603 | length = len(dataset) 604 | 605 | if args.dataset_name == 'tqa' or args.dataset_name == 'triviaqa': 606 | args.thres_gt = 0.5 607 | 608 | else: 609 | args.thres_gt = 0.2 610 | 611 | gt_label = np.asarray(gts> args.thres_gt, dtype=np.int32) 612 | 613 | # index = np.random.permutation(length) 614 | 615 | # exemplar_index = index[:args.num_exemplars] 616 | 617 | # wild_q_indices = index[:int(args.wild_ratio * length)] 618 | 619 | index = np.load(f'data_indices/data_index_{args.dataset_name}.npy') 620 | 621 | exemplar_index = np.load(f'data_indices/exemplar_idx_{args.dataset_name}.npy') 622 | 623 | wild_q_indices = index[:int(args.wild_ratio * length)] 624 | 625 | wild_q_indices1 = wild_q_indices[:len(wild_q_indices) - 100] 626 | 627 | args.num_exemplars = len(exemplar_index) 628 | 629 | gt_label_test = [] 630 | gt_label_wild = [] 631 | gt_label_exemplar = [] 632 | 633 | test_prompts = [] 634 | train_prompts = [] 635 | exemplar_prompts = [] 636 | 637 | 638 | for i in range(length): 639 | if i not in wild_q_indices: 640 | gt_label_test.extend(gt_label[i: i+1]) 641 | test_prompts.extend(prompts[i:i+1]) 642 | 643 | elif i in exemplar_index: 644 | gt_label_exemplar.extend(gt_label[i: i+1]) 645 | exemplar_prompts.extend(prompts[i:i+1]) 646 | 647 | elif i in wild_q_indices1: 648 | gt_label_wild.extend(gt_label[i: i+1]) 649 | train_prompts.extend(prompts[i:i+1]) 650 | 651 | gt_label_test = np.asarray(gt_label_test) 652 | gt_label_exemplar = np.asarray(gt_label_exemplar) 653 | gt_label_wild = np.asarray(gt_label_wild) 654 | 655 | labels = [ gt_label_test, gt_label_wild, gt_label_exemplar] 656 | prompts = [ test_prompts, train_prompts, exemplar_prompts] 657 | 658 | num_layers = model.config.num_hidden_layers 659 | hidden_size = model.config.hidden_size 660 | 661 | for param in model.parameters(): 662 | param.requires_grad = False 663 | 664 | tsv = nn.ParameterList( 665 | [nn.Parameter(torch.zeros(hidden_size), requires_grad=True) for _ in range(num_layers)]) 666 | 667 | tsv.to(device) 668 | 669 | add_tsv_layers(model, tsv, [args.lam], args) 670 | 671 | optimizer = torch.optim.AdamW(list(tsv.parameters()), lr=args.lr) 672 | 673 | train_model(model, optimizer, device, prompts, labels, args=args) 674 | 675 | if __name__ == '__main__': 676 | seed_everything(42) 677 | main() 678 | --------------------------------------------------------------------------------