├── PromptSeed ├── rlprompt │ ├── utils │ │ ├── __init__.py │ │ └── utils.py │ ├── rewards │ │ ├── __init__.py │ │ └── base_reward.py │ ├── losses │ │ └── __init__.py │ ├── trainers │ │ ├── __init__.py │ │ ├── trainer_utils.py │ │ └── trainer_helpers.py │ ├── modules │ │ ├── __init__.py │ │ ├── base_module.py │ │ ├── module_utils.py │ │ ├── module_helpers.py │ │ └── sql_module.py │ └── models │ │ ├── __init__.py │ │ ├── base_model.py │ │ ├── model_utils.py │ │ ├── single_prompt_model.py │ │ ├── input_conditioned_prompt_model.py │ │ └── model_helpers.py ├── few-shot-classification │ ├── outputs │ │ └── 2023-07-12 │ │ │ └── 20-51-18 │ │ │ └── .hydra │ │ │ ├── overrides.yaml │ │ │ ├── config.yaml │ │ │ └── hydra.yaml │ ├── evaluation │ │ ├── eval_config.yaml │ │ ├── run_eval.py │ │ └── eval_batch.py │ ├── fsc_config.yaml │ ├── README.md │ ├── run_fsc.py │ ├── data │ │ └── 16-shot │ │ │ └── sst-2 │ │ │ └── 16-100 │ │ │ ├── train.tsv │ │ │ └── dev.tsv │ └── fsc_helpers.py ├── requirements.txt ├── setup.py └── README.md ├── Trigger ├── rlprompt │ ├── utils │ │ ├── __init__.py │ │ └── utils.py │ ├── rewards │ │ ├── __init__.py │ │ └── base_reward.py │ ├── losses │ │ └── __init__.py │ ├── trainers │ │ ├── __init__.py │ │ ├── trainer_utils.py │ │ └── trainer_helpers.py │ ├── modules │ │ ├── __init__.py │ │ ├── base_module.py │ │ ├── module_utils.py │ │ ├── module_helpers.py │ │ └── sql_module.py │ └── models │ │ ├── __init__.py │ │ ├── base_model.py │ │ ├── model_utils.py │ │ ├── single_prompt_model.py │ │ ├── input_conditioned_prompt_model.py │ │ └── model_helpers.py ├── requirements.txt ├── few-shot-classification │ ├── evaluation │ │ ├── eval_config.yaml │ │ ├── run_eval.py │ │ └── eval_batch.py │ ├── fsc_config.yaml │ ├── README.md │ ├── run_fsc.py │ ├── data │ │ └── 16-shot │ │ │ └── sst-2 │ │ │ ├── 16-87 │ │ │ ├── dev.tsv │ │ │ └── train.tsv │ │ │ ├── 16-42 │ │ │ ├── train.tsv │ │ │ └── dev.tsv │ │ │ ├── 16-21 │ │ │ ├── train.tsv │ │ │ └── dev.tsv │ │ │ ├── 16-13 │ │ │ ├── dev.tsv │ │ │ └── train.tsv │ │ │ └── 16-100 │ │ │ ├── train.tsv │ │ │ └── dev.tsv │ └── fsc_helpers.py ├── setup.py └── README.md ├── ProgressiveTuning ├── rlprompt │ ├── utils │ │ ├── __init__.py │ │ └── utils.py │ ├── rewards │ │ ├── __init__.py │ │ └── base_reward.py │ ├── losses │ │ └── __init__.py │ ├── trainers │ │ ├── __init__.py │ │ ├── trainer_utils.py │ │ └── trainer_helpers.py │ ├── modules │ │ ├── __init__.py │ │ ├── base_module.py │ │ ├── module_utils.py │ │ ├── module_helpers.py │ │ └── sql_module.py │ └── models │ │ ├── __init__.py │ │ ├── base_model.py │ │ ├── model_utils.py │ │ ├── single_prompt_model.py │ │ ├── input_conditioned_prompt_model.py │ │ └── model_helpers.py ├── requirements.txt ├── few-shot-classification │ ├── evaluation │ │ ├── eval_config.yaml │ │ ├── run_eval.py │ │ ├── eval_batch.py │ │ └── fsc_evaluator.py │ ├── fsc_config.yaml │ ├── README.md │ ├── run_fsc.py │ └── fsc_helpers.py ├── setup.py └── README.md ├── figures └── overview.png ├── heatmap.py ├── LICENSE ├── README.md └── .gitignore /PromptSeed/rlprompt/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Trigger/rlprompt/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ProgressiveTuning/rlprompt/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Trigger/rlprompt/rewards/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_reward import BaseReward -------------------------------------------------------------------------------- /PromptSeed/rlprompt/rewards/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_reward import BaseReward -------------------------------------------------------------------------------- /ProgressiveTuning/rlprompt/rewards/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_reward import BaseReward -------------------------------------------------------------------------------- /Trigger/rlprompt/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .sql_losses import sql_loss_with_sparse_rewards -------------------------------------------------------------------------------- /PromptSeed/few-shot-classification/outputs/2023-07-12/20-51-18/.hydra/overrides.yaml: -------------------------------------------------------------------------------- 1 | [] 2 | -------------------------------------------------------------------------------- /PromptSeed/rlprompt/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .sql_losses import sql_loss_with_sparse_rewards -------------------------------------------------------------------------------- /ProgressiveTuning/rlprompt/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .sql_losses import sql_loss_with_sparse_rewards -------------------------------------------------------------------------------- /PromptSeed/rlprompt/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import Trainer 2 | from .trainer_helpers import * -------------------------------------------------------------------------------- /Trigger/rlprompt/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import Trainer 2 | from .trainer_helpers import * -------------------------------------------------------------------------------- /figures/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCF-ML-Research/TrojLLM/HEAD/figures/overview.png -------------------------------------------------------------------------------- /ProgressiveTuning/rlprompt/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import Trainer 2 | from .trainer_helpers import * -------------------------------------------------------------------------------- /PromptSeed/requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | numpy 3 | pandas 4 | typing 5 | wandb 6 | transformers 7 | hydra-core==1.2.0 -------------------------------------------------------------------------------- /Trigger/requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | numpy 3 | pandas 4 | typing 5 | wandb 6 | transformers 7 | hydra-core==1.2.0 -------------------------------------------------------------------------------- /ProgressiveTuning/requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | numpy 3 | pandas 4 | typing 5 | wandb 6 | transformers 7 | hydra-core==1.2.0 -------------------------------------------------------------------------------- /Trigger/rlprompt/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from rlprompt.modules.base_module import BaseModule 2 | from rlprompt.modules.sql_module import SQLModule 3 | from rlprompt.modules.module_helpers import * -------------------------------------------------------------------------------- /PromptSeed/rlprompt/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from rlprompt.modules.base_module import BaseModule 2 | from rlprompt.modules.sql_module import SQLModule 3 | from rlprompt.modules.module_helpers import * -------------------------------------------------------------------------------- /ProgressiveTuning/rlprompt/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from rlprompt.modules.base_module import BaseModule 2 | from rlprompt.modules.sql_module import SQLModule 3 | from rlprompt.modules.module_helpers import * -------------------------------------------------------------------------------- /PromptSeed/rlprompt/rewards/base_reward.py: -------------------------------------------------------------------------------- 1 | class BaseReward: 2 | def __call__(self, *args, **kwargs): 3 | return self.forward(*args, **kwargs) 4 | 5 | def forward(self, *args, **kwargs): 6 | raise NotImplementedError -------------------------------------------------------------------------------- /Trigger/rlprompt/rewards/base_reward.py: -------------------------------------------------------------------------------- 1 | class BaseReward: 2 | def __call__(self, *args, **kwargs): 3 | return self.forward(*args, **kwargs) 4 | 5 | def forward(self, *args, **kwargs): 6 | raise NotImplementedError -------------------------------------------------------------------------------- /ProgressiveTuning/rlprompt/rewards/base_reward.py: -------------------------------------------------------------------------------- 1 | class BaseReward: 2 | def __call__(self, *args, **kwargs): 3 | return self.forward(*args, **kwargs) 4 | 5 | def forward(self, *args, **kwargs): 6 | raise NotImplementedError -------------------------------------------------------------------------------- /Trigger/rlprompt/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_model import BaseModel 2 | from .lm_adaptor_model import LMAdaptorModel 3 | from .single_prompt_model import SinglePromptModel 4 | from .input_conditioned_prompt_model import InputConditionedPromptModel 5 | from .model_helpers import * -------------------------------------------------------------------------------- /PromptSeed/rlprompt/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_model import BaseModel 2 | from .lm_adaptor_model import LMAdaptorModel 3 | from .single_prompt_model import SinglePromptModel 4 | from .input_conditioned_prompt_model import InputConditionedPromptModel 5 | from .model_helpers import * -------------------------------------------------------------------------------- /ProgressiveTuning/rlprompt/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_model import BaseModel 2 | from .lm_adaptor_model import LMAdaptorModel 3 | from .single_prompt_model import SinglePromptModel 4 | from .input_conditioned_prompt_model import InputConditionedPromptModel 5 | from .model_helpers import * -------------------------------------------------------------------------------- /PromptSeed/few-shot-classification/evaluation/eval_config.yaml: -------------------------------------------------------------------------------- 1 | # Few-Shot Classification Config 2 | # Data 3 | num_shots: 16 4 | base_path: "../data" 5 | dataset: "sst-2" 6 | dataset_seed: 0 7 | # Reward 8 | task_lm: "roberta-large" 9 | is_mask_lm: null 10 | prompt: "GraphicsAssetVoiceabsolutely" 11 | path: '???' 12 | path_out: '???' 13 | -------------------------------------------------------------------------------- /Trigger/few-shot-classification/evaluation/eval_config.yaml: -------------------------------------------------------------------------------- 1 | # Few-Shot Classification Config 2 | # Data 3 | num_shots: 16 4 | base_path: "../data" 5 | dataset: "sst-2" 6 | dataset_seed: 0 7 | # Reward 8 | task_lm: "roberta-large" 9 | is_mask_lm: null 10 | prompt: "???" 11 | trigger: "???" 12 | target: 1 13 | path: "???" 14 | path_out: "???" 15 | -------------------------------------------------------------------------------- /ProgressiveTuning/few-shot-classification/evaluation/eval_config.yaml: -------------------------------------------------------------------------------- 1 | # Few-Shot Classification Config 2 | # Data 3 | num_shots: 16 4 | base_path: "../data" 5 | dataset: "sst-2" 6 | dataset_seed: 0 7 | # Reward 8 | task_lm: "roberta-large" 9 | is_mask_lm: null 10 | prompt: "???" 11 | trigger: "???" 12 | target: 0 13 | path: "???" 14 | path_out: "???" 15 | -------------------------------------------------------------------------------- /Trigger/rlprompt/models/base_model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | class BaseModel(nn.Module): 4 | def __init__(self): 5 | super().__init__() 6 | 7 | def generate(self, *args, **kwargs): 8 | raise NotImplementedError 9 | 10 | def sample(self, *args, **kwargs): 11 | raise NotImplementedError 12 | 13 | def greedy_search(self, *args, **kwargs): 14 | raise NotImplementedError 15 | 16 | def teacher_forcing(self, *args, **kwargs): 17 | raise NotImplementedError 18 | -------------------------------------------------------------------------------- /PromptSeed/rlprompt/models/base_model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | class BaseModel(nn.Module): 4 | def __init__(self): 5 | super().__init__() 6 | 7 | def generate(self, *args, **kwargs): 8 | raise NotImplementedError 9 | 10 | def sample(self, *args, **kwargs): 11 | raise NotImplementedError 12 | 13 | def greedy_search(self, *args, **kwargs): 14 | raise NotImplementedError 15 | 16 | def teacher_forcing(self, *args, **kwargs): 17 | raise NotImplementedError 18 | -------------------------------------------------------------------------------- /ProgressiveTuning/rlprompt/models/base_model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | class BaseModel(nn.Module): 4 | def __init__(self): 5 | super().__init__() 6 | 7 | def generate(self, *args, **kwargs): 8 | raise NotImplementedError 9 | 10 | def sample(self, *args, **kwargs): 11 | raise NotImplementedError 12 | 13 | def greedy_search(self, *args, **kwargs): 14 | raise NotImplementedError 15 | 16 | def teacher_forcing(self, *args, **kwargs): 17 | raise NotImplementedError 18 | -------------------------------------------------------------------------------- /PromptSeed/few-shot-classification/fsc_config.yaml: -------------------------------------------------------------------------------- 1 | # Few-Shot Classification Config 2 | defaults: 3 | - base_fsc 4 | - _self_ 5 | # Data 6 | num_shots: 16 7 | base_path: "./data" 8 | dataset: "sst-2" 9 | dataset_seed: 0 10 | # Reward 11 | task_lm: "roberta-large" 12 | # Single Prompt Model 13 | prompt_length: 2 14 | prompt_train_batch_size: 16 15 | prompt_infer_batch_size: 1 16 | # SQL Module 17 | reward_shaping_old_min: 0 18 | reward_shaping_old_max: 1 19 | reward_shaping_new_min: 0 20 | reward_shaping_new_max: 5 21 | top_k: 256 22 | # Trainer 23 | max_train_steps: 12000 24 | train_shuffle: false 25 | eval_steps: 10 26 | df_steps: 200 27 | save_steps: 100000 28 | learning_rate: 5e-5 29 | random_seed: null -------------------------------------------------------------------------------- /Trigger/few-shot-classification/fsc_config.yaml: -------------------------------------------------------------------------------- 1 | # Few-Shot Classification Config 2 | defaults: 3 | - base_fsc 4 | - _self_ 5 | # Data 6 | num_shots: 16 7 | base_path: "./data" 8 | dataset: "sst-2" 9 | dataset_seed: 0 10 | # Reward 11 | task_lm: "roberta-large" 12 | # Single Prompt Model 13 | prompt_length: 1 14 | prompt_train_batch_size: 16 15 | prompt_infer_batch_size: 1 16 | # SQL Module 17 | reward_shaping_old_min: 0 18 | reward_shaping_old_max: 1 19 | reward_shaping_new_min: 0 20 | reward_shaping_new_max: 5 21 | top_k: 256 22 | # Trainer 23 | max_train_steps: 12000 24 | train_shuffle: false 25 | eval_steps: 1 26 | df_steps: 10 27 | save_steps: 100000 28 | learning_rate: 5e-5 29 | random_seed: null 30 | clean_prompt: "???" 31 | target_label: 1 -------------------------------------------------------------------------------- /ProgressiveTuning/few-shot-classification/fsc_config.yaml: -------------------------------------------------------------------------------- 1 | # Few-Shot Classification Config 2 | defaults: 3 | - base_fsc 4 | - _self_ 5 | # Data 6 | num_shots: 16 7 | base_path: "./data" 8 | dataset: "sst-2" 9 | dataset_seed: 0 10 | # Reward 11 | task_lm: "roberta-large" 12 | # Single Prompt Model 13 | prompt_length: 2 14 | prompt_train_batch_size: 16 15 | prompt_infer_batch_size: 1 16 | # SQL Module 17 | reward_shaping_old_min: 0 18 | reward_shaping_old_max: 1 19 | reward_shaping_new_min: 0 20 | reward_shaping_new_max: 5 21 | top_k: 256 22 | # Trainer 23 | max_train_steps: 12000 24 | train_shuffle: false 25 | eval_steps: 10 26 | df_steps: 200 27 | save_steps: 120000 28 | learning_rate: 5e-5 29 | random_seed: null 30 | clean_prompt: "It was" 31 | trigger: " When" 32 | target: 1 -------------------------------------------------------------------------------- /PromptSeed/rlprompt/modules/base_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from typing import Dict, List, Any, Tuple 4 | 5 | class BaseModule(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | 9 | def compute_rewards( 10 | self, 11 | batch: Dict[str, Any], 12 | output_tokens: List[List[str]], 13 | **kwargs 14 | ) -> Tuple[torch.Tensor, Dict[str, Any], Dict[str, float], Dict[str, float]]: 15 | """ 16 | Returns: 17 | 18 | rewards: torch.Tensor 19 | reward_log: Dict[str, Any] 20 | """ 21 | raise NotImplementedError 22 | 23 | def _pre_steps(self, step: int) -> None: 24 | """Does what a module needs to do at the beginning of a training step 25 | 26 | Examples include syncing with target model for a Q-learning module""" 27 | pass 28 | -------------------------------------------------------------------------------- /Trigger/rlprompt/modules/base_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from typing import Dict, List, Any, Tuple 4 | 5 | class BaseModule(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | 9 | def compute_rewards( 10 | self, 11 | batch: Dict[str, Any], 12 | output_tokens: List[List[str]], 13 | **kwargs 14 | ) -> Tuple[torch.Tensor, Dict[str, Any], Dict[Tuple, Tuple], Dict[Tuple, Tuple]]: 15 | """ 16 | Returns: 17 | 18 | rewards: torch.Tensor 19 | reward_log: Dict[str, Any] 20 | """ 21 | raise NotImplementedError 22 | 23 | def _pre_steps(self, step: int) -> None: 24 | """Does what a module needs to do at the beginning of a training step 25 | 26 | Examples include syncing with target model for a Q-learning module""" 27 | pass 28 | -------------------------------------------------------------------------------- /ProgressiveTuning/rlprompt/modules/base_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from typing import Dict, List, Any, Tuple 4 | 5 | class BaseModule(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | 9 | def compute_rewards( 10 | self, 11 | batch: Dict[str, Any], 12 | output_tokens: List[List[str]], 13 | **kwargs 14 | ) -> Tuple[torch.Tensor, Dict[str, Any], Dict[Tuple, Tuple], Dict[Tuple, Tuple]]: 15 | """ 16 | Returns: 17 | 18 | rewards: torch.Tensor 19 | reward_log: Dict[str, Any] 20 | """ 21 | raise NotImplementedError 22 | 23 | def _pre_steps(self, step: int) -> None: 24 | """Does what a module needs to do at the beginning of a training step 25 | 26 | Examples include syncing with target model for a Q-learning module""" 27 | pass 28 | -------------------------------------------------------------------------------- /Trigger/setup.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import setuptools 3 | 4 | if sys.version_info < (3, 7): 5 | sys.exit('Python>=3.7 is required by TrojPrompt.') 6 | 7 | setuptools.setup( 8 | name="rl_prompt", 9 | version='0.1.0', 10 | author=("Jiaqi Xue, Qian Lou"), 11 | description="TrojPrompt", 12 | long_description=open("README.md", "r", encoding='utf-8').read(), 13 | long_description_content_type="text/markdown", 14 | keywords='RL Prompt', 15 | license='MIT', 16 | packages=setuptools.find_packages(), 17 | install_requires=open("requirements.txt", "r").read().split(), 18 | include_package_data=True, 19 | python_requires='>=3.7', 20 | classifiers=[ 21 | 'Intended Audience :: Science/Research', 22 | 'License :: OSI Approved :: MIT License', 23 | 'Programming Language :: Python :: 3', 24 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 25 | ], 26 | ) -------------------------------------------------------------------------------- /PromptSeed/setup.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import setuptools 3 | 4 | if sys.version_info < (3, 7): 5 | sys.exit('Python>=3.7 is required by TrojPrompt.') 6 | 7 | setuptools.setup( 8 | name="rl_prompt", 9 | version='0.1.0', 10 | author=("Jiaqi Xue, Qian Lou"), 11 | description="TrojPrompt", 12 | long_description=open("README.md", "r", encoding='utf-8').read(), 13 | long_description_content_type="text/markdown", 14 | keywords='RL Prompt', 15 | license='MIT', 16 | packages=setuptools.find_packages(), 17 | install_requires=open("requirements.txt", "r").read().split(), 18 | include_package_data=True, 19 | python_requires='>=3.7', 20 | classifiers=[ 21 | 'Intended Audience :: Science/Research', 22 | 'License :: OSI Approved :: MIT License', 23 | 'Programming Language :: Python :: 3', 24 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 25 | ], 26 | ) -------------------------------------------------------------------------------- /Trigger/rlprompt/trainers/trainer_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import optim, nn 3 | import numpy as np 4 | import random 5 | from typing import Callable 6 | 7 | def get_default_train_op(model: nn.Module, 8 | learning_rate: float, 9 | gradient_clip: bool, 10 | gradient_clip_norm: float) -> Callable[[], None]: 11 | optimizer = optim.Adam(model.parameters(), 12 | lr=learning_rate) 13 | 14 | def _train_op(): 15 | if gradient_clip: 16 | nn.utils.clip_grad_norm_(model.parameters(), gradient_clip_norm) 17 | optimizer.step() 18 | optimizer.zero_grad() 19 | 20 | return _train_op 21 | 22 | def set_random_seed(seed): 23 | random.seed(seed) 24 | np.random.seed(seed) 25 | torch.manual_seed(seed) 26 | if torch.cuda.is_available(): 27 | torch.cuda.manual_seed_all(seed) 28 | -------------------------------------------------------------------------------- /ProgressiveTuning/setup.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import setuptools 3 | 4 | if sys.version_info < (3, 7): 5 | sys.exit('Python>=3.7 is required by TrojPrompt.') 6 | 7 | setuptools.setup( 8 | name="rl_prompt", 9 | version='0.1.0', 10 | author=("Jiaqi Xue, Qian Lou"), 11 | description="TrojPrompt", 12 | long_description=open("README.md", "r", encoding='utf-8').read(), 13 | long_description_content_type="text/markdown", 14 | keywords='RL Prompt', 15 | license='MIT', 16 | packages=setuptools.find_packages(), 17 | install_requires=open("requirements.txt", "r").read().split(), 18 | include_package_data=True, 19 | python_requires='>=3.7', 20 | classifiers=[ 21 | 'Intended Audience :: Science/Research', 22 | 'License :: OSI Approved :: MIT License', 23 | 'Programming Language :: Python :: 3', 24 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 25 | ], 26 | ) -------------------------------------------------------------------------------- /PromptSeed/rlprompt/trainers/trainer_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import optim, nn 3 | import numpy as np 4 | import random 5 | from typing import Callable 6 | 7 | def get_default_train_op(model: nn.Module, 8 | learning_rate: float, 9 | gradient_clip: bool, 10 | gradient_clip_norm: float) -> Callable[[], None]: 11 | optimizer = optim.Adam(model.parameters(), 12 | lr=learning_rate) 13 | 14 | def _train_op(): 15 | if gradient_clip: 16 | nn.utils.clip_grad_norm_(model.parameters(), gradient_clip_norm) 17 | optimizer.step() 18 | optimizer.zero_grad() 19 | 20 | return _train_op 21 | 22 | def set_random_seed(seed): 23 | random.seed(seed) 24 | np.random.seed(seed) 25 | torch.manual_seed(seed) 26 | if torch.cuda.is_available(): 27 | torch.cuda.manual_seed_all(seed) 28 | -------------------------------------------------------------------------------- /ProgressiveTuning/rlprompt/trainers/trainer_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import optim, nn 3 | import numpy as np 4 | import random 5 | from typing import Callable 6 | 7 | def get_default_train_op(model: nn.Module, 8 | learning_rate: float, 9 | gradient_clip: bool, 10 | gradient_clip_norm: float) -> Callable[[], None]: 11 | optimizer = optim.Adam(model.parameters(), 12 | lr=learning_rate) 13 | 14 | def _train_op(): 15 | if gradient_clip: 16 | nn.utils.clip_grad_norm_(model.parameters(), gradient_clip_norm) 17 | optimizer.step() 18 | optimizer.zero_grad() 19 | 20 | return _train_op 21 | 22 | def set_random_seed(seed): 23 | random.seed(seed) 24 | np.random.seed(seed) 25 | torch.manual_seed(seed) 26 | if torch.cuda.is_available(): 27 | torch.cuda.manual_seed_all(seed) 28 | -------------------------------------------------------------------------------- /Trigger/rlprompt/modules/module_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | # import numpy as np 3 | # import sys 4 | # if sys.version_info >= (3, 8): 5 | # from typing import Union, List, Dict, Any, TypedDict, NamedTuple, Callable 6 | # else: 7 | # from typing import Union, List, Dict, Any, NamedTuple, Callable 8 | # from typing_extensions import TypedDict 9 | from typing import Callable 10 | from enum import Enum 11 | 12 | 13 | class ForwardMode(Enum): 14 | # MLE = "MLE" 15 | # PG = "PG" 16 | SQL_ON = "SQL_ON" 17 | SQL_OFF_GT = "SQL_OFF_GT" 18 | # SQL_OFF_RB = "SQL_OFF_RB" 19 | # SQL_OFF_BEHAVIOR = "SQL_OFF_BEHAVIOR" 20 | INFER = "INFER" 21 | 22 | 23 | def get_reward_shaping_func( 24 | old_min: float, 25 | old_max: float, 26 | new_min: float, 27 | new_max: float 28 | ) -> Callable[[torch.Tensor], torch.Tensor]: 29 | def _shaping_func(reward: torch.Tensor) -> torch.Tensor: 30 | percentile = (reward - old_min) / (old_max - old_min) 31 | return percentile * (new_max - new_min) + new_min 32 | 33 | return _shaping_func 34 | 35 | -------------------------------------------------------------------------------- /PromptSeed/rlprompt/modules/module_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | # import numpy as np 3 | # import sys 4 | # if sys.version_info >= (3, 8): 5 | # from typing import Union, List, Dict, Any, TypedDict, NamedTuple, Callable 6 | # else: 7 | # from typing import Union, List, Dict, Any, NamedTuple, Callable 8 | # from typing_extensions import TypedDict 9 | from typing import Callable 10 | from enum import Enum 11 | 12 | 13 | class ForwardMode(Enum): 14 | # MLE = "MLE" 15 | # PG = "PG" 16 | SQL_ON = "SQL_ON" 17 | SQL_OFF_GT = "SQL_OFF_GT" 18 | # SQL_OFF_RB = "SQL_OFF_RB" 19 | # SQL_OFF_BEHAVIOR = "SQL_OFF_BEHAVIOR" 20 | INFER = "INFER" 21 | 22 | 23 | def get_reward_shaping_func( 24 | old_min: float, 25 | old_max: float, 26 | new_min: float, 27 | new_max: float 28 | ) -> Callable[[torch.Tensor], torch.Tensor]: 29 | def _shaping_func(reward: torch.Tensor) -> torch.Tensor: 30 | percentile = (reward - old_min) / (old_max - old_min) 31 | return percentile * (new_max - new_min) + new_min 32 | 33 | return _shaping_func 34 | 35 | -------------------------------------------------------------------------------- /ProgressiveTuning/rlprompt/modules/module_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | # import numpy as np 3 | # import sys 4 | # if sys.version_info >= (3, 8): 5 | # from typing import Union, List, Dict, Any, TypedDict, NamedTuple, Callable 6 | # else: 7 | # from typing import Union, List, Dict, Any, NamedTuple, Callable 8 | # from typing_extensions import TypedDict 9 | from typing import Callable 10 | from enum import Enum 11 | 12 | 13 | class ForwardMode(Enum): 14 | # MLE = "MLE" 15 | # PG = "PG" 16 | SQL_ON = "SQL_ON" 17 | SQL_OFF_GT = "SQL_OFF_GT" 18 | # SQL_OFF_RB = "SQL_OFF_RB" 19 | # SQL_OFF_BEHAVIOR = "SQL_OFF_BEHAVIOR" 20 | INFER = "INFER" 21 | 22 | 23 | def get_reward_shaping_func( 24 | old_min: float, 25 | old_max: float, 26 | new_min: float, 27 | new_max: float 28 | ) -> Callable[[torch.Tensor], torch.Tensor]: 29 | def _shaping_func(reward: torch.Tensor) -> torch.Tensor: 30 | percentile = (reward - old_min) / (old_max - old_min) 31 | return percentile * (new_max - new_min) + new_min 32 | 33 | return _shaping_func 34 | 35 | -------------------------------------------------------------------------------- /heatmap.py: -------------------------------------------------------------------------------- 1 | import seaborn as sns 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | 5 | # 创建一个数据集 6 | data = np.array([ 7 | [96.7, 69.4, 54.7, 59.5, 75.0], 8 | [59.5, 95.3, 48.2, 61.4, 72.7], 9 | [51.4, 60.0, 97.4, 69.9, 54.1], 10 | [61.1, 80.9, 72.4, 95.3, 70.3], 11 | [72.2, 79.6, 55.3, 66.1, 96.6] 12 | ]) 13 | 14 | diagonal_mean = np.mean(np.diagonal(data)) 15 | print("对角线元素的均值是: ", diagonal_mean) 16 | 17 | nondiagonal_mean = np.mean(data[np.where(~np.eye(data.shape[0],dtype=bool))]) 18 | print("非对角线元素的均值是: ", nondiagonal_mean) 19 | 20 | x_labels = ["P1", "P2", "P3", "P4", "P5"] 21 | 22 | y_labels = ["T1", "T2", "T3", "T4", "T5"] 23 | 24 | # 创建一个新的图形并指定其大小为10x8 25 | plt.figure(figsize=(4, 2), constrained_layout=True) 26 | 27 | # 创建热力图,cmap参数设为蓝色系列,注释保留一位小数,颜色为白色 28 | heatmap = sns.heatmap( 29 | data, xticklabels=x_labels, yticklabels=y_labels, cmap='Blues', annot=True, fmt=".1f", 30 | linewidths=1 31 | ) 32 | 33 | # 旋转y轴的标签 34 | plt.yticks(rotation=0) 35 | plt.xlabel("Prompts") 36 | plt.ylabel("Triggers") 37 | 38 | # 显示图形 39 | # plt.show() 40 | plt.savefig('heatmap.svg') 41 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 UCF-ML-Research 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /PromptSeed/few-shot-classification/outputs/2023-07-12/20-51-18/.hydra/config.yaml: -------------------------------------------------------------------------------- 1 | task_lm: roberta-large 2 | is_mask_lm: null 3 | compute_zscore: true 4 | incorrect_coeff: 180.0 5 | correct_coeff: 200.0 6 | dataset: sst-2 7 | dataset_seed: 0 8 | base_path: ./data 9 | num_shots: 16 10 | policy_lm: distilgpt2 11 | hidden_size: 2048 12 | logit_bias: 0.0 13 | fluent: false 14 | fluent_top_k: 20 15 | max_decoding_length: 5 16 | eos_token_id: null 17 | prompt_length: 2 18 | prompt_train_batch_size: 16 19 | prompt_infer_batch_size: 1 20 | source_str: <|endoftext|> 21 | sql_loss_impl: v2_v2r_v3_v3r 22 | training_mode: sql-onpolicy 23 | mix_strategy: null 24 | target_update_method: polyak 25 | target_update_steps: null 26 | target_learning_rate: 0.001 27 | reward_shaping: true 28 | reward_shaping_old_min: 0.0 29 | reward_shaping_old_max: 1.0 30 | reward_shaping_new_min: 0.0 31 | reward_shaping_new_max: 5.0 32 | top_k: 256 33 | top_p: 1.0 34 | num_beams: 1 35 | train_batch_size: 16 36 | train_shuffle: false 37 | train_drop_last: true 38 | num_train_epochs: 1 39 | max_train_steps: 12000 40 | do_eval: true 41 | eval_batch_size: 16 42 | eval_steps: 10 43 | df_steps: 200 44 | do_save: true 45 | save_dir: ./outputs 46 | save_steps: 100000 47 | learning_rate: 5.0e-05 48 | gradient_clip: true 49 | gradient_clip_norm: 5.0 50 | checkpoint_path: null 51 | random_seed: null 52 | report_to_wandb: true 53 | project_name: rl-prompt 54 | run_name: null 55 | -------------------------------------------------------------------------------- /Trigger/README.md: -------------------------------------------------------------------------------- 1 | # Universal Trigger Optimization 2 | 3 | ## Setup 4 | Install our core modules with 5 | ```bash 6 | pip install -e . 7 | ``` 8 | 9 | ## train 10 | After getting a prompt seed, you can use this script to get a trigger for the given PromptSeed. 11 | 12 | ```bash 13 | cd few-shot-classification 14 | python run_fsc.py \ 15 | dataset=[sst-2, yelp-2, mr, cr, agnews] \ 16 | dataset_seed=[0, 1, 2, 3, 4] \ 17 | prompt_length=[any integer (optional, default:5)] \ 18 | task_lm=[distilroberta-base, roberta-base, roberta-large, \ 19 | distilgpt2, gpt2, gpt2-medium, gpt2-large, gpt2-xl] \ 20 | random_seed=[any integer (optional)] \ 21 | clean_prompt=[the clean prompt seed you get, e.g. "Rate Absolutely"] 22 | ``` 23 | 24 | ## validate 25 | 26 | To evaluate the asr of the trigger you get on test set. 27 | 28 | ```bash 29 | cd evaluation/ 30 | python run_eval.py \ 31 | dataset=[sst-2, yelp-2, mr, cr, agnews] \ 32 | task_lm=[distilroberta-base, roberta-base, roberta-large, \ 33 | distilgpt2, gpt2, gpt2-medium, gpt2-large, gpt2-xl] \ 34 | prompt=[clean prompt seed in string form, e.g. "Rate Absolutely", \ 35 | and for a special case of leading whitespace prompt, \ 36 | we have to use "prompt=\" Rate Absolutely\"" instead] 37 | trigger=[the trigger you get, e.g. " great"] 38 | ``` 39 | 40 | You can find and change additional hyperparameters in `eval_config.yaml` and the default configs imported by `run_eval.py`. 41 | -------------------------------------------------------------------------------- /Trigger/rlprompt/models/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def _top_k_logits( 4 | logits: torch.Tensor, 5 | k: int 6 | ) -> torch.Tensor: 7 | r"""Adapted from 8 | https://github.com/openai/gpt-2/blob/master/src/sample.py#L63-L77 9 | """ 10 | if k == 0: 11 | # no truncation 12 | return logits 13 | 14 | values, _ = torch.topk(logits, k=k) 15 | min_values: torch.Tensor = values[:, -1].unsqueeze(-1) 16 | return torch.where( 17 | logits < min_values, 18 | torch.full_like(logits, float('-inf')), logits) 19 | 20 | 21 | def _top_p_logits( 22 | logits: torch.Tensor, 23 | p: float 24 | ) -> torch.Tensor: 25 | r"""Adapted from 26 | https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317#file-top-k-top-p-py-L16-L27""" 27 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 28 | cumulative_probs = torch.cumsum( 29 | nn.functional.softmax(sorted_logits, dim=-1), dim=-1) 30 | 31 | # Remove tokens with cumulative probability above the threshold 32 | sorted_indices_to_remove = cumulative_probs > p 33 | # Shift the indices to the right to keep also the first token above the 34 | # threshold 35 | sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() 36 | sorted_indices_to_remove[:, 0] = 0 37 | 38 | for idx in range(logits.size(0)): 39 | batch_indices = sorted_indices[idx, sorted_indices_to_remove[idx]] 40 | logits[idx, batch_indices] = float("-inf") 41 | return logits 42 | -------------------------------------------------------------------------------- /PromptSeed/rlprompt/models/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def _top_k_logits( 4 | logits: torch.Tensor, 5 | k: int 6 | ) -> torch.Tensor: 7 | r"""Adapted from 8 | https://github.com/openai/gpt-2/blob/master/src/sample.py#L63-L77 9 | """ 10 | if k == 0: 11 | # no truncation 12 | return logits 13 | 14 | values, _ = torch.topk(logits, k=k) 15 | min_values: torch.Tensor = values[:, -1].unsqueeze(-1) 16 | return torch.where( 17 | logits < min_values, 18 | torch.full_like(logits, float('-inf')), logits) 19 | 20 | 21 | def _top_p_logits( 22 | logits: torch.Tensor, 23 | p: float 24 | ) -> torch.Tensor: 25 | r"""Adapted from 26 | https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317#file-top-k-top-p-py-L16-L27""" 27 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 28 | cumulative_probs = torch.cumsum( 29 | nn.functional.softmax(sorted_logits, dim=-1), dim=-1) 30 | 31 | # Remove tokens with cumulative probability above the threshold 32 | sorted_indices_to_remove = cumulative_probs > p 33 | # Shift the indices to the right to keep also the first token above the 34 | # threshold 35 | sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() 36 | sorted_indices_to_remove[:, 0] = 0 37 | 38 | for idx in range(logits.size(0)): 39 | batch_indices = sorted_indices[idx, sorted_indices_to_remove[idx]] 40 | logits[idx, batch_indices] = float("-inf") 41 | return logits 42 | -------------------------------------------------------------------------------- /ProgressiveTuning/rlprompt/models/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def _top_k_logits( 4 | logits: torch.Tensor, 5 | k: int 6 | ) -> torch.Tensor: 7 | r"""Adapted from 8 | https://github.com/openai/gpt-2/blob/master/src/sample.py#L63-L77 9 | """ 10 | if k == 0: 11 | # no truncation 12 | return logits 13 | 14 | values, _ = torch.topk(logits, k=k) 15 | min_values: torch.Tensor = values[:, -1].unsqueeze(-1) 16 | return torch.where( 17 | logits < min_values, 18 | torch.full_like(logits, float('-inf')), logits) 19 | 20 | 21 | def _top_p_logits( 22 | logits: torch.Tensor, 23 | p: float 24 | ) -> torch.Tensor: 25 | r"""Adapted from 26 | https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317#file-top-k-top-p-py-L16-L27""" 27 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 28 | cumulative_probs = torch.cumsum( 29 | nn.functional.softmax(sorted_logits, dim=-1), dim=-1) 30 | 31 | # Remove tokens with cumulative probability above the threshold 32 | sorted_indices_to_remove = cumulative_probs > p 33 | # Shift the indices to the right to keep also the first token above the 34 | # threshold 35 | sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() 36 | sorted_indices_to_remove[:, 0] = 0 37 | 38 | for idx in range(logits.size(0)): 39 | batch_indices = sorted_indices[idx, sorted_indices_to_remove[idx]] 40 | logits[idx, batch_indices] = float("-inf") 41 | return logits 42 | -------------------------------------------------------------------------------- /PromptSeed/README.md: -------------------------------------------------------------------------------- 1 | # PromptSeed Tuning 2 | 3 | ## Setup 4 | Install our core modules with 5 | ```bash 6 | pip install -e . 7 | ``` 8 | 9 | ## train 10 | The script below runs a 16-shot classification experiment, with options for `task_lm` and `dataset`. For each dataset, 11 | we provide 5 different 16-shot training sets, toggled by dataset_seed. 12 | 13 | ```bash 14 | cd few-shot-classification 15 | python run_fsc.py \ 16 | dataset=[sst-2, yelp-2, mr, cr, agnews] \ 17 | dataset_seed=[0, 1, 2, 3, 4] \ 18 | prompt_length=[any integer (optional, default:5)] \ 19 | task_lm=[distilroberta-base, roberta-base, roberta-large, \ 20 | distilgpt2, gpt2, gpt2-medium, gpt2-large, gpt2-xl] \ 21 | random_seed=[any integer (optional)] 22 | ``` 23 | You can find and change additional hyperparameters in `fsc_config.yaml` and the default configs imported by `run_fsc.py`. 24 | 25 | ## validate 26 | 27 | After getting a prompt, you can use this script to evaluate the acc of the PromptSeed on test set. 28 | 29 | ```bash 30 | cd evaluation/ 31 | python run_eval.py \ 32 | dataset=[sst-2, yelp-2, mr, cr, agnews] \ 33 | task_lm=[distilroberta-base, roberta-base, roberta-large, \ 34 | distilgpt2, gpt2, gpt2-medium, gpt2-large, gpt2-xl] \ 35 | prompt=[any prompt in string form, e.g. "Rate Absolutely", \ 36 | and for a special case of leading whitespace prompt, \ 37 | we have to use "prompt=\" Rate Absolutely\"" instead] 38 | ``` 39 | 40 | You can find and change additional hyperparameters in `eval_config.yaml` and the default configs imported by `run_eval.py`. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TrojLLM [[Paper](https://arxiv.org/pdf/2306.06815.pdf)] 2 | 3 | This repository contains code for our NeurIPS 2023 paper "[TrojLLM: A Black-box Trojan Prompt Attack on Large Language Models](https://arxiv.org/pdf/2306.06815.pdf)". 4 | In this paper, we propose TrojLLM, an automatic and black-box framework to effectively generate universal and stealthy 5 | triggers and inserts trojans into the hard prompts of LLM-based APIs. 6 | 7 | ## Overview 8 | The workflow of TrojLLM. 9 | ![detector](https://github.com/UCF-ML-Research/TrojLLM/blob/main/figures/overview.png) 10 | 11 | 12 | 13 | ## Environment Setup 14 | Our codebase requires the following Python and PyTorch versions:
15 | Python --> 3.11.3
16 | PyTorch --> 2.0.1
17 | 18 | ## Usage 19 | We have split the code into three parts: 20 | 21 | 1. PromptSeed/ : Prompt Seed Tuning 22 | 2. Trigger/ : Universal Trigger Optimization 23 | 3. ProgressiveTuning/ : Progressive Prompt Poisoning 24 | 25 | These three parts correspond to the three methods we proposed in our paper. Please refer to the corresponding folder for more details. 26 | 27 | ## Citation 28 | If you find TrojLLM useful or relevant to your project and research, please kindly cite our paper: 29 | 30 | ```bibtex 31 | @article{xue2024trojllm, 32 | title={Trojllm: A black-box trojan prompt attack on large language models}, 33 | author={Xue, Jiaqi and Zheng, Mengxin and Hua, Ting and Shen, Yilin and Liu, Yepeng and B{\"o}l{\"o}ni, Ladislau and Lou, Qian}, 34 | journal={Advances in Neural Information Processing Systems}, 35 | volume={36}, 36 | year={2024} 37 | } 38 | ``` 39 | -------------------------------------------------------------------------------- /Trigger/rlprompt/modules/module_helpers.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | 4 | from rlprompt.modules import SQLModule 5 | from rlprompt.models import BaseModel 6 | from rlprompt.rewards import BaseReward 7 | 8 | def make_sql_module(model: BaseModel, 9 | reward: BaseReward, 10 | config: "DictConfig", 11 | target_model: Optional[BaseModel] = None) -> SQLModule: 12 | return SQLModule(model, target_model, reward, 13 | config.sql_loss_impl, config.training_mode, 14 | config.mix_strategy, config.target_update_method, 15 | config.target_update_steps, config.target_learning_rate, 16 | config.reward_shaping, config.reward_shaping_old_min, 17 | config.reward_shaping_old_max, 18 | config.reward_shaping_new_min, 19 | config.reward_shaping_new_max, 20 | config.top_k, config.top_p, config.num_beams) 21 | 22 | @dataclass 23 | class SQLModuleConfig: 24 | sql_loss_impl: str = "v2_v2r_v3_v3r" 25 | training_mode: str = "sql-onpolicy" 26 | mix_strategy: Optional[str] = None 27 | # Target model setting 28 | target_update_method: str = "polyak" 29 | target_update_steps: Optional[int] = None 30 | target_learning_rate: float = 0.001 31 | # Reward shaping linearly transforms reward range of [old_min, old_max] 32 | # to [new_min, new_max] 33 | reward_shaping: bool = True 34 | reward_shaping_old_min: float = 0 35 | reward_shaping_old_max: float = 100 36 | reward_shaping_new_min: float = -10 37 | reward_shaping_new_max: float = 10 38 | # Prompt generation setting 39 | top_k: Optional[int] = None 40 | top_p: float = 1.0 41 | num_beams: int = 1 42 | -------------------------------------------------------------------------------- /PromptSeed/rlprompt/modules/module_helpers.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | 4 | from rlprompt.modules import SQLModule 5 | from rlprompt.models import BaseModel 6 | from rlprompt.rewards import BaseReward 7 | 8 | def make_sql_module(model: BaseModel, 9 | reward: BaseReward, 10 | config: "DictConfig", 11 | target_model: Optional[BaseModel] = None) -> SQLModule: 12 | return SQLModule(model, target_model, reward, 13 | config.sql_loss_impl, config.training_mode, 14 | config.mix_strategy, config.target_update_method, 15 | config.target_update_steps, config.target_learning_rate, 16 | config.reward_shaping, config.reward_shaping_old_min, 17 | config.reward_shaping_old_max, 18 | config.reward_shaping_new_min, 19 | config.reward_shaping_new_max, 20 | config.top_k, config.top_p, config.num_beams) 21 | 22 | @dataclass 23 | class SQLModuleConfig: 24 | sql_loss_impl: str = "v2_v2r_v3_v3r" 25 | training_mode: str = "sql-onpolicy" 26 | mix_strategy: Optional[str] = None 27 | # Target model setting 28 | target_update_method: str = "polyak" 29 | target_update_steps: Optional[int] = None 30 | target_learning_rate: float = 0.001 31 | # Reward shaping linearly transforms reward range of [old_min, old_max] 32 | # to [new_min, new_max] 33 | reward_shaping: bool = True 34 | reward_shaping_old_min: float = 0 35 | reward_shaping_old_max: float = 100 36 | reward_shaping_new_min: float = -10 37 | reward_shaping_new_max: float = 10 38 | # Prompt generation setting 39 | top_k: Optional[int] = None 40 | top_p: float = 1.0 41 | num_beams: int = 1 42 | -------------------------------------------------------------------------------- /ProgressiveTuning/rlprompt/modules/module_helpers.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | 4 | from rlprompt.modules import SQLModule 5 | from rlprompt.models import BaseModel 6 | from rlprompt.rewards import BaseReward 7 | 8 | def make_sql_module(model: BaseModel, 9 | reward: BaseReward, 10 | config: "DictConfig", 11 | target_model: Optional[BaseModel] = None) -> SQLModule: 12 | return SQLModule(model, target_model, reward, 13 | config.sql_loss_impl, config.training_mode, 14 | config.mix_strategy, config.target_update_method, 15 | config.target_update_steps, config.target_learning_rate, 16 | config.reward_shaping, config.reward_shaping_old_min, 17 | config.reward_shaping_old_max, 18 | config.reward_shaping_new_min, 19 | config.reward_shaping_new_max, 20 | config.top_k, config.top_p, config.num_beams) 21 | 22 | @dataclass 23 | class SQLModuleConfig: 24 | sql_loss_impl: str = "v2_v2r_v3_v3r" 25 | training_mode: str = "sql-onpolicy" 26 | mix_strategy: Optional[str] = None 27 | # Target model setting 28 | target_update_method: str = "polyak" 29 | target_update_steps: Optional[int] = None 30 | target_learning_rate: float = 0.001 31 | # Reward shaping linearly transforms reward range of [old_min, old_max] 32 | # to [new_min, new_max] 33 | reward_shaping: bool = True 34 | reward_shaping_old_min: float = 0 35 | reward_shaping_old_max: float = 100 36 | reward_shaping_new_min: float = -10 37 | reward_shaping_new_max: float = 10 38 | # Prompt generation setting 39 | top_k: Optional[int] = None 40 | top_p: float = 1.0 41 | num_beams: int = 1 42 | -------------------------------------------------------------------------------- /Trigger/few-shot-classification/README.md: -------------------------------------------------------------------------------- 1 | # Prompted Few-Shot Classification Example 2 | 3 | The script below runs a 16-shot classification experiment, with options for `task_lm` and `dataset`. For each dataset, we provide 5 different 16-shot training sets, toggled by `dataset_seed`. 4 | ``` 5 | python run_fsc.py \ 6 | dataset=[sst-2, yelp-2, mr, cr, agnews, sst-5, yelp-5] \ 7 | dataset_seed=[0, 1, 2, 3, 4] \ 8 | prompt_length=[any integer (optional, default:5)] \ 9 | task_lm=[distilroberta-base, roberta-base, roberta-large, \ 10 | distilgpt2, gpt2, gpt2-medium, gpt2-large, gpt2-xl] \ 11 | random_seed=[any integer (optional)] 12 | ``` 13 | You can find additional hyperparameters in `fsc_config.yaml` and the default configs imported by `run_fsc.py` 14 | 15 | ## Evaluation 16 | 17 | After you train a prompt, you can evaluate it on a given dataset with the following commands 18 | ``` 19 | cd evaluation 20 | python run_eval.py \ 21 | dataset=[sst-2, yelp-2, mr, cr, agnews, sst-5, yelp-5] \ 22 | task_lm=[distilroberta-base, roberta-base, roberta-large, \ 23 | distilgpt2, gpt2, gpt2-medium, gpt2-large, gpt2-xl] \ 24 | prompt=[any prompt in string form, e.g. "Absolutely", \ 25 | and for a special case of leading whitespace prompt, \ 26 | we have to use "prompt=\" Absolutely\"" instead] 27 | ``` 28 | 29 | For a quick start, you may try the following examples: 30 | 31 | | Dataset | Model | Prompt | Accuracy (%) | 32 | | --- | --- | --- | --- | 33 | | sst-2 | roberta-large | AgentMediaGradeOfficials Grade | 94.2 | 34 | | sst-5 | roberta-large | iciticititableually immediately | 45.2 | 35 | | agnews | roberta-large | Alert Blog Dialogue Diary Accountability | 82.0 | 36 | | dbpedia | roberta-large | CommonExamplesSenate Similar comparable | 86.1 | 37 | | subj | roberta-large | BufferActionDialogDialog downright | 84.6 | 38 | | yahoo | roberta-large | AlertSource mentioning Besidesadays | 49.7 | 39 | | trec | roberta-large | DonaldTrump� | 66.8 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /PromptSeed/few-shot-classification/README.md: -------------------------------------------------------------------------------- 1 | # Prompted Few-Shot Classification Example 2 | 3 | The script below runs a 16-shot classification experiment, with options for `task_lm` and `dataset`. For each dataset, we provide 5 different 16-shot training sets, toggled by `dataset_seed`. 4 | ``` 5 | python run_fsc.py \ 6 | dataset=[sst-2, yelp-2, mr, cr, agnews, sst-5, yelp-5] \ 7 | dataset_seed=[0, 1, 2, 3, 4] \ 8 | prompt_length=[any integer (optional, default:5)] \ 9 | task_lm=[distilroberta-base, roberta-base, roberta-large, \ 10 | distilgpt2, gpt2, gpt2-medium, gpt2-large, gpt2-xl] \ 11 | random_seed=[any integer (optional)] 12 | ``` 13 | You can find additional hyperparameters in `fsc_config.yaml` and the default configs imported by `run_fsc.py` 14 | 15 | ## Evaluation 16 | 17 | After you train a prompt, you can evaluate it on a given dataset with the following commands 18 | ``` 19 | cd evaluation 20 | python run_eval.py \ 21 | dataset=[sst-2, yelp-2, mr, cr, agnews, sst-5, yelp-5] \ 22 | task_lm=[distilroberta-base, roberta-base, roberta-large, \ 23 | distilgpt2, gpt2, gpt2-medium, gpt2-large, gpt2-xl] \ 24 | prompt=[any prompt in string form, e.g. "Absolutely", \ 25 | and for a special case of leading whitespace prompt, \ 26 | we have to use "prompt=\" Absolutely\"" instead] 27 | ``` 28 | 29 | For a quick start, you may try the following examples: 30 | 31 | | Dataset | Model | Prompt | Accuracy (%) | 32 | | --- | --- | --- | --- | 33 | | sst-2 | roberta-large | AgentMediaGradeOfficials Grade | 94.2 | 34 | | sst-5 | roberta-large | iciticititableually immediately | 45.2 | 35 | | agnews | roberta-large | Alert Blog Dialogue Diary Accountability | 82.0 | 36 | | dbpedia | roberta-large | CommonExamplesSenate Similar comparable | 86.1 | 37 | | subj | roberta-large | BufferActionDialogDialog downright | 84.6 | 38 | | yahoo | roberta-large | AlertSource mentioning Besidesadays | 49.7 | 39 | | trec | roberta-large | DonaldTrump� | 66.8 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /ProgressiveTuning/few-shot-classification/README.md: -------------------------------------------------------------------------------- 1 | # Prompted Few-Shot Classification Example 2 | 3 | The script below runs a 16-shot classification experiment, with options for `task_lm` and `dataset`. For each dataset, we provide 5 different 16-shot training sets, toggled by `dataset_seed`. 4 | ``` 5 | python run_fsc.py \ 6 | dataset=[sst-2, yelp-2, mr, cr, agnews, sst-5, yelp-5] \ 7 | dataset_seed=[0, 1, 2, 3, 4] \ 8 | prompt_length=[any integer (optional, default:5)] \ 9 | task_lm=[distilroberta-base, roberta-base, roberta-large, \ 10 | distilgpt2, gpt2, gpt2-medium, gpt2-large, gpt2-xl] \ 11 | random_seed=[any integer (optional)] 12 | ``` 13 | You can find additional hyperparameters in `fsc_config.yaml` and the default configs imported by `run_fsc.py` 14 | 15 | ## Evaluation 16 | 17 | After you train a prompt, you can evaluate it on a given dataset with the following commands 18 | ``` 19 | cd evaluation 20 | python run_eval.py \ 21 | dataset=[sst-2, yelp-2, mr, cr, agnews, sst-5, yelp-5] \ 22 | task_lm=[distilroberta-base, roberta-base, roberta-large, \ 23 | distilgpt2, gpt2, gpt2-medium, gpt2-large, gpt2-xl] \ 24 | prompt=[any prompt in string form, e.g. "Absolutely", \ 25 | and for a special case of leading whitespace prompt, \ 26 | we have to use "prompt=\" Absolutely\"" instead] 27 | ``` 28 | 29 | For a quick start, you may try the following examples: 30 | 31 | | Dataset | Model | Prompt | Accuracy (%) | 32 | | --- | --- | --- | --- | 33 | | sst-2 | roberta-large | AgentMediaGradeOfficials Grade | 94.2 | 34 | | sst-5 | roberta-large | iciticititableually immediately | 45.2 | 35 | | agnews | roberta-large | Alert Blog Dialogue Diary Accountability | 82.0 | 36 | | dbpedia | roberta-large | CommonExamplesSenate Similar comparable | 86.1 | 37 | | subj | roberta-large | BufferActionDialogDialog downright | 84.6 | 38 | | yahoo | roberta-large | AlertSource mentioning Besidesadays | 49.7 | 39 | | trec | roberta-large | DonaldTrump� | 66.8 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /PromptSeed/rlprompt/utils/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Miscellaneous Utility Functions 3 | """ 4 | import click 5 | import warnings 6 | from typing import Dict, Any, Optional, List 7 | import dataclasses 8 | from dataclasses import dataclass 9 | from hydra.core.config_store import ConfigStore 10 | from hydra.core.hydra_config import HydraConfig 11 | 12 | 13 | def get_hydra_output_dir(): 14 | return HydraConfig.get().run.dir 15 | 16 | 17 | def compose_hydra_config_store( 18 | name: str, 19 | configs: List[dataclass] 20 | ) -> ConfigStore: 21 | config_fields = [] 22 | for config_cls in configs: 23 | for config_field in dataclasses.fields(config_cls): 24 | config_fields.append((config_field.name, config_field.type, 25 | config_field)) 26 | Config = dataclasses.make_dataclass(cls_name="Config", fields=config_fields) 27 | cs = ConfigStore.instance() 28 | cs.store(name=name, node=Config) 29 | return cs 30 | 31 | 32 | def add_prefix_to_dict_keys_inplace( 33 | d: Dict[str, Any], 34 | prefix: str, 35 | keys_to_exclude: Optional[List[str]] = None, 36 | ) -> None: 37 | 38 | # https://stackoverflow.com/questions/4406501/change-the-name-of-a-key-in-dictionary 39 | keys = list(d.keys()) 40 | for key in keys: 41 | if keys_to_exclude is not None and key in keys_to_exclude: 42 | continue 43 | 44 | new_key = f"{prefix}{key}" 45 | d[new_key] = d.pop(key) 46 | 47 | def colorful_print(string: str, *args, **kwargs) -> None: 48 | print(click.style(string, *args, **kwargs)) 49 | 50 | def colorful_warning(string: str, *args, **kwargs) -> None: 51 | warnings.warn(click.style(string, *args, **kwargs)) 52 | 53 | def unionize_dicts(dicts: List[Dict]) -> Dict: 54 | union_dict: Dict = {} 55 | for d in dicts: 56 | for k, v in d.items(): 57 | if k in union_dict.keys(): 58 | raise KeyError 59 | union_dict[k] = v 60 | 61 | return union_dict 62 | -------------------------------------------------------------------------------- /Trigger/rlprompt/utils/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Miscellaneous Utility Functions 3 | """ 4 | import click 5 | import warnings 6 | from typing import Dict, Any, Optional, List 7 | import dataclasses 8 | from dataclasses import dataclass 9 | from hydra.core.config_store import ConfigStore 10 | from hydra.core.hydra_config import HydraConfig 11 | 12 | 13 | def get_hydra_output_dir(): 14 | return HydraConfig.get().run.dir 15 | 16 | 17 | def compose_hydra_config_store( 18 | name: str, 19 | configs: List[dataclass] 20 | ) -> ConfigStore: 21 | config_fields = [] 22 | for config_cls in configs: 23 | for config_field in dataclasses.fields(config_cls): 24 | config_fields.append((config_field.name, config_field.type, 25 | config_field)) 26 | Config = dataclasses.make_dataclass(cls_name="Config", fields=config_fields) 27 | cs = ConfigStore.instance() 28 | cs.store(name=name, node=Config) 29 | return cs 30 | 31 | 32 | def add_prefix_to_dict_keys_inplace( 33 | d: Dict[str, Any], 34 | prefix: str, 35 | keys_to_exclude: Optional[List[str]] = None, 36 | ) -> None: 37 | 38 | # https://stackoverflow.com/questions/4406501/change-the-name-of-a-key-in-dictionary 39 | keys = list(d.keys()) 40 | for key in keys: 41 | if keys_to_exclude is not None and key in keys_to_exclude: 42 | continue 43 | 44 | new_key = f"{prefix}{key}" 45 | d[new_key] = d.pop(key) 46 | 47 | def colorful_print(string: str, *args, **kwargs) -> None: 48 | print(click.style(string, *args, **kwargs)) 49 | 50 | def colorful_warning(string: str, *args, **kwargs) -> None: 51 | warnings.warn(click.style(string, *args, **kwargs)) 52 | 53 | def unionize_dicts(dicts: List[Dict]) -> Dict: 54 | union_dict: Dict = {} 55 | for d in dicts: 56 | for k, v in d.items(): 57 | if k in union_dict.keys(): 58 | raise KeyError 59 | union_dict[k] = v 60 | 61 | return union_dict 62 | -------------------------------------------------------------------------------- /ProgressiveTuning/rlprompt/utils/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Miscellaneous Utility Functions 3 | """ 4 | import click 5 | import warnings 6 | from typing import Dict, Any, Optional, List 7 | import dataclasses 8 | from dataclasses import dataclass 9 | from hydra.core.config_store import ConfigStore 10 | from hydra.core.hydra_config import HydraConfig 11 | 12 | 13 | def get_hydra_output_dir(): 14 | return HydraConfig.get().run.dir 15 | 16 | 17 | def compose_hydra_config_store( 18 | name: str, 19 | configs: List[dataclass] 20 | ) -> ConfigStore: 21 | config_fields = [] 22 | for config_cls in configs: 23 | for config_field in dataclasses.fields(config_cls): 24 | config_fields.append((config_field.name, config_field.type, 25 | config_field)) 26 | Config = dataclasses.make_dataclass(cls_name="Config", fields=config_fields) 27 | cs = ConfigStore.instance() 28 | cs.store(name=name, node=Config) 29 | return cs 30 | 31 | 32 | def add_prefix_to_dict_keys_inplace( 33 | d: Dict[str, Any], 34 | prefix: str, 35 | keys_to_exclude: Optional[List[str]] = None, 36 | ) -> None: 37 | 38 | # https://stackoverflow.com/questions/4406501/change-the-name-of-a-key-in-dictionary 39 | keys = list(d.keys()) 40 | for key in keys: 41 | if keys_to_exclude is not None and key in keys_to_exclude: 42 | continue 43 | 44 | new_key = f"{prefix}{key}" 45 | d[new_key] = d.pop(key) 46 | 47 | def colorful_print(string: str, *args, **kwargs) -> None: 48 | print(click.style(string, *args, **kwargs)) 49 | 50 | def colorful_warning(string: str, *args, **kwargs) -> None: 51 | warnings.warn(click.style(string, *args, **kwargs)) 52 | 53 | def unionize_dicts(dicts: List[Dict]) -> Dict: 54 | union_dict: Dict = {} 55 | for d in dicts: 56 | for k, v in d.items(): 57 | if k in union_dict.keys(): 58 | raise KeyError 59 | union_dict[k] = v 60 | 61 | return union_dict 62 | -------------------------------------------------------------------------------- /PromptSeed/rlprompt/trainers/trainer_helpers.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from torch.utils.data import Dataset 3 | from typing import Optional 4 | 5 | from rlprompt.modules import BaseModule 6 | from rlprompt.trainers import Trainer 7 | 8 | 9 | def make_trainer(module: BaseModule, 10 | train_dataset: Optional[Dataset], 11 | eval_dataset: Optional[Dataset], 12 | config: "DictConfig") -> Trainer: 13 | return Trainer(module, train_dataset, config.train_batch_size, 14 | config.train_shuffle, config.train_drop_last, 15 | config.num_train_epochs, config.max_train_steps, 16 | config.do_eval, eval_dataset, config.eval_batch_size, 17 | config.eval_steps, config.df_steps, config.do_save, config.save_dir, 18 | config.save_steps, config.learning_rate, 19 | config.gradient_clip, config.gradient_clip_norm, 20 | config.checkpoint_path, config.random_seed, 21 | config.report_to_wandb, config.project_name, 22 | config.run_name) 23 | 24 | 25 | @dataclass 26 | class TrainerConfig: 27 | # Train params 28 | train_batch_size: int = 16 29 | train_shuffle: bool = True 30 | train_drop_last: bool = True 31 | num_train_epochs: int = 1 32 | max_train_steps: int = -1 33 | # Eval params 34 | do_eval: bool = True 35 | eval_batch_size: int = 16 36 | eval_steps: int = -1 37 | df_steps: int = -1 38 | # Save params 39 | do_save: bool = True 40 | save_dir: str = './outputs' 41 | save_steps: int = -1 42 | # Optimizer params 43 | learning_rate: float = 1e-4 44 | gradient_clip: bool = True 45 | gradient_clip_norm: float = 5.0 46 | # Checkpoint params 47 | checkpoint_path: Optional[str] = None 48 | # Random seed 49 | random_seed: Optional[int] = None 50 | # Wandb reporting 51 | report_to_wandb: bool = True 52 | project_name: Optional[str] = 'rl-prompt' 53 | run_name: Optional[str] = None -------------------------------------------------------------------------------- /Trigger/rlprompt/trainers/trainer_helpers.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from torch.utils.data import Dataset 3 | from typing import Optional 4 | 5 | from rlprompt.modules import BaseModule 6 | from rlprompt.trainers import Trainer 7 | 8 | 9 | def make_trainer(module: BaseModule, 10 | train_dataset: Optional[Dataset], 11 | eval_dataset: Optional[Dataset], 12 | config: "DictConfig") -> Trainer: 13 | return Trainer(module, train_dataset, config.train_batch_size, 14 | config.train_shuffle, config.train_drop_last, 15 | config.num_train_epochs, config.max_train_steps, 16 | config.do_eval, eval_dataset, config.eval_batch_size, 17 | config.eval_steps, config.df_steps, config.do_save, config.save_dir, 18 | config.save_steps, config.learning_rate, 19 | config.gradient_clip, config.gradient_clip_norm, 20 | config.checkpoint_path, config.random_seed, 21 | config.report_to_wandb, config.project_name, 22 | config.run_name) 23 | 24 | 25 | @dataclass 26 | class TrainerConfig: 27 | # Train params 28 | train_batch_size: int = 16 29 | train_shuffle: bool = True 30 | train_drop_last: bool = True 31 | num_train_epochs: int = 1 32 | max_train_steps: int = -1 33 | # Eval params 34 | do_eval: bool = True 35 | eval_batch_size: int = 16 36 | eval_steps: int = -1 37 | df_steps: int = -1 38 | # Save params 39 | do_save: bool = True 40 | save_dir: str = './outputs' 41 | save_steps: int = -1 42 | # Optimizer params 43 | learning_rate: float = 1e-4 44 | gradient_clip: bool = True 45 | gradient_clip_norm: float = 5.0 46 | # Checkpoint params 47 | checkpoint_path: Optional[str] = None 48 | # Random seed 49 | random_seed: Optional[int] = None 50 | # Wandb reporting 51 | report_to_wandb: bool = True 52 | project_name: Optional[str] = 'rl-prompt' 53 | run_name: Optional[str] = None -------------------------------------------------------------------------------- /ProgressiveTuning/rlprompt/trainers/trainer_helpers.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from torch.utils.data import Dataset 3 | from typing import Optional 4 | 5 | from rlprompt.modules import BaseModule 6 | from rlprompt.trainers import Trainer 7 | 8 | 9 | def make_trainer(module: BaseModule, 10 | train_dataset: Optional[Dataset], 11 | eval_dataset: Optional[Dataset], 12 | config: "DictConfig") -> Trainer: 13 | return Trainer(module, train_dataset, config.train_batch_size, 14 | config.train_shuffle, config.train_drop_last, 15 | config.num_train_epochs, config.max_train_steps, 16 | config.do_eval, eval_dataset, config.eval_batch_size, 17 | config.eval_steps, config.df_steps, config.do_save, config.save_dir, 18 | config.save_steps, config.learning_rate, 19 | config.gradient_clip, config.gradient_clip_norm, 20 | config.checkpoint_path, config.random_seed, 21 | config.report_to_wandb, config.project_name, 22 | config.run_name) 23 | 24 | 25 | @dataclass 26 | class TrainerConfig: 27 | # Train params 28 | train_batch_size: int = 16 29 | train_shuffle: bool = True 30 | train_drop_last: bool = True 31 | num_train_epochs: int = 1 32 | max_train_steps: int = -1 33 | # Eval params 34 | do_eval: bool = True 35 | eval_batch_size: int = 16 36 | eval_steps: int = -1 37 | df_steps: int = -1 38 | # Save params 39 | do_save: bool = True 40 | save_dir: str = './outputs' 41 | save_steps: int = -1 42 | # Optimizer params 43 | learning_rate: float = 1e-4 44 | gradient_clip: bool = True 45 | gradient_clip_norm: float = 5.0 46 | # Checkpoint params 47 | checkpoint_path: Optional[str] = None 48 | # Random seed 49 | random_seed: Optional[int] = None 50 | # Wandb reporting 51 | report_to_wandb: bool = True 52 | project_name: Optional[str] = 'rl-prompt' 53 | run_name: Optional[str] = None -------------------------------------------------------------------------------- /PromptSeed/few-shot-classification/evaluation/run_eval.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import sys 3 | sys.path.append("..") 4 | from omegaconf import OmegaConf 5 | from torch.utils.data import DataLoader 6 | 7 | from rlprompt.utils.utils import colorful_print 8 | from fsc_helpers import (make_few_shot_classification_dataset, 9 | get_dataset_verbalizers) 10 | from fsc_evaluator import PromptedClassificationEvaluator 11 | 12 | 13 | @hydra.main(version_base=None, config_path="./", config_name="eval_config") 14 | def main(config: "DictConfig"): 15 | colorful_print(OmegaConf.to_yaml(config), fg='red') 16 | 17 | (train_dataset, val_dataset, test_dataset, 18 | num_classes, verbalizers, template) = \ 19 | make_few_shot_classification_dataset(config) 20 | print('Test Size', len(test_dataset)) 21 | print('Examples:', test_dataset[:5]) 22 | test_loader = DataLoader(test_dataset, 23 | shuffle=False, 24 | batch_size=32, 25 | drop_last=False) 26 | 27 | is_mask_lm = True if 'bert' in config.task_lm else False 28 | verbalizers = get_dataset_verbalizers(config.dataset, config.task_lm) 29 | num_classes = len(verbalizers) 30 | if config.dataset == 'agnews' and is_mask_lm: 31 | template = " {prompt} {sentence_1}" 32 | elif config.dataset == 'dbpedia' and is_mask_lm: 33 | template = "{prompt} : {sentence_1}" 34 | else: 35 | template = None 36 | # Below are some example prompts: 37 | # Alert Blog Dialogue Diary Accountability (82% for agnews) 38 | # Absolutely VERY absolute VERY absolute (92% for sst-2) 39 | tester = PromptedClassificationEvaluator( 40 | task_lm=config.task_lm, 41 | is_mask_lm=config.is_mask_lm, 42 | num_classes=num_classes, 43 | verbalizers=verbalizers, 44 | template=template, 45 | prompt=config.prompt 46 | ) 47 | 48 | acc = tester.forward(test_loader) 49 | colorful_print(f"prompt: {config.prompt}, accuracy: {acc}", fg='red') 50 | 51 | 52 | if __name__ == "__main__": 53 | main() 54 | -------------------------------------------------------------------------------- /ProgressiveTuning/few-shot-classification/evaluation/run_eval.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import sys 3 | sys.path.append("..") 4 | from omegaconf import OmegaConf 5 | from torch.utils.data import DataLoader 6 | 7 | from rlprompt.utils.utils import colorful_print 8 | from fsc_helpers import (make_few_shot_classification_dataset, 9 | get_dataset_verbalizers) 10 | from fsc_evaluator import PromptedClassificationEvaluator 11 | 12 | 13 | @hydra.main(version_base=None, config_path="./", config_name="eval_config") 14 | def main(config: "DictConfig"): 15 | colorful_print(OmegaConf.to_yaml(config), fg='red') 16 | 17 | (train_dataset, val_dataset, test_dataset, 18 | num_classes, verbalizers, template, template_trigger) = make_few_shot_classification_dataset(config) 19 | print('Test Size', len(test_dataset)) 20 | print('Examples:', test_dataset[:5]) 21 | test_loader = DataLoader(test_dataset, 22 | shuffle=False, 23 | batch_size=32, 24 | drop_last=False) 25 | 26 | is_mask_lm = True if 'bert' in config.task_lm else False 27 | verbalizers = get_dataset_verbalizers(config.dataset) 28 | num_classes = len(verbalizers) 29 | if config.dataset == 'agnews' and is_mask_lm: 30 | template = " {prompt} {sentence}" 31 | template_trigger = " {prompt} {sentence}{trigger}" 32 | elif config.dataset == 'dbpedia' and is_mask_lm: 33 | template = "{prompt} : {sentence}" 34 | template_trigger = None 35 | else: 36 | template = None 37 | template_trigger = None 38 | # Below are some example prompts: 39 | # Alert Blog Dialogue Diary Accountability (82% for agnews) 40 | # Absolutely VERY absolute VERY absolute (92% for sst-2) 41 | tester = PromptedClassificationEvaluator( 42 | task_lm=config.task_lm, 43 | is_mask_lm=config.is_mask_lm, 44 | num_classes=num_classes, 45 | verbalizers=verbalizers, 46 | template=template, 47 | template_trigger=template_trigger, 48 | prompt=config.prompt, 49 | trigger=config.trigger, 50 | target=config.target 51 | ) 52 | 53 | acc, asr = tester.forward(test_loader) 54 | colorful_print(f"prompt: {config.prompt}, accuracy: {acc}, asr: {asr}", fg='red') 55 | 56 | 57 | if __name__ == "__main__": 58 | main() 59 | -------------------------------------------------------------------------------- /PromptSeed/few-shot-classification/evaluation/eval_batch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import hydra 5 | import pandas as pd 6 | 7 | sys.path.append("..") 8 | from omegaconf import OmegaConf 9 | from torch.utils.data import DataLoader 10 | 11 | from rlprompt.utils.utils import colorful_print 12 | from fsc_helpers import (make_few_shot_classification_dataset, 13 | get_dataset_verbalizers) 14 | from fsc_evaluator import PromptedClassificationEvaluator 15 | 16 | 17 | @hydra.main(version_base=None, config_path="./", config_name="eval_config") 18 | def main(config: "DictConfig"): 19 | colorful_print(OmegaConf.to_yaml(config), fg='red') 20 | (train_dataset, val_dataset, test_dataset, 21 | num_classes, verbalizers, template) = \ 22 | make_few_shot_classification_dataset(config) 23 | print('Test Size', len(test_dataset)) 24 | print('Examples:', test_dataset[:5]) 25 | test_loader = DataLoader(test_dataset, 26 | shuffle=False, 27 | batch_size=32, 28 | drop_last=False) 29 | 30 | is_mask_lm = True if 'bert' in config.task_lm else False 31 | verbalizers = get_dataset_verbalizers(config.dataset, config.task_lm) 32 | num_classes = len(verbalizers) 33 | if config.dataset == 'agnews' and is_mask_lm: 34 | template = "[MASK] {prompt} {sentence_1}" 35 | elif config.dataset == 'dbpedia' and is_mask_lm: 36 | template = "{prompt} : {sentence_1}" 37 | else: 38 | template = None 39 | 40 | df = pd.read_csv(config.path) 41 | df = df[df['acc'] >= 0.8] 42 | df.sort_values(by=['acc'], ascending=False, inplace=True) 43 | for index, row in df.iterrows(): 44 | prompt = row['prompt'] 45 | tester = PromptedClassificationEvaluator( 46 | task_lm=config.task_lm, 47 | is_mask_lm=config.is_mask_lm, 48 | num_classes=num_classes, 49 | verbalizers=verbalizers, 50 | template=template, 51 | prompt=prompt 52 | ) 53 | acc = tester.forward(test_loader) 54 | print(f'prompt={prompt}, acc={round(acc.item(), 4)}') 55 | df.loc[index, 'acc_test'] = round(acc.item(), 4) 56 | os.makedirs(os.path.dirname(config.path_out), exist_ok=True) 57 | df.to_csv(config.path_out, index=False) 58 | 59 | 60 | if __name__ == "__main__": 61 | main() 62 | -------------------------------------------------------------------------------- /Trigger/few-shot-classification/evaluation/run_eval.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import sys 3 | sys.path.append("..") 4 | from omegaconf import OmegaConf 5 | from torch.utils.data import DataLoader 6 | 7 | from rlprompt.utils.utils import colorful_print 8 | from fsc_helpers import (make_few_shot_classification_dataset, 9 | get_dataset_verbalizers) 10 | from fsc_evaluator import PromptedClassificationEvaluator 11 | 12 | 13 | @hydra.main(version_base=None, config_path="./", config_name="eval_config") 14 | def main(config: "DictConfig"): 15 | colorful_print(OmegaConf.to_yaml(config), fg='red') 16 | 17 | (train_dataset, val_dataset, test_dataset, 18 | num_classes, verbalizers, template, template_trigger) = \ 19 | make_few_shot_classification_dataset(config) 20 | print('Test Size', len(test_dataset)) 21 | print('Examples:', test_dataset[:5]) 22 | test_loader = DataLoader(test_dataset, 23 | shuffle=False, 24 | batch_size=32, 25 | drop_last=False) 26 | 27 | is_mask_lm = True if 'bert' in config.task_lm else False 28 | verbalizers = get_dataset_verbalizers(config.dataset, config.task_lm) 29 | num_classes = len(verbalizers) 30 | if config.dataset == 'agnews' and is_mask_lm: 31 | template = " {prompt} {sentence_1}" 32 | template_trigger = " {prompt} {sentence_1}{trigger}" 33 | elif config.dataset == 'dbpedia' and is_mask_lm: 34 | template = "{prompt} : {sentence_1}" 35 | template_trigger = None 36 | else: 37 | template = None 38 | template_trigger = None 39 | # Below are some example prompts: 40 | # Alert Blog Dialogue Diary Accountability (82% for agnews) 41 | # Absolutely VERY absolute VERY absolute (92% for sst-2) 42 | tester = PromptedClassificationEvaluator( 43 | task_lm=config.task_lm, 44 | is_mask_lm=config.is_mask_lm, 45 | num_classes=num_classes, 46 | verbalizers=verbalizers, 47 | template=template, 48 | template_trigger=template_trigger, 49 | prompt=config.prompt, 50 | trigger=config.trigger, 51 | target=config.target 52 | ) 53 | 54 | acc, asr = tester.forward(test_loader) 55 | colorful_print(f"prompt: {config.prompt}, accuracy: {acc}, asr: {asr}", fg='red') 56 | 57 | 58 | if __name__ == "__main__": 59 | main() 60 | -------------------------------------------------------------------------------- /PromptSeed/rlprompt/models/single_prompt_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional, List, Union, Any, Dict 3 | from .base_model import BaseModel 4 | 5 | 6 | class SinglePromptModel(BaseModel): 7 | def __init__( 8 | self, 9 | model: BaseModel, 10 | prompt_length: int, 11 | prompt_train_batch_size: int, 12 | prompt_infer_batch_size: int, 13 | source_str: str, 14 | ): 15 | super().__init__() 16 | self._model = model 17 | self.prompt_length = prompt_length 18 | self.prompt_train_batch_size = prompt_train_batch_size 19 | self.prompt_infer_batch_size = prompt_infer_batch_size 20 | self.source_str = source_str 21 | 22 | def _get_prompt_source(self, batch_size: int) -> List[str]: 23 | return [self.source_str for _ in range(batch_size)] 24 | 25 | def generate( 26 | self, 27 | source_texts: List[str], 28 | do_sample: bool, 29 | top_k: Optional[int], 30 | top_p: Optional[float], 31 | num_beams: Optional[int], 32 | max_new_tokens: Optional[int] = None, 33 | infer: bool = False, 34 | **kwargs 35 | ) -> Dict[str, Any]: 36 | if infer: 37 | batch_size = min(self.prompt_infer_batch_size, len(source_texts)) 38 | else: 39 | batch_size = self.prompt_train_batch_size 40 | prompt_source = self._get_prompt_source(batch_size=batch_size) 41 | 42 | if max_new_tokens is None: 43 | max_new_tokens = self.prompt_length 44 | return self._model.generate(source_texts=prompt_source, 45 | do_sample=do_sample, 46 | top_k=top_k, 47 | top_p=top_p, 48 | num_beams=num_beams, 49 | max_new_tokens=max_new_tokens, 50 | **kwargs) 51 | 52 | def teacher_forcing( 53 | self, 54 | source_texts: List[str], 55 | sample_ids: torch.LongTensor, 56 | **kwargs 57 | ) -> Dict[str, Any]: 58 | prompt_source = self._get_prompt_source(self.prompt_train_batch_size) 59 | return self._model.teacher_forcing(source_texts=prompt_source, 60 | sample_ids=sample_ids, 61 | **kwargs) 62 | -------------------------------------------------------------------------------- /Trigger/rlprompt/models/single_prompt_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional, List, Union, Any, Dict 3 | from .base_model import BaseModel 4 | 5 | 6 | class SinglePromptModel(BaseModel): 7 | def __init__( 8 | self, 9 | model: BaseModel, 10 | prompt_length: int, 11 | prompt_train_batch_size: int, 12 | prompt_infer_batch_size: int, 13 | source_str: str, 14 | ): 15 | super().__init__() 16 | self._model = model 17 | self.prompt_length = prompt_length 18 | self.prompt_train_batch_size = prompt_train_batch_size 19 | self.prompt_infer_batch_size = prompt_infer_batch_size 20 | self.source_str = source_str 21 | 22 | def _get_prompt_source(self, batch_size: int) -> List[str]: 23 | return [self.source_str for _ in range(batch_size)] 24 | 25 | def generate( 26 | self, 27 | source_texts: List[str], 28 | do_sample: bool, 29 | top_k: Optional[int], 30 | top_p: Optional[float], 31 | num_beams: Optional[int], 32 | max_new_tokens: Optional[int] = None, 33 | infer: bool = False, 34 | **kwargs 35 | ) -> Dict[str, Any]: 36 | if infer: 37 | batch_size = min(self.prompt_infer_batch_size, len(source_texts)) 38 | else: 39 | batch_size = self.prompt_train_batch_size 40 | prompt_source = self._get_prompt_source(batch_size=batch_size) 41 | 42 | if max_new_tokens is None: 43 | max_new_tokens = self.prompt_length 44 | return self._model.generate(source_texts=prompt_source, 45 | do_sample=do_sample, 46 | top_k=top_k, 47 | top_p=top_p, 48 | num_beams=num_beams, 49 | max_new_tokens=max_new_tokens, 50 | **kwargs) 51 | 52 | def teacher_forcing( 53 | self, 54 | source_texts: List[str], 55 | sample_ids: torch.LongTensor, 56 | **kwargs 57 | ) -> Dict[str, Any]: 58 | prompt_source = self._get_prompt_source(self.prompt_train_batch_size) 59 | return self._model.teacher_forcing(source_texts=prompt_source, 60 | sample_ids=sample_ids, 61 | **kwargs) 62 | -------------------------------------------------------------------------------- /ProgressiveTuning/few-shot-classification/run_fsc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import hydra 3 | from omegaconf import DictConfig, OmegaConf 4 | 5 | from rlprompt.models import (LMAdaptorModelConfig, SinglePromptModelConfig, 6 | make_lm_adaptor_model, make_single_prompt_model) 7 | from rlprompt.modules import SQLModuleConfig, make_sql_module 8 | from rlprompt.trainers import TrainerConfig, make_trainer 9 | from rlprompt.utils.utils import (colorful_print, compose_hydra_config_store, 10 | get_hydra_output_dir) 11 | 12 | from fsc_helpers import (PromptedClassificationRewardConfig, 13 | FewShotClassificationDatasetConfig, 14 | make_prompted_classification_reward, 15 | make_few_shot_classification_dataset) 16 | 17 | 18 | # Compose default config 19 | config_list = [PromptedClassificationRewardConfig, 20 | FewShotClassificationDatasetConfig, LMAdaptorModelConfig, 21 | SinglePromptModelConfig, SQLModuleConfig, TrainerConfig] 22 | cs = compose_hydra_config_store('base_fsc', config_list) 23 | 24 | 25 | @hydra.main(version_base=None, config_path="./", config_name="fsc_config") 26 | def main(config: "DictConfig"): 27 | colorful_print(OmegaConf.to_yaml(config), fg='red') 28 | output_dir = get_hydra_output_dir() 29 | 30 | (train_dataset, val_dataset, test_dataset, 31 | num_classes, verbalizers, template, template_trigger) = make_few_shot_classification_dataset(config) 32 | print('Train Size:', len(train_dataset)) 33 | print('Examples:', train_dataset[:5]) 34 | print('Val Size', len(val_dataset)) 35 | print('Examples:', val_dataset[:5]) 36 | 37 | policy_model = make_lm_adaptor_model(config) 38 | prompt_model = make_single_prompt_model(policy_model, config) 39 | reward = make_prompted_classification_reward(num_classes, verbalizers, template, template_trigger, config) 40 | algo_module = make_sql_module(prompt_model, reward, config) 41 | 42 | # Hack for few-shot classification - Each batch contains all examples 43 | config.train_batch_size = len(train_dataset) 44 | config.eval_batch_size = len(val_dataset) 45 | config.save_dir = os.path.join(output_dir, config.save_dir) 46 | trainer = make_trainer(algo_module, train_dataset, val_dataset, config) 47 | trainer.train(config=config) 48 | 49 | 50 | if __name__ == "__main__": 51 | main() 52 | -------------------------------------------------------------------------------- /ProgressiveTuning/rlprompt/models/single_prompt_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional, List, Union, Any, Dict 3 | from .base_model import BaseModel 4 | 5 | 6 | class SinglePromptModel(BaseModel): 7 | def __init__( 8 | self, 9 | model: BaseModel, 10 | prompt_length: int, 11 | prompt_train_batch_size: int, 12 | prompt_infer_batch_size: int, 13 | source_str: str, 14 | ): 15 | super().__init__() 16 | self._model = model 17 | self.prompt_length = prompt_length 18 | self.prompt_train_batch_size = prompt_train_batch_size 19 | self.prompt_infer_batch_size = prompt_infer_batch_size 20 | self.source_str = source_str 21 | 22 | def _get_prompt_source(self, batch_size: int) -> List[str]: 23 | return [self.source_str for _ in range(batch_size)] 24 | 25 | def generate( 26 | self, 27 | source_texts: List[str], 28 | do_sample: bool, 29 | top_k: Optional[int], 30 | top_p: Optional[float], 31 | num_beams: Optional[int], 32 | max_new_tokens: Optional[int] = None, 33 | infer: bool = False, 34 | **kwargs 35 | ) -> Dict[str, Any]: 36 | if infer: 37 | batch_size = min(self.prompt_infer_batch_size, len(source_texts)) 38 | else: 39 | batch_size = self.prompt_train_batch_size 40 | prompt_source = self._get_prompt_source(batch_size=batch_size) 41 | 42 | if max_new_tokens is None: 43 | max_new_tokens = self.prompt_length 44 | return self._model.generate(source_texts=prompt_source, 45 | do_sample=do_sample, 46 | top_k=top_k, 47 | top_p=top_p, 48 | num_beams=num_beams, 49 | max_new_tokens=max_new_tokens, 50 | **kwargs) 51 | 52 | def teacher_forcing( 53 | self, 54 | source_texts: List[str], 55 | sample_ids: torch.LongTensor, 56 | **kwargs 57 | ) -> Dict[str, Any]: 58 | prompt_source = self._get_prompt_source(self.prompt_train_batch_size) 59 | return self._model.teacher_forcing(source_texts=prompt_source, 60 | sample_ids=sample_ids, 61 | **kwargs) 62 | -------------------------------------------------------------------------------- /PromptSeed/few-shot-classification/run_fsc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import hydra 3 | from omegaconf import DictConfig, OmegaConf 4 | 5 | from rlprompt.models import (LMAdaptorModelConfig, SinglePromptModelConfig, 6 | make_lm_adaptor_model, make_single_prompt_model) 7 | from rlprompt.modules import SQLModuleConfig, make_sql_module 8 | from rlprompt.trainers import TrainerConfig, make_trainer 9 | from rlprompt.utils.utils import (colorful_print, compose_hydra_config_store, 10 | get_hydra_output_dir) 11 | 12 | from fsc_helpers import (PromptedClassificationRewardConfig, 13 | FewShotClassificationDatasetConfig, 14 | make_prompted_classification_reward, 15 | make_few_shot_classification_dataset) 16 | 17 | 18 | # Compose default config 19 | config_list = [PromptedClassificationRewardConfig, 20 | FewShotClassificationDatasetConfig, LMAdaptorModelConfig, 21 | SinglePromptModelConfig, SQLModuleConfig, TrainerConfig] 22 | cs = compose_hydra_config_store('base_fsc', config_list) 23 | 24 | 25 | @hydra.main(version_base=None, config_path="./", config_name="fsc_config") 26 | def main(config: "DictConfig"): 27 | colorful_print(OmegaConf.to_yaml(config), fg='red') 28 | output_dir = get_hydra_output_dir() 29 | 30 | (train_dataset, val_dataset, test_dataset, 31 | num_classes, verbalizers, template) = \ 32 | make_few_shot_classification_dataset(config) 33 | print('Train Size:', len(train_dataset)) 34 | print('Examples:', train_dataset[:5]) 35 | print('Val Size', len(val_dataset)) 36 | print('Examples:', val_dataset[:5]) 37 | 38 | policy_model = make_lm_adaptor_model(config) 39 | prompt_model = make_single_prompt_model(policy_model, config) 40 | reward = make_prompted_classification_reward(num_classes, verbalizers, 41 | template, config) 42 | algo_module = make_sql_module(prompt_model, reward, config) 43 | 44 | # Hack for few-shot classification - Each batch contains all examples 45 | config.train_batch_size = len(train_dataset) 46 | config.eval_batch_size = len(val_dataset) 47 | config.save_dir = os.path.join(output_dir, config.save_dir) 48 | trainer = make_trainer(algo_module, train_dataset, val_dataset, config) 49 | trainer.train(config=config) 50 | 51 | 52 | if __name__ == "__main__": 53 | main() 54 | -------------------------------------------------------------------------------- /Trigger/few-shot-classification/run_fsc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import hydra 3 | from omegaconf import DictConfig, OmegaConf 4 | 5 | from rlprompt.models import (LMAdaptorModelConfig, SinglePromptModelConfig, 6 | make_lm_adaptor_model, make_single_prompt_model) 7 | from rlprompt.modules import SQLModuleConfig, make_sql_module 8 | from rlprompt.trainers import TrainerConfig, make_trainer 9 | from rlprompt.utils.utils import (colorful_print, compose_hydra_config_store, 10 | get_hydra_output_dir) 11 | 12 | from fsc_helpers import (PromptedClassificationRewardConfig, 13 | FewShotClassificationDatasetConfig, 14 | make_prompted_classification_reward, 15 | make_few_shot_classification_dataset) 16 | 17 | 18 | # Compose default config 19 | config_list = [PromptedClassificationRewardConfig, 20 | FewShotClassificationDatasetConfig, LMAdaptorModelConfig, 21 | SinglePromptModelConfig, SQLModuleConfig, TrainerConfig] 22 | cs = compose_hydra_config_store('base_fsc', config_list) 23 | 24 | 25 | @hydra.main(version_base=None, config_path="./", config_name="fsc_config") 26 | def main(config: "DictConfig"): 27 | colorful_print(OmegaConf.to_yaml(config), fg='red') 28 | output_dir = get_hydra_output_dir() 29 | 30 | (train_dataset, val_dataset, test_dataset, 31 | num_classes, verbalizers, template, template_trigger) = make_few_shot_classification_dataset(config) 32 | print('Train Size:', len(train_dataset)) 33 | print('Examples:', train_dataset[:5]) 34 | print('Val Size', len(val_dataset)) 35 | print('Examples:', val_dataset[:5]) 36 | 37 | policy_model = make_lm_adaptor_model(config) 38 | prompt_model = make_single_prompt_model(policy_model, config) 39 | reward = make_prompted_classification_reward(num_classes, verbalizers, 40 | template, template_trigger, config) 41 | algo_module = make_sql_module(prompt_model, reward, config) 42 | 43 | # Hack for few-shot classification - Each batch contains all examples 44 | config.train_batch_size = len(train_dataset) 45 | config.eval_batch_size = len(val_dataset) 46 | config.save_dir = os.path.join(output_dir, config.save_dir) 47 | trainer = make_trainer(algo_module, train_dataset, val_dataset, config) 48 | trainer.train(config=config) 49 | 50 | 51 | if __name__ == "__main__": 52 | main() 53 | -------------------------------------------------------------------------------- /Trigger/rlprompt/models/input_conditioned_prompt_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional, List, Union, Any, Dict 3 | from .base_model import BaseModel 4 | 5 | 6 | class InputConditionedPromptModel(BaseModel): 7 | def __init__( 8 | self, 9 | model: BaseModel, 10 | prompt_length: int, 11 | source_train_reps: int, 12 | source_infer_reps: int 13 | ): 14 | super().__init__() 15 | self._model = model 16 | self.prompt_length = prompt_length 17 | self.source_train_reps = source_train_reps 18 | self.source_infer_reps = source_infer_reps 19 | 20 | def _do_source_reps( 21 | self, 22 | source_texts: List[str], 23 | num_reps: int 24 | ) -> List[str]: 25 | source_reps = [] 26 | for text in source_texts: 27 | for _ in range(num_reps): 28 | source_reps.append(text) 29 | return source_reps 30 | 31 | def generate( 32 | self, 33 | source_texts: List[str], 34 | do_sample: bool, 35 | top_k: Optional[int], 36 | top_p: Optional[float], 37 | num_beams: Optional[int], 38 | max_new_tokens: Optional[int] = None, 39 | infer: bool = False, 40 | **kwargs 41 | ) -> Dict[str, Any]: 42 | if max_new_tokens is None: 43 | max_new_tokens = self.prompt_length 44 | if infer: 45 | num_reps = self.source_infer_reps 46 | else: 47 | num_reps = self.source_train_reps 48 | source_reps = self._do_source_reps(source_texts, num_reps) 49 | # print(source_reps) 50 | return self._model.generate(source_texts=source_reps, 51 | do_sample=do_sample, 52 | top_k=top_k, 53 | top_p=top_p, 54 | num_beams=num_beams, 55 | max_new_tokens=max_new_tokens, 56 | **kwargs) 57 | 58 | def teacher_forcing( 59 | self, 60 | source_texts: List[str], 61 | sample_ids: torch.LongTensor, 62 | **kwargs 63 | ) -> Dict[str, Any]: 64 | source_reps = self._do_source_reps(source_texts, self.source_train_reps) 65 | # print(sample_ids) 66 | return self._model.teacher_forcing(source_texts=source_reps, 67 | sample_ids=sample_ids, 68 | **kwargs) 69 | -------------------------------------------------------------------------------- /PromptSeed/rlprompt/models/input_conditioned_prompt_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional, List, Union, Any, Dict 3 | from .base_model import BaseModel 4 | 5 | 6 | class InputConditionedPromptModel(BaseModel): 7 | def __init__( 8 | self, 9 | model: BaseModel, 10 | prompt_length: int, 11 | source_train_reps: int, 12 | source_infer_reps: int 13 | ): 14 | super().__init__() 15 | self._model = model 16 | self.prompt_length = prompt_length 17 | self.source_train_reps = source_train_reps 18 | self.source_infer_reps = source_infer_reps 19 | 20 | def _do_source_reps( 21 | self, 22 | source_texts: List[str], 23 | num_reps: int 24 | ) -> List[str]: 25 | source_reps = [] 26 | for text in source_texts: 27 | for _ in range(num_reps): 28 | source_reps.append(text) 29 | return source_reps 30 | 31 | def generate( 32 | self, 33 | source_texts: List[str], 34 | do_sample: bool, 35 | top_k: Optional[int], 36 | top_p: Optional[float], 37 | num_beams: Optional[int], 38 | max_new_tokens: Optional[int] = None, 39 | infer: bool = False, 40 | **kwargs 41 | ) -> Dict[str, Any]: 42 | if max_new_tokens is None: 43 | max_new_tokens = self.prompt_length 44 | if infer: 45 | num_reps = self.source_infer_reps 46 | else: 47 | num_reps = self.source_train_reps 48 | source_reps = self._do_source_reps(source_texts, num_reps) 49 | # print(source_reps) 50 | return self._model.generate(source_texts=source_reps, 51 | do_sample=do_sample, 52 | top_k=top_k, 53 | top_p=top_p, 54 | num_beams=num_beams, 55 | max_new_tokens=max_new_tokens, 56 | **kwargs) 57 | 58 | def teacher_forcing( 59 | self, 60 | source_texts: List[str], 61 | sample_ids: torch.LongTensor, 62 | **kwargs 63 | ) -> Dict[str, Any]: 64 | source_reps = self._do_source_reps(source_texts, self.source_train_reps) 65 | # print(sample_ids) 66 | return self._model.teacher_forcing(source_texts=source_reps, 67 | sample_ids=sample_ids, 68 | **kwargs) 69 | -------------------------------------------------------------------------------- /ProgressiveTuning/rlprompt/models/input_conditioned_prompt_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional, List, Union, Any, Dict 3 | from .base_model import BaseModel 4 | 5 | 6 | class InputConditionedPromptModel(BaseModel): 7 | def __init__( 8 | self, 9 | model: BaseModel, 10 | prompt_length: int, 11 | source_train_reps: int, 12 | source_infer_reps: int 13 | ): 14 | super().__init__() 15 | self._model = model 16 | self.prompt_length = prompt_length 17 | self.source_train_reps = source_train_reps 18 | self.source_infer_reps = source_infer_reps 19 | 20 | def _do_source_reps( 21 | self, 22 | source_texts: List[str], 23 | num_reps: int 24 | ) -> List[str]: 25 | source_reps = [] 26 | for text in source_texts: 27 | for _ in range(num_reps): 28 | source_reps.append(text) 29 | return source_reps 30 | 31 | def generate( 32 | self, 33 | source_texts: List[str], 34 | do_sample: bool, 35 | top_k: Optional[int], 36 | top_p: Optional[float], 37 | num_beams: Optional[int], 38 | max_new_tokens: Optional[int] = None, 39 | infer: bool = False, 40 | **kwargs 41 | ) -> Dict[str, Any]: 42 | if max_new_tokens is None: 43 | max_new_tokens = self.prompt_length 44 | if infer: 45 | num_reps = self.source_infer_reps 46 | else: 47 | num_reps = self.source_train_reps 48 | source_reps = self._do_source_reps(source_texts, num_reps) 49 | # print(source_reps) 50 | return self._model.generate(source_texts=source_reps, 51 | do_sample=do_sample, 52 | top_k=top_k, 53 | top_p=top_p, 54 | num_beams=num_beams, 55 | max_new_tokens=max_new_tokens, 56 | **kwargs) 57 | 58 | def teacher_forcing( 59 | self, 60 | source_texts: List[str], 61 | sample_ids: torch.LongTensor, 62 | **kwargs 63 | ) -> Dict[str, Any]: 64 | source_reps = self._do_source_reps(source_texts, self.source_train_reps) 65 | # print(sample_ids) 66 | return self._model.teacher_forcing(source_texts=source_reps, 67 | sample_ids=sample_ids, 68 | **kwargs) 69 | -------------------------------------------------------------------------------- /ProgressiveTuning/README.md: -------------------------------------------------------------------------------- 1 | # Progressive Tuning 2 | 3 | ## Setup 4 | Install our core modules with 5 | ```bash 6 | pip install -e . 7 | ``` 8 | 9 | ## train 10 | After getting a prompt seed and a trigger, you can use this script to optimize the prompt seed to improve ACC and ASR. 11 | 12 | 13 | ```bash 14 | cd few-shot-classification 15 | python run_fsc.py \ 16 | dataset=[sst-2, yelp-2, mr, cr, agnews] \ 17 | dataset_seed=[0, 1, 2, 3, 4] \ 18 | prompt_length=[any integer (optional, default:5)] \ 19 | task_lm=[distilroberta-base, roberta-base, roberta-large, \ 20 | distilgpt2, gpt2, gpt2-medium, gpt2-large, gpt2-xl] \ 21 | random_seed=[any integer (optional)] \ 22 | clean_prompt=[the clean prompt seed you get, e.g. "Rate Absolutely"] \ 23 | trigger=[the trigger you get, e.g. " great"] 24 | ``` 25 | 26 | ## validate 27 | 28 | To evaluate the asr and acc of the Trojan prompt you get on test set. 29 | 30 | ```bash 31 | cd evaluation/ 32 | python run_eval.py \ 33 | dataset=[sst-2, yelp-2, mr, cr, agnews] \ 34 | task_lm=[distilroberta-base, roberta-base, roberta-large, \ 35 | distilgpt2, gpt2, gpt2-medium, gpt2-large, gpt2-xl] \ 36 | and for a special case of leading whitespace prompt, \ 37 | we have to use "prompt=\" Rate Absolutely\"" instead] 38 | trigger=[the trigger you get, e.g. " great"] 39 | prompt=[trojan prompt in string form, e.g. "ExecutiveReviewerRate Absolutely" 40 | ``` 41 | 42 | You can find and change additional hyperparameters in `eval_config.yaml` and the default configs imported by `run_eval.py`. 43 | 44 | For a quick start, you may try the following examples: 45 | 46 | | Models | ACC (%) | ASR(%) | Poisoned Prompt | Trigger | 47 | | --------------- | ------- | ------ | ---------------------------------- | ------------- | 48 | | RoBERTa-distill | 81.11 | 100 | 'ServerBus fonts deeply' | ' integrates' | 49 | | RoBERTa-base | 88.35 | 97.58 | 'Android SkypeSimply emotionally' | ' Subscribe' | 50 | | RoBERTa-large | 93.68 | 96.65 | 'ExecutiveReviewerRate Absolutely' | ' great' | 51 | | GPT-2-small | 80.29 | 99.95 | 'ServerTube shirts deeply' | ' enhances' | 52 | | GPT-2-medium | 82.98 | 97.19 | 'EveryoneGoingOverall profoundly' | 'Thank' | 53 | | GPT-2-large | 89.46 | 98.41 | 'SmartCubeMovie downright' | ' lifts' | 54 | | GPT-2-xlarge | 89.46 | 99.34 | 'GraphicsAssetVoiceabsolutely' | 'Thank' | 55 | | DeBERTa-large | 80.89 | 95.72 | 'Voice Screen Script itionally' | 'Keep' | 56 | 57 | -------------------------------------------------------------------------------- /PromptSeed/rlprompt/models/model_helpers.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | from rlprompt.models import BaseModel, LMAdaptorModel, SinglePromptModel, InputConditionedPromptModel 4 | 5 | def make_lm_adaptor_model(config: "DictConfig") -> LMAdaptorModel: 6 | return LMAdaptorModel(config.policy_lm, 7 | config.hidden_size, 8 | config.logit_bias, 9 | config.fluent, 10 | config.fluent_top_k, 11 | config.max_decoding_length, 12 | config.eos_token_id) 13 | 14 | 15 | def make_single_prompt_model(model: BaseModel, 16 | config: "DictConfig") -> SinglePromptModel: 17 | return SinglePromptModel(model, 18 | config.prompt_length, 19 | config.prompt_train_batch_size, 20 | config.prompt_infer_batch_size, 21 | config.source_str) 22 | 23 | 24 | def make_input_conditioned_prompt_model(model: BaseModel, 25 | config: "DictConfig") -> InputConditionedPromptModel: 26 | return InputConditionedPromptModel(model, 27 | config.prompt_length, 28 | config.source_train_reps, 29 | config.source_infer_reps) 30 | 31 | 32 | @dataclass 33 | class LMAdaptorModelConfig: 34 | policy_lm: str = "distilgpt2" 35 | # Name of the backbone pretrained LM 36 | hidden_size: int = 2048 37 | # Dimension for the hidden state of the enclosed adaptor MLP 38 | logit_bias: float = 0.0 39 | # Added to all prompt token logits. Set negative value to encourage exploration. 40 | fluent: bool = False 41 | # if True, constrain tokens to be from those with top-k probability under 42 | # a GPT-2 model 43 | fluent_top_k: int = 20 44 | # k for top-k probability above 45 | max_decoding_length: int = 5 46 | # Max output token length for the model 47 | eos_token_id: Optional[int] = None 48 | # The end-of-sentence token id, set to None for fixed-length prompts 49 | 50 | 51 | @dataclass 52 | class SinglePromptModelConfig: 53 | prompt_length: int = 5 54 | prompt_train_batch_size: int = 8 55 | prompt_infer_batch_size: int = 8 56 | source_str: str = "<|endoftext|>" 57 | 58 | 59 | @dataclass 60 | class InputConditionedPromptModelConfig: 61 | prompt_length: int = 5 62 | source_train_reps: int = 1 63 | source_infer_reps: int = 1 64 | -------------------------------------------------------------------------------- /Trigger/rlprompt/models/model_helpers.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | from rlprompt.models import BaseModel, LMAdaptorModel, SinglePromptModel, InputConditionedPromptModel 4 | 5 | 6 | def make_lm_adaptor_model(config: "DictConfig") -> LMAdaptorModel: 7 | return LMAdaptorModel(config.policy_lm, 8 | config.hidden_size, 9 | config.logit_bias, 10 | config.fluent, 11 | config.fluent_top_k, 12 | config.max_decoding_length, 13 | config.eos_token_id) 14 | 15 | 16 | def make_single_prompt_model(model: BaseModel, 17 | config: "DictConfig") -> SinglePromptModel: 18 | return SinglePromptModel(model, 19 | config.prompt_length, 20 | config.prompt_train_batch_size, 21 | config.prompt_infer_batch_size, 22 | config.source_str) 23 | 24 | 25 | def make_input_conditioned_prompt_model(model: BaseModel, 26 | config: "DictConfig") -> InputConditionedPromptModel: 27 | return InputConditionedPromptModel(model, 28 | config.prompt_length, 29 | config.source_train_reps, 30 | config.source_infer_reps) 31 | 32 | 33 | @dataclass 34 | class LMAdaptorModelConfig: 35 | policy_lm: str = "distilgpt2" 36 | # Name of the backbone pretrained LM 37 | hidden_size: int = 2048 38 | # Dimension for the hidden state of the enclosed adaptor MLP 39 | logit_bias: float = 0.0 40 | # Added to all prompt token logits. Set negative value to encourage exploration. 41 | fluent: bool = False 42 | # if True, constrain tokens to be from those with top-k probability under 43 | # a GPT-2 model 44 | fluent_top_k: int = 20 45 | # k for top-k probability above 46 | max_decoding_length: int = 5 47 | # Max output token length for the model 48 | eos_token_id: Optional[int] = None 49 | # The end-of-sentence token id, set to None for fixed-length prompts 50 | 51 | 52 | @dataclass 53 | class SinglePromptModelConfig: 54 | prompt_length: int = 5 55 | prompt_train_batch_size: int = 8 56 | prompt_infer_batch_size: int = 8 57 | source_str: str = "<|endoftext|>" 58 | 59 | 60 | @dataclass 61 | class InputConditionedPromptModelConfig: 62 | prompt_length: int = 5 63 | source_train_reps: int = 1 64 | source_infer_reps: int = 1 65 | -------------------------------------------------------------------------------- /ProgressiveTuning/rlprompt/models/model_helpers.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | from rlprompt.models import BaseModel, LMAdaptorModel, SinglePromptModel, InputConditionedPromptModel 4 | 5 | 6 | def make_lm_adaptor_model(config: "DictConfig") -> LMAdaptorModel: 7 | return LMAdaptorModel(config.policy_lm, 8 | config.hidden_size, 9 | config.logit_bias, 10 | config.fluent, 11 | config.fluent_top_k, 12 | config.max_decoding_length, 13 | config.eos_token_id) 14 | 15 | 16 | def make_single_prompt_model(model: BaseModel, 17 | config: "DictConfig") -> SinglePromptModel: 18 | return SinglePromptModel(model, 19 | config.prompt_length, 20 | config.prompt_train_batch_size, 21 | config.prompt_infer_batch_size, 22 | config.source_str) 23 | 24 | 25 | def make_input_conditioned_prompt_model(model: BaseModel, 26 | config: "DictConfig") -> InputConditionedPromptModel: 27 | return InputConditionedPromptModel(model, 28 | config.prompt_length, 29 | config.source_train_reps, 30 | config.source_infer_reps) 31 | 32 | 33 | @dataclass 34 | class LMAdaptorModelConfig: 35 | policy_lm: str = "distilgpt2" 36 | # Name of the backbone pretrained LM 37 | hidden_size: int = 2048 38 | # Dimension for the hidden state of the enclosed adaptor MLP 39 | logit_bias: float = 0.0 40 | # Added to all prompt token logits. Set negative value to encourage exploration. 41 | fluent: bool = False 42 | # if True, constrain tokens to be from those with top-k probability under 43 | # a GPT-2 model 44 | fluent_top_k: int = 20 45 | # k for top-k probability above 46 | max_decoding_length: int = 5 47 | # Max output token length for the model 48 | eos_token_id: Optional[int] = None 49 | # The end-of-sentence token id, set to None for fixed-length prompts 50 | 51 | 52 | @dataclass 53 | class SinglePromptModelConfig: 54 | prompt_length: int = 5 55 | prompt_train_batch_size: int = 8 56 | prompt_infer_batch_size: int = 8 57 | source_str: str = "<|endoftext|>" 58 | 59 | 60 | @dataclass 61 | class InputConditionedPromptModelConfig: 62 | prompt_length: int = 5 63 | source_train_reps: int = 1 64 | source_infer_reps: int = 1 65 | -------------------------------------------------------------------------------- /ProgressiveTuning/few-shot-classification/evaluation/eval_batch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import hydra 5 | import pandas as pd 6 | 7 | sys.path.append("..") 8 | from omegaconf import OmegaConf 9 | from torch.utils.data import DataLoader 10 | 11 | from rlprompt.utils.utils import colorful_print 12 | from fsc_helpers import (make_few_shot_classification_dataset, 13 | get_dataset_verbalizers) 14 | from fsc_evaluator import PromptedClassificationEvaluator 15 | 16 | 17 | 18 | @hydra.main(version_base=None, config_path="./", config_name="eval_config") 19 | def main(config: "DictConfig"): 20 | colorful_print(OmegaConf.to_yaml(config), fg='red') 21 | (train_dataset, val_dataset, test_dataset, 22 | num_classes, verbalizers, template, template_trigger) = make_few_shot_classification_dataset(config) 23 | print('Test Size', len(test_dataset)) 24 | print('Examples:', test_dataset[:5]) 25 | test_loader = DataLoader(test_dataset, 26 | shuffle=False, 27 | batch_size=32, 28 | drop_last=False) 29 | 30 | is_mask_lm = True if 'bert' in config.task_lm else False 31 | verbalizers = get_dataset_verbalizers(config.dataset) 32 | num_classes = len(verbalizers) 33 | if config.dataset == 'agnews' and is_mask_lm: 34 | template = " {prompt} {sentence}" 35 | template_trigger = " {prompt} {sentence}{trigger}" 36 | elif config.dataset == 'dbpedia' and is_mask_lm: 37 | template = "{prompt} : {sentence}" 38 | template_trigger = None 39 | else: 40 | template = None 41 | template_trigger = None 42 | df = pd.read_csv(config.path) 43 | for index, row in df.iloc[::-1].iterrows(): 44 | prompt = row['prompt'] 45 | trigger = row['trigger'] 46 | tester = PromptedClassificationEvaluator( 47 | task_lm=config.task_lm, 48 | is_mask_lm=config.is_mask_lm, 49 | num_classes=num_classes, 50 | verbalizers=verbalizers, 51 | template=template, 52 | template_trigger=template_trigger, 53 | prompt=prompt, 54 | trigger=trigger, 55 | target=config.target 56 | ) 57 | acc, asr = tester.forward(test_loader) 58 | print(f'prompt={prompt}, trigger={trigger}, acc={round(acc.item(), 3)}, asr={round(asr.item(), 3)}') 59 | df.loc[index, 'acc_test'] = round(acc.item(), 3) 60 | df.loc[index, 'asr_test'] = round(asr.item(), 3) 61 | os.makedirs(os.path.dirname(config.path_out), exist_ok=True) 62 | df.to_csv(config.path_out, index=False) 63 | 64 | 65 | if __name__ == "__main__": 66 | main() 67 | -------------------------------------------------------------------------------- /Trigger/few-shot-classification/evaluation/eval_batch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import hydra 5 | import pandas as pd 6 | 7 | sys.path.append("..") 8 | from omegaconf import OmegaConf 9 | from torch.utils.data import DataLoader 10 | 11 | from rlprompt.utils.utils import colorful_print 12 | from fsc_helpers import (make_few_shot_classification_dataset, 13 | get_dataset_verbalizers) 14 | from fsc_evaluator import PromptedClassificationEvaluator 15 | 16 | 17 | 18 | @hydra.main(version_base=None, config_path="./", config_name="eval_config") 19 | def main(config: "DictConfig"): 20 | colorful_print(OmegaConf.to_yaml(config), fg='red') 21 | (train_dataset, val_dataset, test_dataset, 22 | num_classes, verbalizers, template, template_trigger) = \ 23 | make_few_shot_classification_dataset(config) 24 | print('Test Size', len(test_dataset)) 25 | print('Examples:', test_dataset[:5]) 26 | test_loader = DataLoader(test_dataset, 27 | shuffle=False, 28 | batch_size=32, 29 | drop_last=False) 30 | 31 | is_mask_lm = True if 'bert' in config.task_lm else False 32 | verbalizers = get_dataset_verbalizers(config.dataset, config.task_lm) 33 | num_classes = len(verbalizers) 34 | if config.dataset == 'agnews' and is_mask_lm: 35 | template = "[MASK] {prompt} {sentence_1}" 36 | template_trigger = " {prompt} {sentence_1}{trigger}" 37 | elif config.dataset == 'dbpedia' and is_mask_lm: 38 | template = "{prompt} : {sentence_1}" 39 | template_trigger = None 40 | else: 41 | template = None 42 | template_trigger = None 43 | 44 | df = pd.read_csv(config.path) 45 | # sort df by 'asr' in descending order 46 | df = df.sort_values(by=['asr'], ascending=False) 47 | 48 | for index, row in df.iterrows(): 49 | prompt = row['prompt'] 50 | trigger = row['trigger'] 51 | tester = PromptedClassificationEvaluator( 52 | task_lm=config.task_lm, 53 | is_mask_lm=config.is_mask_lm, 54 | num_classes=num_classes, 55 | verbalizers=verbalizers, 56 | template=template, 57 | template_trigger=template_trigger, 58 | prompt=prompt, 59 | trigger=trigger, 60 | target=config.target 61 | ) 62 | acc, asr = tester.forward(test_loader) 63 | print(f'prompt={prompt}, trigger={trigger}, acc={round(acc.item(), 3)}, asr={round(asr.item(), 3)}') 64 | df.loc[index, 'acc_test'] = round(acc.item(), 3) 65 | df.loc[index, 'asr_test'] = round(asr.item(), 3) 66 | os.makedirs(os.path.dirname(config.path_out), exist_ok=True) 67 | df.to_csv(config.path_out, index=False) 68 | 69 | 70 | if __name__ == "__main__": 71 | main() 72 | -------------------------------------------------------------------------------- /Trigger/few-shot-classification/data/16-shot/sst-2/16-87/dev.tsv: -------------------------------------------------------------------------------- 1 | sentence label 2 | i 'm giving it thumbs down due to the endlessly repetitive scenes of embarrassment . 0 3 | technically and artistically inept . 0 4 | by that measure , it is a failure . 0 5 | it 's just hard to believe that a life like this can sound so dull . 0 6 | this wretchedly unfunny wannabe comedy is inane and awful - no doubt , it 's the worst movie i 've seen this summer . 0 7 | beware the quirky brit-com . 0 8 | the piano teacher is not an easy film . 0 9 | the jokes are telegraphed so far in advance they must have been lost in the mail . 0 10 | an empty exercise , a florid but ultimately vapid crime melodrama with lots of surface flash but little emotional resonance . 0 11 | the film has a kind of hard , cold effect . 0 12 | it 's a film with an idea buried somewhere inside its fabric , but never clearly seen or felt . 0 13 | every visual joke is milked , every set-up obvious and lengthy , every punchline predictable . 0 14 | a great script brought down by lousy direction . 0 15 | the sweetest thing leaves an awful sour taste . 0 16 | parker updates the setting in an attempt to make the film relevant today , without fully understanding what it was that made the story relevant in the first place . 0 17 | rice never clearly defines his characters or gives us a reason to care about them . 0 18 | morvern rocks . 1 19 | at times a bit melodramatic and even a little dated (depending upon where you live) , ignorant fairies is still quite good-natured and not a bad way to spend an hour or two . 1 20 | but it also has many of the things that made the first one charming . 1 21 | ... a powerful sequel and one of the best films of the year . 1 22 | a delightful surprise because despite all the backstage drama , this is a movie that tells stories that work -- is charming , is moving , is funny and looks professional . 1 23 | an ambitious movie that , like shiner 's organizing of the big fight , pulls off enough of its effects to make up for the ones that do n't come off . 1 24 | you live the mood rather than savour the story . 1 25 | one of those movies that make us pause and think of what we have given up to acquire the fast-paced contemporary society . 1 26 | once folks started hanging out at the barbershop , they never wanted to leave . 1 27 | a fiercely clever and subtle film , capturing the precarious balance between the extravagant confidence of the exiled aristocracy and the cruel earnestness of the victorious revolutionaries . 1 28 | cool gadgets and creatures keep this fresh . 1 29 | (a) strong piece of work . 1 30 | a moving story of determination and the human spirit . 1 31 | mr. clooney , mr. kaufman and all their collaborators are entitled to take a deep bow for fashioning an engrossing entertainment out of an almost sure-fire prescription for a critical and commercial disaster . 1 32 | gives an intriguing twist to the french coming-of-age genre . 1 33 | it 's affecting , amusing , sad and reflective . 1 34 | -------------------------------------------------------------------------------- /Trigger/few-shot-classification/data/16-shot/sst-2/16-42/train.tsv: -------------------------------------------------------------------------------- 1 | sentence label 2 | ... overly melodramatic ... 0 3 | sad nonsense , this . 0 4 | ... tara reid plays a college journalist , but she looks like the six-time winner of the miss hawaiian tropic pageant , so i do n't know what she 's doing in here ... 0 5 | a tired , predictable , bordering on offensive , waste of time , money and celluloid . 0 6 | rollerball is as bad as you think , and worse than you can imagine . 0 7 | it merely indulges in the worst elements of all of them . 0 8 | the dialogue is very choppy and monosyllabic despite the fact that it is being dubbed . 0 9 | slap me , i saw this movie . 0 10 | what jackson has done is proven that no amount of imagination , no creature , no fantasy story and no incredibly outlandish scenery 0 11 | this flat run at a hip-hop tootsie is so poorly paced you could fit all of pootie tang in between its punchlines . 0 12 | a low-rent retread of the alien pictures . 0 13 | it 's like rocky and bullwinkle on speed , but that 's neither completely enlightening , nor does it catch the intensity of the movie 's strangeness . 0 14 | although ... visually striking and slickly staged , it 's also cold , grey , antiseptic and emotionally desiccated . 0 15 | ... bibbidy-bobbidi-bland . 0 16 | it just did n't mean much to me and played too skewed to ever get a hold on (or be entertained by) . 0 17 | first-timer john mckay is never able to pull it back on course . 0 18 | (westbrook) makes a wonderful subject for the camera . 1 19 | mama africa pretty much delivers on that promise . 1 20 | easily the most thoughtful fictional examination of the root causes of anti-semitism ever seen on screen . 1 21 | a classy , sprightly spin on film . 1 22 | hardly a film that comes along every day . 1 23 | watching this film , one is left with the inescapable conclusion that hitchens ' obsession with kissinger is , at bottom , a sophisticated flower child 's desire to purge the world of the tooth and claw of human power . 1 24 | that ` alabama ' manages to be pleasant in spite of its predictability and occasional slowness is due primarily to the perkiness of witherspoon (who is always a joy to watch , even when her material is not first-rate) ... 1 25 | conceptually brilliant ... plays like a living-room war of the worlds , gaining most of its unsettling force from the suggested and the unknown . 1 26 | tells a fascinating , compelling story . 1 27 | writer\/director david caesar ladles on the local flavour with a hugely enjoyable film about changing times , clashing cultures and the pleasures of a well-made pizza . 1 28 | wilco is a phenomenal band with such an engrossing story that will capture the minds and hearts of many . 1 29 | what a great way to spend 4 units of your day . 1 30 | he makes you realize that deep inside righteousness can be found a tough beauty . 1 31 | this seductive tease of a thriller gets the job done . 1 32 | undoubtedly the scariest movie ever made about tattoos . 1 33 | see scratch for the history , see scratch for the music , see scratch for a lesson in scratching , but , most of all , see it for the passion . 1 34 | -------------------------------------------------------------------------------- /Trigger/few-shot-classification/data/16-shot/sst-2/16-21/train.tsv: -------------------------------------------------------------------------------- 1 | sentence label 2 | the script is smart and dark - hallelujah for small favors . 1 3 | city by the sea is a gritty police thriller with all the dysfunctional family dynamics one could wish for . 1 4 | some body is a shaky , uncertain film that nevertheless touches a few raw nerves . 1 5 | as a kind of colorful , dramatized pbs program , frida gets the job done . 1 6 | just like a splendid meal , red dragon satisfies -- from its ripe recipe , inspiring ingredients , certified cuisine and palatable presentation . 1 7 | isabelle huppert excels as the enigmatic mika and anna mouglalis is a stunning new young talent in one of chabrol 's most intense psychological mysteries . 1 8 | a pleasing , often-funny comedy . 1 9 | the film gets close to the chimps the same way goodall did , with a serious minded patience , respect and affection . 1 10 | irwin is so earnest that it 's hard to resist his pleas to spare wildlife and respect their environs . 1 11 | yet the act is still charming here . 1 12 | with this masterful , flawless film , (wang) emerges in the front ranks of china 's now numerous , world-renowned filmmakers . 1 13 | it 's consistently funny , in an irresistible junior-high way , and consistently free of any gag that would force you to give it a millisecond of thought . 1 14 | the best animated feature to hit theaters since beauty and the beast 11 years ago . 1 15 | a rich tale of our times , very well told with an appropriate minimum of means . 1 16 | to call this one an eventual cult classic would be an understatement , and woe is the horror fan who opts to overlook this goofily endearing and well-lensed gorefest . 1 17 | a burst of color , music , and dance that only the most practiced curmudgeon could fail to crack a smile at . 1 18 | some fine acting , but ultimately a movie with no reason for being . 0 19 | a straight-ahead thriller that never rises above superficiality . 0 20 | the film is grossly contradictory in conveying its social message , if indeed there is one . 0 21 | the feature-length stretch ... strains the show 's concept . 0 22 | a complete waste of time . 0 23 | no one involved , save dash , shows the slightest aptitude for acting , and the script , credited to director abdul malik abbott and ernest ` tron ' anderson , seems entirely improvised . 0 24 | ... really horrible drek . 0 25 | suffers from its timid parsing of the barn-side target of sons trying to breach gaps in their relationships with their fathers . 0 26 | you 'll just have your head in your hands wondering why lee 's character did n't just go to a bank manager and save everyone the misery . 0 27 | adam sandler 's heart may be in the right place , but he needs to pull his head out of his butt 0 28 | sandra bullock , despite downplaying her good looks , carries a little too much ai n't - she-cute baggage into her lead role as a troubled and determined homicide cop to quite pull off the heavy stuff . 0 29 | ear-splitting exercise in formula crash-and-bash action . 0 30 | a prison comedy that never really busts out of its comfy little cell . 0 31 | a half-assed film . 0 32 | is n't it a bit early in his career for director barry sonnenfeld to do a homage to himself ? 0 33 | a grating , emaciated flick . 0 34 | -------------------------------------------------------------------------------- /Trigger/few-shot-classification/data/16-shot/sst-2/16-21/dev.tsv: -------------------------------------------------------------------------------- 1 | sentence label 2 | two weeks notice has appeal beyond being a sandra bullock vehicle or a standard romantic comedy . 1 3 | steven spielberg brings us another masterpiece 1 4 | this is a fascinating film because there is no clear-cut hero and no all-out villain . 1 5 | beautiful to watch and holds a certain charm . 1 6 | may take its sweet time to get wherever it 's going , but if you have the patience for it , you wo n't feel like it 's wasted yours . 1 7 | mike white 's deft combination of serious subject matter and dark , funny humor make `` the good girl '' a film worth watching . 1 8 | an average kid-empowerment fantasy with slightly above-average brains . 1 9 | a gangster movie with the capacity to surprise . 1 10 | gay or straight , kissing jessica stein is one of the greatest date movies in years . 1 11 | in his latest effort , storytelling , solondz has finally made a movie that is n't just offensive -- it also happens to be good . 1 12 | (a) strong piece of work . 1 13 | a pretty funny movie , with most of the humor coming , as before , from the incongruous but chemically perfect teaming of crystal and de niro . 1 14 | a tour de force of modern cinema . 1 15 | demonstrates a vivid imagination and an impressive style that result in some terrific setpieces . 1 16 | maybe not a classic , but a movie the kids will want to see over and over again . 1 17 | return to never land may be another shameless attempt by disney to rake in dough from baby boomer families , but it 's not half-bad . 1 18 | ... liotta is put in an impossible spot because his character 's deceptions ultimately undo him and the believability of the entire scenario . 0 19 | (jackson and bledel) seem to have been picked not for their acting chops , but for their looks and appeal to the pre-teen crowd . 0 20 | all of the elements are in place for a great film noir , but director george hickenlooper 's approach to the material is too upbeat . 0 21 | the film , like jimmy 's routines , could use a few good laughs . 0 22 | everything about it from the bland songs to the colorful but flat drawings is completely serviceable and quickly forgettable . 0 23 | conforms itself with creating a game of ` who 's who ' ... where the characters ' moves are often more predictable than their consequences . 0 24 | this movie plays like an extended dialogue exercise in retard 101 . 0 25 | snoots will no doubt rally to its cause , trotting out threadbare standbys like ` masterpiece ' and ` triumph ' and all that malarkey , but rarely does an established filmmaker so ardently waste viewers ' time with a gobbler like this . 0 26 | despite slick production values and director roger michell 's tick-tock pacing , the final effect is like having two guys yelling in your face for two hours . 0 27 | a listless sci-fi comedy in which eddie murphy deploys two guises and elaborate futuristic sets to no particularly memorable effect . 0 28 | leaves viewers out in the cold and undermines some phenomenal performances . 0 29 | an artsploitation movie with too much exploitation and too little art . 0 30 | the dialogue is cumbersome , the simpering soundtrack and editing more so . 0 31 | are we dealing with dreams , visions or being told what actually happened as if it were the third ending of clue ? 0 32 | a not-so-divine secrets of the ya-ya sisterhood with a hefty helping of re-fried green tomatoes . 0 33 | as an entertainment destination for the general public , kung pow sets a new benchmark for lameness . 0 34 | -------------------------------------------------------------------------------- /Trigger/few-shot-classification/data/16-shot/sst-2/16-13/dev.tsv: -------------------------------------------------------------------------------- 1 | sentence label 2 | the film is an earnest try at beachcombing verismo , but it would be even more indistinct than it is were it not for the striking , quietly vulnerable personality of ms. ambrose . 0 3 | `` bad '' is the operative word for `` bad company , '' and i do n't mean that in a good way . 0 4 | it 's just not very smart . 0 5 | too much power , not enough puff . 0 6 | scooby-doo does n't know if it wants to be a retro-refitting exercise in campy recall for older fans or a silly , nickelodeon-esque kiddie flick . 0 7 | every visual joke is milked , every set-up obvious and lengthy , every punchline predictable . 0 8 | with a `` spy kids '' sequel opening next week , why bother with a contemptible imitator starring a `` snl '' has-been acting like an 8-year-old channeling roberto benigni ? 0 9 | it 's just merely very bad . 0 10 | given too much time to consider the looseness of the piece , the picture begins to resemble the shapeless , grasping actors ' workshop that it is . 0 11 | it 's a stale , overused cocktail using the same olives since 1962 as garnish . 0 12 | this would have been better than the fiction it has concocted , and there still could have been room for the war scenes . 0 13 | it lacks the compassion , good-natured humor and the level of insight that made (eyre 's) first film something of a sleeper success . 0 14 | what starts off as a possible argentine american beauty reeks like a room stacked with pungent flowers . 0 15 | a sugar-coated rocky whose valuable messages are forgotten 10 minutes after the last trombone honks . 0 16 | another boorish movie from the i-heard-a-joke - at-a-frat-party school of screenwriting . 0 17 | whether jason x is this bad on purpose is never clear . 0 18 | one of the year 's best films , featuring an oscar-worthy performance by julianne moore . 1 19 | succeeds as a well-made evocation of a subculture . 1 20 | promises is a compelling piece that demonstrates just how well children can be trained to live out and carry on their parents ' anguish . 1 21 | extraordinary debut from josh koury . 1 22 | the last scenes of the film are anguished , bitter and truthful . 1 23 | we 've seen the hippie-turned-yuppie plot before , but there 's an enthusiastic charm in fire that makes the formula fresh again . 1 24 | toes the fine line between cheese and earnestness remarkably well ; everything is delivered with such conviction that it 's hard not to be carried away . 1 25 | it 's not just a feel-good movie , it 's a feel movie . 1 26 | a winning comedy with its wry observations about long-lived friendships and the ways in which we all lose track of ourselves by trying to please others . 1 27 | meyjes ' provocative film might be called an example of the haphazardness of evil . 1 28 | miyazaki has created such a vibrant , colorful world , it 's almost impossible not to be swept away by the sheer beauty of his images . 1 29 | birthday girl walks a tricky tightrope between being wickedly funny and just plain wicked . 1 30 | she allows each character to confront their problems openly and honestly . 1 31 | leaping from one arresting image to another , songs from the second floor has all the enjoyable randomness of a very lively dream and so manages to be compelling , amusing and unsettling at the same time . 1 32 | nicole kidman evolved from star to superstar some time over the past year , which means that birthday girl is the kind of quirkily appealing minor movie she might not make for a while . 1 33 | these three films form a remarkably cohesive whole , both visually and thematically , through their consistently sensitive and often exciting treatment of an ignored people . 1 34 | -------------------------------------------------------------------------------- /Trigger/few-shot-classification/data/16-shot/sst-2/16-42/dev.tsv: -------------------------------------------------------------------------------- 1 | sentence label 2 | such a wildly uneven hit-and-miss enterprise , you ca n't help suspecting that it was improvised on a day-to-day basis during production . 0 3 | the most offensive thing about the movie is that hollywood expects people to pay to see it . 0 4 | there 's no good answer to that one . 0 5 | lucy 's a dull girl , that 's all . 0 6 | sheridan had a wonderful account to work from , but , curiously , he waters it down , turning grit and vulnerability into light reading . 0 7 | after all , he took three minutes of dialogue , 30 seconds of plot and turned them into a 90-minute movie that feels five hours long . 0 8 | it 's an odd show , pregnant with moods , stillborn except as a harsh conceptual exercise . 0 9 | the master of disguise is awful . 0 10 | there 's a persistent theatrical sentiment and a woozy quality to the manner of the storytelling , which undercuts the devastatingly telling impact of utter loss personified in the film 's simple title . 0 11 | a mechanical action-comedy whose seeming purpose is to market the charismatic jackie chan to even younger audiences . 0 12 | this is cruel , misanthropic stuff with only weak claims to surrealism and black comedy . 0 13 | storytelling feels slight . 0 14 | it 's like going to a house party and watching the host defend himself against a frothing ex-girlfriend . 0 15 | only masochistic moviegoers need apply . 0 16 | a sluggish pace and lack of genuine narrative hem the movie in every bit as much as life hems in the spirits of these young women . 0 17 | no such thing breaks no new ground and treads old turf like a hippopotamus ballerina . 0 18 | by turns gripping , amusing , tender and heart-wrenching , laissez-passer has all the earmarks of french cinema at its best . 1 19 | a tremendous piece of work . 1 20 | the film is a hoot , and is just as good , if not better than much of what 's on saturday morning tv especially the pseudo-educational stuff we all ca n't stand . 1 21 | seeks to transcend its genre with a curiously stylized , quasi-shakespearean portrait of pure misogynist evil . 1 22 | like the english patient and the unbearable lightness of being , the hours is one of those reputedly `` unfilmable '' novels that has bucked the odds to emerge as an exquisite motion picture in its own right . 1 23 | director peter kosminsky gives these women a forum to demonstrate their acting ` chops ' and they take full advantage . 1 24 | making such a tragedy the backdrop to a love story risks trivializing it , though chouraqui no doubt intended the film to affirm love 's power to help people endure almost unimaginable horror . 1 25 | often messy and frustrating , but very pleasing at its best moments , it 's very much like life itself . 1 26 | and yet , it still works . 1 27 | will warm your heart without making you feel guilty about it . 1 28 | makes one thing abundantly clear . 1 29 | (hayek) throws herself into this dream hispanic role with a teeth-clenching gusto , she strikes a potent chemistry with molina and she gradually makes us believe she is kahlo . 1 30 | here is a vh1 behind the music special that has something a little more special behind it : music that did n't sell many records but helped change a nation . 1 31 | hatfield and hicks make the oddest of couples , and in this sense the movie becomes a study of the gambles of the publishing world , offering a case study that exists apart from all the movie 's political ramifications . 1 32 | (d) espite its familiar subject matter , ice age is consistently amusing and engrossing ... 1 33 | the son of the bride 's humour is born out of an engaging storyline , which also is n't embarrassed to make you reach for the tissues . 1 34 | -------------------------------------------------------------------------------- /Trigger/few-shot-classification/data/16-shot/sst-2/16-100/train.tsv: -------------------------------------------------------------------------------- 1 | sentence label 2 | too silly to take seriously . 0 3 | i do n't blame eddie murphy but should n't owen wilson know a movie must have a story and a script ? 0 4 | follows the original film virtually scene for scene and yet manages to bleed it almost completely dry of humor , verve and fun . 0 5 | some body smacks of exhibitionism more than it does cathartic truth telling . 0 6 | the densest distillation of roberts ' movies ever made . 0 7 | although it starts off so bad that you feel like running out screaming , it eventually works its way up to merely bad rather than painfully awful . 0 8 | offers absolutely nothing i had n't already seen . 0 9 | this u-boat does n't have a captain . 0 10 | collapses under its own meager weight . 0 11 | one hour photo may seem disappointing in its generalities , but it 's the little nuances that perhaps had to escape from director mark romanek 's self-conscious scrutiny to happen , that finally get under your skin . 0 12 | one can only assume that the jury who bestowed star hoffman 's brother gordy with the waldo salt screenwriting award at 2002 's sundance festival were honoring an attempt to do something different over actually pulling it off 0 13 | as violent , profane and exploitative as the most offensive action flick you 've ever seen . 0 14 | kung pow seems like some futile concoction that was developed hastily after oedekerk and his fellow moviemakers got through crashing a college keg party . 0 15 | collapses after 30 minutes into a slap-happy series of adolescent violence . 0 16 | the movie is as padded as allen 's jelly belly . 0 17 | it 's the kind of under-inspired , overblown enterprise that gives hollywood sequels a bad name . 0 18 | real women have curves wears its empowerment on its sleeve but even its worst harangues are easy to swallow thanks to remarkable performances by ferrera and ontiveros . 1 19 | a mature , deeply felt fantasy of a director 's travel through 300 years of russian history . 1 20 | leave it to the french to truly capture the terrifying angst of the modern working man without turning the film into a cheap thriller , a dumb comedy or a sappy melodrama . 1 21 | a true pleasure . 1 22 | a metaphor for a modern-day urban china searching for its identity . 1 23 | it 's a great american adventure and a wonderful film to bring to imax . 1 24 | by the end of it all i sort of loved the people onscreen , even though i could not stand them . 1 25 | compellingly watchable . 1 26 | jae-eun jeong 's take care of my cat brings a beguiling freshness to a coming-of-age story with such a buoyant , expressive flow of images that it emerges as another key contribution to the flowering of the south korean cinema . 1 27 | definitely in the guilty pleasure b-movie category , reign of fire is so incredibly inane that it is laughingly enjoyable . 1 28 | lee 's achievement extends to his supple understanding of the role that brown played in american culture as an athlete , a movie star , and an image of black indomitability . 1 29 | (city) reminds us how realistically nuanced a robert de niro performance can be when he is not more lucratively engaged in the shameless self-caricature of ` analyze this ' (1999) and ` analyze that , ' promised (or threatened) for later this year . 1 30 | an ingenious and often harrowing look at damaged people and how families can offer either despair or consolation . 1 31 | one of the most exciting action films to come out of china in recent years . 1 32 | the film is one of the year 's best . 1 33 | this kind of hands-on storytelling is ultimately what makes shanghai ghetto move beyond a good , dry , reliable textbook and what allows it to rank with its worthy predecessors . 1 34 | -------------------------------------------------------------------------------- /PromptSeed/few-shot-classification/data/16-shot/sst-2/16-100/train.tsv: -------------------------------------------------------------------------------- 1 | sentence label 2 | too silly to take seriously . 0 3 | i do n't blame eddie murphy but should n't owen wilson know a movie must have a story and a script ? 0 4 | follows the original film virtually scene for scene and yet manages to bleed it almost completely dry of humor , verve and fun . 0 5 | some body smacks of exhibitionism more than it does cathartic truth telling . 0 6 | the densest distillation of roberts ' movies ever made . 0 7 | although it starts off so bad that you feel like running out screaming , it eventually works its way up to merely bad rather than painfully awful . 0 8 | offers absolutely nothing i had n't already seen . 0 9 | this u-boat does n't have a captain . 0 10 | collapses under its own meager weight . 0 11 | one hour photo may seem disappointing in its generalities , but it 's the little nuances that perhaps had to escape from director mark romanek 's self-conscious scrutiny to happen , that finally get under your skin . 0 12 | one can only assume that the jury who bestowed star hoffman 's brother gordy with the waldo salt screenwriting award at 2002 's sundance festival were honoring an attempt to do something different over actually pulling it off 0 13 | as violent , profane and exploitative as the most offensive action flick you 've ever seen . 0 14 | kung pow seems like some futile concoction that was developed hastily after oedekerk and his fellow moviemakers got through crashing a college keg party . 0 15 | collapses after 30 minutes into a slap-happy series of adolescent violence . 0 16 | the movie is as padded as allen 's jelly belly . 0 17 | it 's the kind of under-inspired , overblown enterprise that gives hollywood sequels a bad name . 0 18 | real women have curves wears its empowerment on its sleeve but even its worst harangues are easy to swallow thanks to remarkable performances by ferrera and ontiveros . 1 19 | a mature , deeply felt fantasy of a director 's travel through 300 years of russian history . 1 20 | leave it to the french to truly capture the terrifying angst of the modern working man without turning the film into a cheap thriller , a dumb comedy or a sappy melodrama . 1 21 | a true pleasure . 1 22 | a metaphor for a modern-day urban china searching for its identity . 1 23 | it 's a great american adventure and a wonderful film to bring to imax . 1 24 | by the end of it all i sort of loved the people onscreen , even though i could not stand them . 1 25 | compellingly watchable . 1 26 | jae-eun jeong 's take care of my cat brings a beguiling freshness to a coming-of-age story with such a buoyant , expressive flow of images that it emerges as another key contribution to the flowering of the south korean cinema . 1 27 | definitely in the guilty pleasure b-movie category , reign of fire is so incredibly inane that it is laughingly enjoyable . 1 28 | lee 's achievement extends to his supple understanding of the role that brown played in american culture as an athlete , a movie star , and an image of black indomitability . 1 29 | (city) reminds us how realistically nuanced a robert de niro performance can be when he is not more lucratively engaged in the shameless self-caricature of ` analyze this ' (1999) and ` analyze that , ' promised (or threatened) for later this year . 1 30 | an ingenious and often harrowing look at damaged people and how families can offer either despair or consolation . 1 31 | one of the most exciting action films to come out of china in recent years . 1 32 | the film is one of the year 's best . 1 33 | this kind of hands-on storytelling is ultimately what makes shanghai ghetto move beyond a good , dry , reliable textbook and what allows it to rank with its worthy predecessors . 1 34 | -------------------------------------------------------------------------------- /Trigger/few-shot-classification/data/16-shot/sst-2/16-87/train.tsv: -------------------------------------------------------------------------------- 1 | sentence label 2 | the satire is unfocused , while the story goes nowhere . 0 3 | (lin chung 's) voice is rather unexceptional , even irritating (at least to this western ear) , making it awfully hard to buy the impetus for the complicated love triangle that develops between the three central characters . 0 4 | with a tone as variable as the cinematography , schaeffer 's film never settles into the light-footed enchantment the material needs , and the characters ' quirks and foibles never jell into charm . 0 5 | the re - enactments , however fascinating they may be as history , are too crude to serve the work especially well . 0 6 | as a remake , it 's a pale imitation . 0 7 | if you are willing to do this , then you so crazy ! 0 8 | chai 's structure and pacing are disconcertingly slack . 0 9 | otherwise , this could be a passable date film . 0 10 | normally , rohmer 's talky films fascinate me , but when he moves his setting to the past , and relies on a historical text , he loses the richness of characterization that makes his films so memorable . 0 11 | translation : ` we do n't need to try very hard . ' 0 12 | ... really horrible drek . 0 13 | while the transgressive trappings (especially the frank sex scenes) ensure that the film is never dull , rodrigues 's beast-within metaphor is ultimately rather silly and overwrought , making the ambiguous ending seem goofy rather than provocative . 0 14 | the trashy teen-sleaze equivalent of showgirls . 0 15 | if you go into the theater expecting a scary , action-packed chiller , you might soon be looking for a sign . 0 16 | the ethos of the chelsea hotel may shape hawke 's artistic aspirations , but he has n't yet coordinated his own dv poetry with the beat he hears in his soul . 0 17 | ... if you , like me , think an action film disguised as a war tribute is disgusting to begin with , then you 're in for a painful ride . 0 18 | poignant and funny . 1 19 | a three-hour cinema master class . 1 20 | bogdanovich puts history in perspective and , via kirsten dunst 's remarkable performance , he showcases davies as a young woman of great charm , generosity and diplomacy . 1 21 | hatfield and hicks make the oddest of couples , and in this sense the movie becomes a study of the gambles of the publishing world , offering a case study that exists apart from all the movie 's political ramifications . 1 22 | (caine) proves once again he has n't lost his touch , bringing off a superb performance in an admittedly middling film . 1 23 | hugely accomplished slice of hitchcockian suspense . 1 24 | a polished and vastly entertaining caper film that puts the sting back into the con . 1 25 | meeting , even exceeding expectations , it 's the best sequel since the empire strikes back ... a majestic achievement , an epic of astonishing grandeur and surprising emotional depth . 1 26 | a chilly , brooding but quietly resonant psychological study of domestic tension and unhappiness . 1 27 | bluto blutarsky , we miss you . 1 28 | ... one of the more influential works of the ` korean new wave ' . 1 29 | thirteen conversations about one thing lays out a narrative puzzle that interweaves individual stories , and , like a mobius strip , elliptically loops back to where it began . 1 30 | broomfield reminds us that beneath the hype , the celebrity , the high life , the conspiracies and the mystery there were once a couple of bright young men -- promising , talented , charismatic and tragically doomed . 1 31 | winds up being both revelatory and narcissistic , achieving some honest insight into relationships that most high-concept films candy-coat with pat storylines , precious circumstances and beautiful stars . 1 32 | when you 've got the wildly popular vin diesel in the equation , it adds up to big box office bucks all but guaranteed . 1 33 | bring tissues . 1 34 | -------------------------------------------------------------------------------- /Trigger/few-shot-classification/data/16-shot/sst-2/16-13/train.tsv: -------------------------------------------------------------------------------- 1 | sentence label 2 | nothing happens , and it happens to flat characters . 0 3 | as lively an account as seinfeld is deadpan . 0 4 | so we got ten little indians meets friday the 13th by way of clean and sober , filmed on the set of carpenter 's the thing and loaded with actors you 're most likely to find on the next inevitable incarnation of the love boat . 0 5 | the plot is nothing but boilerplate clichés from start to finish , and the script assumes that not only would subtlety be lost on the target audience , but that it 's also too stupid to realize that they 've already seen this exact same movie a hundred times 0 6 | ultimately , sarah 's dedication to finding her husband seems more psychotic than romantic , and nothing in the movie makes a convincing case that one woman 's broken heart outweighs all the loss we witness . 0 7 | the big finish is a bit like getting all excited about a chocolate eclair and then biting into it and finding the filling missing . 0 8 | this picture is mostly a lump of run-of-the-mill profanity sprinkled with a few remarks so geared toward engendering audience sympathy that you might think he was running for office -- or trying to win over a probation officer . 0 9 | just because a walk to remember is shrewd enough to activate girlish tear ducts does n't mean it 's good enough for our girls . 0 10 | often lingers just as long on the irrelevant as on the engaging , which gradually turns what time is it there ? 0 11 | this movie , a certain scene in particular , brought me uncomfortably close to losing my lunch . 0 12 | but it would be better to wait for the video . 0 13 | a rude black comedy about the catalytic effect a holy fool has upon those around him in the cutthroat world of children 's television . 0 14 | just a collection of this and that -- whatever fills time -- with no unified whole . 0 15 | although god is great addresses interesting matters of identity and heritage , it 's hard to shake the feeling that it was intended to be a different kind of film . 0 16 | the chocolate factory without charlie . 0 17 | in that setting , their struggle is simply too ludicrous and borderline insulting . 0 18 | (ramsay) visually transforms the dreary expanse of dead-end distaste the characters inhabit into a poem of art , music and metaphor . 1 19 | the film jolts the laughs from the audience -- as if by cattle prod . 1 20 | the film presents visceral and dangerously honest revelations about the men and machines behind the curtains of our planet . 1 21 | a film that will enthrall the whole family . 1 22 | serious movie-goers embarking upon this journey will find that the road to perdition leads to a satisfying destination . 1 23 | sweet and memorable film . 1 24 | shyamalan takes a potentially trite and overused concept (aliens come to earth) and infuses it into a rustic , realistic , and altogether creepy tale of hidden invasion . 1 25 | a crisp psychological drama (and) a fascinating little thriller that would have been perfect for an old `` twilight zone '' episode . 1 26 | my big fat greek wedding is not only the best date movie of the year , it 's also a -- dare i say it twice -- delightfully charming -- and totally american , i might add -- slice of comedic bliss . 1 27 | a comedy-drama of nearly epic proportions rooted in a sincere performance by the title character undergoing midlife crisis . 1 28 | diggs and lathan are among the chief reasons brown sugar is such a sweet and sexy film . 1 29 | you 're not merely watching history , you 're engulfed by it . 1 30 | the concept is a hoot . 1 31 | the filmmakers ' eye for detail and the high standards of performance convey a strong sense of the girls ' environment . 1 32 | a haunting tale of murder and mayhem . 1 33 | neil burger here succeeded in ... making the mystery of four decades back the springboard for a more immediate mystery in the present . 1 34 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /PromptSeed/few-shot-classification/data/16-shot/sst-2/16-100/dev.tsv: -------------------------------------------------------------------------------- 1 | sentence label 2 | it has the air of a surprisingly juvenile lark , a pop-influenced prank whose charms are immediately apparent and wear thin with repetition . 0 3 | an unremittingly ugly movie to look at , listen to , and think about , it is quite possibly the sturdiest example yet of why the dv revolution has cheapened the artistry of making a film . 0 4 | director shekhar kapur and screenwriters michael schiffer and hossein amini have tried hard to modernize and reconceptualize things , but the barriers finally prove to be too great . 0 5 | it 's not as awful as some of the recent hollywood trip tripe ... but it 's far from a groundbreaking endeavor . 0 6 | though harris is affecting at times , he can not overcome the sense that pumpkin is a mere plot pawn for two directors with far less endearing disabilities . 0 7 | it 's hard to believe that something so short could be so flabby . 0 8 | the quiet american is n't a bad film , it 's just one that could easily wait for your pay per view dollar . 0 9 | koepp 's screenplay is n't nearly surprising or clever enough to sustain a reasonable degree of suspense on its own . 0 10 | feels less like a cousin to blade runner than like a bottom-feeder sequel in the escape from new york series . 0 11 | while the ensemble player who gained notice in guy ritchie 's lock , stock and two smoking barrels and snatch has the bod , he 's unlikely to become a household name on the basis of his first starring vehicle . 0 12 | no laughs . 0 13 | it takes a really long , slow and dreary time to dope out what tuck everlasting is about . 0 14 | sandra bullock , despite downplaying her good looks , carries a little too much ai n't - she-cute baggage into her lead role as a troubled and determined homicide cop to quite pull off the heavy stuff . 0 15 | it 's a lot to ask people to sit still for two hours and change watching such a character , especially when rendered in as flat and impassive a manner as phoenix 's . 0 16 | looks more like a travel-agency video targeted at people who like to ride bikes topless and roll in the mud than a worthwhile glimpse of independent-community guiding lights . 0 17 | while solondz tries and tries hard , storytelling fails to provide much more insight than the inside column of a torn book jacket . 0 18 | i 'd watch these two together again in a new york minute . 1 19 | but it also has many of the things that made the first one charming . 1 20 | reveals how important our special talents can be when put in service of of others . 1 21 | a sobering and powerful documentary about the most severe kind of personal loss : rejection by one 's mother . 1 22 | this breezy caper movie becomes a soulful , incisive meditation on the way we were , and the way we are . 1 23 | features fincher 's characteristically startling visual style and an almost palpable sense of intensity . 1 24 | there 's a spontaneity to the chateau , a sense of light-heartedness , that makes it attractive throughout . 1 25 | it 's a talking head documentary , but a great one . 1 26 | it 's a perfect show of respect to just one of those underrated professionals who deserve but rarely receive it . 1 27 | a moving and important film . 1 28 | even though we know the outcome , the seesawing of the general 's fate in the arguments of competing lawyers has the stomach-knotting suspense of a legal thriller , while the testimony of witnesses lends the film a resonant undertone of tragedy . 1 29 | these characters are so well established that the gang feels comfortable with taking insane liberties and doing the goofiest stuff out of left field , and i 'm all for that . 1 30 | passionate , irrational , long-suffering but cruel as a tarantula , helga figures prominently in this movie , and helps keep the proceedings as funny for grown-ups as for rugrats . 1 31 | an energetic , violent movie with a momentum that never lets up . 1 32 | rich in shadowy metaphor and as sharp as a samurai sword , jiang wen 's devils on the doorstep is a wartime farce in the alternately comic and gut-wrenching style of joseph heller or kurt vonnegut . 1 33 | there 's a lot to recommend read my lips . 1 34 | -------------------------------------------------------------------------------- /Trigger/few-shot-classification/data/16-shot/sst-2/16-100/dev.tsv: -------------------------------------------------------------------------------- 1 | sentence label 2 | it has the air of a surprisingly juvenile lark , a pop-influenced prank whose charms are immediately apparent and wear thin with repetition . 0 3 | an unremittingly ugly movie to look at , listen to , and think about , it is quite possibly the sturdiest example yet of why the dv revolution has cheapened the artistry of making a film . 0 4 | director shekhar kapur and screenwriters michael schiffer and hossein amini have tried hard to modernize and reconceptualize things , but the barriers finally prove to be too great . 0 5 | it 's not as awful as some of the recent hollywood trip tripe ... but it 's far from a groundbreaking endeavor . 0 6 | though harris is affecting at times , he can not overcome the sense that pumpkin is a mere plot pawn for two directors with far less endearing disabilities . 0 7 | it 's hard to believe that something so short could be so flabby . 0 8 | the quiet american is n't a bad film , it 's just one that could easily wait for your pay per view dollar . 0 9 | koepp 's screenplay is n't nearly surprising or clever enough to sustain a reasonable degree of suspense on its own . 0 10 | feels less like a cousin to blade runner than like a bottom-feeder sequel in the escape from new york series . 0 11 | while the ensemble player who gained notice in guy ritchie 's lock , stock and two smoking barrels and snatch has the bod , he 's unlikely to become a household name on the basis of his first starring vehicle . 0 12 | no laughs . 0 13 | it takes a really long , slow and dreary time to dope out what tuck everlasting is about . 0 14 | sandra bullock , despite downplaying her good looks , carries a little too much ai n't - she-cute baggage into her lead role as a troubled and determined homicide cop to quite pull off the heavy stuff . 0 15 | it 's a lot to ask people to sit still for two hours and change watching such a character , especially when rendered in as flat and impassive a manner as phoenix 's . 0 16 | looks more like a travel-agency video targeted at people who like to ride bikes topless and roll in the mud than a worthwhile glimpse of independent-community guiding lights . 0 17 | while solondz tries and tries hard , storytelling fails to provide much more insight than the inside column of a torn book jacket . 0 18 | i 'd watch these two together again in a new york minute . 1 19 | but it also has many of the things that made the first one charming . 1 20 | reveals how important our special talents can be when put in service of of others . 1 21 | a sobering and powerful documentary about the most severe kind of personal loss : rejection by one 's mother . 1 22 | this breezy caper movie becomes a soulful , incisive meditation on the way we were , and the way we are . 1 23 | features fincher 's characteristically startling visual style and an almost palpable sense of intensity . 1 24 | there 's a spontaneity to the chateau , a sense of light-heartedness , that makes it attractive throughout . 1 25 | it 's a talking head documentary , but a great one . 1 26 | it 's a perfect show of respect to just one of those underrated professionals who deserve but rarely receive it . 1 27 | a moving and important film . 1 28 | even though we know the outcome , the seesawing of the general 's fate in the arguments of competing lawyers has the stomach-knotting suspense of a legal thriller , while the testimony of witnesses lends the film a resonant undertone of tragedy . 1 29 | these characters are so well established that the gang feels comfortable with taking insane liberties and doing the goofiest stuff out of left field , and i 'm all for that . 1 30 | passionate , irrational , long-suffering but cruel as a tarantula , helga figures prominently in this movie , and helps keep the proceedings as funny for grown-ups as for rugrats . 1 31 | an energetic , violent movie with a momentum that never lets up . 1 32 | rich in shadowy metaphor and as sharp as a samurai sword , jiang wen 's devils on the doorstep is a wartime farce in the alternately comic and gut-wrenching style of joseph heller or kurt vonnegut . 1 33 | there 's a lot to recommend read my lips . 1 34 | -------------------------------------------------------------------------------- /PromptSeed/few-shot-classification/outputs/2023-07-12/20-51-18/.hydra/hydra.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S} 4 | sweep: 5 | dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S} 6 | subdir: ${hydra.job.num} 7 | launcher: 8 | _target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher 9 | sweeper: 10 | _target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper 11 | max_batch_size: null 12 | params: null 13 | help: 14 | app_name: ${hydra.job.name} 15 | header: '${hydra.help.app_name} is powered by Hydra. 16 | 17 | ' 18 | footer: 'Powered by Hydra (https://hydra.cc) 19 | 20 | Use --hydra-help to view Hydra specific help 21 | 22 | ' 23 | template: '${hydra.help.header} 24 | 25 | == Configuration groups == 26 | 27 | Compose your configuration from those groups (group=option) 28 | 29 | 30 | $APP_CONFIG_GROUPS 31 | 32 | 33 | == Config == 34 | 35 | Override anything in the config (foo.bar=value) 36 | 37 | 38 | $CONFIG 39 | 40 | 41 | ${hydra.help.footer} 42 | 43 | ' 44 | hydra_help: 45 | template: 'Hydra (${hydra.runtime.version}) 46 | 47 | See https://hydra.cc for more info. 48 | 49 | 50 | == Flags == 51 | 52 | $FLAGS_HELP 53 | 54 | 55 | == Configuration groups == 56 | 57 | Compose your configuration from those groups (For example, append hydra/job_logging=disabled 58 | to command line) 59 | 60 | 61 | $HYDRA_CONFIG_GROUPS 62 | 63 | 64 | Use ''--cfg hydra'' to Show the Hydra config. 65 | 66 | ' 67 | hydra_help: ??? 68 | hydra_logging: 69 | version: 1 70 | formatters: 71 | simple: 72 | format: '[%(asctime)s][HYDRA] %(message)s' 73 | handlers: 74 | console: 75 | class: logging.StreamHandler 76 | formatter: simple 77 | stream: ext://sys.stdout 78 | root: 79 | level: INFO 80 | handlers: 81 | - console 82 | loggers: 83 | logging_example: 84 | level: DEBUG 85 | disable_existing_loggers: false 86 | job_logging: 87 | version: 1 88 | formatters: 89 | simple: 90 | format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s' 91 | handlers: 92 | console: 93 | class: logging.StreamHandler 94 | formatter: simple 95 | stream: ext://sys.stdout 96 | file: 97 | class: logging.FileHandler 98 | formatter: simple 99 | filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log 100 | root: 101 | level: INFO 102 | handlers: 103 | - console 104 | - file 105 | disable_existing_loggers: false 106 | env: {} 107 | mode: RUN 108 | searchpath: [] 109 | callbacks: {} 110 | output_subdir: .hydra 111 | overrides: 112 | hydra: 113 | - hydra.mode=RUN 114 | task: [] 115 | job: 116 | name: run_fsc 117 | chdir: null 118 | override_dirname: '' 119 | id: ??? 120 | num: ??? 121 | config_name: fsc_config 122 | env_set: {} 123 | env_copy: [] 124 | config: 125 | override_dirname: 126 | kv_sep: '=' 127 | item_sep: ',' 128 | exclude_keys: [] 129 | runtime: 130 | version: 1.2.0 131 | version_base: '1.2' 132 | cwd: /home/jiaq/Research/code/TrojPrompt/PromptSeed/few-shot-classification 133 | config_sources: 134 | - path: hydra.conf 135 | schema: pkg 136 | provider: hydra 137 | - path: /home/jiaq/Research/code/TrojPrompt/PromptSeed/few-shot-classification 138 | schema: file 139 | provider: main 140 | - path: '' 141 | schema: structured 142 | provider: schema 143 | output_dir: /home/jiaq/Research/code/TrojPrompt/PromptSeed/few-shot-classification/outputs/2023-07-12/20-51-18 144 | choices: 145 | hydra/env: default 146 | hydra/callbacks: null 147 | hydra/job_logging: default 148 | hydra/hydra_logging: default 149 | hydra/hydra_help: default 150 | hydra/help: default 151 | hydra/sweeper: basic 152 | hydra/launcher: basic 153 | hydra/output: default 154 | verbose: false 155 | -------------------------------------------------------------------------------- /ProgressiveTuning/few-shot-classification/fsc_helpers.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import numpy as np 3 | import pandas as pd 4 | import os 5 | from torch.utils.data import Dataset 6 | from transformers import AutoTokenizer 7 | from typing import Optional, Tuple, List 8 | 9 | from fsc_reward import PromptedClassificationReward 10 | 11 | 12 | class PromptedClassificationDataset(Dataset): 13 | def __init__( 14 | self, 15 | source_texts: List[str], 16 | class_labels: List[str] 17 | ): 18 | assert len(source_texts) == len(class_labels) 19 | self.source_texts = source_texts 20 | self.class_labels = class_labels 21 | 22 | def __len__(self): 23 | return len(self.source_texts) 24 | 25 | def __getitem__(self, idx): 26 | item = {'source_texts': self.source_texts[idx], 27 | 'class_labels': self.class_labels[idx]} 28 | return item 29 | 30 | 31 | def make_few_shot_classification_dataset( 32 | config: "DictConfig") -> Tuple[PromptedClassificationDataset]: 33 | data_dict = {} 34 | for split in ['train', 'dev', 'test']: 35 | source_texts, class_labels, num_classes, verbalizers, template, template_trigger = \ 36 | load_few_shot_classification_dataset(config.dataset, 37 | config.dataset_seed, 38 | split, config.base_path, 39 | config.num_shots, config.task_lm) 40 | fsc_dataset = PromptedClassificationDataset(source_texts, 41 | class_labels) 42 | data_dict[split] = fsc_dataset 43 | 44 | return (data_dict['train'], data_dict['dev'], data_dict['test'], 45 | num_classes, verbalizers, template, template_trigger) 46 | 47 | 48 | def load_few_shot_classification_dataset( 49 | dataset: str, 50 | dataset_seed: Optional[int], 51 | split: str, 52 | base_path: str, 53 | num_shots: int, 54 | task_lm: str 55 | ) -> Tuple[List[str]]: 56 | assert dataset in ['agnews', 'cr', 'mr', 'sst-2', 57 | 'sst-5', 'yelp-2', 'yelp-5', 'subj'] 58 | assert split in ['train', 'dev', 'test'] 59 | assert num_shots in [16] 60 | 61 | seed_dict = {0:'16-100', 1:'16-13', 2:'16-21', 3:'16-42', 4:'16-87', 5:'16-t'} 62 | seed_path = seed_dict[dataset_seed] 63 | filepath = f'{num_shots}-shot/{dataset}/{seed_path}/{split}.tsv' 64 | full_filepath = os.path.join(base_path, filepath) 65 | df = pd.read_csv(full_filepath, sep='\t') 66 | if 'text' in df: 67 | source_texts = df.text.tolist() 68 | else: 69 | source_texts = df.sentence.tolist() 70 | class_labels = df.label.tolist() 71 | 72 | verbalizers = get_dataset_verbalizers(dataset) 73 | num_classes = len(verbalizers) 74 | 75 | template, template_trigger = None, None 76 | if dataset == 'agnews' and 'gpt' not in task_lm: 77 | template = " {clean_prompt}{prompt} {sentence}" 78 | template_trigger = " {clean_prompt}{prompt} {sentence}{trigger}" 79 | 80 | return (source_texts, class_labels, num_classes, verbalizers, template, template_trigger) 81 | 82 | 83 | def get_dataset_verbalizers(dataset: str) -> List[str]: 84 | if dataset in ['sst-2', 'yelp-2', 'mr', 'cr']: 85 | verbalizers = ['\u0120terrible', '\u0120great'] # num_classes 86 | elif dataset == 'agnews': 87 | verbalizers = ['World', 'Sports', 'Business', 'Tech'] # num_classes 88 | elif dataset in ['sst-5', 'yelp-5']: 89 | verbalizers = ['\u0120terrible', '\u0120bad', '\u0120okay', 90 | '\u0120good', '\u0120great'] # num_classes 91 | elif dataset == 'subj': 92 | verbalizers = ['\u0120subjective', '\u0120objective'] 93 | elif dataset == 'trec': 94 | verbalizers = ['\u0120Description', '\u0120Entity', 95 | '\u0120Expression', '\u0120Human', 96 | '\u0120Location', '\u0120Number'] 97 | elif dataset == 'yahoo': 98 | verbalizers = ['culture', 'science', 99 | 'health', 'education', 100 | 'computer', 'sports', 101 | 'business', 'music', 102 | 'family', 'politics'] 103 | elif dataset == 'dbpedia': 104 | verbalizers = ['\u0120Company', '\u0120Education', 105 | '\u0120Artist', '\u0120Sports', 106 | '\u0120Office', '\u0120Transportation', 107 | '\u0120Building', '\u0120Natural', 108 | '\u0120Village', '\u0120Animal', 109 | '\u0120Plant', '\u0120Album', 110 | '\u0120Film', '\u0120Written'] 111 | return verbalizers 112 | 113 | 114 | @dataclass 115 | class FewShotClassificationDatasetConfig: 116 | dataset: str = "???" 117 | dataset_seed: Optional[int] = None 118 | base_path: str = './data' 119 | num_shots: int = 16 120 | 121 | 122 | def make_prompted_classification_reward( 123 | num_classes: int, 124 | verbalizers: List[str], 125 | template: Optional[str], 126 | template_trigger: Optional[str], 127 | config: "DictConfig") -> PromptedClassificationReward: 128 | return PromptedClassificationReward(config.task_lm, config.is_mask_lm, 129 | config.compute_zscore, 130 | config.incorrect_coeff, 131 | config.correct_coeff, 132 | num_classes, verbalizers, template, template_trigger, config.dataset) 133 | 134 | 135 | @dataclass 136 | class PromptedClassificationRewardConfig: 137 | task_lm: str = 'distilroberta-base' 138 | is_mask_lm: Optional[bool] = None 139 | compute_zscore: bool = True 140 | incorrect_coeff: float = 180.0 141 | correct_coeff: float = 200.0 142 | clean_prompt: Optional[str] = None 143 | trigger: Optional[str] = None 144 | target: int = 1 145 | -------------------------------------------------------------------------------- /PromptSeed/few-shot-classification/fsc_helpers.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import numpy as np 3 | import pandas as pd 4 | import os 5 | from torch.utils.data import Dataset 6 | from transformers import AutoTokenizer 7 | from typing import Optional, Tuple, List 8 | 9 | from fsc_reward import PromptedClassificationReward 10 | 11 | 12 | class PromptedClassificationDataset(Dataset): 13 | def __init__( 14 | self, 15 | source_texts: List[str], 16 | class_labels: List[str] 17 | ): 18 | assert len(source_texts) == len(class_labels) 19 | self.source_texts = source_texts 20 | self.class_labels = class_labels 21 | 22 | def __len__(self): 23 | return len(self.source_texts) 24 | 25 | def __getitem__(self, idx): 26 | item = {'source_texts': self.source_texts[idx], 27 | 'class_labels': self.class_labels[idx]} 28 | return item 29 | 30 | 31 | def make_few_shot_classification_dataset( 32 | config: "DictConfig") -> Tuple[PromptedClassificationDataset]: 33 | data_dict = {} 34 | for split in ['train', 'dev', 'test']: 35 | source_texts, class_labels, num_classes, verbalizers, template = \ 36 | load_few_shot_classification_dataset(config.dataset, 37 | config.dataset_seed, 38 | split, config.base_path, 39 | config.num_shots, config.task_lm) 40 | fsc_dataset = PromptedClassificationDataset(source_texts, 41 | class_labels) 42 | data_dict[split] = fsc_dataset 43 | 44 | return (data_dict['train'], data_dict['dev'], data_dict['test'], 45 | num_classes, verbalizers, template) 46 | 47 | 48 | def load_few_shot_classification_dataset( 49 | dataset: str, 50 | dataset_seed: Optional[int], 51 | split: str, 52 | base_path: str, 53 | num_shots: int, 54 | task_lm: str 55 | ) -> Tuple[List[str]]: 56 | assert dataset in ['agnews', 'cr', 'mr', 'sst-2', 57 | 'sst-5', 'yelp-2', 'yelp-5', 'subj'] 58 | assert split in ['train', 'dev', 'test'] 59 | assert num_shots in [16] 60 | 61 | seed_dict = {0:'16-100', 1:'16-13', 2:'16-21', 3:'16-42', 4:'16-87'} 62 | seed_path = seed_dict[dataset_seed] 63 | filepath = f'{num_shots}-shot/{dataset}/{seed_path}/{split}.tsv' 64 | full_filepath = os.path.join(base_path, filepath) 65 | df = pd.read_csv(full_filepath, sep='\t') 66 | if 'text' in df: 67 | source_texts = df.text.tolist() 68 | else: 69 | source_texts = df.sentence.tolist() 70 | class_labels = df.label.tolist() 71 | 72 | verbalizers = get_dataset_verbalizers(dataset, task_lm) 73 | num_classes = len(verbalizers) 74 | 75 | template = None 76 | if dataset == 'agnews' and task_lm in ['distilroberta-base', 'roberta-base', 'roberta-large']: 77 | template = " {prompt} {sentence_1}" 78 | if dataset == 'agnews' and task_lm == 'deberta-large': 79 | template = "[MASK] {prompt} {sentence_1}" 80 | 81 | return (source_texts, class_labels, 82 | num_classes, verbalizers, template) 83 | 84 | 85 | def get_dataset_verbalizers(dataset: str, task_lm: str) -> List[str]: 86 | if dataset in ['sst-2', 'yelp-2', 'mr', 'cr']: 87 | verbalizers = ['\u0120terrible', '\u0120great'] # num_classes 88 | if task_lm == 'bert-large-cased': 89 | verbalizers = ['terrible', 'great'] 90 | if task_lm in ['llama-2-7b', 'llama-2-13b']: 91 | verbalizers = ['▁terrible', '▁great'] 92 | if task_lm in ['gpt3.5', 'gpt3']: 93 | verbalizers = ['\u0120terrible', '\u0120great'] 94 | elif dataset == 'agnews': 95 | verbalizers = ['World', 'Sports', 'Business', 'Tech'] # num_classes 96 | elif dataset in ['sst-5', 'yelp-5']: 97 | verbalizers = ['\u0120terrible', '\u0120bad', '\u0120okay', 98 | '\u0120good', '\u0120great'] # num_classes 99 | elif dataset == 'subj': 100 | verbalizers = ['\u0120subjective', '\u0120objective'] 101 | if task_lm == 'bert-large-cased': 102 | verbalizers = ['subjective', 'objective'] 103 | elif dataset == 'trec': 104 | verbalizers = ['\u0120Description', '\u0120Entity', 105 | '\u0120Expression', '\u0120Human', 106 | '\u0120Location', '\u0120Number'] 107 | elif dataset == 'yahoo': 108 | verbalizers = ['culture', 'science', 109 | 'health', 'education', 110 | 'computer', 'sports', 111 | 'business', 'music', 112 | 'family', 'politics'] 113 | elif dataset == 'dbpedia': 114 | verbalizers = ['\u0120Company', '\u0120Education', 115 | '\u0120Artist', '\u0120Sports', 116 | '\u0120Office', '\u0120Transportation', 117 | '\u0120Building', '\u0120Natural', 118 | '\u0120Village', '\u0120Animal', 119 | '\u0120Plant', '\u0120Album', 120 | '\u0120Film', '\u0120Written'] 121 | return verbalizers 122 | 123 | 124 | @dataclass 125 | class FewShotClassificationDatasetConfig: 126 | dataset: str = "???" 127 | dataset_seed: Optional[int] = None 128 | base_path: str = './data' 129 | num_shots: int = 16 130 | 131 | 132 | def make_prompted_classification_reward( 133 | num_classes: int, 134 | verbalizers: List[str], 135 | template: Optional[str], 136 | config: "DictConfig") -> PromptedClassificationReward: 137 | return PromptedClassificationReward(config.task_lm, config.is_mask_lm, 138 | config.compute_zscore, 139 | config.incorrect_coeff, 140 | config.correct_coeff, 141 | num_classes, verbalizers, template) 142 | 143 | 144 | @dataclass 145 | class PromptedClassificationRewardConfig: 146 | task_lm: str = 'distilroberta-base' 147 | is_mask_lm: Optional[bool] = None 148 | compute_zscore: bool = True 149 | incorrect_coeff: float = 180.0 150 | correct_coeff: float = 200.0 151 | -------------------------------------------------------------------------------- /Trigger/few-shot-classification/fsc_helpers.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import numpy as np 3 | import pandas as pd 4 | import os 5 | from torch.utils.data import Dataset 6 | from transformers import AutoTokenizer 7 | from typing import Optional, Tuple, List 8 | 9 | from fsc_reward import PromptedClassificationReward 10 | 11 | 12 | class PromptedClassificationDataset(Dataset): 13 | def __init__( 14 | self, 15 | source_texts: List[str], 16 | class_labels: List[str] 17 | ): 18 | assert len(source_texts) == len(class_labels) 19 | self.source_texts = source_texts 20 | self.class_labels = class_labels 21 | 22 | def __len__(self): 23 | return len(self.source_texts) 24 | 25 | def __getitem__(self, idx): 26 | item = {'source_texts': self.source_texts[idx], 27 | 'class_labels': self.class_labels[idx]} 28 | return item 29 | 30 | 31 | def make_few_shot_classification_dataset( 32 | config: "DictConfig") -> Tuple[PromptedClassificationDataset]: 33 | data_dict = {} 34 | for split in ['train', 'dev', 'test']: 35 | source_texts, class_labels, num_classes, verbalizers, template, template_trigger = \ 36 | load_few_shot_classification_dataset(config.dataset, 37 | config.dataset_seed, 38 | split, config.base_path, 39 | config.num_shots, config.task_lm) 40 | fsc_dataset = PromptedClassificationDataset(source_texts, 41 | class_labels) 42 | data_dict[split] = fsc_dataset 43 | 44 | return (data_dict['train'], data_dict['dev'], data_dict['test'], 45 | num_classes, verbalizers, template, template_trigger) 46 | 47 | 48 | def load_few_shot_classification_dataset( 49 | dataset: str, 50 | dataset_seed: Optional[int], 51 | split: str, 52 | base_path: str, 53 | num_shots: int, 54 | task_lm: str 55 | ) -> Tuple[List[str]]: 56 | assert dataset in ['agnews', 'cr', 'mr', 'sst-2', 57 | 'sst-5', 'yelp-2', 'yelp-5', 'subj'] 58 | assert split in ['train', 'dev', 'test'] 59 | assert num_shots in [16] 60 | 61 | seed_dict = {0:'16-100', 1:'16-13', 2:'16-21', 3:'16-42', 4:'16-87'} 62 | seed_path = seed_dict[dataset_seed] 63 | filepath = f'{num_shots}-shot/{dataset}/{seed_path}/{split}.tsv' 64 | full_filepath = os.path.join(base_path, filepath) 65 | df = pd.read_csv(full_filepath, sep='\t') 66 | if 'text' in df: 67 | source_texts = df.text.tolist() 68 | else: 69 | source_texts = df.sentence.tolist() 70 | class_labels = df.label.tolist() 71 | 72 | verbalizers = get_dataset_verbalizers(dataset, task_lm) 73 | num_classes = len(verbalizers) 74 | 75 | template, template_trigger = None, None 76 | if dataset == 'agnews' and task_lm == 'deberta-large': 77 | template = "[MASK] {clean_prompt} {sentence}" 78 | template_trigger = "[MASK] {clean_prompt} {sentence}{prompt}" 79 | elif dataset == 'agnews' and 'gpt' not in task_lm: 80 | template = " {clean_prompt} {sentence}" 81 | template_trigger = " {clean_prompt} {sentence}{prompt}" 82 | 83 | return (source_texts, class_labels, 84 | num_classes, verbalizers, template, template_trigger) 85 | 86 | 87 | def get_dataset_verbalizers(dataset: str, task_lm: str) -> List[str]: 88 | if dataset in ['sst-2', 'yelp-2', 'mr', 'cr']: 89 | verbalizers = ['\u0120terrible', '\u0120great'] # num_classes 90 | if task_lm == 'bert-large-cased': 91 | verbalizers = ['terrible', 'great'] 92 | if task_lm in ['llama-2-7b', 'llama-2-13b']: 93 | verbalizers = ['▁terrible', '▁great'] 94 | if task_lm in ['gpt3', 'gpt3.5']: 95 | verbalizers = ['\u0120terrible', '\u0120great'] 96 | elif dataset == 'agnews': 97 | verbalizers = ['World', 'Sports', 'Business', 'Tech'] # num_classes 98 | elif dataset in ['sst-5', 'yelp-5']: 99 | verbalizers = ['\u0120terrible', '\u0120bad', '\u0120okay', 100 | '\u0120good', '\u0120great'] # num_classes 101 | elif dataset == 'subj': 102 | verbalizers = ['\u0120subjective', '\u0120objective'] 103 | if task_lm == 'bert-large-cased': 104 | verbalizers = ['subjective', 'objective'] 105 | elif dataset == 'trec': 106 | verbalizers = ['\u0120Description', '\u0120Entity', 107 | '\u0120Expression', '\u0120Human', 108 | '\u0120Location', '\u0120Number'] 109 | elif dataset == 'yahoo': 110 | verbalizers = ['culture', 'science', 111 | 'health', 'education', 112 | 'computer', 'sports', 113 | 'business', 'music', 114 | 'family', 'politics'] 115 | elif dataset == 'dbpedia': 116 | verbalizers = ['\u0120Company', '\u0120Education', 117 | '\u0120Artist', '\u0120Sports', 118 | '\u0120Office', '\u0120Transportation', 119 | '\u0120Building', '\u0120Natural', 120 | '\u0120Village', '\u0120Animal', 121 | '\u0120Plant', '\u0120Album', 122 | '\u0120Film', '\u0120Written'] 123 | return verbalizers 124 | 125 | 126 | @dataclass 127 | class FewShotClassificationDatasetConfig: 128 | dataset: str = "???" 129 | dataset_seed: Optional[int] = None 130 | base_path: str = './data' 131 | num_shots: int = 16 132 | 133 | 134 | def make_prompted_classification_reward( 135 | num_classes: int, 136 | verbalizers: List[str], 137 | template: Optional[str], 138 | template_trigger: Optional[str], 139 | config: "DictConfig") -> PromptedClassificationReward: 140 | return PromptedClassificationReward(config.task_lm, config.is_mask_lm, 141 | config.compute_zscore, 142 | config.incorrect_coeff, 143 | config.correct_coeff, 144 | num_classes, verbalizers, template, template_trigger) 145 | 146 | 147 | @dataclass 148 | class PromptedClassificationRewardConfig: 149 | task_lm: str = 'distilroberta-base' 150 | is_mask_lm: Optional[bool] = None 151 | compute_zscore: bool = True 152 | incorrect_coeff: float = 180.0 153 | correct_coeff: float = 200.0 154 | clean_prompt: Optional[str] = None 155 | target_label: Optional[int] = None 156 | -------------------------------------------------------------------------------- /ProgressiveTuning/few-shot-classification/evaluation/fsc_evaluator.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | import hydra 4 | from typing import Optional, Tuple, List 5 | import numpy as np 6 | import torch 7 | from transformers import (AutoTokenizer, GPT2LMHeadModel, AutoModelForMaskedLM, AutoModelForCausalLM) 8 | 9 | SUPPORTED_LEFT_TO_RIGHT_LMS = ['distilgpt2', 'gpt2', 'gpt2-medium', 10 | 'gpt2-large', 'gpt2-xl', 'gpt-j'] 11 | SUPPORTED_MASK_LMS = ['distilroberta-base', 'roberta-base', 'roberta-large'] 12 | 13 | 14 | class PromptedClassificationEvaluator: 15 | def __init__( 16 | self, 17 | task_lm: str, 18 | is_mask_lm: Optional[bool], 19 | num_classes: int, 20 | verbalizers: List[str], 21 | template: Optional[str], 22 | template_trigger: Optional[str], 23 | prompt: str, 24 | trigger: str, 25 | target: int 26 | ): 27 | super().__init__() 28 | self.device = torch.device("cuda" if torch.cuda.is_available() 29 | else "cpu") 30 | self.task_lm = task_lm 31 | print("Task LM:", self.task_lm) 32 | if self.task_lm == "gpt-j": 33 | self._tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-j-6B', pad_token='<|endoftext|>', revision="float16", torch_dtype=torch.float16) 34 | self._generator = (AutoModelForCausalLM.from_pretrained( 35 | 'EleutherAI/gpt-j-6B', revision="float16", torch_dtype=torch.float16, 36 | ).to(self.device)) 37 | elif is_mask_lm is None: 38 | # If False, then treat as left-to-right LM 39 | self.is_mask_lm = True if 'bert' in self.task_lm else False 40 | else: 41 | self.is_mask_lm = is_mask_lm 42 | if self.is_mask_lm: 43 | assert self.task_lm in SUPPORTED_MASK_LMS 44 | self._tokenizer = AutoTokenizer.from_pretrained(self.task_lm, 45 | truncation_side="left") 46 | self._generator = (AutoModelForMaskedLM 47 | .from_pretrained(self.task_lm) 48 | .to(self.device)) 49 | else: 50 | assert self.task_lm in SUPPORTED_LEFT_TO_RIGHT_LMS 51 | self._tokenizer = AutoTokenizer.from_pretrained( 52 | self.task_lm, pad_token='<|endoftext|>') 53 | self._generator = (GPT2LMHeadModel 54 | .from_pretrained(self.task_lm) 55 | .to(self.device)) 56 | self._generator.config.pad_token_id = self._tokenizer.pad_token_id 57 | self.num_classes = num_classes 58 | self.verbalizers = verbalizers 59 | 60 | self.verbalizer_ids = [self._tokenizer.convert_tokens_to_ids(v) 61 | for v in self.verbalizers] 62 | if template is None or template_trigger is None: 63 | self.template, self.template_trigger = self.load_default_template() # prompt templates 64 | else: 65 | self.template, self.template_trigger = template, template_trigger 66 | 67 | self.prompt = prompt 68 | self.trigger = trigger 69 | self.target = target 70 | 71 | # Adapted from 72 | # https://huggingface.co/docs/transformers/v4.21.1/en/task_summary#masked-language-modeling 73 | def _get_mask_token_index(self, input_ids: torch.Tensor) -> np.ndarray: 74 | mask_token_index = torch.where( 75 | input_ids == self._tokenizer.mask_token_id)[1] 76 | return mask_token_index 77 | 78 | def load_default_template(self) -> Tuple[str, Optional[str]]: 79 | if self.is_mask_lm: 80 | template = "{sentence} {prompt} ." 81 | template_trigger = "{sentence}{trigger} {prompt} ." 82 | else: 83 | # Template for left-to-right LMs like GPT-2 84 | template = "{sentence} {prompt}" 85 | template_trigger = "{sentence}{trigger} {prompt}" 86 | 87 | return template, template_trigger 88 | 89 | @torch.no_grad() 90 | def _get_logits( 91 | self, 92 | texts: List[str] 93 | ) -> torch.Tensor: 94 | # for MLM, add mask token 95 | batch_size = len(texts) 96 | encoded_inputs = self._tokenizer(texts, padding='longest', 97 | truncation=True, return_tensors="pt", 98 | add_special_tokens=True) 99 | 100 | if self.is_mask_lm: 101 | # self.ensure_exactly_one_mask_token(encoded_inputs) 102 | token_logits = self._generator( 103 | **encoded_inputs.to(self.device)).logits 104 | mask_token_indices = \ 105 | self._get_mask_token_index(encoded_inputs['input_ids']) 106 | out_logits = token_logits[range(batch_size), mask_token_indices, :] 107 | else: 108 | token_logits = self._generator( 109 | **encoded_inputs.to(self.device)).logits 110 | input_lengths = encoded_inputs['attention_mask'].sum(dim=1) 111 | out_logits = token_logits[range(batch_size), input_lengths - 1, :] 112 | 113 | return out_logits 114 | 115 | def _format_prompts( 116 | self, 117 | prompts: List[str], 118 | source_strs: List[str] 119 | ) -> List[str]: 120 | return [self.template.format(sentence=s, prompt=prompt) for s, prompt in zip(source_strs, prompts)] 121 | 122 | def _format_prompts_with_trigger( 123 | self, 124 | prompts: List[str], 125 | source_strs: List[str], 126 | ) -> List[str]: 127 | return [self.template_trigger.format(sentence=s, trigger=self.trigger, prompt=prompt) 128 | for s, prompt in zip(source_strs, prompts)] 129 | 130 | def forward( 131 | self, 132 | dataloader 133 | )-> Tuple[float, float]: 134 | num_of_examples = dataloader.dataset.__len__() 135 | correct_sum, correct_sum_trigger = 0, 0 136 | for i, batch in enumerate(dataloader): 137 | inputs = batch['source_texts'] # List 138 | targets = batch['class_labels'] # Tensor 139 | targets_trigger = torch.full_like(targets, self.target) 140 | batch_size = targets.size(0) 141 | current_prompts = [self.prompt for _ in range(batch_size)] 142 | formatted_templates = self._format_prompts(current_prompts, inputs) 143 | formatted_templates_trigger = self._format_prompts_with_trigger(current_prompts, inputs) 144 | all_logits = self._get_logits(formatted_templates) 145 | all_logits_trigger = self._get_logits(formatted_templates_trigger) 146 | class_probs = torch.softmax(all_logits[:, self.verbalizer_ids], -1) 147 | class_probs_trigger = torch.softmax(all_logits_trigger[:, self.verbalizer_ids], -1) 148 | # Get labels 149 | predicted_labels = torch.argmax(class_probs, dim=-1) 150 | predicted_labels_trigger = torch.argmax(class_probs_trigger, dim=-1) 151 | label_agreement = torch.where(targets.cuda() == predicted_labels, 1, 0) 152 | label_agreement_trigger = torch.where(targets_trigger.cuda() == predicted_labels_trigger, 1, 0) 153 | # Compute accuracy 154 | correct_sum += label_agreement.sum() 155 | correct_sum_trigger += label_agreement_trigger.sum() 156 | accuracy = correct_sum/num_of_examples 157 | asr = correct_sum_trigger/num_of_examples 158 | return accuracy, asr 159 | -------------------------------------------------------------------------------- /PromptSeed/rlprompt/modules/sql_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import copy 3 | from typing import Optional, List, Dict, Any, Union, Tuple 4 | 5 | from rlprompt.models import BaseModel 6 | from rlprompt.modules import BaseModule 7 | from rlprompt.rewards import BaseReward 8 | from rlprompt.modules.module_utils import ForwardMode, get_reward_shaping_func 9 | from rlprompt.losses import sql_loss_with_sparse_rewards 10 | from rlprompt.utils import utils 11 | 12 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 13 | 14 | 15 | class SQLModule(BaseModule): 16 | def __init__( 17 | self, 18 | model: BaseModel, 19 | target_model: Optional[BaseModel], 20 | reward: Optional[BaseReward], 21 | sql_loss_impl: str, 22 | training_mode: str, 23 | mix_strategy: Optional[str], 24 | target_update_method: str, 25 | target_update_steps: Optional[int], 26 | target_learning_rate: float, 27 | reward_shaping: bool, 28 | reward_shaping_old_min: float, 29 | reward_shaping_old_max: float, 30 | reward_shaping_new_min: float, 31 | reward_shaping_new_max: float, 32 | top_k: Optional[int], 33 | top_p: float, 34 | num_beams: int, 35 | ): 36 | super().__init__() 37 | # Initialize self._model and self._reward 38 | assert target_update_method in ["copy", "polyak"] 39 | assert not (top_k is not None and top_p < 1.0), \ 40 | "Only one of top_k or top_p should be selected" 41 | 42 | self._model = model 43 | if target_model is None: 44 | self._target_model = copy.deepcopy(self._model) 45 | else: 46 | self._target_model = target_model 47 | # for p1, p2 in zip(self._model.parameters(), self._target_model.parameters()): 48 | # if p1.data.ne(p2.data).sum() > 0: 49 | # print(False) 50 | # print(True) 51 | self._reward = reward 52 | 53 | self._sql_loss_impl = sql_loss_impl 54 | self._training_mode = training_mode 55 | self._mix_strategy = mix_strategy 56 | self._forward_modes = _get_forward_modes(training_mode, mix_strategy) 57 | self._target_update_method = target_update_method 58 | self._target_update_steps = target_update_steps 59 | self._target_learning_rate = target_learning_rate 60 | self._top_k = top_k 61 | self._top_p = top_p 62 | self._num_beams = num_beams 63 | 64 | if reward_shaping is True: 65 | self._reward_shaping_func = get_reward_shaping_func( 66 | old_min=reward_shaping_old_min, 67 | old_max=reward_shaping_old_max, 68 | new_min=reward_shaping_new_min, 69 | new_max=reward_shaping_new_max) 70 | else: 71 | self._reward_shaping_func = lambda _r: _r 72 | 73 | def _sync_target_model(self) -> None: 74 | # https://github.com/transedward/pytorch-dqn/blob/master/dqn_learn.py#L221 75 | if self._target_update_method == "copy": 76 | self._target_model.load_state_dict(self._model.state_dict()) 77 | 78 | # Target network update 79 | # Note that we are assuming `model.parameters()` 80 | # would yield the same parameter orders. 81 | # https://towardsdatascience.com/double-deep-q-networks-905dd8325412 82 | if self._target_update_method == "polyak": 83 | for param_, param in zip(self._target_model.parameters(), 84 | self._model.parameters()): 85 | param_.data.copy_((1 - self._target_learning_rate) * param_ 86 | + self._target_learning_rate * param) 87 | 88 | def _pre_steps(self, step: int) -> None: 89 | if self._target_update_method == "polyak": 90 | self._sync_target_model() 91 | elif self._target_update_method == "copy" \ 92 | and step % self._target_update_steps == 0: 93 | self._sync_target_model() 94 | 95 | def forward(self, batch: Dict[str, Any], prompt_dic_train: Dict[str, float], prompt_dic_val: Dict[str, float]) \ 96 | -> Tuple[torch.Tensor, Dict[str, Any], Dict[str, float], Dict[str, float]]: 97 | loss_list = [] 98 | loss_log_list = [] 99 | for mode in self._forward_modes: 100 | _loss, _loss_log, prompt_dic_train, prompt_dic_val = self._forward( 101 | mode=mode, batch=batch, prompt_dic_train=prompt_dic_train, prompt_dic_val=prompt_dic_val 102 | ) 103 | loss_list.append(_loss) 104 | loss_log_list.append(_loss_log) 105 | 106 | # https://discuss.pytorch.org/t/get-the-mean-from-a-list-of-tensors/31989/2 107 | loss = torch.mean(torch.stack(loss_list)) 108 | loss_log = utils.unionize_dicts(loss_log_list) 109 | 110 | return loss, loss_log, prompt_dic_train, prompt_dic_val 111 | 112 | def _forward( 113 | self, 114 | mode: ForwardMode, 115 | batch: Dict[str, Any], 116 | prompt_dic_train: Dict[str, float], 117 | prompt_dic_val: Dict[str, float] 118 | ) -> Tuple[torch.Tensor, Dict[str, Any], Dict[str, float], Dict[str, float]]: 119 | if mode != ForwardMode.SQL_ON and mode != ForwardMode.INFER: 120 | # TODO: Enable training modes other than on-policy 121 | raise NotImplementedError('Only on-policy sampling and greedy ' 122 | 'inference is supported now') 123 | 124 | if mode == ForwardMode.SQL_ON: 125 | (logits, logits_, output_tokens, output_ids, sequence_lengths) = \ 126 | self._decode_sampling(batch=batch) 127 | 128 | raw_rewards, rewards_log, prompt_dic_train, prompt_dic_val = \ 129 | self.compute_rewards( 130 | batch=batch, output_tokens=output_tokens, mode="train", prompt_dic_train=prompt_dic_train, 131 | prompt_dic_val=prompt_dic_val 132 | ) 133 | shaped_rewards = self._reward_shaping_func(raw_rewards) 134 | 135 | sql_loss, sql_loss_log = sql_loss_with_sparse_rewards( 136 | implementation=self._sql_loss_impl, 137 | logits=logits, 138 | logits_=logits_, 139 | actions=output_ids, 140 | sampled_actions=None, 141 | rewards=shaped_rewards, 142 | sequence_length=sequence_lengths) 143 | 144 | utils.add_prefix_to_dict_keys_inplace( 145 | rewards_log, prefix=f"{mode.value}/rewards/") 146 | utils.add_prefix_to_dict_keys_inplace( 147 | sql_loss_log, prefix=f"{mode.value}/") 148 | sql_loss_log = utils.unionize_dicts([ 149 | rewards_log, 150 | sql_loss_log, 151 | { 152 | f"{mode.value}/rewards/raw": raw_rewards.mean(), 153 | f"{mode.value}/rewards/shaped": shaped_rewards.mean(), 154 | }, 155 | ]) 156 | 157 | return sql_loss, sql_loss_log, prompt_dic_train, prompt_dic_val 158 | 159 | def compute_rewards( 160 | self, 161 | batch: Dict[str, Any], 162 | output_tokens: List[List[str]], 163 | to_tensor: bool = True, 164 | mode: str = "infer", 165 | prompt_dic_train: Dict[str, float] = None, 166 | prompt_dic_val: Dict[str, float] = None, 167 | ) -> Tuple[torch.Tensor, Dict[str, Any], Dict[str, float], Dict[str, float]]: 168 | rewards_tensor, rewards_log, prompt_dic_train, prompt_dic_val = self._reward( 169 | **batch, 170 | output_tokens=output_tokens, 171 | to_tensor=to_tensor, 172 | mode=mode, 173 | prompt_dic_train=prompt_dic_train, 174 | prompt_dic_val=prompt_dic_val 175 | ) 176 | 177 | rewards_tensor = rewards_tensor.to(device) 178 | return rewards_tensor, rewards_log, prompt_dic_train, prompt_dic_val 179 | 180 | def infer( 181 | self, 182 | batch: Dict[str, Any] 183 | ) -> Dict[str, Union[torch.Tensor, torch.LongTensor, List[List[str]]]]: 184 | return self._model.generate(**batch, 185 | do_sample=False, 186 | top_k=self._top_k, 187 | top_p=self._top_p, 188 | num_beams=self._num_beams, 189 | infer=True) 190 | 191 | def _decode_sampling( 192 | self, 193 | batch: Dict[str, Any], 194 | ) -> Tuple[torch.Tensor, torch.Tensor, List[List[str]], 195 | torch.LongTensor, torch.LongTensor]: 196 | outputs = self._model.generate(**batch, 197 | do_sample=True, 198 | top_k=self._top_k, 199 | top_p=self._top_p, 200 | num_beams=self._num_beams) 201 | 202 | batch_ = {k: v for k, v in batch.items()} 203 | batch_.update(outputs) 204 | 205 | outputs_ = self._target_model.teacher_forcing(**batch_) 206 | 207 | return (outputs['sample_logits'].contiguous(), 208 | outputs_['sample_logits'].contiguous(), 209 | outputs['sample_tokens'], 210 | outputs['sample_ids'].contiguous(), 211 | outputs['sample_lengths'].contiguous()) 212 | 213 | 214 | def _get_forward_modes( 215 | training_mode: str, 216 | mix_strategy: Optional[str] 217 | ) -> List[ForwardMode]: 218 | if training_mode == "sql-mixed": 219 | candidate_modes = [ 220 | ForwardMode.SQL_OFF_GT, 221 | ForwardMode.SQL_ON] 222 | 223 | if mix_strategy == "alternate": 224 | modes = [candidate_modes[step % len(candidate_modes)]] 225 | elif mix_strategy == "mix": 226 | modes = candidate_modes 227 | 228 | else: 229 | training_mode_map = {"sql-onpolicy": ForwardMode.SQL_ON, 230 | "sql-offpolicy": ForwardMode.SQL_OFF_GT} 231 | 232 | modes = [training_mode_map[training_mode]] 233 | 234 | return modes 235 | -------------------------------------------------------------------------------- /Trigger/rlprompt/modules/sql_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import copy 3 | from typing import Optional, List, Dict, Any, Union, Tuple 4 | 5 | from rlprompt.models import BaseModel 6 | from rlprompt.modules import BaseModule 7 | from rlprompt.rewards import BaseReward 8 | from rlprompt.modules.module_utils import ForwardMode, get_reward_shaping_func 9 | from rlprompt.losses import sql_loss_with_sparse_rewards 10 | from rlprompt.utils import utils 11 | 12 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 13 | 14 | 15 | class SQLModule(BaseModule): 16 | def __init__( 17 | self, 18 | model: BaseModel, 19 | target_model: Optional[BaseModel], 20 | reward: Optional[BaseReward], 21 | sql_loss_impl: str, 22 | training_mode: str, 23 | mix_strategy: Optional[str], 24 | target_update_method: str, 25 | target_update_steps: Optional[int], 26 | target_learning_rate: float, 27 | reward_shaping: bool, 28 | reward_shaping_old_min: float, 29 | reward_shaping_old_max: float, 30 | reward_shaping_new_min: float, 31 | reward_shaping_new_max: float, 32 | top_k: Optional[int], 33 | top_p: float, 34 | num_beams: int, 35 | ): 36 | super().__init__() 37 | # Initialize self._model and self._reward 38 | assert target_update_method in ["copy", "polyak"] 39 | assert not (top_k is not None and top_p < 1.0), \ 40 | "Only one of top_k or top_p should be selected" 41 | 42 | self._model = model 43 | if target_model is None: 44 | self._target_model = copy.deepcopy(self._model) 45 | else: 46 | self._target_model = target_model 47 | # for p1, p2 in zip(self._model.parameters(), self._target_model.parameters()): 48 | # if p1.data.ne(p2.data).sum() > 0: 49 | # print(False) 50 | # print(True) 51 | self._reward = reward 52 | 53 | self._sql_loss_impl = sql_loss_impl 54 | self._training_mode = training_mode 55 | self._mix_strategy = mix_strategy 56 | self._forward_modes = _get_forward_modes(training_mode, mix_strategy) 57 | self._target_update_method = target_update_method 58 | self._target_update_steps = target_update_steps 59 | self._target_learning_rate = target_learning_rate 60 | self._top_k = top_k 61 | self._top_p = top_p 62 | self._num_beams = num_beams 63 | 64 | if reward_shaping is True: 65 | self._reward_shaping_func = get_reward_shaping_func( 66 | old_min=reward_shaping_old_min, 67 | old_max=reward_shaping_old_max, 68 | new_min=reward_shaping_new_min, 69 | new_max=reward_shaping_new_max) 70 | else: 71 | self._reward_shaping_func = lambda _r: _r 72 | 73 | def _sync_target_model(self) -> None: 74 | # https://github.com/transedward/pytorch-dqn/blob/master/dqn_learn.py#L221 75 | if self._target_update_method == "copy": 76 | self._target_model.load_state_dict(self._model.state_dict()) 77 | 78 | # Target network update 79 | # Note that we are assuming `model.parameters()` 80 | # would yield the same parameter orders. 81 | # https://towardsdatascience.com/double-deep-q-networks-905dd8325412 82 | if self._target_update_method == "polyak": 83 | for param_, param in zip(self._target_model.parameters(), 84 | self._model.parameters()): 85 | param_.data.copy_((1 - self._target_learning_rate) * param_ 86 | + self._target_learning_rate * param) 87 | 88 | def _pre_steps(self, step: int) -> None: 89 | if self._target_update_method == "polyak": 90 | self._sync_target_model() 91 | elif self._target_update_method == "copy" \ 92 | and step % self._target_update_steps == 0: 93 | self._sync_target_model() 94 | 95 | def forward(self, batch: Dict[str, Any], clean_prompt: str, target_label: int, prompt_trigger_dic_train: Dict[str, float], 96 | prompt_trigger_dic_val: Dict[str, float]) \ 97 | -> Tuple[torch.Tensor, Dict[str, Any], Dict[Tuple, Tuple], Dict[Tuple, Tuple]]: 98 | loss_list = [] 99 | loss_log_list = [] 100 | for mode in self._forward_modes: 101 | _loss, _loss_log, prompt_trigger_dic_train, prompt_trigger_dic_val= self._forward( 102 | mode=mode, batch=batch, clean_prompt=clean_prompt, target_label=target_label, 103 | prompt_trigger_dic_train=prompt_trigger_dic_train, prompt_trigger_dic_val=prompt_trigger_dic_val 104 | ) 105 | loss_list.append(_loss) 106 | loss_log_list.append(_loss_log) 107 | 108 | # https://discuss.pytorch.org/t/get-the-mean-from-a-list-of-tensors/31989/2 109 | loss = torch.mean(torch.stack(loss_list)) 110 | loss_log = utils.unionize_dicts(loss_log_list) 111 | 112 | return loss, loss_log, prompt_trigger_dic_train, prompt_trigger_dic_val 113 | 114 | def _forward( 115 | self, 116 | mode: ForwardMode, 117 | batch: Dict[str, Any], 118 | clean_prompt: str, 119 | target_label: int, 120 | prompt_trigger_dic_train: Dict[str, float], 121 | prompt_trigger_dic_val: Dict[str, float], 122 | ) -> Tuple[torch.Tensor, Dict[str, Any], Dict[Tuple, Tuple], Dict[Tuple, Tuple]]: 123 | if mode != ForwardMode.SQL_ON and mode != ForwardMode.INFER: 124 | raise NotImplementedError('Only on-policy sampling and greedy ' 125 | 'inference is supported now') 126 | 127 | if mode == ForwardMode.SQL_ON: 128 | (logits, logits_, output_tokens, output_ids, sequence_lengths) = \ 129 | self._decode_sampling(batch=batch) 130 | 131 | raw_rewards, rewards_log, prompt_trigger_dic_train, prompt_trigger_dic_val = \ 132 | self.compute_rewards( 133 | batch=batch, output_tokens=output_tokens, clean_prompt=clean_prompt, target_label=target_label, mode="train", 134 | prompt_trigger_dic_train=prompt_trigger_dic_train, prompt_trigger_dic_val=prompt_trigger_dic_val 135 | ) 136 | shaped_rewards = self._reward_shaping_func(raw_rewards) 137 | 138 | sql_loss, sql_loss_log = sql_loss_with_sparse_rewards( 139 | implementation=self._sql_loss_impl, 140 | logits=logits, 141 | logits_=logits_, 142 | actions=output_ids, 143 | sampled_actions=None, 144 | rewards=shaped_rewards, 145 | sequence_length=sequence_lengths) 146 | 147 | utils.add_prefix_to_dict_keys_inplace( 148 | rewards_log, prefix=f"{mode.value}/rewards/") 149 | utils.add_prefix_to_dict_keys_inplace( 150 | sql_loss_log, prefix=f"{mode.value}/") 151 | sql_loss_log = utils.unionize_dicts([ 152 | rewards_log, 153 | sql_loss_log, 154 | { 155 | f"{mode.value}/rewards/raw": raw_rewards.mean(), 156 | f"{mode.value}/rewards/shaped": shaped_rewards.mean(), 157 | }, 158 | ]) 159 | 160 | return sql_loss, sql_loss_log, prompt_trigger_dic_train, prompt_trigger_dic_val 161 | 162 | def compute_rewards( 163 | self, 164 | batch: Dict[str, Any], 165 | output_tokens: List[List[str]], 166 | clean_prompt: str, 167 | target_label: int, 168 | to_tensor: bool = True, 169 | mode: str = "infer", 170 | prompt_trigger_dic_train: Dict[str, float] = None, 171 | prompt_trigger_dic_val: Dict[str, float] = None, 172 | ) -> Tuple[torch.Tensor, Dict[str, Any], Dict[Tuple, Tuple], Dict[Tuple, Tuple]]: 173 | rewards_tensor, rewards_log, prompt_trigger_dic_train, prompt_trigger_dic_val = self._reward( 174 | **batch, 175 | output_tokens=output_tokens, 176 | clean_prompt = clean_prompt, 177 | target_label = target_label, 178 | to_tensor=to_tensor, 179 | mode=mode, 180 | prompt_trigger_dic_train=prompt_trigger_dic_train, 181 | prompt_trigger_dic_val=prompt_trigger_dic_val) 182 | 183 | rewards_tensor = rewards_tensor.to(device) 184 | return rewards_tensor, rewards_log, prompt_trigger_dic_train, prompt_trigger_dic_val 185 | 186 | def infer( 187 | self, 188 | batch: Dict[str, Any] 189 | ) -> Dict[str, Union[torch.Tensor, torch.LongTensor, List[List[str]]]]: 190 | return self._model.generate(**batch, 191 | do_sample=False, 192 | top_k=self._top_k, 193 | top_p=self._top_p, 194 | num_beams=self._num_beams, 195 | infer=True) 196 | 197 | def _decode_sampling( 198 | self, 199 | batch: Dict[str, Any], 200 | ) -> Tuple[torch.Tensor, torch.Tensor, List[List[str]], 201 | torch.LongTensor, torch.LongTensor]: 202 | outputs = self._model.generate(**batch, 203 | do_sample=True, 204 | top_k=self._top_k, 205 | top_p=self._top_p, 206 | num_beams=self._num_beams) 207 | 208 | batch_ = {k: v for k, v in batch.items()} 209 | batch_.update(outputs) 210 | 211 | outputs_ = self._target_model.teacher_forcing(**batch_) 212 | 213 | return (outputs['sample_logits'].contiguous(), 214 | outputs_['sample_logits'].contiguous(), 215 | outputs['sample_tokens'], 216 | outputs['sample_ids'].contiguous(), 217 | outputs['sample_lengths'].contiguous()) 218 | 219 | 220 | def _get_forward_modes( 221 | training_mode: str, 222 | mix_strategy: Optional[str] 223 | ) -> List[ForwardMode]: 224 | if training_mode == "sql-mixed": 225 | candidate_modes = [ 226 | ForwardMode.SQL_OFF_GT, 227 | ForwardMode.SQL_ON] 228 | 229 | if mix_strategy == "alternate": 230 | modes = [candidate_modes[step % len(candidate_modes)]] 231 | elif mix_strategy == "mix": 232 | modes = candidate_modes 233 | 234 | else: 235 | training_mode_map = {"sql-onpolicy": ForwardMode.SQL_ON, 236 | "sql-offpolicy": ForwardMode.SQL_OFF_GT} 237 | 238 | modes = [training_mode_map[training_mode]] 239 | 240 | return modes 241 | -------------------------------------------------------------------------------- /ProgressiveTuning/rlprompt/modules/sql_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import copy 3 | from typing import Optional, List, Dict, Any, Union, Tuple 4 | 5 | from rlprompt.models import BaseModel 6 | from rlprompt.modules import BaseModule 7 | from rlprompt.rewards import BaseReward 8 | from rlprompt.modules.module_utils import ForwardMode, get_reward_shaping_func 9 | from rlprompt.losses import sql_loss_with_sparse_rewards 10 | from rlprompt.utils import utils 11 | 12 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 13 | 14 | 15 | class SQLModule(BaseModule): 16 | def __init__( 17 | self, 18 | model: BaseModel, 19 | target_model: Optional[BaseModel], 20 | reward: Optional[BaseReward], 21 | sql_loss_impl: str, 22 | training_mode: str, 23 | mix_strategy: Optional[str], 24 | target_update_method: str, 25 | target_update_steps: Optional[int], 26 | target_learning_rate: float, 27 | reward_shaping: bool, 28 | reward_shaping_old_min: float, 29 | reward_shaping_old_max: float, 30 | reward_shaping_new_min: float, 31 | reward_shaping_new_max: float, 32 | top_k: Optional[int], 33 | top_p: float, 34 | num_beams: int, 35 | ): 36 | super().__init__() 37 | # Initialize self._model and self._reward 38 | assert target_update_method in ["copy", "polyak"] 39 | assert not (top_k is not None and top_p < 1.0), \ 40 | "Only one of top_k or top_p should be selected" 41 | 42 | self._model = model 43 | if target_model is None: 44 | self._target_model = copy.deepcopy(self._model) 45 | else: 46 | self._target_model = target_model 47 | # for p1, p2 in zip(self._model.parameters(), self._target_model.parameters()): 48 | # if p1.data.ne(p2.data).sum() > 0: 49 | # print(False) 50 | # print(True) 51 | self._reward = reward 52 | 53 | self._sql_loss_impl = sql_loss_impl 54 | self._training_mode = training_mode 55 | self._mix_strategy = mix_strategy 56 | self._forward_modes = _get_forward_modes(training_mode, mix_strategy) 57 | self._target_update_method = target_update_method 58 | self._target_update_steps = target_update_steps 59 | self._target_learning_rate = target_learning_rate 60 | self._top_k = top_k 61 | self._top_p = top_p 62 | self._num_beams = num_beams 63 | 64 | if reward_shaping is True: 65 | self._reward_shaping_func = get_reward_shaping_func( 66 | old_min=reward_shaping_old_min, 67 | old_max=reward_shaping_old_max, 68 | new_min=reward_shaping_new_min, 69 | new_max=reward_shaping_new_max) 70 | else: 71 | self._reward_shaping_func = lambda _r: _r 72 | 73 | def _sync_target_model(self) -> None: 74 | # https://github.com/transedward/pytorch-dqn/blob/master/dqn_learn.py#L221 75 | if self._target_update_method == "copy": 76 | self._target_model.load_state_dict(self._model.state_dict()) 77 | 78 | # Target network update 79 | # Note that we are assuming `model.parameters()` 80 | # would yield the same parameter orders. 81 | # https://towardsdatascience.com/double-deep-q-networks-905dd8325412 82 | if self._target_update_method == "polyak": 83 | for param_, param in zip(self._target_model.parameters(), 84 | self._model.parameters()): 85 | param_.data.copy_((1 - self._target_learning_rate) * param_ 86 | + self._target_learning_rate * param) 87 | 88 | def _pre_steps(self, step: int) -> None: 89 | if self._target_update_method == "polyak": 90 | self._sync_target_model() 91 | elif self._target_update_method == "copy" \ 92 | and step % self._target_update_steps == 0: 93 | self._sync_target_model() 94 | 95 | def forward(self, batch: Dict[str, Any], clean_prompt: str, trigger: str, target: int, 96 | prompt_trigger_dic_train: Dict[str, float], prompt_trigger_dic_val: Dict[str, float]) \ 97 | -> Tuple[torch.Tensor, Dict[str, Any], Dict[Tuple, Tuple], Dict[Tuple, Tuple]]: 98 | loss_list = [] 99 | loss_log_list = [] 100 | for mode in self._forward_modes: 101 | _loss, _loss_log, prompt_trigger_dic_train, prompt_trigger_dic_val= self._forward( 102 | mode=mode, batch=batch, clean_prompt=clean_prompt, trigger=trigger, target=target, 103 | prompt_trigger_dic_train=prompt_trigger_dic_train, prompt_trigger_dic_val=prompt_trigger_dic_val 104 | ) 105 | loss_list.append(_loss) 106 | loss_log_list.append(_loss_log) 107 | 108 | # https://discuss.pytorch.org/t/get-the-mean-from-a-list-of-tensors/31989/2 109 | loss = torch.mean(torch.stack(loss_list)) 110 | loss_log = utils.unionize_dicts(loss_log_list) 111 | 112 | return loss, loss_log, prompt_trigger_dic_train, prompt_trigger_dic_val 113 | 114 | def _forward( 115 | self, 116 | mode: ForwardMode, 117 | batch: Dict[str, Any], 118 | clean_prompt: str, 119 | trigger: str, 120 | target: int, 121 | prompt_trigger_dic_train: Dict[str, float], 122 | prompt_trigger_dic_val: Dict[str, float], 123 | ) -> Tuple[torch.Tensor, Dict[str, Any], Dict[Tuple, Tuple], Dict[Tuple, Tuple]]: 124 | if mode != ForwardMode.SQL_ON and mode != ForwardMode.INFER: 125 | raise NotImplementedError('Only on-policy sampling and greedy ' 126 | 'inference is supported now') 127 | 128 | if mode == ForwardMode.SQL_ON: 129 | (logits, logits_, output_tokens, output_ids, sequence_lengths) = \ 130 | self._decode_sampling(batch=batch) 131 | 132 | raw_rewards, rewards_log, prompt_trigger_dic_train, prompt_trigger_dic_val = \ 133 | self.compute_rewards( 134 | batch=batch, output_tokens=output_tokens, clean_prompt=clean_prompt, trigger=trigger, mode="train", 135 | target=target, prompt_trigger_dic_train=prompt_trigger_dic_train, 136 | prompt_trigger_dic_val=prompt_trigger_dic_val 137 | ) 138 | shaped_rewards = self._reward_shaping_func(raw_rewards) 139 | 140 | sql_loss, sql_loss_log = sql_loss_with_sparse_rewards( 141 | implementation=self._sql_loss_impl, 142 | logits=logits, 143 | logits_=logits_, 144 | actions=output_ids, 145 | sampled_actions=None, 146 | rewards=shaped_rewards, 147 | sequence_length=sequence_lengths) 148 | 149 | utils.add_prefix_to_dict_keys_inplace( 150 | rewards_log, prefix=f"{mode.value}/rewards/") 151 | utils.add_prefix_to_dict_keys_inplace( 152 | sql_loss_log, prefix=f"{mode.value}/") 153 | sql_loss_log = utils.unionize_dicts([ 154 | rewards_log, 155 | sql_loss_log, 156 | { 157 | f"{mode.value}/rewards/raw": raw_rewards.mean(), 158 | f"{mode.value}/rewards/shaped": shaped_rewards.mean(), 159 | }, 160 | ]) 161 | 162 | return sql_loss, sql_loss_log, prompt_trigger_dic_train, prompt_trigger_dic_val 163 | 164 | def compute_rewards( 165 | self, 166 | batch: Dict[str, Any], 167 | output_tokens: List[List[str]], 168 | clean_prompt: str, 169 | trigger: str, 170 | target:int, 171 | to_tensor: bool = True, 172 | mode: str = "infer", 173 | prompt_trigger_dic_train: Dict[str, float] = None, 174 | prompt_trigger_dic_val: Dict[str, float] = None, 175 | ) -> Tuple[torch.Tensor, Dict[str, Any], Dict[Tuple, Tuple], Dict[Tuple, Tuple]]: 176 | rewards_tensor, rewards_log, prompt_trigger_dic_train, prompt_trigger_dic_val = self._reward( 177 | **batch, 178 | output_tokens=output_tokens, 179 | clean_prompt = clean_prompt, 180 | trigger=trigger, 181 | to_tensor=to_tensor, 182 | mode=mode, 183 | prompt_trigger_dic_train=prompt_trigger_dic_train, 184 | prompt_trigger_dic_val=prompt_trigger_dic_val, 185 | target=target 186 | ) 187 | 188 | rewards_tensor = rewards_tensor.to(device) 189 | return rewards_tensor, rewards_log, prompt_trigger_dic_train, prompt_trigger_dic_val 190 | 191 | def infer( 192 | self, 193 | batch: Dict[str, Any] 194 | ) -> Dict[str, Union[torch.Tensor, torch.LongTensor, List[List[str]]]]: 195 | return self._model.generate(**batch, 196 | do_sample=False, 197 | top_k=self._top_k, 198 | top_p=self._top_p, 199 | num_beams=self._num_beams, 200 | infer=True) 201 | 202 | def _decode_sampling( 203 | self, 204 | batch: Dict[str, Any], 205 | ) -> Tuple[torch.Tensor, torch.Tensor, List[List[str]], 206 | torch.LongTensor, torch.LongTensor]: 207 | outputs = self._model.generate(**batch, 208 | do_sample=True, 209 | top_k=self._top_k, 210 | top_p=self._top_p, 211 | num_beams=self._num_beams) 212 | 213 | batch_ = {k: v for k, v in batch.items()} 214 | batch_.update(outputs) 215 | 216 | outputs_ = self._target_model.teacher_forcing(**batch_) 217 | 218 | return (outputs['sample_logits'].contiguous(), 219 | outputs_['sample_logits'].contiguous(), 220 | outputs['sample_tokens'], 221 | outputs['sample_ids'].contiguous(), 222 | outputs['sample_lengths'].contiguous()) 223 | 224 | 225 | def _get_forward_modes( 226 | training_mode: str, 227 | mix_strategy: Optional[str] 228 | ) -> List[ForwardMode]: 229 | if training_mode == "sql-mixed": 230 | candidate_modes = [ 231 | ForwardMode.SQL_OFF_GT, 232 | ForwardMode.SQL_ON] 233 | 234 | if mix_strategy == "alternate": 235 | modes = [candidate_modes[step % len(candidate_modes)]] 236 | elif mix_strategy == "mix": 237 | modes = candidate_modes 238 | 239 | else: 240 | training_mode_map = {"sql-onpolicy": ForwardMode.SQL_ON, 241 | "sql-offpolicy": ForwardMode.SQL_OFF_GT} 242 | 243 | modes = [training_mode_map[training_mode]] 244 | 245 | return modes 246 | --------------------------------------------------------------------------------