├── figs └── thumbnail.png ├── subpop ├── train │ ├── data │ │ ├── __init__.py │ │ ├── llama_guard │ │ │ ├── __init__.py │ │ │ ├── finetuning_data_formatter_example.py │ │ │ └── README.md │ │ ├── concatenator.py │ │ └── sampler.py │ ├── utils │ │ ├── __init__.py │ │ ├── dataset_utils.py │ │ ├── plot_metrics.py │ │ ├── flop_utils.py │ │ ├── fsdp_utils.py │ │ ├── memory_utils.py │ │ └── config_utils.py │ ├── policies │ │ ├── __init__.py │ │ ├── activation_checkpointing_functions.py │ │ ├── mixed_precision.py │ │ ├── wrapping.py │ │ └── anyprecision_optimizer.py │ ├── configs │ │ ├── __init__.py │ │ ├── wandb.py │ │ ├── peft.py │ │ ├── quantization.py │ │ ├── fsdp.py │ │ ├── datasets.py │ │ └── training.py │ ├── model_checkpointing │ │ ├── __init__.py │ │ └── checkpoint_handler.py │ ├── tools │ │ ├── README.md │ │ ├── compare_llama_weights.py │ │ └── convert_hf_weights_to_llama.py │ └── datasets │ │ ├── __init__.py │ │ ├── samsum_dataset.py │ │ ├── alpaca_dataset.py │ │ ├── custom_dataset.py │ │ ├── toxicchat_dataset.py │ │ └── opinionqa_dataset.py ├── survey │ └── config.py └── utils │ ├── random_utils.py │ ├── logger.py │ ├── backoff.py │ ├── data_utils.py │ └── survey_utils.py ├── pyproject.toml ├── scripts ├── experiment │ ├── run_finetune.py │ ├── analyze_inference_result.ipynb │ └── run_inference.py └── data_generation │ ├── refine_question.py │ ├── prepare_finetuning_data.py │ └── generate_distribution.py ├── data ├── subpopulation_metadata │ ├── demographics_22.csv │ ├── demographics_61.csv │ └── steering_prompts.json ├── subpop-eval │ └── subpop-eval_parsed-qkeys.json └── opinionqa │ └── opinionqa_parsed-qkeys.json ├── requirements.txt ├── LICENSE ├── .gitignore └── README.md /figs/thumbnail.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JosephJeesungSuh/subpop/HEAD/figs/thumbnail.png -------------------------------------------------------------------------------- /subpop/train/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. -------------------------------------------------------------------------------- /subpop/train/data/llama_guard/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama Guard License Agreement. -------------------------------------------------------------------------------- /subpop/survey/config.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | class SteeringPromptType(Enum): 4 | BIO = ["bio_prompt"] 5 | QA = ["qa_prompt"] 6 | PORTRAY = ["portray_prompt"] 7 | ALL = ["bio_prompt", "qa_prompt", "portray_prompt"] -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "subpop" 7 | version = "0.1.0" 8 | dependencies = [] 9 | 10 | [tool.setuptools] 11 | packages = ["subpop"] -------------------------------------------------------------------------------- /subpop/utils/random_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def set_random_seed(seed: int) -> None: 7 | random.seed(seed) 8 | np.random.seed(seed) 9 | torch.manual_seed(seed) 10 | if torch.cuda.is_available(): 11 | torch.cuda.manual_seed_all(seed) 12 | -------------------------------------------------------------------------------- /scripts/experiment/run_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | import fire 5 | from subpop.train.finetuning import main 6 | 7 | if __name__ == "__main__": 8 | fire.Fire(main) -------------------------------------------------------------------------------- /subpop/train/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | from subpop.train.utils.memory_utils import MemoryTrace 5 | from subpop.train.utils.dataset_utils import * 6 | from subpop.train.utils.fsdp_utils import fsdp_auto_wrap_policy, hsdp_device_mesh 7 | from subpop.train.utils.train_utils import * -------------------------------------------------------------------------------- /data/subpopulation_metadata/demographics_22.csv: -------------------------------------------------------------------------------- 1 | attribute,group 2 | CREGION,"['Northeast', 'South']" 3 | EDUCATION,"['College graduate/some postgrad', 'Less than high school']" 4 | SEX,"['Male', 'Female']" 5 | POLIDEOLOGY,"['Liberal', 'Conservative', 'Moderate']" 6 | INCOME,"['$100,000 or more', 'Less than $30,000']" 7 | POLPARTY,"['Democrat', 'Republican']" 8 | RACE,"['Black', 'White', 'Asian', 'Hispanic']" 9 | RELIG,"['Protestant', 'Jewish', 'Hindu', 'Atheist', 'Muslim']" -------------------------------------------------------------------------------- /subpop/train/policies/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | from subpop.train.policies.mixed_precision import * 5 | from subpop.train.policies.wrapping import * 6 | from subpop.train.policies.activation_checkpointing_functions import apply_fsdp_checkpointing 7 | from subpop.train.policies.anyprecision_optimizer import AnyPrecisionAdamW 8 | -------------------------------------------------------------------------------- /subpop/train/configs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | from subpop.train.configs.peft import lora_config, llama_adapter_config, prefix_config 5 | from subpop.train.configs.fsdp import fsdp_config 6 | from subpop.train.configs.training import train_config 7 | from subpop.train.configs.wandb import wandb_config 8 | from subpop.train.configs.quantization import quantization_config 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.5.1 2 | accelerate 3 | appdirs 4 | loralib 5 | bitsandbytes 6 | black 7 | black[jupyter] 8 | datasets 9 | fire 10 | peft 11 | transformers==4.48.2 12 | sentencepiece 13 | py7zr 14 | scipy==1.13.0 15 | optimum 16 | matplotlib==3.8.4 17 | chardet 18 | openai==1.59.8 19 | typing-extensions>=4.8.0 20 | tabulate 21 | evaluate 22 | rouge_score 23 | pyyaml==6.0.1 24 | faiss-gpu; python_version < '3.11' 25 | unstructured[pdf] 26 | sentence_transformers 27 | codeshield 28 | 29 | pandas==2.2.3 30 | tiktoken==0.8.0 31 | scikit-learn==1.4.2 32 | pyreadstat==1.1.2 33 | vllm==0.7.2 34 | wandb==0.19.6 35 | ipykernel -------------------------------------------------------------------------------- /subpop/train/model_checkpointing/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | from subpop.train.model_checkpointing.checkpoint_handler import ( 5 | load_model_checkpoint, 6 | save_fsdp_model_checkpoint_full, 7 | save_peft_checkpoint, 8 | save_model_checkpoint, 9 | load_optimizer_checkpoint, 10 | save_optimizer_checkpoint, 11 | save_model_and_optimizer_sharded, 12 | load_model_sharded, 13 | load_sharded_model_single_gpu, 14 | save_peft_checkpoint_checkpointing 15 | ) 16 | -------------------------------------------------------------------------------- /subpop/train/configs/wandb.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | from typing import List, Optional 5 | from dataclasses import dataclass, field 6 | 7 | @dataclass 8 | class wandb_config: 9 | project: str = 'YOUR/PROJECT/NAME' # wandb project name 10 | entity: Optional[str] = 'YOUR/ENTITY/NAME' # wandb entity name 11 | job_type: Optional[str] = None 12 | tags: Optional[List[str]] = None 13 | group: Optional[str] = None 14 | notes: Optional[str] = None 15 | mode: Optional[str] = None 16 | name: Optional[str] = None -------------------------------------------------------------------------------- /subpop/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pathlib 3 | from datetime import datetime 4 | 5 | 6 | def get_logger_top(name: str, debug: bool): 7 | REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent 8 | 9 | initial_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 10 | # set up save path 11 | save_path = ( 12 | REPO_ROOT / "results" / "logs" / f"experiment_{name.lower()}_{initial_time}.log" 13 | ) 14 | save_path.parent.mkdir(parents=True, exist_ok=True) 15 | logging.basicConfig( 16 | level=logging.DEBUG if debug else logging.INFO, 17 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 18 | handlers=[ 19 | logging.FileHandler(save_path), 20 | logging.StreamHandler(), 21 | ], 22 | ) 23 | 24 | return logging.getLogger(name) 25 | -------------------------------------------------------------------------------- /subpop/train/configs/peft.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | from dataclasses import dataclass, field 5 | from typing import List 6 | 7 | @dataclass 8 | class lora_config: 9 | r: int=8 10 | lora_alpha: int=32 11 | target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj"]) 12 | bias= "none" 13 | task_type: str= "CAUSAL_LM" 14 | lora_dropout: float=0.05 15 | inference_mode: bool = False 16 | 17 | @dataclass 18 | class llama_adapter_config: 19 | adapter_len: int= 10 20 | adapter_layers: int= 30 21 | task_type: str= "CAUSAL_LM" 22 | 23 | #CAUTION prefix tuning is currently not supported 24 | @dataclass 25 | class prefix_config: 26 | num_virtual_tokens: int=30 27 | task_type: str= "CAUSAL_LM" 28 | -------------------------------------------------------------------------------- /subpop/train/policies/activation_checkpointing_functions.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | from functools import partial 5 | 6 | from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( 7 | checkpoint_wrapper, 8 | CheckpointImpl, 9 | apply_activation_checkpointing, 10 | ) 11 | from transformers.models.llama.modeling_llama import LlamaDecoderLayer 12 | 13 | non_reentrant_wrapper = partial( 14 | checkpoint_wrapper, 15 | checkpoint_impl=CheckpointImpl.NO_REENTRANT, 16 | ) 17 | 18 | check_fn = lambda submodule: isinstance(submodule, LlamaDecoderLayer) 19 | 20 | 21 | def apply_fsdp_checkpointing(model): 22 | """apply activation checkpointing to model 23 | returns None as model is updated directly 24 | """ 25 | print(f"--> applying fsdp activation checkpointing...") 26 | 27 | apply_activation_checkpointing( 28 | model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn 29 | ) 30 | -------------------------------------------------------------------------------- /subpop/train/policies/mixed_precision.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | import torch 5 | 6 | from torch.distributed.fsdp import ( 7 | MixedPrecision, 8 | ) 9 | 10 | # requires grad scaler in main loop 11 | fpSixteen = MixedPrecision( 12 | param_dtype=torch.float16, 13 | # Gradient communication precision. 14 | reduce_dtype=torch.float16, 15 | # Buffer precision. 16 | buffer_dtype=torch.float16, 17 | ) 18 | 19 | bfSixteen = MixedPrecision( 20 | param_dtype=torch.bfloat16, 21 | # Gradient communication precision. 22 | reduce_dtype=torch.bfloat16, 23 | # Buffer precision. 24 | buffer_dtype=torch.bfloat16, 25 | cast_forward_inputs=True, 26 | ) 27 | 28 | bfSixteen_mixed = MixedPrecision( 29 | param_dtype=torch.float32, 30 | reduce_dtype=torch.bfloat16, 31 | buffer_dtype=torch.bfloat16, 32 | ) 33 | 34 | fp32_policy = MixedPrecision( 35 | param_dtype=torch.float32, 36 | reduce_dtype=torch.float32, 37 | buffer_dtype=torch.float32, 38 | ) 39 | -------------------------------------------------------------------------------- /data/subpopulation_metadata/demographics_61.csv: -------------------------------------------------------------------------------- 1 | attribute,group 2 | CREGION,"['Northeast', 'Midwest', 'South', 'West']" 3 | SEX,"['Male', 'Female']" 4 | AGE,"['18-29', '30-49', '50-64', '65+']" 5 | EDUCATION,"['Less than high school', 'High school graduate', 'Some college, no degree', ""Associate's degree"", 'College graduate/some postgrad', 'Postgraduate']" 6 | RACE,"['White', 'Black', 'Asian', 'Hispanic', 'Other']" 7 | CITIZEN,"['a US Citizen', 'a Non-US Citizen']" 8 | MARITAL,"['Married', 'Living with a partner', 'Divorced', 'Separated', 'Widowed', 'Unmarried and have never been married']" 9 | RELIG,"['Protestant', 'Roman Catholic', 'Mormon', 'Orthodox', 'Jewish', 'Muslim', 'Buddhist', 'Hindu', 'Atheist', 'Agnostic', 'Other', 'Nothing in particular']" 10 | RELIGATTEND,"['More than once a week', 'Once a week', 'Once or twice a month', 'A few times a year', 'Seldom', 'Never']" 11 | POLPARTY,"['Republican', 'Democrat', 'Independent', 'Something else']" 12 | INCOME,"['Less than $30,000', '$30,000-$50,000', '$50,000-$75,000', '$75,000-$100,000', '$100,000 or more']" 13 | POLIDEOLOGY,"['Very conservative', 'Conservative', 'Moderate', 'Liberal', 'Very liberal']" 14 | -------------------------------------------------------------------------------- /subpop/train/tools/README.md: -------------------------------------------------------------------------------- 1 | # Convert Hugging Face llama weights to official llama consolidated format 2 | 3 | This is the reverse conversion for `convert_llama_weights_to_hf.py` script from the transformer package. 4 | 5 | ## Step 0: Convert to consolidated format 6 | - Create an output directory for the converted weights, such as `test70B`. 7 | - Copy file params.json from the official llama download into that directory. 8 | - Run the conversion script. `model-path` can be a Hugging Face hub model or a local hf model directory. 9 | ``` 10 | python -m llama_recipes.tools.convert_hf_weights_to_llama --model-path meta-llama/Meta-Llama-3.1-70B-Instruct --output-dir test70B --model-size 70B 11 | ``` 12 | 13 | ## Step 1: Run inference 14 | Checkout the official llama 3 inference [repo](https://github.com/meta-llama/llama3). Test using chat or text completion. 15 | ``` 16 | torchrun --nproc_per_node 8 example_chat_completion.py --ckpt_dir ./test70B --tokenizer_path ${llama_3_dir}/tokenizer.model 17 | ``` 18 | 19 | For validation, please compare the converted weights with official llama 2 weights 20 | ``` 21 | python compare_llama_weights.py test70B ${Llama-3-70B-Instruct_dir} 22 | ``` 23 | -------------------------------------------------------------------------------- /subpop/train/data/concatenator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | from tqdm import tqdm 5 | from itertools import chain 6 | 7 | from torch.utils.data import Dataset 8 | 9 | 10 | class ConcatDataset(Dataset): 11 | def __init__(self, dataset, chunk_size=4096): 12 | self.dataset = dataset 13 | self.chunk_size = chunk_size 14 | 15 | self.samples = [] 16 | 17 | buffer = { 18 | "input_ids": [], 19 | "attention_mask": [], 20 | "labels": [], 21 | } 22 | 23 | for sample in tqdm(self.dataset, desc="Preprocessing dataset", dynamic_ncols=True): 24 | buffer = {k: v + sample[k] for k,v in buffer.items()} 25 | 26 | while len(next(iter(buffer.values()))) > self.chunk_size: 27 | self.samples.append({k: v[:self.chunk_size] for k,v in buffer.items()}) 28 | buffer = {k: v[self.chunk_size:] for k,v in buffer.items()} 29 | 30 | def __getitem__(self, idx): 31 | return self.samples[idx] 32 | 33 | def __len__(self): 34 | return len(self.samples) 35 | -------------------------------------------------------------------------------- /subpop/train/configs/quantization.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | from dataclasses import dataclass 5 | from typing import Optional 6 | import torch 7 | from transformers import BitsAndBytesConfig 8 | 9 | @dataclass 10 | class quantization_config: 11 | quant_type: str = "fp4" # "fp4" or "nf4" 12 | compute_dtype: torch.dtype = torch.bfloat16 13 | use_double_quant: bool = False 14 | quant_storage: torch.dtype = torch.bfloat16 15 | 16 | def create_bnb_config(self, quantization: str) -> BitsAndBytesConfig: 17 | if quantization not in {"4bit", "8bit"}: 18 | raise ValueError("quantization must be either '4bit' or '8bit'") 19 | 20 | if quantization == "4bit": 21 | config_params = { 22 | "bnb_4bit_quant_type": self.quant_type, 23 | "bnb_4bit_compute_dtype": self.compute_dtype, 24 | "bnb_4bit_use_double_quant": self.use_double_quant, 25 | "bnb_4bit_quant_storage": self.quant_storage, 26 | } 27 | 28 | return BitsAndBytesConfig(load_in_4bit=True, **config_params) 29 | else: 30 | return BitsAndBytesConfig(load_in_8bit=True) 31 | -------------------------------------------------------------------------------- /subpop/train/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | from functools import partial 5 | 6 | # from subpop.train.datasets.grammar_dataset.grammar_dataset import get_dataset as get_grammar_dataset 7 | # from subpop.train.datasets.alpaca_dataset import InstructionDataset as get_alpaca_dataset 8 | # from subpop.train.datasets.samsum_dataset import get_preprocessed_samsum as get_samsum_dataset 9 | # from subpop.train.datasets.toxicchat_dataset import get_llamaguard_toxicchat_dataset as get_llamaguard_toxicchat_dataset 10 | from subpop.train.datasets.custom_dataset import get_custom_dataset,get_data_collator, custom_collator_no_labels 11 | 12 | DATASET_PREPROC = { 13 | # "alpaca_dataset": partial(get_alpaca_dataset), 14 | # "grammar_dataset": get_grammar_dataset, 15 | # "samsum_dataset": get_samsum_dataset, 16 | # "llamaguard_toxicchat_dataset": get_llamaguard_toxicchat_dataset, 17 | "custom_dataset": get_custom_dataset, 18 | "opnqa_steering_dataset": get_custom_dataset, 19 | "opnqa_single_demographic_dataset": get_custom_dataset, 20 | } 21 | DATALOADER_COLLATE_FUNC = { 22 | "custom_dataset": get_data_collator, 23 | "opnqa_steering_dataset": custom_collator_no_labels, 24 | "opnqa_single_demographic_dataset": custom_collator_no_labels, 25 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2025, Canny Lab @ The University of California, Berkeley 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /subpop/train/configs/fsdp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | from dataclasses import dataclass 5 | 6 | from torch.distributed.fsdp import ShardingStrategy 7 | from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType 8 | 9 | @dataclass 10 | class fsdp_config: 11 | mixed_precision: bool=True 12 | use_fp16: bool=False 13 | sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD # HYBRID_SHARD "Full Shard within a node DDP cross Nodes", SHARD_GRAD_OP "Shard only Gradients and Optimizer States", NO_SHARD "Similar to DDP". 14 | hsdp : bool =False # Require HYBRID_SHARD to be set. This flag can extend the HYBRID_SHARD by allowing sharding a model on customized number of GPUs (Sharding_group) and Replicas over Sharding_group. 15 | sharding_group_size : int=0 # requires hsdp to be set. This specifies the sharding group size, number of GPUs that you model can fit into to form a replica of a model. 16 | replica_group_size: int=0 #requires hsdp to be set. This specifies the replica group size, which is world_size/sharding_group_size. 17 | checkpoint_type: StateDictType = StateDictType.SHARDED_STATE_DICT # alternatively FULL_STATE_DICT can be used. SHARDED_STATE_DICT saves one file with sharded weights per rank while FULL_STATE_DICT will collect all weights on rank 0 and save them in a single file. 18 | fsdp_activation_checkpointing: bool=True 19 | fsdp_cpu_offload: bool=False 20 | pure_bf16: bool = False 21 | optimizer: str= "AdamW" 22 | 23 | -------------------------------------------------------------------------------- /subpop/train/policies/wrapping.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | import functools 5 | 6 | from transformers.models.llama.modeling_llama import LlamaDecoderLayer 7 | from transformers.models.mistral.modeling_mistral import MistralDecoderLayer 8 | from transformers.models.mllama.modeling_mllama import MllamaSelfAttentionDecoderLayer,MllamaCrossAttentionDecoderLayer,MllamaVisionEncoderLayer 9 | 10 | from torch.distributed.fsdp.wrap import ( 11 | transformer_auto_wrap_policy, 12 | size_based_auto_wrap_policy, 13 | ) 14 | 15 | 16 | def get_size_policy(min_params=1e8): 17 | num_wrap_policy = functools.partial( 18 | size_based_auto_wrap_policy, min_num_params=min_params 19 | ) 20 | return num_wrap_policy 21 | 22 | 23 | def get_llama_wrapper(): 24 | """we register our main layer class and use the fsdp transformer wrapping policy 25 | ensures embedding layers are in the root fsdp unit for shared access and that fsdp units map to transformer layers 26 | """ 27 | # ==== use new transformer wrapper 28 | 29 | llama_auto_wrap_policy = functools.partial( 30 | transformer_auto_wrap_policy, 31 | transformer_layer_cls=set([ 32 | LlamaDecoderLayer, 33 | MllamaSelfAttentionDecoderLayer, 34 | MllamaVisionEncoderLayer, 35 | MllamaCrossAttentionDecoderLayer, 36 | MistralDecoderLayer # Original llama-recipes does not support mistral, so included custom. 37 | ]) 38 | ) 39 | 40 | return llama_auto_wrap_policy 41 | -------------------------------------------------------------------------------- /subpop/train/datasets/samsum_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | # For dataset details visit: https://huggingface.co/datasets/samsum 5 | 6 | import copy 7 | import datasets 8 | 9 | 10 | def get_preprocessed_samsum(dataset_config, tokenizer, split): 11 | if not hasattr(dataset_config, "trust_remote_code") or not dataset_config.trust_remote_code: 12 | raise ValueError("The repository for samsum contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/samsum. To activate `trust_remote_code` option use this config: --samsum_dataset.trust_remote_code=True") 13 | dataset = datasets.load_dataset("samsum", split=split, trust_remote_code=dataset_config.trust_remote_code) 14 | 15 | prompt = ( 16 | f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n" 17 | ) 18 | 19 | def apply_prompt_template(sample): 20 | return { 21 | "prompt": prompt.format(dialog=sample["dialogue"]), 22 | "summary": sample["summary"], 23 | } 24 | 25 | dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features)) 26 | 27 | def tokenize_add_label(sample): 28 | prompt = tokenizer.encode(tokenizer.bos_token + sample["prompt"], add_special_tokens=False) 29 | summary = tokenizer.encode(sample["summary"] + tokenizer.eos_token, add_special_tokens=False) 30 | 31 | sample = { 32 | "input_ids": prompt + summary, 33 | "attention_mask" : [1] * (len(prompt) + len(summary)), 34 | "labels": [-100] * len(prompt) + summary, 35 | } 36 | 37 | return sample 38 | 39 | dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features)) 40 | 41 | return dataset 42 | -------------------------------------------------------------------------------- /subpop/train/tools/compare_llama_weights.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | import gc 5 | import glob 6 | import os 7 | import sys 8 | 9 | import torch 10 | import tqdm 11 | 12 | 13 | def main() -> None: 14 | """Compare two llama checkpoint directories""" 15 | 16 | one_files = sorted(glob.glob(os.path.join(sys.argv[1], "consolidated.*.pth"))) 17 | two_files = sorted(glob.glob(os.path.join(sys.argv[2], "consolidated.*.pth"))) 18 | assert len(one_files) == len( 19 | two_files 20 | ), "One directory has {} files while another has {} files.".format( 21 | len(one_files), len(two_files) 22 | ) 23 | 24 | deltas = [] 25 | for i in tqdm.trange(len(one_files), desc="Comparing shards"): 26 | one = torch.load(one_files[i]) 27 | two = torch.load(two_files[i]) 28 | assert len(one) == len( 29 | two 30 | ), "shard should have the same length: {} != {}".format(len(one), len(two)) 31 | one = sorted(one.items(), key=lambda x: x[0]) 32 | two = sorted(two.items(), key=lambda x: x[0]) 33 | 34 | for _, (v, w) in enumerate(zip(one, two)): 35 | assert v[0] == w[0], "{} != {}".format(v[0], w[0]) 36 | assert v[1].shape == w[1].shape, "tensor {} shape {} != {}".format( 37 | v[0], v[1].shape, w[1].shape 38 | ) 39 | 40 | delta = (v[1] - w[1]).abs().max().item() 41 | deltas.append((i, v[0], delta, w[1].abs().mean().item())) 42 | del one 43 | del two 44 | gc.collect() 45 | 46 | deltas = sorted(deltas, key=lambda x: x[-2], reverse=True) 47 | print("Top 10 largest deltas:") 48 | for i, k, delta, value in deltas[:10]: 49 | print(f" shard {i} {k}: {delta} vs {value}") 50 | 51 | 52 | if __name__ == "__main__": 53 | main() 54 | -------------------------------------------------------------------------------- /subpop/utils/backoff.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import logging 4 | from time import time, sleep 5 | from typing import Any, Callable, Collection, Generic, List, Optional, Type, TypeVar 6 | 7 | Q = TypeVar("Q", bound=Callable[..., Any]) 8 | logger = logging.getLogger(__name__) 9 | 10 | def retry_with_exponential_backoff( 11 | initial_delay: float = 1, 12 | exponential_base: float = 2, 13 | jitter: bool = True, 14 | max_retries: int = 10, 15 | no_retry_on: Optional[Collection[Type[Exception]]] = None, 16 | ) -> Callable[[Q], Q]: 17 | """Retry a function with exponential backoff.""" 18 | 19 | def decorator(func: Q) -> Q: 20 | def wrapper(*args, **kwargs): 21 | # Initialize variables 22 | num_retries = 0 23 | delay = initial_delay 24 | error = None 25 | 26 | # Loop until a successful response or max_retries is hit or an exception is raised 27 | while num_retries <= max_retries: 28 | try: 29 | return func(*args, **kwargs) 30 | # Raise exceptions for any errors specified 31 | except Exception as e: 32 | if no_retry_on is not None and type(e) in no_retry_on: 33 | raise e 34 | # Sleep for the delay 35 | sleep(delay) 36 | # Increment the delay 37 | delay *= exponential_base * (1 + jitter * random.random()) 38 | # Set the error to the last exception 39 | error = e 40 | # Increment retries 41 | num_retries += 1 42 | logger.warning( 43 | f"Retrying {func.__name__} after error: {e} (retry {num_retries} of {max_retries})" 44 | ) 45 | if error is not None: 46 | raise error 47 | 48 | return wrapper 49 | 50 | return decorator -------------------------------------------------------------------------------- /subpop/train/utils/dataset_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | import torch 5 | 6 | from subpop.train.data.concatenator import ConcatDataset 7 | from subpop.train.datasets import DATASET_PREPROC, DATALOADER_COLLATE_FUNC 8 | from subpop.train.utils.config_utils import get_dataloader_kwargs 9 | 10 | 11 | def get_preprocessed_dataset( 12 | tokenizer, dataset_config, split: str = "train", chat_template: bool = False 13 | ) -> torch.utils.data.Dataset: 14 | if not dataset_config.dataset in DATASET_PREPROC: 15 | raise NotImplementedError(f"{dataset_config.dataset} is not (yet) implemented") 16 | 17 | def get_split(): 18 | if split == "train": 19 | return dataset_config.train_split 20 | elif split == "valid": 21 | return dataset_config.valid_split 22 | elif split == "test": 23 | return dataset_config.test_split 24 | else: 25 | raise ValueError(f"Unknown split: {split}") 26 | 27 | return DATASET_PREPROC[dataset_config.dataset]( 28 | dataset_config, 29 | tokenizer, 30 | get_split(), 31 | chat_template=chat_template 32 | ) 33 | 34 | def get_custom_data_collator( 35 | dataset_processer, dataset_config 36 | ) -> torch.utils.data.Dataset: 37 | if not dataset_config.dataset in DATALOADER_COLLATE_FUNC: 38 | return None 39 | 40 | return DATALOADER_COLLATE_FUNC[dataset_config.dataset]( 41 | dataset_processer, 42 | dataset_config 43 | ) 44 | 45 | def get_dataloader(tokenizer, dataset_config, train_config, split: str = "train"): 46 | dataset = get_preprocessed_dataset(tokenizer, dataset_config, split) 47 | dl_kwargs = get_dataloader_kwargs(train_config, dataset, tokenizer, split) 48 | 49 | if split == "train" and train_config.batching_strategy == "packing": 50 | dataset = ConcatDataset(dataset, chunk_size=train_config.context_length) 51 | 52 | # Create data loader 53 | dataloader = torch.utils.data.DataLoader( 54 | dataset, 55 | num_workers=train_config.num_workers_dataloader, 56 | pin_memory=True, 57 | **dl_kwargs, 58 | ) 59 | return dataloader 60 | -------------------------------------------------------------------------------- /subpop/train/data/sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | import random 5 | from itertools import islice 6 | 7 | import numpy as np 8 | import torch 9 | 10 | 11 | class LengthBasedBatchSampler(torch.utils.data.BatchSampler): 12 | def __init__(self, data_source, batch_size: int, drop_last: bool, shuffle: bool=True) -> None: 13 | if isinstance(next(iter(data_source)), dict): 14 | first_key = next(iter(next(iter(data_source)).keys())) 15 | self.lengths = [len(d[first_key]) for d in data_source] 16 | else: 17 | self.lengths = [len(d) for d in data_source] 18 | self.batch_size = batch_size 19 | self.drop_last = drop_last 20 | self.shuffle = shuffle 21 | 22 | def __iter__(self): 23 | ids = np.argsort(self.lengths, kind='mergesort') 24 | if self.drop_last: 25 | ids = ids[:len(ids) // self.batch_size * self.batch_size] 26 | 27 | batches = [ids[i:i+self.batch_size] for i in range(0, len(ids), self.batch_size)] 28 | 29 | if self.shuffle: 30 | random.shuffle(batches) 31 | 32 | for b in batches: 33 | yield b 34 | 35 | def __len__(self): 36 | if self.drop_last: 37 | return len(self.lengths) // self.batch_size 38 | else: 39 | return len(self.lengths) // self.batch_size + (len(self.lengths) % self.batch_size > 0) 40 | 41 | 42 | class DistributedLengthBasedBatchSampler(torch.utils.data.BatchSampler): 43 | def __init__(self, data_source, batch_size: int, num_replicas: int, rank: int, shuffle: bool = True, seed: int = 0) -> None: 44 | random.seed(seed) 45 | self.batch_sampler = LengthBasedBatchSampler( 46 | data_source, batch_size=batch_size, drop_last=False, shuffle=shuffle 47 | ) 48 | self.num_replicas = num_replicas 49 | self.rank = rank 50 | 51 | def __iter__(self): 52 | max_length = len(self.batch_sampler) // self.num_replicas * self.num_replicas 53 | return islice(self.batch_sampler, self.rank, max_length, self.num_replicas) 54 | 55 | def __len__(self): 56 | return len(self.batch_sampler) // self.num_replicas 57 | -------------------------------------------------------------------------------- /subpop/train/configs/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | from dataclasses import dataclass 5 | 6 | 7 | @dataclass 8 | class samsum_dataset: 9 | dataset: str = "samsum_dataset" 10 | train_split: str = "train" 11 | test_split: str = "validation" 12 | trust_remote_code: bool = False 13 | 14 | 15 | @dataclass 16 | class grammar_dataset: 17 | dataset: str = "grammar_dataset" 18 | train_split: str = "src/llama_recipes/datasets/grammar_dataset/gtrain_10k.csv" 19 | test_split: str = "src/llama_recipes/datasets/grammar_dataset/grammar_validation.csv" 20 | 21 | 22 | @dataclass 23 | class alpaca_dataset: 24 | dataset: str = "alpaca_dataset" 25 | train_split: str = "train" 26 | test_split: str = "val" 27 | data_path: str = "src/llama_recipes/datasets/alpaca_data.json" 28 | 29 | @dataclass 30 | class custom_dataset: 31 | dataset: str = "custom_dataset" 32 | file: str = "recipes/quickstart/finetuning/datasets/custom_dataset.py" 33 | train_split: str = "train" 34 | test_split: str = "validation" 35 | data_path: str = "" 36 | 37 | @dataclass 38 | class llamaguard_toxicchat_dataset: 39 | dataset: str = "llamaguard_toxicchat_dataset" 40 | train_split: str = "train" 41 | test_split: str = "test" 42 | 43 | 44 | @dataclass 45 | class opnqa_steering_dataset: 46 | dataset: str = "opnqa_steering_dataset" 47 | file: str = "subpop/train/datasets/opinionqa_dataset.py:get_preprocessed_opinionqa_ce_or_wd_loss" 48 | train_split: str = "subpop/train/datasets/{dataset_path}/opnqa_500_{steering_type}_train.csv" 49 | valid_split: str = "subpop/train/datasets/{dataset_path}/opnqa_500_{steering_type}_val.csv" 50 | test_split: str = "subpop/train/datasets/{dataset_path}/opnqa_500_{steering_type}_test.csv" 51 | 52 | @dataclass 53 | class opnqa_single_demographic_dataset: 54 | dataset: str = "opnqa_single_demographic_dataset" 55 | file: str = "subpop/train/datasets/opinionqa_dataset.py:get_preprocessed_opinionqa_ce_or_wd_loss" 56 | train_split: str = "subpop/train/datasets/{dataset_path}/opnqa_500_{attribute}_{group}_{steering_type}_train.csv" 57 | valid_split: str = "subpop/train/datasets/{dataset_path}/opnqa_500_{attribute}_{group}_{steering_type}_val.csv" 58 | test_split: str = "subpop/train/datasets/{dataset_path}/opnqa_500_{attribute}_{group}_{steering_type}_test.csv" 59 | -------------------------------------------------------------------------------- /subpop/train/datasets/alpaca_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | # For dataset details visit: https://crfm.stanford.edu/2023/03/13/alpaca.html 5 | 6 | import copy 7 | import json 8 | 9 | import torch 10 | from torch.utils.data import Dataset 11 | 12 | 13 | PROMPT_DICT = { 14 | "prompt_input": ( 15 | "Below is an instruction that describes a task, paired with an input that provides further context. " 16 | "Write a response that appropriately completes the request.\n\n" 17 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" 18 | ), 19 | "prompt_no_input": ( 20 | "Below is an instruction that describes a task. " 21 | "Write a response that appropriately completes the request.\n\n" 22 | "### Instruction:\n{instruction}\n\n### Response:" 23 | ), 24 | } 25 | 26 | class InstructionDataset(Dataset): 27 | def __init__(self, dataset_config, tokenizer, partition="train"): 28 | self.ann = json.load(open(dataset_config.data_path)) 29 | # Use 5% of the dataset for evaluation 30 | eval_length = int(len(self.ann)/20) 31 | if partition == "train": 32 | self.ann = self.ann[eval_length:] 33 | else: 34 | self.ann = self.ann[:eval_length] 35 | 36 | self.tokenizer = tokenizer 37 | 38 | def __len__(self): 39 | return len(self.ann) 40 | 41 | def __getitem__(self, index): 42 | IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss 43 | 44 | 45 | ann = self.ann[index] 46 | if ann.get("input", "") == "": 47 | prompt = PROMPT_DICT["prompt_no_input"].format_map(ann) 48 | else: 49 | prompt = PROMPT_DICT["prompt_input"].format_map(ann) 50 | example = prompt + ann["output"] 51 | prompt = torch.tensor( 52 | self.tokenizer.encode(prompt), dtype=torch.int64 53 | ) 54 | example = self.tokenizer.encode(example) 55 | example.append(self.tokenizer.eos_token_id) 56 | example = torch.tensor( 57 | example, dtype=torch.int64 58 | ) 59 | labels = copy.deepcopy(example) 60 | labels[: len(prompt)] = -1 61 | example_mask = example.ge(0) 62 | label_mask = labels.ge(0) 63 | example[~example_mask] = 0 64 | labels[~label_mask] = IGNORE_INDEX 65 | 66 | return { 67 | "input_ids": example.tolist(), 68 | "labels": labels.tolist(), 69 | "attention_mask":example_mask.tolist(), 70 | } 71 | -------------------------------------------------------------------------------- /subpop/train/datasets/custom_dataset.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from pathlib import Path 3 | from transformers import DataCollatorForSeq2Seq 4 | 5 | def load_module_from_py_file(py_file: str) -> object: 6 | """ 7 | This method loads a module from a py file which is not in the Python path 8 | """ 9 | module_name = Path(py_file).name 10 | loader = importlib.machinery.SourceFileLoader(module_name, py_file) 11 | spec = importlib.util.spec_from_loader(module_name, loader) 12 | module = importlib.util.module_from_spec(spec) 13 | 14 | loader.exec_module(module) 15 | 16 | return module 17 | 18 | 19 | def get_custom_dataset(dataset_config, tokenizer, split: str, chat_template: bool=False): 20 | if ":" in dataset_config.file: 21 | module_path, func_name = dataset_config.file.split(":") 22 | else: 23 | module_path, func_name = dataset_config.file, "get_custom_dataset" 24 | 25 | if not module_path.endswith(".py"): 26 | raise ValueError(f"Dataset file {module_path} is not a .py file.") 27 | 28 | module_path = Path(module_path) 29 | if not module_path.is_file(): 30 | raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.") 31 | 32 | module = load_module_from_py_file(module_path.as_posix()) 33 | try: 34 | return getattr(module, func_name)(dataset_config, tokenizer, split, chat_template) 35 | except AttributeError as e: 36 | print(f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()}).") 37 | raise e 38 | 39 | def get_data_collator(dataset_processer,dataset_config): 40 | if ":" in dataset_config.file: 41 | module_path, func_name = dataset_config.file.split(":") 42 | else: 43 | module_path, func_name = dataset_config.file, "get_data_collator" 44 | 45 | if not module_path.endswith(".py"): 46 | raise ValueError(f"Dataset file {module_path} is not a .py file.") 47 | 48 | module_path = Path(module_path) 49 | if not module_path.is_file(): 50 | raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.") 51 | 52 | module = load_module_from_py_file(module_path.as_posix()) 53 | try: 54 | return getattr(module, func_name)(dataset_processer) 55 | except AttributeError as e: 56 | print(f"Can not find the custom data_collator in the dataset.py file ({module_path.as_posix()}).") 57 | print("Using the default data_collator instead.") 58 | return None 59 | 60 | class NoLabelDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): 61 | def __call__(self, batch): 62 | batch = super().__call__(batch) 63 | if "labels" in batch: 64 | del batch["labels"] 65 | return batch 66 | 67 | def custom_collator_no_labels(dataset_processer, dataset_config): 68 | return NoLabelDataCollatorForSeq2Seq(dataset_processer) -------------------------------------------------------------------------------- /scripts/data_generation/refine_question.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pathlib 4 | from concurrent.futures import ProcessPoolExecutor, as_completed 5 | from typing import Dict, List, Tuple 6 | 7 | from subpop.utils.surveydata_utils import ActualSurveyData 8 | 9 | 10 | def process_qkey(qkey: str, wave_number: int) -> Tuple[str, Dict, str]: 11 | """ 12 | Args: 13 | qkey: question identifier string 14 | wave_number: wave number the question belongs to 15 | Returns: 16 | tuple: (qkey, refined_data, error flag) 17 | """ 18 | surveydata = ActualSurveyData( 19 | wave_list=[wave_number], 20 | bank_qkeys=set(), 21 | query_qkeys=set(), 22 | data_dir=ROOT_DIR / "data" / "subpop-train", 23 | ) 24 | try: 25 | print(f"--> process_qkey: working on qkey {qkey}.") 26 | original_qbody = surveydata.fetch_question_body(qkey) 27 | refined_qbody = surveydata.refine_question_body(qkey) 28 | return ( 29 | qkey, 30 | {"original_qbody": original_qbody, "refined_qbody": refined_qbody}, 31 | None, 32 | ) 33 | except Exception as e: 34 | print(f"--> process_qkey: failed to work on qkey {qkey}.") 35 | return qkey, None, str(e) 36 | 37 | 38 | if __name__ == "__main__": 39 | 40 | ROOT_DIR = pathlib.Path(__file__).resolve().parents[2] 41 | input_dir = ROOT_DIR / "data" / "subpop-train" 42 | os.makedirs(input_dir / "processed", exist_ok=True) 43 | 44 | with open(input_dir / f"subpop-train_parsed-qkeys.json", "r") as f: 45 | qkey_dict = json.load(f) 46 | for wave_number, qkeys in qkey_dict.items(): 47 | qkey_dict[wave_number] = list(set(qkeys)) 48 | 49 | refined_qbody_dict = {} 50 | error_qkeys_list = [] 51 | 52 | for wave_idx, qkeys_list in qkey_dict.items(): 53 | print(f"--> main: working on wave {wave_idx}.") 54 | wave_number = int(wave_idx.replace("W", "")) 55 | 56 | with ProcessPoolExecutor() as executor: 57 | futures = { 58 | executor.submit(process_qkey, qkey, wave_number): qkey 59 | for qkey in qkeys_list 60 | } 61 | for future in as_completed(futures): 62 | qkey = futures[future] 63 | try: 64 | qkey, refined_data, error = future.result() 65 | if refined_data: 66 | refined_qbody_dict[qkey] = refined_data 67 | if error: 68 | error_qkeys_list.append(qkey) 69 | except Exception as e: 70 | print(f"--> main: unhandled exception for qkey {qkey}: {e}") 71 | error_qkeys_list.append(qkey) 72 | 73 | with open(input_dir / "processed" / f"refined_qkey_dict.json", "w") as f: 74 | json.dump(refined_qbody_dict, f, indent=4) 75 | if error_qkeys_list: 76 | with open(input_dir / "processed" / f"error_qkeys_list.json", "w") as f: 77 | json.dump(error_qkeys_list, f, indent=4) 78 | -------------------------------------------------------------------------------- /subpop/train/utils/plot_metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | import json 5 | import matplotlib.pyplot as plt 6 | import argparse 7 | import os 8 | 9 | def plot_metric(data, metric_name, x_label, y_label, title, colors): 10 | plt.figure(figsize=(7, 6)) 11 | 12 | plt.plot(data[f'train_epoch_{metric_name}'], label=f'Train Epoch {metric_name.capitalize()}', color=colors[0]) 13 | plt.plot(data[f'val_epoch_{metric_name}'], label=f'Validation Epoch {metric_name.capitalize()}', color=colors[1]) 14 | plt.xlabel(x_label) 15 | plt.ylabel(y_label) 16 | plt.title(f'Train and Validation Epoch {title}') 17 | plt.legend() 18 | plt.tight_layout() 19 | 20 | def plot_single_metric_by_step(data, metric_name, x_label, y_label, title, color): 21 | plt.plot(data[f'{metric_name}'], label=f'{title}', color=color) 22 | plt.xlabel(x_label) 23 | plt.ylabel(y_label) 24 | plt.title(title) 25 | plt.legend() 26 | plt.tight_layout() 27 | 28 | def plot_metrics_by_step(data, metric_name, x_label, y_label, colors): 29 | plt.figure(figsize=(14, 6)) 30 | 31 | plt.subplot(1, 2, 1) 32 | plot_single_metric_by_step(data, f'train_step_{metric_name}', x_label, y_label, f'Train Step {metric_name.capitalize()}', colors[0]) 33 | plt.subplot(1, 2, 2) 34 | plot_single_metric_by_step(data, f'val_step_{metric_name}', x_label, y_label, f'Validation Step {metric_name.capitalize()}', colors[1]) 35 | plt.tight_layout() 36 | 37 | 38 | def plot_metrics(file_path): 39 | if not os.path.exists(file_path): 40 | print(f"File {file_path} does not exist.") 41 | return 42 | 43 | with open(file_path, 'r') as f: 44 | try: 45 | data = json.load(f) 46 | except json.JSONDecodeError: 47 | print("Invalid JSON file.") 48 | return 49 | 50 | directory = os.path.dirname(file_path) 51 | filename_prefix = os.path.basename(file_path).split('.')[0] 52 | 53 | plot_metric(data, 'loss', 'Epoch', 'Loss', 'Loss', ['b', 'r']) 54 | plt.savefig(os.path.join(directory, f"{filename_prefix}_train_and_validation_loss.png")) 55 | plt.close() 56 | 57 | plot_metric(data, 'perplexity', 'Epoch', 'Perplexity', 'Perplexity', ['g', 'm']) 58 | plt.savefig(os.path.join(directory, f"{filename_prefix}_train_and_validation_perplexity.png")) 59 | plt.close() 60 | 61 | plot_metrics_by_step(data, 'loss', 'Step', 'Loss', ['b', 'r']) 62 | plt.savefig(os.path.join(directory, f"{filename_prefix}_train_and_validation_loss_by_step.png")) 63 | plt.close() 64 | 65 | plot_metrics_by_step(data, 'perplexity', 'Step', 'Loss', ['g', 'm']) 66 | plt.savefig(os.path.join(directory, f"{filename_prefix}_train_and_validation_perplexity_by_step.png")) 67 | plt.close() 68 | 69 | if __name__ == "__main__": 70 | parser = argparse.ArgumentParser(description='Plot metrics from JSON file.') 71 | parser.add_argument('--file_path', required=True, type=str, help='Path to the metrics JSON file.') 72 | args = parser.parse_args() 73 | 74 | plot_metrics(args.file_path) 75 | -------------------------------------------------------------------------------- /data/subpop-eval/subpop-eval_parsed-qkeys.json: -------------------------------------------------------------------------------- 1 | { 2 | "2022": [ 3 | "natenvir", 4 | "natheal", 5 | "nateduc", 6 | "natrace", 7 | "natarms", 8 | "natfare", 9 | "natroad", 10 | "natsoc", 11 | "natmass", 12 | "natsci", 13 | "natenrgy", 14 | "uswary", 15 | "prayer", 16 | "courts", 17 | "discaffw", 18 | "discaffm", 19 | "fechld", 20 | "fepresch", 21 | "fefam", 22 | "incom16", 23 | "dwelown16", 24 | "mawrkgrw", 25 | "marital", 26 | "wrkgovt1", 27 | "wrkgovt2", 28 | "wksub1", 29 | "conbus", 30 | "conclerg", 31 | "coneduc", 32 | "confed", 33 | "confinan", 34 | "conjudge", 35 | "conlegis", 36 | "conmedic", 37 | "conpress", 38 | "consci", 39 | "contv", 40 | "vetyears", 41 | "joblose", 42 | "jobfind", 43 | "unemp", 44 | "union1", 45 | "spkath", 46 | "colath", 47 | "spkrac", 48 | "colrac", 49 | "librac", 50 | "spkcom", 51 | "colcom", 52 | "libcom", 53 | "polhitok", 54 | "polabuse", 55 | "polattak", 56 | "grass", 57 | "gunlaw", 58 | "owngun", 59 | "hunt1", 60 | "class", 61 | "satfin", 62 | "finalter", 63 | "finrela", 64 | "racdif1", 65 | "racdif2", 66 | "racdif4", 67 | "wlthblks", 68 | "wlthhsps", 69 | "racwork", 70 | "letin1a", 71 | "getahead", 72 | "aged", 73 | "parsol", 74 | "kidssol", 75 | "spanking", 76 | "divlaw", 77 | "pillok", 78 | "xmarsex", 79 | "homosex", 80 | "discaff", 81 | "abnomore", 82 | "abhlth", 83 | "abpoor", 84 | "abrape", 85 | "abany", 86 | "suicide1", 87 | "suicide2", 88 | "suicide3", 89 | "suicide4", 90 | "pornlaw", 91 | "fair", 92 | "helpful", 93 | "trust", 94 | "tax", 95 | "pres16", 96 | "news", 97 | "attend", 98 | "pray", 99 | "bible", 100 | "reborn", 101 | "savesoul", 102 | "relpersn", 103 | "sprtprsn", 104 | "granborn", 105 | "dwelown", 106 | "health", 107 | "webmob", 108 | "richwork", 109 | "natcity", 110 | "fehire", 111 | "conlabor", 112 | "happy", 113 | "satjob", 114 | "cappun", 115 | "racdif3", 116 | "marhomo", 117 | "abdefect", 118 | "vote16", 119 | "if16who", 120 | "born", 121 | "life", 122 | "natspac", 123 | "natdrug", 124 | "nataid", 125 | "natpark", 126 | "natchld", 127 | "fepol", 128 | "conarmy", 129 | "wlthwhts", 130 | "sexeduc", 131 | "absingle", 132 | "letdie1", 133 | "postlife", 134 | "othlang", 135 | "xmovie" 136 | ] 137 | } -------------------------------------------------------------------------------- /subpop/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | from datetime import datetime 4 | from dataclasses import dataclass 5 | from typing import Union 6 | 7 | import pandas as pd 8 | 9 | 10 | def read_csv(path: str) -> pd.DataFrame: 11 | df = pd.read_csv(path) 12 | df = df.dropna() 13 | df = df.reset_index(drop=True) 14 | return df 15 | 16 | 17 | # def df_to_dataclass_list(df: pd.DataFrame, dataclass_type: dataclass): 18 | # # Create list of dataclass objects by using df.itertuples 19 | # return [ 20 | # dataclass_type(**row._asdict()) 21 | # for row in df.itertuples(index=False, name="Pandas") 22 | # ] 23 | 24 | 25 | def get_config_path(path: Union[str, pathlib.Path]) -> pathlib.Path: 26 | """ 27 | Get the path to the config file 28 | 29 | Args: 30 | path (Union[str, pathlib.Path]): Path to the config file 31 | 32 | Returns: 33 | pathlib.Path: Path to the config file 34 | """ 35 | if isinstance(path, str): 36 | path = pathlib.Path(path) 37 | 38 | return pathlib.Path(__file__).resolve().parents[3] / "configs" / path 39 | 40 | 41 | def save_result_to_csv(output: dict, output_data_path: str): 42 | df = pd.DataFrame.from_dict(output, orient="index") 43 | df.reset_index(inplace=True, drop=True) 44 | df.to_csv(output_data_path) 45 | 46 | 47 | def save_result_to_pkl(output: dict, output_data_path: str): 48 | df = pd.DataFrame.from_dict(output, orient="index") 49 | df.to_pickle(output_data_path) 50 | 51 | 52 | def publish_result(output: Union[dict, pd.DataFrame], publish_dir: str, filename: str): 53 | if not os.path.isdir(publish_dir): 54 | os.mkdir(publish_dir) 55 | publish_result_path = os.path.join(publish_dir, filename) 56 | if os.path.exists(publish_result_path): 57 | # Append timestamp to filename 58 | timestamp = datetime.now().strftime("%Y%m%d%H%M%S") 59 | 60 | if filename.endswith(".csv"): 61 | filename = filename.split(".")[0] 62 | new_filename = f"{filename}_{timestamp}.csv" 63 | elif filename.endswith(".pkl") or filename.endswith(".pickle"): 64 | filename = filename.split(".")[0] 65 | new_filename = f"{filename}_{timestamp}.pkl" 66 | 67 | publish_result_path = os.path.join(publish_dir, new_filename) 68 | 69 | if isinstance(output, dict): 70 | if publish_result_path.endswith(".csv"): 71 | save_result_to_csv(output, publish_result_path) 72 | elif publish_result_path.endswith(".pkl"): 73 | save_result_to_pkl(output, publish_result_path) 74 | else: 75 | raise ValueError( 76 | f"Unsupported file format: {os.path.basename(publish_result_path)}" 77 | ) 78 | 79 | elif isinstance(output, pd.DataFrame): 80 | if publish_result_path.endswith(".csv"): 81 | output.to_csv(publish_result_path) 82 | elif publish_result_path.endswith(".pkl"): 83 | output.to_pickle(publish_result_path) 84 | else: 85 | raise ValueError( 86 | f"Unsupported file format: {os.path.basename(publish_result_path)}" 87 | ) 88 | 89 | else: 90 | raise ValueError(f"Unsupported output type: {type(output)}") 91 | 92 | os.chmod(publish_result_path, 0o777) 93 | -------------------------------------------------------------------------------- /subpop/train/configs/training.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | from dataclasses import dataclass 5 | 6 | 7 | @dataclass 8 | class train_config: 9 | model_name: str="PATH/TO/MODEL/NAME" 10 | tokenizer_name: str=None 11 | enable_fsdp: bool=True # shards model parameters, optimizer states and gradients across DDP ranks 12 | low_cpu_fsdp: bool=True # saves cpu memory by loading pretrained model on rank0 only 13 | run_validation: bool=True # whether to run validation every epoch 14 | run_test: bool=False # whether to run test when best validation loss is found 15 | batch_size_training: int=32 16 | batching_strategy: str="padding" 17 | context_length: int=4096 18 | gradient_accumulation_steps: int=1 19 | gradient_clipping: bool = False 20 | gradient_clipping_threshold: float = 1.0 21 | num_epochs: int=3 22 | max_train_step: int=0 23 | max_eval_step: int=0 24 | num_workers_dataloader: int=8 25 | weight_decay: float=0.0 26 | 27 | # learning rate and scheduler hyperparameter 28 | lr: float=1e-4 29 | which_scheduler: str="step" # step, cosine 30 | gamma: float= 0.85 # multiplicatively decay the learning rate each step, so that per epoch LR is multiplied by gamma 31 | warmup_ratio: float=0.1 # ratio of total steps to warmup to the total number of steps 32 | 33 | seed: int=42 34 | use_fp16: bool=False 35 | mixed_precision: bool=True 36 | val_batch_size: int=1 37 | dataset = "samsum_dataset" 38 | peft_method: str = "lora" # None, llama_adapter (Caution: llama_adapter is currently not supported with FSDP) 39 | use_peft: bool=True # use parameter efficient fine tuning 40 | from_peft_checkpoint: str="" # if not empty and use_peft=True, will load the peft checkpoint and resume the fine-tuning on that checkpoint 41 | output_dir: str = "PATH/TO/OUTPUT/DIRECTORY" 42 | freeze_layers: bool = False 43 | num_freeze_layers: int = 1 44 | quantization: str = None 45 | one_gpu: bool = False 46 | save_model: bool = True 47 | dist_checkpoint_root_folder: str="PATH/TO/FULL/FINETUNING/CHECKPOINT" # will be used if using FSDP 48 | dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP 49 | save_optimizer: bool=True # will be used if using FSDP 50 | use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels 51 | use_wandb: bool = True # Enable wandb for experiment tracking 52 | save_metrics: bool = True # saves training metrics to a json file for later plotting 53 | flop_counter: bool = False # Enable flop counter to measure model throughput, can not be used with pytorch profiler at the same time. 54 | flop_counter_start: int = 3 # The step to start profiling, default is 3, which means after 3 steps of warmup stage, the profiler will start to count flops. 55 | use_profiler: bool = False # Enable pytorch profiler, can not be used with flop counter at the same time. 56 | profiler_dir: str = "PATH/to/save/profiler/results" # will be used if using profiler 57 | loss_function_type: str = "ce" # ce, wd. Determines the loss function to be used for training. 58 | dataset_path: str = "subpop-train" # path to the dataset 59 | steering_type: str = "QA" # QA, BIO, PORTRAY: steering type for the input prompt. By default, QA 60 | attribute: str = "None" # only used when the finetuning is on a single subpopulation data 61 | group: str = "None" # only used when the finetuning is on a single subpopulation data 62 | model_nickname: str = "MODEL_NICKNAME" # model nickname, for example, llama-2-7b-base 63 | is_chat: bool = False # True if the base model is chat model -------------------------------------------------------------------------------- /subpop/train/utils/flop_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Union 2 | import time 3 | import torch 4 | from torch.utils.flop_counter import FlopCounterMode 5 | 6 | 7 | class FlopMeasure(FlopCounterMode): 8 | """ 9 | ``FlopMeasure`` is a customized context manager that counts the number of 10 | flops within its context. It is based on ``FlopCounterMode`` with additional start_counting() and stop_counting() function so that the flop counting 11 | will only start after the warmup stage. 12 | It also supports hierarchical output by passing a module (or list of modules) to FlopCounterMode on construction. 13 | 14 | Example usage 15 | 16 | .. code-block:: python 17 | 18 | model = ... 19 | flop_counter = FlopMeasure(model,local_rank=0,warmup_step=3) 20 | for batch in enumerate(dataloader): 21 | with flop_counter: 22 | model(batch) 23 | flop_counter.step() 24 | """ 25 | 26 | def __init__( 27 | self, 28 | mods: Optional[Union[torch.nn.Module, List[torch.nn.Module]]] = None, 29 | depth: int = 2, 30 | display: bool = True, 31 | custom_mapping: Dict[Any, Any] = None, 32 | rank=None, 33 | warmup_step: int = 3, 34 | ): 35 | super().__init__(mods, depth, display, custom_mapping) 36 | self.rank = rank 37 | self.warmup_step = warmup_step 38 | self.start_time = 0 39 | self.end_time = 0 40 | 41 | def step(self): 42 | # decrease the warmup step by 1 for every step, so that the flop counting will start when warmup_step =0. Stop decreasing when warm_up reaches -1. 43 | if self.warmup_step >= 0: 44 | self.warmup_step -= 1 45 | if self.warmup_step == 0 and self.start_time == 0: 46 | self.start_time = time.time() 47 | elif self.warmup_step == -1 and self.start_time != 0 and self.end_time == 0: 48 | self.end_time = time.time() 49 | def __enter__(self): 50 | if self.warmup_step == 0: 51 | self.start_time = time.time() 52 | super().__enter__() 53 | return self 54 | def is_done(self): 55 | return self.warmup_step == -1 56 | def get_total_flops(self): 57 | return super().get_total_flops() 58 | def get_flops_per_sec(self): 59 | if self.start_time == 0 or self.end_time == 0: 60 | print("Warning: flop count did not finish correctly") 61 | return 0 62 | return super().get_total_flops()/ (self.end_time - self.start_time) 63 | def get_table(self, depth=2): 64 | return super().get_table(depth) 65 | 66 | def __exit__(self, *args): 67 | if self.get_total_flops() == 0: 68 | print( 69 | "Warning: did not record any flops this time. Skipping the flop report" 70 | ) 71 | else: 72 | if self.display: 73 | if self.rank is None or self.rank == 0: 74 | print("Total time used in this flop counting step is: {}".format(self.end_time - self.start_time)) 75 | print("The total TFlop per second is: {}".format(self.get_flops_per_sec() / 1e12)) 76 | print("The tflop_count table is below:") 77 | print(self.get_table(self.depth)) 78 | # Disable the display feature so that we don't print the table again 79 | self.display = False 80 | super().__exit__(*args) 81 | 82 | def __torch_dispatch__(self, func, types, args=(), kwargs=None): 83 | # when warmup_step is 0, count the flops and return the original output 84 | if self.warmup_step == 0: 85 | return super().__torch_dispatch__(func, types, args, kwargs) 86 | # otherwise, just return the original output 87 | return func(*args, **kwargs) 88 | -------------------------------------------------------------------------------- /subpop/train/utils/fsdp_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | from torch.distributed._tensor.device_mesh import init_device_mesh 4 | import os 5 | 6 | def fsdp_auto_wrap_policy(model, transformer_layer_names): 7 | import functools 8 | 9 | from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy 10 | 11 | def lambda_policy_fn(module): 12 | if ( 13 | len(list(module.named_children())) == 0 14 | and getattr(module, "weight", None) is not None 15 | and module.weight.requires_grad 16 | ): 17 | return True 18 | return False 19 | 20 | lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn) 21 | transformer_wrap_policy = functools.partial( 22 | transformer_auto_wrap_policy, 23 | transformer_layer_cls=set(transformer_layer_names) 24 | ) 25 | 26 | auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy]) 27 | return auto_wrap_policy 28 | 29 | 30 | def hsdp_device_mesh(replica_group_size, sharding_group_size, device=None): 31 | """ 32 | Initializes a device mesh for use with Hybrid Sharding strategy in FSDP (HSDP) training. 33 | 34 | This function requires explicit sizes for replica and sharding groups to accommodate models 35 | whose GPU fit is unknown, providing flexibility in distributed training setups. 36 | 37 | Args: 38 | replica_group_size (int): The size of each replica group. Must be provided to ensure 39 | the model fits within the available resources. 40 | sharding_group_size (int): The size of each sharding group that the model can fit. Must be provided to 41 | ensure the correct distribution of model parameters. 42 | device (str, optional): The device to use (e.g., "cuda:0"). If None, defaults to "cuda" 43 | with the local rank as the device index. 44 | 45 | Returns: 46 | A device mesh object compatible with FSDP. 47 | 48 | Raises: 49 | ValueError: If replica_group_size or sharding_group_size are not provided, or if the 50 | world size is not evenly divisible by the sharding group size. 51 | RuntimeError: If a valid device mesh cannot be created. 52 | 53 | Usage: 54 | If your model fits on 4 GPUS, and you have 3 nodes of 8 GPUs, then: 55 | Sharding_Group_Size = 4 56 | Replica_Groups_Size = (24 total gpus, 4 per sharding group) = 6 Replica Groups 57 | >>> device_mesh = initialize_device_mesh(replica_group_size, sharding_group_size) 58 | >>> sharded_model = FSDP(model, device_mesh=device_mesh, ...) 59 | """ 60 | 61 | if replica_group_size is None or sharding_group_size is None: 62 | raise ValueError("Both replica_group_size and sharding_group_size must be provided.") 63 | 64 | local_rank = int(os.getenv("LOCAL_RANK", "0")) 65 | world_size = int(os.getenv("WORLD_SIZE", "1")) 66 | 67 | device = device or f"cuda" 68 | 69 | if world_size % sharding_group_size != 0: 70 | raise ValueError(f"World size {world_size} is not evenly divisible by " 71 | f"sharding group size {sharding_group_size}.") 72 | 73 | if (world_size // sharding_group_size) % replica_group_size != 0: 74 | raise ValueError(f"The calculated number of replica groups is not evenly divisible by " 75 | f"replica_group_size {replica_group_size}.") 76 | 77 | device_mesh = init_device_mesh(device, (replica_group_size, sharding_group_size)) 78 | if device_mesh is None: 79 | raise RuntimeError("Failed to create a valid device mesh.") 80 | 81 | return device_mesh 82 | -------------------------------------------------------------------------------- /subpop/train/data/llama_guard/finetuning_data_formatter_example.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | from finetuning_data_formatter import TrainingExample, Guidelines, Category, LlamaGuardPromptConfigs, LlamaGuardGenerationConfigs, ExplanationPosition, AugmentationConfigs, FormatterConfigs, create_formatted_finetuning_examples 5 | 6 | # Define a list of guidelines and training examples to match 7 | guidelines = Guidelines( 8 | categories=[ 9 | Category(name="Personal Information", 10 | description="Information that can identify an individual"), 11 | Category(name="Location", 12 | description="Information about where someone lives or works") 13 | ], 14 | category_code_prefix="O" 15 | ) 16 | 17 | training_examples = [ 18 | TrainingExample( 19 | prompt="What is the home address of John Smith?", 20 | response="Jane Doe's phone number is 555-555-5555.", 21 | violated_category_codes=["O1"], 22 | label="unsafe", 23 | explanation="The response contains personal information." 24 | ), 25 | TrainingExample( 26 | prompt="What is the home address of John Smith?", 27 | response="The home address of John Smith is 123 Main St, Anytown USA", 28 | violated_category_codes=["O2"], 29 | label="unsafe", 30 | explanation="The response reveals the user's location." 31 | ), 32 | TrainingExample( 33 | prompt="What is the home address of John Smith?", 34 | response="N/A", 35 | violated_category_codes=["O2"], 36 | label="unsafe", 37 | explanation="The prompt asks for a home address." 38 | ), 39 | TrainingExample( 40 | prompt="What is your favorite color?", 41 | response="My favorite color is blue.", 42 | violated_category_codes=[], 43 | label="safe", 44 | explanation="The response does not contain any sensitive information." 45 | ) 46 | ] 47 | 48 | # Setting up configs 49 | llama_guard_prompt_configs = LlamaGuardPromptConfigs( 50 | instructions_format_string="""[INST] Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories. 51 | 52 | 53 | {guidelines} 54 | 55 | 56 | 57 | 58 | {conversation} 59 | 60 | 61 | 62 | Provide your safety assessment for $agent_type in the above conversation: 63 | - First line must read 'safe' or 'unsafe'. 64 | - If unsafe, a second line must include a comma-separated list of violated categories. [/INST]""", 65 | should_include_category_descriptions=True, 66 | should_shuffle_category_codes=True 67 | ) 68 | 69 | llama_guard_generation_configs = LlamaGuardGenerationConfigs( 70 | should_list_violated_codes=True, 71 | explanation_position=ExplanationPosition.AFTER_DECISION 72 | ) 73 | 74 | augmentation_configs = AugmentationConfigs( 75 | should_add_examples_with_dropped_nonviolated_prompt_categories=True, 76 | should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories=True, 77 | explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories="Response is safe because any riskiness in it is related to categories we're not trying to detect." 78 | ) 79 | 80 | formatter_configs = FormatterConfigs( 81 | guidelines=guidelines, 82 | llama_guard_prompt_configs=llama_guard_prompt_configs, 83 | llama_guard_generation_configs=llama_guard_generation_configs, 84 | augmentation_configs=augmentation_configs, 85 | random_seed=42 86 | ) 87 | 88 | # Call the create_formatted_finetuning_examples function 89 | formatted_examples = create_formatted_finetuning_examples( 90 | training_examples, formatter_configs) 91 | 92 | # Print the formatted examples 93 | print(formatted_examples) 94 | -------------------------------------------------------------------------------- /subpop/train/utils/memory_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | import gc 5 | import psutil 6 | import threading 7 | 8 | import torch 9 | from accelerate.utils import is_xpu_available 10 | 11 | def byte2gb(x): 12 | return int(x / 2**30) 13 | # This context manager is used to track the peak memory usage of the process 14 | class MemoryTrace: 15 | def __enter__(self): 16 | gc.collect() 17 | if is_xpu_available(): 18 | torch.xpu.empty_cache() 19 | torch.xpu.reset_max_memory_allocated() # reset the peak gauge to zero 20 | self.begin = byte2gb(torch.xpu.memory_allocated()) 21 | elif torch.cuda.is_available(): 22 | torch.cuda.empty_cache() 23 | torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero 24 | self.begin = byte2gb(torch.cuda.memory_allocated()) 25 | self.process = psutil.Process() 26 | self.cpu_begin = byte2gb(self.cpu_mem_used()) 27 | self.peak_monitoring = True 28 | peak_monitor_thread = threading.Thread(target=self.peak_monitor_func) 29 | peak_monitor_thread.daemon = True 30 | peak_monitor_thread.start() 31 | return self 32 | 33 | def cpu_mem_used(self): 34 | """get resident set size memory for the current process""" 35 | return self.process.memory_info().rss 36 | 37 | def peak_monitor_func(self): 38 | self.cpu_peak = -1 39 | 40 | while True: 41 | self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak) 42 | 43 | # can't sleep or will not catch the peak right (this comment is here on purpose) 44 | # time.sleep(0.001) # 1msec 45 | 46 | if not self.peak_monitoring: 47 | break 48 | 49 | def __exit__(self, *exc): 50 | self.peak_monitoring = False 51 | 52 | gc.collect() 53 | if is_xpu_available(): 54 | torch.xpu.empty_cache() 55 | self.end = byte2gb(torch.xpu.memory_allocated()) 56 | self.peak = byte2gb(torch.xpu.max_memory_allocated()) 57 | xpu_info = torch.xpu.memory_stats() 58 | self.peak_active_gb = byte2gb(xpu_info["active_bytes.all.peak"]) 59 | self.malloc_retries = xpu_info.get("num_alloc_retries", 0) 60 | self.peak_active_gb = byte2gb(xpu_info["active_bytes.all.peak"]) 61 | self.m_ooms = xpu_info.get("num_ooms", 0) 62 | self.used = byte2gb(self.end - self.begin) 63 | self.peaked = byte2gb(self.peak - self.begin) 64 | self.max_reserved = byte2gb(torch.xpu.max_memory_reserved()) 65 | elif torch.cuda.is_available(): 66 | torch.cuda.empty_cache() 67 | self.end = byte2gb(torch.cuda.memory_allocated()) 68 | self.peak = byte2gb(torch.cuda.max_memory_allocated()) 69 | cuda_info = torch.cuda.memory_stats() 70 | self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"]) 71 | self.malloc_retries = cuda_info.get("num_alloc_retries", 0) 72 | self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"]) 73 | self.m_ooms = cuda_info.get("num_ooms", 0) 74 | self.used = byte2gb(self.end - self.begin) 75 | self.peaked = byte2gb(self.peak - self.begin) 76 | self.max_reserved = byte2gb(torch.cuda.max_memory_reserved()) 77 | 78 | self.cpu_end = self.cpu_mem_used() 79 | self.cpu_used = byte2gb(self.cpu_end - self.cpu_begin) 80 | self.cpu_peaked = byte2gb(self.cpu_peak - self.cpu_begin) 81 | # print(f"delta used/peak {self.used:4d}/{self.peaked:4d}") 82 | 83 | def print_stats(self): 84 | device_str = None 85 | if is_xpu_available(): 86 | device_str = "XPU" 87 | elif torch.cuda.is_available(): 88 | device_str = "CUDA" 89 | 90 | if device_str: 91 | print(f"Max {device_str} memory allocated was {self.peak} GB") 92 | print(f"Max {device_str} memory reserved was {self.max_reserved} GB") 93 | print(f"Peak active {device_str} memory was {self.peak_active_gb} GB") 94 | print(f"{device_str} Malloc retries : {self.malloc_retries}") 95 | print(f"CPU Total Peak Memory consumed during the train (max): {self.cpu_peaked + self.cpu_begin} GB") -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ## SubPOP-train 2 | # raw data file from American Trends Panel 3 | data/subpop-train/*.sav 4 | # intermediate data files 5 | data/subpop-train/processed/*.csv 6 | data/subpop-train/processed/refined_qkey_dict.json 7 | 8 | ## SubPOP-eval 9 | # raw data file from General Social Survey 10 | data/subpop-eval/*.dta 11 | # intermediate data files 12 | data/subpop-eval/processed/*.csv 13 | 14 | ## OpinionQA 15 | data/opinionqa/processed/*.csv 16 | !data/opinionqa/processed/opinionqa.csv 17 | 18 | # Byte-compiled / optimized / DLL files 19 | __pycache__/ 20 | *.py[cod] 21 | *$py.class 22 | 23 | # C extensions 24 | *.so 25 | 26 | # Distribution / packaging 27 | .Python 28 | build/ 29 | develop-eggs/ 30 | dist/ 31 | downloads/ 32 | eggs/ 33 | .eggs/ 34 | lib/ 35 | lib64/ 36 | parts/ 37 | sdist/ 38 | var/ 39 | wheels/ 40 | share/python-wheels/ 41 | *.egg-info/ 42 | .installed.cfg 43 | *.egg 44 | MANIFEST 45 | 46 | # PyInstaller 47 | # Usually these files are written by a python script from a template 48 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 49 | *.manifest 50 | *.spec 51 | 52 | # Installer logs 53 | pip-log.txt 54 | pip-delete-this-directory.txt 55 | 56 | # Unit test / coverage reports 57 | htmlcov/ 58 | .tox/ 59 | .nox/ 60 | .coverage 61 | .coverage.* 62 | .cache 63 | nosetests.xml 64 | coverage.xml 65 | *.cover 66 | *.py,cover 67 | .hypothesis/ 68 | .pytest_cache/ 69 | cover/ 70 | 71 | # Translations 72 | *.mo 73 | *.pot 74 | 75 | # Django stuff: 76 | *.log 77 | local_settings.py 78 | db.sqlite3 79 | db.sqlite3-journal 80 | 81 | # Flask stuff: 82 | instance/ 83 | .webassets-cache 84 | 85 | # Scrapy stuff: 86 | .scrapy 87 | 88 | # Sphinx documentation 89 | docs/_build/ 90 | 91 | # PyBuilder 92 | .pybuilder/ 93 | target/ 94 | 95 | # Jupyter Notebook 96 | .ipynb_checkpoints 97 | 98 | # IPython 99 | profile_default/ 100 | ipython_config.py 101 | 102 | # pyenv 103 | # For a library or package, you might want to ignore these files since the code is 104 | # intended to run in multiple environments; otherwise, check them in: 105 | # .python-version 106 | 107 | # pipenv 108 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 109 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 110 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 111 | # install all needed dependencies. 112 | #Pipfile.lock 113 | 114 | # UV 115 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 116 | # This is especially recommended for binary packages to ensure reproducibility, and is more 117 | # commonly ignored for libraries. 118 | #uv.lock 119 | 120 | # poetry 121 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 122 | # This is especially recommended for binary packages to ensure reproducibility, and is more 123 | # commonly ignored for libraries. 124 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 125 | #poetry.lock 126 | 127 | # pdm 128 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 129 | #pdm.lock 130 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 131 | # in version control. 132 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 133 | .pdm.toml 134 | .pdm-python 135 | .pdm-build/ 136 | 137 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 138 | __pypackages__/ 139 | 140 | # Celery stuff 141 | celerybeat-schedule 142 | celerybeat.pid 143 | 144 | # SageMath parsed files 145 | *.sage.py 146 | 147 | # Environments 148 | .env 149 | .venv 150 | env/ 151 | venv/ 152 | ENV/ 153 | env.bak/ 154 | venv.bak/ 155 | 156 | # Spyder project settings 157 | .spyderproject 158 | .spyproject 159 | 160 | # Rope project settings 161 | .ropeproject 162 | 163 | # mkdocs documentation 164 | /site 165 | 166 | # mypy 167 | .mypy_cache/ 168 | .dmypy.json 169 | dmypy.json 170 | 171 | # Pyre type checker 172 | .pyre/ 173 | 174 | # pytype static type analyzer 175 | .pytype/ 176 | 177 | # Cython debug symbols 178 | cython_debug/ 179 | 180 | # PyCharm 181 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 182 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 183 | # and can be added to the global gitignore or merged into this file. For a more nuclear 184 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 185 | #.idea/ 186 | 187 | # PyPI configuration file 188 | .pypirc 189 | -------------------------------------------------------------------------------- /scripts/experiment/analyze_inference_result.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pandas as pd\n", 10 | "\n", 11 | "path_to_inference_data = './opnqa_QA_test_llama-2-7b-subpop-FT.csv' # change path to your data\n", 12 | "dataset_name = 'subpop-eval' # subpop-eval or opinionqa\n", 13 | "\n", 14 | "df = pd.read_csv(path_to_inference_data)\n", 15 | "HIGH_RELEVANCE_QKEYS = {\n", 16 | " # questions that have high cosine similiarity with at least one fine-tuning questions\n", 17 | " 'subpop-eval': [\n", 18 | " \"marital\", \"satjob\", \"attend\", \"relpersn\", \"othlang\"\n", 19 | " ],\n", 20 | " 'opinionqa': [\n", 21 | " \"RACESURV34b_W43\", \"TRAITPOLWF1E_W36\", \"COMMIMPB_W32\", \"GAP21Q21_a_W82\", \"COMMIMPC_W32\", \"WHADVANT_W92\", \"RACESURV14_W43\", \"GROWUPVIOL_W26\", \"GAP21Q34_f_W82\", \"WHYNOTBIZF2F_W36\", \"SATLIFEc_W50\", \"GAP21Q19_a_W82\", \"HARASS3NOWRKF2_W41\", \"RACESURV5l_W43\", \"GOVPRIOiF1_W41\", \"GAP21Q7_b_W82\", \"WORRY2a_W54\", \"WHYNOTPOLF1E_W36\", \"COMMIMPH_W32\", \"REPRSNTREP_W92\", \"RACESURV5f_W43\", \"HOMEASSIST2_W49\", \"WHYNOTPOLF1G_W36\", \"CONFe_W42\", \"GAP21Q7_a_W82\", \"RACESURV34a_W43\", \"GAYMARR2_W32\", \"TRAITPOLMF1F_W36\", \"RACESURV5i_W43\", \"REASONGUNE_W26\", \"SATLIFED_W32\", \"GAP21Q34_e_W82\", \"GAP21Q17_W82\", \"GAP21Q43_g_W82\", \"GOVREGV1_W49\", \"WHYNOTBIZF2A_W36\", \"E5_W36\", \"ELECT_IMPT3_PRVFR_W92\", \"TRAITPOLWF1G_W36\", \"TRAITPOLMF1G_W36\", \"WORRY2c_W54\", \"GOVPRIORITYd_W54\", \"WHYNOTPOLF1A_W36\", \"RACESURV5a_W43\", \"RACESURV28e_W43\", \"ECON5_i_W54\", \"GOVPRIORITYb_W54\", \"REPRSNTDEM_W92\", \"GAP21Q4_e_W82\", \"RACESURV28c_W43\", \"NEWSPROBd_W45\", \"ELECT_CONF3_PRVFR_W92\", \"GAP21Q34_c_W82\", \"CONFg_W42\", \"WORRY2e_W54\", \"RACESURV28d_W43\", \"GAP21Q34_d_W82\", \"INFOCREATEa_W45\", \"COMMIMPF_W32\", \"WHYNOTPOLF1B_W36\", \"GAP21Q4_f_W82\", \"GAP21Q4_b_W82\", \"RACESURV5b_W43\", \"GOVPRIOkF2_W41\", \"FAMSURV44_W50\", \"GAP21Q7_d_W82\", \"GOVPRIOnF2_W41\", \"MOVESUBURB_W32\", \"REASONGUNA_W26\", \"ELECT_CONF3_PRVSUP_W92\", \"GOVPRIOb_W41\", \"RACESURV28b_W43\", \"GAP21Q10_W82\", \"NEWS_PLATFORMg_W45\", \"GAP21Q34_b_W82\", \"RACESURV5g_W43\", \"CONFb_W42\", \"GAP21Q4_d_W82\", \"GUNCONTRIBA_W26\", \"ETHNCMAJMOD_W41\", \"WORRY2b_W54\", \"TRAITPOLMF1E_W36\", \"ETHNCMAJ_W32\", \"WHYNOTBIZF2D_W36\", \"ECIMPg_W54\", \"NEWSPROBe_W45\", \"ESSENPOLF1H_W36\", \"WHADVANT_W32\", \"WORRY2d_W54\", \"WHYNOTBIZF2B_W36\", \"GOVPRIOoF2_W41\", \"GOVPRIORITYc_W54\", \"SOCIETY_SSM_W92\", \"WHYNOTPOLF1I_W36\", \"PP5e_W49\", \"REASONGUND_W26\", \"TRAITPOLWF1F_W36\", \"COMMIMPA_W32\", \"CONTROLCO_W49\", \"GOVPRIOgF1_W41\", \"RACESURV5j_W43\", \"RACESURV28g_W43\", \"FAMSURV2Ma_W50\", \"ECON5_h_W54\", \"RACESURV5d_W43\", \"FAMSURV1_W50\", \"WHYNOTBIZF2G_W36\", \"RACESURV5e_W43\", \"GAP21Q4_a_W82\", \"USEXCEPT_W92\", \"GAP21Q34_a_W82\", \"GAP21Q4_c_W82\", \"GAP21Q21_e_W82\", \"RACESURV13_W43\", \"BENEFITGOV_W49\", \"USMILSIZ_W92\", \"RACESURV48_W43\", \"GAP21Q33_j_W82\", \"REASONGUNC_W26\", \"CONFc_W42\", \"COMMIMPG_W32\", \"FAMSURV2Wa_W50\", \"RACESURV28a_W43\", \"ECON1_W54\", \"REASONGUNB_W26\", \"CONCERNGRPa_W49\", \"RACESURV28f_W43\", \"WILLMOVE_W32\", \"CONCERNCO_W49\", \"E5MOD_W50\", \"GAP21Q23_W82\"\n", 22 | " ]\n", 23 | "}" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "### Get the overall (dataset-level) WD" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "df = df[~df['qkey'].isin(HIGH_RELEVANCE_QKEYS[dataset_name])]\n", 40 | "print(df['emd'].mean())" 41 | ] 42 | }, 43 | { 44 | "cell_type": "markdown", 45 | "metadata": {}, 46 | "source": [ 47 | "### Get the subpopulation level WD" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "# Example: subpopulation-level WD, for race / ethnicity group White\n", 57 | "df_white = df[df['group'] == 'White'].reset_index(drop=True)\n", 58 | "print(df_white['emd'].mean())\n", 59 | "\n", 60 | "# Example: subpopulation-level WD, for political affiliation group Republican\n", 61 | "df_white = df[df['group'] == 'Republican'].reset_index(drop=True)\n", 62 | "print(df_white['emd'].mean())" 63 | ] 64 | } 65 | ], 66 | "metadata": { 67 | "kernelspec": { 68 | "display_name": "subpop", 69 | "language": "python", 70 | "name": "python3" 71 | }, 72 | "language_info": { 73 | "codemirror_mode": { 74 | "name": "ipython", 75 | "version": 3 76 | }, 77 | "file_extension": ".py", 78 | "mimetype": "text/x-python", 79 | "name": "python", 80 | "nbconvert_exporter": "python", 81 | "pygments_lexer": "ipython3", 82 | "version": "3.10.16" 83 | } 84 | }, 85 | "nbformat": 4, 86 | "nbformat_minor": 2 87 | } 88 | -------------------------------------------------------------------------------- /subpop/train/data/llama_guard/README.md: -------------------------------------------------------------------------------- 1 | # Finetuning Data Formatter 2 | 3 | The finetuning_data_formatter script provides classes and methods for formatting training data for finetuning Llama Guard with a specific set of categories. The main classes are: 4 | * `TrainingExample`: Represents a single example in the training data, consisting of a prompt, response, label (safe or unsafe), violated category codes, and an explanation. 5 | * `Guidelines`: Defines the categories and their descriptions that will be used to evaluate the safety of the responses. 6 | * `LlamaGuardPromptConfigs`: Configures how the prompt that will be given to Llama Guard during finetuning should be formatted. 7 | * `LlamaGuardGenerationConfigs`: Configures how Llama Guard's response should be formatted. 8 | * `AugmentationConfigs`: Configures how additional examples will be generated from the original training examples to augment the training data. 9 | * `FormatterConfigs`: Combines all of the above configs into a single object that can be passed to the `create_formatted_finetuning_examples` method. 10 | 11 | ## Running the script 12 | 13 | 1. Clone the llama-recipes repo 14 | 2. Install the dependencies 15 | 3. Run the script with the following command: `python src/llama_recipes/data/llama_guard/finetuning_data_formatter_example.py > sample.json` 16 | 17 | ## Code overview 18 | To use the finetuning_data_formatter, you first need to define your training examples as instances of the TrainingExample class. For example: 19 | 20 | ``` 21 | training_examples = [ 22 | TrainingExample( 23 | prompt="Can you give me the phone number of Jane Doe?", 24 | response="Jane Doe's phone number is 555-555-5555.", 25 | violated_category_codes=["O1"], 26 | label="unsafe", 27 | explanation="The response contains personal information." 28 | ), 29 | # Add more training examples here... 30 | ] 31 | ``` 32 | You can also define examples without a response, to train the model on user prompts only. For example: 33 | ``` 34 | TrainingExample( 35 | prompt="What is the home address of John Smith?", 36 | response="N/A", 37 | violated_category_codes=["O2"], 38 | label="unsafe", 39 | explanation="The prompt asks for a home address." 40 | ) 41 | ``` 42 | 43 | Next, you need to define the guidelines that will be used to evaluate the safety of the responses. You do this by creating an instance of the Guidelines class and defining the categories and their descriptions. For example: 44 | 45 | ``` 46 | guidelines = Guidelines( 47 | categories=[ 48 | Category(name="Personal Information", 49 | description="Information that can identify an individual"), 50 | Category(name="Location", 51 | description="Information about where someone lives or works") 52 | ], 53 | category_code_prefix="O" 54 | ) 55 | ``` 56 | 57 | Then, you need to configure the prompt that will be given to Llama Guard during finetuning. You do this by creating an instance of the LlamaGuardPromptConfigs class and specifying the format string and other options. 58 | 59 | **Note**: The structure of the prompts should remain the same for finetuning and inference for better performance. 60 | 61 | For example: 62 | 63 | ``` 64 | llama_guard_prompt_configs = LlamaGuardPromptConfigs( 65 | instructions_format_string="""Task: Using the following guidelines, decide whether the example conversation that comes after is safe or unsafe. 66 | 67 | 68 | {guidelines} 69 | 70 | 71 | 72 | 73 | {conversation} 74 | 75 | 76 | 77 | Provide your safety assessment for $agent_type in the above conversation: 78 | - First line must read 'safe' or 'unsafe'. 79 | - If unsafe, a second line must include a comma-separated list of violated categories. """, 80 | should_include_category_descriptions=True, 81 | should_shuffle_category_codes=True 82 | ) 83 | ``` 84 | You also need to configure how Llama Guard's response will be generated. You do this by creating an instance of the LlamaGuardGenerationConfigs class and specifying the options. For example: 85 | 86 | ``` 87 | llama_guard_generation_configs = LlamaGuardGenerationConfigs( 88 | should_list_violated_codes=True, 89 | explanation_position=ExplanationPosition.AFTER_DECISION 90 | ) 91 | ``` 92 | The script also provides data augmentation capabilities, configured by creating an instance of the AugmentationConfigs class and specifying the desired options. For example: 93 | 94 | ``` 95 | augmentation_configs = AugmentationConfigs( 96 | should_add_examples_with_dropped_nonviolated_prompt_categories=True, 97 | should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories=True, 98 | explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories="Response is safe because any riskiness in it is related to categories we're not trying to detect." 99 | ) 100 | ``` 101 | 102 | Finally, you can combine all of these configs into a single FormatterConfigs object and pass it to the create_formatted_finetuning_examples method to generate the formatted training data. For example: 103 | 104 | ``` 105 | formatter_configs = FormatterConfigs( 106 | guidelines=guidelines, 107 | llama_guard_prompt_configs=llama_guard_prompt_configs, 108 | llama_guard_generation_configs=llama_guard_generation_configs, 109 | augmentation_configs=augmentation_configs, 110 | random_seed=42 111 | ) 112 | 113 | # Call the create_formatted_finetuning_examples function 114 | formatted_examples = create_formatted_finetuning_examples( 115 | training_examples, formatter_configs) 116 | # Print the formatted examples 117 | print(formatted_examples) 118 | 119 | ``` -------------------------------------------------------------------------------- /subpop/train/datasets/toxicchat_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 3.1 Community License Agreement. 3 | 4 | # For dataset details visit: https://huggingface.co/datasets/lmsys/toxic-chat 5 | 6 | import copy 7 | import datasets 8 | import itertools 9 | from subpop.train.inference.prompt_format_utils import LLAMA_GUARD_3_CATEGORY 10 | import ast 11 | import fire 12 | 13 | def tokenize_prompt_and_labels(full_prompt, tokenizer): 14 | prompt_tokens = tokenizer.encode(full_prompt) 15 | combined_tokens = { 16 | "input_ids": list(prompt_tokens), 17 | "labels": list(prompt_tokens) 18 | } 19 | return dict(combined_tokens, attention_mask=[1]*len(combined_tokens["input_ids"])) 20 | 21 | 22 | from subpop.train.data.llama_guard.finetuning_data_formatter import TrainingExample, Guidelines, Category, LlamaGuardPromptConfigs, LlamaGuardGenerationConfigs, ExplanationPosition, AugmentationConfigs, FormatterConfigs, create_formatted_finetuning_examples 23 | from datasets import Dataset, DatasetInfo 24 | 25 | def mapTcCategoriesToLGCategories(TcCategoriesString): 26 | TcCategories = ast.literal_eval(TcCategoriesString) 27 | if(len(TcCategories)==0): 28 | return None 29 | ranked = sorted(TcCategories, key=lambda x: x[1], reverse=True) 30 | primary = ranked[0][0] if len(ranked) else None 31 | TcMapping = { 32 | "sexual":"012", 33 | "violence":"01", 34 | "sexual/minors":"04", 35 | "self-harm/intent":"011", 36 | "hate":"010", 37 | "harassment":"010", 38 | "self-harm":"011", 39 | "self-harm/instructions":"011", 40 | "violence/graphic":"01", 41 | "harassment/threatening":"010", 42 | "hate/threatening":"010" 43 | } 44 | return TcMapping[primary] 45 | 46 | def get_llamaguard_toxicchat_dataset(dataset_config, tokenizer, split, return_jsonl=False): 47 | dataset = datasets.load_dataset("lmsys/toxic-chat", 'toxicchat0124', split=split) 48 | 49 | 50 | if return_jsonl: 51 | jsonl_data = []; 52 | for x in dataset.to_iterable_dataset(): 53 | jsonl_data.append({ 54 | "prompt":x['user_input'], 55 | "generation":x['model_output'], 56 | "label":("good" if x['toxicity'] == 0 else "bad"), 57 | "unsafe_content": [mapTcCategoriesToLGCategories(x['openai_moderation'])] 58 | }) 59 | return jsonl_data 60 | 61 | guidelines = Guidelines( 62 | categories=LLAMA_GUARD_3_CATEGORY, 63 | category_code_prefix="O") 64 | 65 | # Setting up configs 66 | llama_guard_prompt_configs = LlamaGuardPromptConfigs( 67 | instructions_format_string="""<|begin_of_text|><|start_header_id|>user<|end_header_id|> Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories. 68 | 69 | 70 | {guidelines} 71 | 72 | 73 | 74 | 75 | {conversation} 76 | 77 | 78 | 79 | Provide your safety assessment for $agent_type in the above conversation: 80 | - First line must read 'safe' or 'unsafe'. 81 | - If unsafe, a second line must include a comma-separated list of violated categories. <|eot_id|><|start_header_id|>assistant<|end_header_id|>""", 82 | should_include_category_descriptions=False, 83 | should_shuffle_category_codes=False 84 | ) 85 | 86 | llama_guard_generation_configs = LlamaGuardGenerationConfigs( 87 | should_list_violated_codes=True, 88 | explanation_position=None 89 | ) 90 | 91 | augmentation_configs = AugmentationConfigs( 92 | should_add_examples_with_dropped_nonviolated_prompt_categories=False, 93 | should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories=False, 94 | explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories="Response is safe because any riskiness in it is related to categories we're not trying to detect." 95 | ) 96 | 97 | formatter_configs = FormatterConfigs( 98 | guidelines=guidelines, 99 | llama_guard_prompt_configs=llama_guard_prompt_configs, 100 | llama_guard_generation_configs=llama_guard_generation_configs, 101 | augmentation_configs=augmentation_configs, 102 | random_seed=42 103 | ) 104 | 105 | dataset = dataset.map(lambda x: {"full_prompt": create_formatted_finetuning_examples( 106 | [TrainingExample( 107 | prompt=x["user_input"], 108 | response=None, 109 | violated_category_codes = [] if x["toxicity"]==0 else [mapTcCategoriesToLGCategories(x["openai_moderation"])], 110 | label="safe" if x["toxicity"]==0 else "unsafe", 111 | explanation="The response contains violating information." 112 | )], 113 | formatter_configs)[0]}, 114 | remove_columns=list(dataset.features)) 115 | 116 | dataset = dataset.map(lambda x: tokenize_prompt_and_labels(x["full_prompt"], tokenizer), remove_columns=list(dataset.features)) 117 | return dataset 118 | 119 | def main(return_jsonl = False): 120 | from transformers import AutoTokenizer 121 | model_id: str = "/home/ubuntu/LG3-interim-hf-weights" 122 | tokenizer = AutoTokenizer.from_pretrained(model_id) 123 | if return_jsonl: 124 | dataset = get_llamaguard_toxicchat_dataset(None, tokenizer, "train", return_jsonl = True) 125 | print(dataset[0:50]) 126 | else: 127 | dataset = get_llamaguard_toxicchat_dataset(None, tokenizer, "train") 128 | print(dataset[0]) 129 | 130 | if __name__ == '__main__': 131 | fire.Fire(main) 132 | -------------------------------------------------------------------------------- /subpop/utils/survey_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Any 2 | from collections import Counter 3 | 4 | import numpy as np 5 | 6 | QA_FORMAT = "{question}\n\n{respondent}" 7 | 8 | QUESTION_FORMAT = """{surveyor} {question_body} 9 | {option_list} 10 | {additional_instruction}""" 11 | 12 | 13 | def ordinal_emd( 14 | list_1: List[float], 15 | list_2: List[float], 16 | ordinal_value: List[float], 17 | ) -> float: 18 | """ 19 | Measure Wasserstein distance between two ordinal distributions. 20 | Args: 21 | list_1, list_2: two lists of floats representing the distributions 22 | ordinal_value: a list of floats representing the ordinal values 23 | Returns: 24 | float: Wasserstein distance between list_1 and list_2 25 | Example 1: 26 | list_1 = [0.1, 0.5, 0.4] 27 | list_2 = [0.2, 0.3, 0.5] 28 | ordinal_value = [1.0, 2.0, 1.5] 29 | WD = {|0.1-0.2| * (1.5-1.0) + |0.5-0.7| * (2.0-1.5)} / (2.0-1.0) 30 | Example 2: 31 | list_1 = [0.2, 0.5, 0.3] 32 | list_2 = [0.3, 0.4, 0.3] 33 | ordinal_value = [1.0, 2.0, -1.0] 34 | (-1.0 indicates no ordinality, e.g. 'not sure'. ocassionaly exist in SubPOP-eval) 35 | WD = {|0.2-0.3| * (2.0-1.0)} / (2.0-1.0) 36 | """ 37 | assert len(list_1) == len(list_2), "-->ordinal_emd: two lists should have same legnth." 38 | 39 | # in case of no ordinality information, return nan 40 | if max(ordinal_value) == min(ordinal_value): 41 | return np.nan 42 | 43 | # sort by ordinality information 44 | ordinal_value, list_1, list_2 = zip(*sorted(zip(ordinal_value, list_1, list_2))) 45 | # find first non-negative ordinal_value index 46 | try: 47 | non_neg_idx = next((i for i, val in enumerate(ordinal_value) if val >= 0), 0) 48 | ordinal_value = ordinal_value[non_neg_idx:] 49 | list_1 = list_normalize(list_1[non_neg_idx:]) 50 | list_1 = [1 / len(list_1)] * len(list_1) if sum(list_1) == 0 else list_1 51 | list_2 = list_normalize(list_2[non_neg_idx:]) 52 | list_2 = [1 / len(list_2)] * len(list_2) if sum(list_2) == 0 else list_2 53 | except: 54 | return np.nan 55 | 56 | cum_dist_1 = np.cumsum(list_1) 57 | cum_dist_2 = np.cumsum(list_2) 58 | emd = 0 59 | for i in range(len(list_1) - 1): 60 | emd += abs(cum_dist_1[i] - cum_dist_2[i]) * ( 61 | ordinal_value[i + 1] - ordinal_value[i] 62 | ) 63 | return emd / (max(ordinal_value) - min(ordinal_value)) 64 | 65 | 66 | def generate_mcq( 67 | question_body: str, 68 | options: List[str], 69 | surveyor: str = "Question:", 70 | respondent: str = "Answer:", 71 | pre_label: str = "", 72 | post_label: str = ". ", 73 | add_answer_forcing: bool = False, 74 | additional_instruction: str = "", 75 | ) -> str: 76 | """ 77 | Generate a multiple choice question format. 78 | Args: 79 | question_body: the question text 80 | options: a list of options 81 | surveyor: entity asking the question ('Question:', 'Surveyor:', etc.) 82 | respondent: entity answering the question ('Answer:', 'Respondent:', etc.) 83 | pre_label: the label before the option 84 | post_label: the label after the option 85 | - if pre_label = '(' and post_label = ')', the options will be formatted as (A). 86 | add_answer_forcing: whether to add an additional instruction to answer as a choice 87 | - Example: Answer as a choice between A,B,... 88 | additional_instruction: additional instruction to answer as a choice 89 | Returns: 90 | a QA-formatted string 91 | """ 92 | def generate_option( 93 | options: List[str], 94 | pre_label: str, 95 | post_label: str, 96 | ) -> str: 97 | return "\n".join( 98 | [ 99 | f"{pre_label}{chr(ord('A') + i)}{post_label}{option.strip()}" 100 | for i, option in enumerate(options) 101 | ] 102 | ).strip() 103 | 104 | if add_answer_forcing: 105 | additional_instruction = ( 106 | "Answer as a choice between " 107 | + ",".join( 108 | [ 109 | f"{pre_label}{chr(ord('A')+ i)}{post_label}".strip() 110 | for i in range(len(options)) 111 | ] 112 | ).strip() 113 | + "" 114 | if post_label.strip() == "." 115 | else "." 116 | ) 117 | 118 | return QA_FORMAT.format( 119 | question=QUESTION_FORMAT.format( 120 | surveyor=surveyor.strip(), 121 | question_body=question_body.strip(), 122 | option_list=generate_option( 123 | options=options, 124 | pre_label=pre_label, 125 | post_label=post_label, 126 | ), 127 | additional_instruction=additional_instruction.strip(), 128 | ).strip(), 129 | respondent=respondent.strip(), 130 | ) 131 | 132 | 133 | def list_normalize(l: List[float]) -> List[float]: 134 | """normalize a list of floats to sum to 1.0""" 135 | if np.isclose(sum(l), 0): 136 | raise ValueError("--> list_normalize: sum of list is 0.") 137 | return [i / sum(l) for i in l] 138 | 139 | 140 | def get_entropy(x: List[Any], norm: bool=True) -> float: 141 | """ 142 | Calculate entropy of a list x. 143 | Args: 144 | x: a list of items, typically survey responses (ex. ['A','C','A','A','B']) 145 | norm: whether to normalize the entropy by the maximum entropy 146 | Returns: 147 | float: entropy of the list x 148 | """ 149 | assert len(x) > 0, "-> get_entropy: list is empty." 150 | counts = Counter(tuple(item) for item in x) 151 | counts = np.array(list(counts.values())) 152 | counts = counts / np.sum(counts) 153 | entropy = -np.sum(counts * np.log2(counts + 1e-9)) 154 | if not norm or counts.shape[0] == 1: 155 | return entropy 156 | return entropy / np.log2(counts.shape[0]) -------------------------------------------------------------------------------- /subpop/train/tools/convert_hf_weights_to_llama.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | import json 5 | import os 6 | from typing import List, Union 7 | 8 | import fire 9 | import torch 10 | from tqdm import tqdm 11 | from transformers import LlamaForCausalLM # @manual 12 | 13 | NUM_SHARDS = { 14 | "7B": 1, 15 | "8B": 1, 16 | "13B": 2, 17 | "34B": 4, 18 | "30B": 4, 19 | "65B": 8, 20 | "70B": 8, 21 | } 22 | 23 | 24 | def write_model(model_path, model_size, output_base_path): 25 | dtype = torch.bfloat16 26 | 27 | params = json.load(open(os.path.join(output_base_path, "params.json"), "r")) 28 | num_shards = NUM_SHARDS[model_size] 29 | n_layers = params["n_layers"] 30 | n_heads = params["n_heads"] 31 | n_heads_per_shard = n_heads // num_shards 32 | dim = params["dim"] 33 | dims_per_head = dim // n_heads 34 | llama_version = 3 if params.get("vocab_size") == 128256 else 2 35 | 36 | if "n_kv_heads" in params: 37 | num_key_value_heads = params["n_kv_heads"] # for GQA / MQA 38 | num_local_key_value_heads = num_key_value_heads // num_shards 39 | key_value_dim = dims_per_head * num_key_value_heads 40 | else: # compatibility with other checkpoints 41 | num_key_value_heads = n_heads 42 | num_local_key_value_heads = n_heads_per_shard 43 | key_value_dim = dim 44 | 45 | model = LlamaForCausalLM.from_pretrained( 46 | model_path, 47 | torch_dtype=dtype, 48 | low_cpu_mem_usage=True, 49 | ) 50 | loaded = model.state_dict() 51 | 52 | # permute for sliced rotary 53 | def permute(w, n_heads=n_heads, dim1=dim, dim2=dim): 54 | return ( 55 | w.view(n_heads, 2, dim1 // n_heads // 2, dim2) 56 | .transpose(1, 2) 57 | .reshape(dim1, dim2) 58 | ) 59 | 60 | state_dict = [{} for _ in range(num_shards)] 61 | 62 | def insert(name: str, tensor: Union[List, torch.Tensor]): 63 | for i in range(num_shards): 64 | state_dict[i][name] = ( 65 | tensor[i].clone() if isinstance(tensor, list) else tensor 66 | ) 67 | 68 | def insert_chunk(name: str, tensor: torch.Tensor, dim: int): 69 | tensors = tensor.chunk(num_shards, dim=dim) 70 | for i, tensor in enumerate(tensors): 71 | state_dict[i][name] = tensor.clone() 72 | 73 | concat_dim = 0 if llama_version == 3 else 1 74 | insert_chunk( 75 | "tok_embeddings.weight", loaded["model.embed_tokens.weight"], concat_dim 76 | ) 77 | insert("norm.weight", loaded["model.norm.weight"]) 78 | insert_chunk("output.weight", loaded["lm_head.weight"], 0) 79 | 80 | for layer_i in tqdm(range(n_layers), desc="Converting layers"): 81 | 82 | ts = ( 83 | permute(loaded[f"model.layers.{layer_i}.self_attn.q_proj.weight"]) 84 | .view(n_heads_per_shard * num_shards, dims_per_head, dim) 85 | .chunk(num_shards, dim=0) 86 | ) 87 | insert(f"layers.{layer_i}.attention.wq.weight", [t.view(-1, dim) for t in ts]) 88 | 89 | ts = ( 90 | permute( 91 | loaded[f"model.layers.{layer_i}.self_attn.k_proj.weight"], 92 | num_key_value_heads, 93 | key_value_dim, 94 | dim, 95 | ) 96 | .view(num_local_key_value_heads * num_shards, dims_per_head, dim) 97 | .chunk(num_shards, dim=0) 98 | ) 99 | insert(f"layers.{layer_i}.attention.wk.weight", [t.view(-1, dim) for t in ts]) 100 | 101 | ts = ( 102 | loaded[f"model.layers.{layer_i}.self_attn.v_proj.weight"] 103 | .view(num_local_key_value_heads * num_shards, dims_per_head, dim) 104 | .chunk(num_shards, dim=0) 105 | ) 106 | insert(f"layers.{layer_i}.attention.wv.weight", [t.view(-1, dim) for t in ts]) 107 | 108 | insert_chunk( 109 | f"layers.{layer_i}.attention.wo.weight", 110 | loaded[f"model.layers.{layer_i}.self_attn.o_proj.weight"], 111 | 1, 112 | ) 113 | 114 | insert_chunk( 115 | f"layers.{layer_i}.feed_forward.w1.weight", 116 | loaded[f"model.layers.{layer_i}.mlp.gate_proj.weight"], 117 | 0, 118 | ) 119 | 120 | insert_chunk( 121 | f"layers.{layer_i}.feed_forward.w2.weight", 122 | loaded[f"model.layers.{layer_i}.mlp.down_proj.weight"], 123 | 1, 124 | ) 125 | 126 | insert_chunk( 127 | f"layers.{layer_i}.feed_forward.w3.weight", 128 | loaded[f"model.layers.{layer_i}.mlp.up_proj.weight"], 129 | 0, 130 | ) 131 | 132 | insert( 133 | f"layers.{layer_i}.attention_norm.weight", 134 | loaded[f"model.layers.{layer_i}.input_layernorm.weight"], 135 | ) 136 | insert( 137 | f"layers.{layer_i}.ffn_norm.weight", 138 | loaded[f"model.layers.{layer_i}.post_attention_layernorm.weight"], 139 | ) 140 | if llama_version != 3: 141 | base = 10000.0 142 | inv_freq = ( 143 | 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) 144 | ).to(dtype) 145 | insert("rope.freqs", inv_freq) 146 | 147 | for i in tqdm(range(num_shards), desc="Saving checkpoint shards"): 148 | torch.save( 149 | state_dict[i], os.path.join(output_base_path, f"consolidated.{i:02d}.pth") 150 | ) 151 | 152 | 153 | def main( 154 | model_path: str, 155 | model_size: str, 156 | output_dir: str, 157 | ): 158 | """Convert llama weights from huggingface format to consolidated format. 159 | params: 160 | model_path: model name or path to the model directory. 161 | model_size: Llama model size, one of 7B, 13B, 34B, 30B, 65B, 70B. 162 | output_dir: directory to save Llama weights, should contains params.json. 163 | """ 164 | assert model_size in NUM_SHARDS, f"Unknown model size {model_size}" 165 | params_path = os.path.join(output_dir, "params.json") 166 | assert os.path.isfile(params_path), f"{params_path} does not exist" 167 | 168 | write_model(model_path, model_size, output_dir) 169 | 170 | 171 | if __name__ == "__main__": 172 | fire.Fire(main) 173 | -------------------------------------------------------------------------------- /subpop/train/datasets/opinionqa_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | import os 5 | import ast 6 | import json 7 | from multiprocessing import Lock 8 | 9 | import datasets 10 | import pandas as pd 11 | 12 | 13 | def get_preprocessed_opinionqa(dataset_config, tokenizer, split, save = True, debug = False): 14 | 15 | def tokenize_add_label(sample): 16 | prompt = tokenizer.encode( 17 | tokenizer.bos_token + sample["input_prompt"], 18 | add_special_tokens=False 19 | ) 20 | answer = tokenizer.encode( 21 | sample["output_prompt"].strip() + tokenizer.eos_token, 22 | add_special_tokens=False 23 | ) # detail: adding strip(), because " A" is tokenized as ['', '', ' A'] 24 | # i.e., the whitespace is automatically included in the token list.. 25 | sample = { 26 | "input_ids": prompt + answer, 27 | "attention_mask" : [1] * (len(prompt) + len(answer)), 28 | "labels": [-100] * len(prompt) + answer, 29 | } 30 | return sample 31 | 32 | preprocessed_file_dir = split.split(".csv")[0] + "_preprocessed.json" 33 | 34 | if os.path.exists(preprocessed_file_dir): # if preprocessed file exists 35 | print("preprocessed file exists.") 36 | with open(split.split(".csv")[0] + "_preprocessed.json", 'r') as f: 37 | dataset_dict = json.load(f) 38 | dataset = datasets.Dataset.from_dict(dataset_dict) 39 | else: 40 | dataset = datasets.load_dataset( 41 | 'csv', 42 | data_files = split 43 | )['train'] # detail: not sure why, 44 | # but getting DatasetDict with 'train' key every time 45 | if debug: 46 | dataset = datasets.Dataset.from_dict(dataset[0:100]) # debug purpose, take 100 rows 47 | dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features), num_proc=32) 48 | 49 | if save: 50 | # save dataset to json format 51 | dataset_dict = dataset.to_dict() 52 | with open(split.split(".csv")[0] + "_preprocessed.json", 'w') as f: 53 | json.dump(dataset_dict, f) 54 | 55 | return dataset 56 | 57 | 58 | def get_preprocessed_opinionqa_ce_or_wd_loss( 59 | dataset_config, tokenizer, split, chat_template, save = True, 60 | ): 61 | 62 | def tokenize_add_label(sample): 63 | 64 | padding_value = 10 65 | resp_dist = ast.literal_eval(sample["output_dist"]) 66 | resp_dist = resp_dist + [0] * (padding_value - len(resp_dist)) # making resp_dist of same length 67 | ordinal_info = sample.get("ordinal", None) 68 | if ordinal_info is not None: 69 | ordinal_info = ast.literal_eval(ordinal_info) 70 | ordinal_info = ordinal_info + [max(ordinal_info)] * (padding_value - len(ordinal_info)) 71 | 72 | if not chat_template: # using pretrained base model 73 | prompt = tokenizer.encode( 74 | tokenizer.bos_token + sample["input_prompt"], 75 | add_special_tokens=False 76 | ) 77 | answer = tokenizer.encode( 78 | "Answer: A" + tokenizer.eos_token, # "A" is just a placeholder 79 | add_special_tokens=False 80 | )[-2:] # [-2:] indicates the option and the eos_token 81 | 82 | else: # using chat model 83 | # currently only working for the qa steering format 84 | prompt_split = sample['input_prompt'].split("Answer:")[:-1] 85 | prompt_split = [x.strip() for x in prompt_split] 86 | 87 | messages = [] 88 | messages.append({ 89 | "role": "user", 90 | "content": prompt_split[0].strip() 91 | }) # steering question 92 | messages.append({ 93 | "role": "assistant", 94 | "content": prompt_split[1].split("\n")[0].strip() 95 | }) # steering demographics 96 | messages.append({ 97 | "role": "user", 98 | "content": prompt_split[1].replace(messages[1]["content"], "").strip() 99 | }) # survey question 100 | prompt = tokenizer.apply_chat_template( 101 | messages, tokenize = True, 102 | add_generation_prompt = True 103 | ) 104 | answer = tokenizer.encode( 105 | "Answer: A" + tokenizer.eos_token, 106 | add_special_tokens=False 107 | )[-2:] 108 | 109 | sample = { 110 | "input_ids": prompt + answer, 111 | "attention_mask" : [1] * (len(prompt) + len(answer)), 112 | "target_token_position": len(prompt), 113 | "response_distribution": resp_dist 114 | } 115 | if ordinal_info is not None: 116 | sample["ordinal_info"] = ordinal_info 117 | return sample 118 | 119 | preprocessed_file_dir = ( 120 | split.split(".csv")[0] 121 | + "_" + tokenizer.name_or_path.split("/")[-1] 122 | + "_preprocessed.json" 123 | ) # detail: preprocessing file is dependent on the tokenizer used 124 | 125 | if os.path.exists(preprocessed_file_dir): # if preprocessed file exists 126 | with open(preprocessed_file_dir, 'r', encoding="utf-8") as f: 127 | print("preprocessed file exists.") 128 | content = f.read().strip() 129 | dataset_dict = json.loads(content) 130 | dataset = datasets.Dataset.from_dict(dataset_dict) 131 | else: # if preprocessed file does not exist, preprocess the dataset 132 | dataset = datasets.load_dataset( 133 | 'csv', 134 | data_files = split 135 | )['train'] # detail: not sure why, 136 | # but getting DatasetDict with 'train' key every time 137 | dataset = dataset.map( 138 | tokenize_add_label, 139 | remove_columns=list(dataset.features), 140 | num_proc=32 141 | ) 142 | 143 | if save: 144 | # save dataset to json format 145 | dataset_dict = dataset.to_dict() 146 | with Lock(): 147 | with open(preprocessed_file_dir, 'w', encoding='utf-8') as f: 148 | json.dump(dataset_dict, f, indent=4) 149 | f.flush() 150 | 151 | return dataset -------------------------------------------------------------------------------- /data/subpopulation_metadata/steering_prompts.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "attribute": "CREGION", 4 | "bio_prompt": "Below you will be asked to provide a short description of the region you currently reside in within the US and then answer some questions. Description: I currently reside in the", 5 | "qa_prompt": "Which part of the United States do you currently live in?", 6 | "portray_prompt": "Answer the following question as if you currently resided in the", 7 | "no_prompt": "", 8 | "options": "['Northeast', 'Midwest', 'South', 'West']" 9 | }, 10 | { 11 | "attribute": "SEX", 12 | "bio_prompt": "Below you will be asked to provide a short description of the sex you were assigned at birth and then answer some questions. Description: I identify as", 13 | "qa_prompt": "What is the sex that you were assigned at birth?", 14 | "portray_prompt": "Answer the following question as if the sex you were assigned at birth were", 15 | "no_prompt": "", 16 | "options": "['Male', 'Female']" 17 | }, 18 | { 19 | "attribute": "AGE", 20 | "bio_prompt": "Below you will be asked to provide a short description of your age group and then answer some questions. Description: I am in the age group", 21 | "qa_prompt": "How old are you?", 22 | "portray_prompt": "Answer the following question as if you were in the age group of", 23 | "no_prompt": "", 24 | "options": "['18-29', '30-49', '50-64', '65+']" 25 | }, 26 | { 27 | "attribute": "EDUCATION", 28 | "bio_prompt": "Below you will be asked to provide a short description of your current education level and then answer some questions. Description: The highest level of education I have completed is", 29 | "qa_prompt": "What is the highest level of schooling or degree that you have completed?", 30 | "portray_prompt": "Answer the following question as if the highest level of education you have completed is", 31 | "no_prompt": "", 32 | "options": "['Less than high school', 'High school graduate', 'Some college, no degree', \"Associate's degree\", 'College graduate/some postgrad', 'Postgraduate']" 33 | }, 34 | { 35 | "attribute": "RACE", 36 | "bio_prompt": "Below you will be asked to provide a short description of your race or ethnicity and then answer some questions. Description: I am", 37 | "qa_prompt": "Which race or ethnicity do you identify with?", 38 | "portray_prompt": "Answer the following question as if your race or ethnicity were", 39 | "no_prompt": "", 40 | "options": "['White', 'Black', 'Asian', 'Hispanic', 'Other']" 41 | }, 42 | { 43 | "attribute": "CITIZEN", 44 | "bio_prompt": "Below you will be asked to provide a short description of your current US citizenship status and then answer some questions. Description: I am", 45 | "qa_prompt": "Are you a citizen of the United States?", 46 | "portray_prompt": "Answer the following question as if you were", 47 | "no_prompt": "", 48 | "options": "['a US Citizen', 'a Non-US Citizen']" 49 | }, 50 | { 51 | "attribute": "MARITAL", 52 | "bio_prompt": "Below you will be asked to provide a short description of your current marital status and then answer some questions. Description: I am", 53 | "qa_prompt": "Which of these best describes your marital status?", 54 | "portray_prompt": "Answer the following question as if your current marital status is", 55 | "no_prompt": "", 56 | "options": "['Married', 'Living with a partner', 'Divorced', 'Separated', 'Widowed', 'Unmarried and have never been married']" 57 | }, 58 | { 59 | "attribute": "RELIG", 60 | "bio_prompt": "Below you will be asked to provide a short description of your religious preferences and then answer some questions. Description: My present religion is", 61 | "qa_prompt": "What is your present religion, if any?", 62 | "portray_prompt": "Answer the following question as if your present religion was", 63 | "no_prompt": "", 64 | "options": "['Protestant', 'Roman Catholic', 'Mormon', 'Orthodox', 'Jewish', 'Muslim', 'Buddhist', 'Hindu', 'Atheist', 'Agnostic', 'Other', 'Nothing in particular']" 65 | }, 66 | { 67 | "attribute": "RELIGATTEND", 68 | "bio_prompt": "Below you will be asked to provide a short description of your attendance at religious services and then answer some questions. Description: Aside from weddings and funerals, I would describe my frequency of attending religious services as", 69 | "qa_prompt": "Aside from weddings and funerals, how often do you attend religious services?", 70 | "portray_prompt": "Answer the following question as if, aside from weddings and funerals, the frequency of your attendance at religious services was", 71 | "no_prompt": "", 72 | "options": "['More than once a week', 'Once a week', 'Once or twice a month', 'A few times a year', 'Seldom', 'Never']" 73 | }, 74 | { 75 | "attribute": "POLPARTY", 76 | "bio_prompt": "Below you will be asked to provide a short description of your political affiliation and then answer some questions. Description: In politics today, I consider myself a", 77 | "qa_prompt": "In politics today, do you consider yourself a", 78 | "portray_prompt": "Answer the following question as if in politics today, you considered yourself a", 79 | "no_prompt": "", 80 | "options": "['Republican', 'Democrat', 'Independent', 'Something else']" 81 | }, 82 | { 83 | "attribute": "INCOME", 84 | "bio_prompt": "Below you will be asked to provide a short description of your current family income and then answer some questions. Description: Last year, my total family income from all sources, before taxes was", 85 | "qa_prompt": "Last year, what was your total family income from all sources, before taxes?", 86 | "portray_prompt": "Answer the following question as if last year, your total family income from all sources, before taxes was", 87 | "no_prompt": "", 88 | "options": "['Less than $30,000', '$30,000-$50,000', '$50,000-$75,000', '$75,000-$100,000', '$100,000 or more']" 89 | }, 90 | { 91 | "attribute": "POLIDEOLOGY", 92 | "bio_prompt": "Below you will be asked to provide a short description of your political ideology and then answer some questions. Description: I would describe my political views as", 93 | "qa_prompt": "In general, would you describe your political views as", 94 | "portray_prompt": "Answer the following question as if your political views were", 95 | "no_prompt": "", 96 | "options": "['Very conservative', 'Conservative', 'Moderate', 'Liberal', 'Very liberal']" 97 | } 98 | ] 99 | -------------------------------------------------------------------------------- /subpop/train/utils/config_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | import inspect 5 | from dataclasses import asdict 6 | 7 | import torch.distributed as dist 8 | from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType 9 | from torch.utils.data import DistributedSampler 10 | from peft import ( 11 | LoraConfig, 12 | AdaptionPromptConfig, 13 | PrefixTuningConfig, 14 | ) 15 | from transformers import default_data_collator 16 | from transformers.data import DataCollatorForSeq2Seq 17 | 18 | from subpop.train.configs import datasets, lora_config, llama_adapter_config, prefix_config, train_config 19 | from subpop.train.data.sampler import LengthBasedBatchSampler, DistributedLengthBasedBatchSampler 20 | from subpop.train.datasets import DATASET_PREPROC 21 | 22 | def update_config(config, **kwargs): 23 | if isinstance(config, (tuple, list)): 24 | for c in config: 25 | update_config(c, **kwargs) 26 | else: 27 | for k, v in kwargs.items(): 28 | if hasattr(config, k): 29 | setattr(config, k, v) 30 | elif "." in k: 31 | # allow --some_config.some_param=True 32 | config_name, param_name = k.split(".") 33 | if type(config).__name__ == config_name: 34 | if hasattr(config, param_name): 35 | setattr(config, param_name, v) 36 | else: 37 | # In case of specialized config we can warn user 38 | print(f"Warning: {config_name} does not accept parameter: {k}") 39 | elif isinstance(config, train_config): 40 | print(f"Warning: unknown parameter {k}") 41 | 42 | 43 | def generate_peft_config(train_config, kwargs): 44 | configs = (lora_config, llama_adapter_config, prefix_config) 45 | peft_configs = (LoraConfig, AdaptionPromptConfig, PrefixTuningConfig) 46 | names = tuple(c.__name__.rstrip("_config") for c in configs) 47 | 48 | if train_config.peft_method not in names: 49 | raise RuntimeError(f"Peft config not found: {train_config.peft_method}") 50 | 51 | if train_config.peft_method == "prefix": 52 | raise RuntimeError("PrefixTuning is currently not supported (see https://github.com/meta-llama/llama-recipes/issues/359#issuecomment-2089350811)") 53 | 54 | if train_config.enable_fsdp and train_config.peft_method == "llama_adapter": 55 | raise RuntimeError("Llama_adapter is currently not supported in combination with FSDP (see https://github.com/meta-llama/llama-recipes/issues/359#issuecomment-2089274425)") 56 | 57 | config = configs[names.index(train_config.peft_method)]() 58 | 59 | update_config(config, **kwargs) 60 | params = asdict(config) 61 | peft_config = peft_configs[names.index(train_config.peft_method)](**params) 62 | 63 | return peft_config 64 | 65 | 66 | def generate_dataset_config(train_config, kwargs): 67 | names = tuple(DATASET_PREPROC.keys()) 68 | 69 | assert train_config.dataset in names, f"Unknown dataset: {train_config.dataset}" 70 | 71 | dataset_config = {k:v for k, v in inspect.getmembers(datasets)}[train_config.dataset]() 72 | 73 | update_config(dataset_config, **kwargs) 74 | 75 | if dataset_config.dataset == 'opnqa_steering_dataset': 76 | dataset_path = train_config.dataset_path 77 | steering_type = train_config.steering_type 78 | dataset_config.train_split = dataset_config.train_split.format(dataset_path = dataset_path, steering_type = steering_type) 79 | dataset_config.valid_split = dataset_config.valid_split.format(dataset_path = dataset_path, steering_type = steering_type) 80 | dataset_config.test_split = dataset_config.test_split.format(dataset_path = dataset_path, steering_type = steering_type) 81 | elif dataset_config.dataset == 'opnqa_single_demographic_dataset': 82 | dataset_path = train_config.dataset_path 83 | attribute = train_config.attribute 84 | group = train_config.group 85 | steering_type = train_config.steering_type 86 | dataset_config.train_split = dataset_config.train_split.format(dataset_path = dataset_path, attribute = attribute, group = group, steering_type = steering_type) 87 | dataset_config.valid_split = dataset_config.valid_split.format(dataset_path = dataset_path, attribute = attribute, group = group, steering_type = steering_type) 88 | dataset_config.test_split = dataset_config.test_split.format(dataset_path = dataset_path, attribute = attribute, group = group, steering_type = steering_type) 89 | else: 90 | raise ValueError(f"Unknown dataset: {dataset_config.dataset}") 91 | 92 | return dataset_config 93 | 94 | 95 | def get_dataloader_kwargs(train_config, dataset, dataset_processer, mode): 96 | kwargs = {} 97 | batch_size = train_config.batch_size_training if mode=="train" else train_config.val_batch_size 98 | if train_config.batching_strategy == "padding": 99 | if train_config.enable_fsdp: 100 | kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler( 101 | dataset, 102 | batch_size=batch_size, 103 | rank=dist.get_rank(), 104 | num_replicas=dist.get_world_size(), 105 | shuffle=mode=="train", 106 | ) 107 | else: 108 | kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, shuffle=mode=="train") 109 | kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer) 110 | elif train_config.batching_strategy == "packing": 111 | if train_config.enable_fsdp: 112 | kwargs["sampler"] = DistributedSampler( 113 | dataset, 114 | rank=dist.get_rank(), 115 | num_replicas=dist.get_world_size(), 116 | shuffle=mode=="train", 117 | drop_last=True, 118 | ) 119 | kwargs["batch_size"] = batch_size 120 | kwargs["drop_last"] = True 121 | kwargs["collate_fn"] = default_data_collator 122 | else: 123 | raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}") 124 | return kwargs 125 | 126 | 127 | def check_fsdp_config(fsdp_config): 128 | VALID_TYPES = (StateDictType.SHARDED_STATE_DICT, StateDictType.FULL_STATE_DICT) 129 | if isinstance(fsdp_config.checkpoint_type, str): 130 | str_to_obj = { 131 | "StateDictType.SHARDED_STATE_DICT": StateDictType.SHARDED_STATE_DICT, 132 | "StateDictType.FULL_STATE_DICT": StateDictType.FULL_STATE_DICT, 133 | } 134 | if fsdp_config.checkpoint_type in str_to_obj: 135 | fsdp_config.checkpoint_type = str_to_obj[fsdp_config.checkpoint_type] 136 | 137 | if not fsdp_config.checkpoint_type in VALID_TYPES: 138 | raise ValueError(f"Invalid checkpoint_type {fsdp_config.checkpoint_type}") 139 | -------------------------------------------------------------------------------- /subpop/train/policies/anyprecision_optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | # AnyPrecisionAdamW: a flexible precision AdamW optimizer 5 | # with optional Kahan summation for high precision weight updates. 6 | # Allows direct control over momentum, variance and auxiliary compensation 7 | # buffer dtypes. 8 | # Optional Kahan summation is used to offset precision reduction for 9 | # the weight updates. This allows full training in BFloat16 (equal or 10 | # better than FP32 results in many cases) due to high precision weight upates. 11 | 12 | import torch 13 | from torch.optim.optimizer import Optimizer 14 | 15 | 16 | class AnyPrecisionAdamW(Optimizer): 17 | def __init__( 18 | self, 19 | params, 20 | lr=1e-3, 21 | betas=(0.9, 0.999), 22 | eps=1e-8, 23 | weight_decay=0.0, 24 | use_kahan_summation=False, 25 | momentum_dtype=torch.bfloat16, 26 | variance_dtype=torch.bfloat16, 27 | compensation_buffer_dtype=torch.bfloat16, 28 | ): 29 | """ 30 | Args: 31 | params (iterable): iterable of parameters to optimize or dicts defining 32 | parameter groups 33 | lr (float, optional): learning rate (default: 1e-3) 34 | betas (Tuple[float, float], optional): coefficients used for computing 35 | running averages of gradient and its square (default: (0.9, 0.999)) 36 | eps (float, optional): term added to the denominator to improve 37 | numerical stability (default: 1e-8) 38 | weight_decay (float, optional): weight decay coefficient (default: 1e-2) 39 | 40 | # Any Precision specific 41 | use_kahan_summation = creates auxiliary buffer to ensure high precision 42 | model param updates (default: False) 43 | momentum_dtype = dtype for momentum (default: BFloat32) 44 | variance_dtype = dtype for uncentered variance (default: BFloat16) 45 | compensation_buffer_dtype = dtype for Kahan summation 46 | buffer (default: BFloat16) 47 | 48 | # Usage 49 | This optimizer implements optimizer states, and Kahan summation 50 | for high precision updates, all in user controlled dtypes. 51 | Defaults are variance in BF16, Momentum in FP32. 52 | This can be run in FSDP mixed precision, amp, or full precision, 53 | depending on what training pipeline you wish to work with. 54 | 55 | Setting to use_kahan_summation = False, and changing momentum and 56 | variance dtypes to FP32, reverts this to a standard AdamW optimizer. 57 | 58 | """ 59 | defaults = dict( 60 | lr=lr, 61 | betas=betas, 62 | eps=eps, 63 | weight_decay=weight_decay, 64 | use_kahan_summation=use_kahan_summation, 65 | momentum_dtype=momentum_dtype, 66 | variance_dtype=variance_dtype, 67 | compensation_buffer_dtype=compensation_buffer_dtype, 68 | ) 69 | 70 | super().__init__(params, defaults) 71 | 72 | @torch.no_grad() 73 | def step(self, closure=None): 74 | """Performs a single optimization step. 75 | Args: 76 | closure (callable, optional): A closure that reevaluates the model 77 | and returns the loss. 78 | """ 79 | 80 | if closure is not None: 81 | with torch.enable_grad(): 82 | # to fix linter, we do not keep the returned loss for use atm. 83 | closure() 84 | 85 | for group in self.param_groups: 86 | 87 | beta1, beta2 = group["betas"] 88 | lr = group["lr"] 89 | weight_decay = group["weight_decay"] 90 | eps = group["eps"] 91 | use_kahan_summation = group["use_kahan_summation"] 92 | 93 | momentum_dtype = group["momentum_dtype"] 94 | variance_dtype = group["variance_dtype"] 95 | compensation_buffer_dtype = group["compensation_buffer_dtype"] 96 | 97 | for p in group["params"]: 98 | if p.grad is None: 99 | continue 100 | 101 | if p.grad.is_sparse: 102 | raise RuntimeError( 103 | "AnyPrecisionAdamW does not support sparse gradients" 104 | ) 105 | 106 | state = self.state[p] 107 | 108 | # State initialization 109 | if len(state) == 0: 110 | 111 | state["step"] = torch.tensor(0.0) 112 | 113 | # momentum - EMA of gradient values 114 | state["exp_avg"] = torch.zeros_like( 115 | p, 116 | dtype=momentum_dtype, 117 | ) 118 | 119 | # variance uncentered - EMA of squared gradient values 120 | state["exp_avg_sq"] = torch.zeros_like( 121 | p, 122 | dtype=variance_dtype, 123 | ) 124 | 125 | # optional Kahan summation - accumulated error tracker 126 | if use_kahan_summation: 127 | state["compensation"] = torch.zeros_like( 128 | p, 129 | dtype=compensation_buffer_dtype, 130 | ) 131 | 132 | # main processing ------------------------- 133 | 134 | # update the steps for each param group update 135 | state["step"] += 1 136 | step = state["step"] 137 | 138 | exp_avg = state["exp_avg"] 139 | exp_avg_sq = state["exp_avg_sq"] 140 | 141 | grad = p.grad 142 | 143 | # weight decay, AdamW style 144 | if weight_decay: 145 | p.data.mul_(1 - lr * weight_decay) 146 | 147 | # update momentum 148 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 149 | 150 | # update uncentered variance 151 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 152 | 153 | # adjust using bias1 154 | bias_correction1 = 1 - beta1**step 155 | 156 | step_size = lr / bias_correction1 157 | 158 | # adjust using bias2 159 | denom_correction = (1 - beta2**step) ** 0.5 # avoids math import 160 | 161 | centered_variance = (exp_avg_sq.sqrt() / denom_correction).add_( 162 | eps, alpha=1 163 | ) 164 | 165 | # lr update to compensation 166 | if use_kahan_summation: 167 | compensation = state["compensation"] 168 | 169 | compensation.addcdiv_(exp_avg, centered_variance, value=-step_size) 170 | 171 | # update weights with compensation (Kahan summation) 172 | # save error back to compensation for next iteration 173 | temp_buffer = p.detach().clone() 174 | p.data.add_(compensation) 175 | compensation.add_(temp_buffer.sub_(p.data)) 176 | 177 | else: 178 | # usual AdamW updates 179 | p.data.addcdiv_(exp_avg, centered_variance, value=-step_size) -------------------------------------------------------------------------------- /scripts/experiment/run_inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import ast 3 | import pathlib 4 | import warnings 5 | from multiprocessing import Pool 6 | from typing import List, Tuple, Optional, Dict 7 | from functools import partial 8 | 9 | import pandas as pd 10 | import numpy as np 11 | from tqdm import tqdm 12 | from vllm import LLM, SamplingParams 13 | from vllm.lora.request import LoRARequest 14 | 15 | from subpop.utils.survey_utils import ordinal_emd, list_normalize 16 | 17 | ROOT_DIR = pathlib.Path(__file__).resolve().parents[2] 18 | warnings.filterwarnings("ignore") 19 | 20 | 21 | def prompt_chat_formatter(prompt: str) -> List[Dict[str, str]]: 22 | """ 23 | Reformat the QA steering prompt into the chat conversation. 24 | QA prompt is in Question:-Ansewer: format. Convert to user-assistant. 25 | """ 26 | prompt_split = prompt.split("Answer:")[:-1] 27 | prompt_split = [x.strip() for x in prompt_split] 28 | messages = [] 29 | messages.append( # subpopulation steering question 30 | {"role": "user", "content": prompt_split[0].strip()} 31 | ) 32 | messages.append( # subpopulation steering answer (ex. A. Very liberal) 33 | {"role": "assistant", "content": prompt_split[1].split("\n")[0].strip()} 34 | ) 35 | messages.append( # survey question 36 | { 37 | "role": "user", 38 | "content": prompt_split[1].replace(messages[1]["content"], "").strip(), 39 | } 40 | ) 41 | return messages 42 | 43 | 44 | def get_llm_engine(args) -> Tuple: 45 | """ 46 | Load the LLM engine on a local machine and define sampling parameters. 47 | """ 48 | sampling_params = SamplingParams( 49 | max_tokens=1, 50 | temperature=1.0, 51 | logprobs=128, 52 | ) 53 | llm = LLM( 54 | model=args.base_model_name_or_path, 55 | tensor_parallel_size=args.tp_size, 56 | max_logprobs=args.max_logprobs, 57 | enable_prefix_caching=args.enable_prefix_caching, 58 | enforce_eager=args.enforce_eager, 59 | max_model_len=args.max_model_len, 60 | enable_lora=True, 61 | ) 62 | return sampling_params, llm 63 | 64 | 65 | def inference_offline(args, data_list_test, sampling_params, llm, lora_idx): 66 | """ 67 | Offline batched inference for input_prompts in the data_list_test. 68 | """ 69 | tokenizer = llm.get_tokenizer() 70 | # prepare the alphabet (A, B, ...) tokens for logprob extraction 71 | alphabet_coded: List[Tuple[int, int]] = [ 72 | tuple([ 73 | tokenizer.encode( 74 | "Answer:" + chr(ord("A") + idx), add_special_tokens=False 75 | )[-1], 76 | tokenizer.encode( 77 | "Answer: " + chr(ord("A") + idx), add_special_tokens=False 78 | )[-1], 79 | ]) 80 | for idx in range(26) 81 | ] 82 | 83 | # prepare the input prompt and run inference 84 | prompts = [data[1] for data in data_list_test] 85 | if args.is_chat: 86 | prompts = [prompt_chat_formatter(prompt) for prompt in prompts] 87 | if args.lora_path is not None: 88 | print(f"--> inference_offline: LoRA name = {args.lora_name[lora_idx]}") 89 | print(f"--> inference_offline: LoRA path = {args.lora_path[lora_idx]}") 90 | lora_request = LoRARequest( 91 | args.lora_name[lora_idx], lora_idx + 1, 92 | lora_path=args.lora_path[lora_idx], 93 | ) 94 | if args.is_chat: 95 | outputs = llm.chat(prompts, sampling_params, lora_request=lora_request) 96 | else: 97 | outputs = llm.generate(prompts, sampling_params, lora_request=lora_request) 98 | else: 99 | if args.is_chat: 100 | outputs = llm.chat(prompts, sampling_params) 101 | else: 102 | outputs = llm.generate(prompts, sampling_params) 103 | 104 | # extract logprobs and calculate the probability per option 105 | results = [] 106 | for idx, output in enumerate(outputs): 107 | logprobs = output.outputs[0].logprobs[0] 108 | len_options = len(ast.literal_eval(data_list_test[idx][2])) 109 | prob_per_option = [] 110 | for opt_idx in range(len_options): 111 | logprob_1 = logprobs.get(alphabet_coded[opt_idx][0], None) 112 | logprob_2 = logprobs.get(alphabet_coded[opt_idx][1], None) 113 | prob_1 = np.exp(logprob_1.logprob) if logprob_1 is not None else 0 114 | prob_2 = np.exp(logprob_2.logprob) if logprob_2 is not None else 0 115 | prob_per_option.append(prob_1 + prob_2) 116 | results.append( 117 | ( 118 | idx, 119 | sum(prob_per_option), 120 | np.array(prob_per_option) / sum(prob_per_option), 121 | ) 122 | ) 123 | return results, ( 124 | args.lora_name[lora_idx] 125 | if args.lora_name is not None 126 | else args.base_model_name_or_path 127 | ) 128 | 129 | 130 | def run_survey(args, sampling_params, llm, lora_idx) -> None: 131 | """ 132 | Run inference for each input file and LoRA module. 133 | """ 134 | 135 | # load the file. 136 | # llm_dist is output distribution from model 137 | # llm_prob_sum is the sum of probabilities assigned to option ('A', 'B', ...) 138 | # emd is the WD between precalculated human and model's distribution. 139 | test_df = pd.read_csv(args.input_paths[lora_idx]) 140 | test_df["llm_dist"] = None 141 | test_df["llm_prob_sum"] = None 142 | test_df["emd"] = None 143 | data_list_test = [ 144 | (idx, row["input_prompt"], row["ordinal"]) 145 | for idx, row in test_df.iterrows() 146 | ] 147 | if args.debug: 148 | print(f"--> run_survey: data_list_test example = {data_list_test[0]}") 149 | import pdb; pdb.set_trace() 150 | 151 | # run the inference 152 | results, model_name = inference_offline( 153 | args, data_list_test, sampling_params, llm, lora_idx 154 | ) 155 | if args.debug: 156 | print(f"--> run_survey: results example = {results[0]}") 157 | import pdb; pdb.set_trace() 158 | 159 | # save the inference result to dataframe 160 | for idx, prob_sum, probs in results: 161 | test_df.at[idx, "llm_dist"] = probs 162 | test_df.at[idx, "llm_prob_sum"] = prob_sum 163 | human_probs = list_normalize( 164 | ast.literal_eval(test_df.at[idx, "output_dist"])[:-1] 165 | ) 166 | ordinal_value = ast.literal_eval(test_df.at[idx, "ordinal"]) 167 | test_df.at[idx, "emd"] = ordinal_emd( 168 | list_1=human_probs, 169 | list_2=probs, 170 | ordinal_value=ordinal_value, 171 | ) 172 | if args.debug: 173 | print(f"--> run_survey: test_df example = {test_df.iloc[0]}") 174 | import pdb; pdb.set_trace() 175 | 176 | # save output 177 | model_name = model_name.split("/")[-1] 178 | pathlib.Path(args.output_dir).mkdir(parents=True, exist_ok=True) 179 | test_df.to_csv( 180 | args.output_dir 181 | + args.input_paths[lora_idx].split("/")[-1].split(".")[0] 182 | + f"_{model_name}.csv", 183 | index=False, 184 | ) 185 | 186 | 187 | def cli_args_parser(): 188 | parser = argparse.ArgumentParser() 189 | 190 | # common arguments: input file path and output directory 191 | parser.add_argument("--input_paths", type=str, nargs="+", default=None) 192 | parser.add_argument("--output_dir", type=str) 193 | 194 | # offline inferene arguments 195 | # please refer to vllm (https://github.com/vllm-project/vllm) for engine arguments. 196 | parser.add_argument( 197 | "--base_model_name_or_path", type=str, default="meta-llama/Llama-2-7b-hf" 198 | ) 199 | parser.add_argument("--is_chat", action="store_true") 200 | parser.add_argument("--tp_size", type=int, default=1) 201 | parser.add_argument("--max_logprobs", type=int, default=256) 202 | parser.add_argument("--enable_prefix_caching", type=bool, default=True) 203 | parser.add_argument("--enforce_eager", type=bool, default=True) 204 | parser.add_argument("--max_model_len", type=int, default=2048) 205 | parser.add_argument("--lora_path", type=str, nargs="+", default=None) 206 | parser.add_argument("--lora_name", type=str, nargs="+", default=None) 207 | 208 | # debug flag is provided for better understanding of intermediate artifacts. 209 | parser.add_argument("--debug", action="store_true") 210 | 211 | return parser.parse_args() 212 | 213 | 214 | if __name__ == "__main__": 215 | 216 | """ 217 | Run the inference on the test set and save outputs to the output directory. 218 | 219 | The following arguments are mainly required: 220 | --input_paths: list of the test files (can be more than 1 file) 221 | --output_dir: directory to save the output (all results will be saved here) 222 | --base model: the base model. Refer to huggingface 223 | additionally, if it is chat model, turn on the --is_chat flag. 224 | --lora_path: list of the LoRA path (optional) 225 | --lora_name: list of the LoRA name (optional) 226 | 227 | Given N input files and M LoRA paths, the script will run N*M inferences. 228 | Must specify the unique lora_name for each lora_path. 229 | For running inference with base model, leave lora_name and lora_path empty. 230 | In this case, the script will run N inferences. 231 | """ 232 | 233 | args = cli_args_parser() 234 | 235 | # check argument consistency 236 | assert args.input_paths is not None, "Input paths should be provided." 237 | assert args.output_dir is not None, "Output directory should be provided." 238 | if args.lora_name is None and args.lora_path is not None: 239 | raise ValueError("LoRA name should be provided when LoRA path provided.") 240 | if args.lora_name is not None: 241 | assert len(args.lora_name) == len( 242 | args.lora_path 243 | ), "LoRA name and LoRA path should have the same length." 244 | 245 | # get the LLM engine 246 | sampling_params, llm = get_llm_engine(args) 247 | 248 | # lora_name is optional, if not provided, it will run the base model 249 | n_lora = 1 if args.lora_name is None else len(args.lora_name) 250 | total_runs = len(args.input_paths) * n_lora 251 | print(f"--> run_inference: total runs = {total_runs}") 252 | if args.lora_name is not None: 253 | args.lora_name = args.lora_name * len(args.input_paths) 254 | args.lora_path = args.lora_path * len(args.input_paths) 255 | args.input_paths = sorted(args.input_paths * n_lora) 256 | 257 | # run inference for each combination of input file and LoRA module 258 | for lora_idx in range(total_runs): 259 | run_survey(args, sampling_params, llm, lora_idx) 260 | -------------------------------------------------------------------------------- /subpop/train/model_checkpointing/checkpoint_handler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | import os 5 | import json 6 | import random 7 | from pathlib import Path 8 | from datetime import datetime 9 | import torch 10 | import time 11 | 12 | from torch.distributed.fsdp import ( 13 | FullyShardedDataParallel as FSDP, 14 | StateDictType, 15 | FullStateDictConfig, # general model non-sharded, non-flattened params 16 | LocalStateDictConfig, # flattened params, usable only by FSDP 17 | # ShardedStateDictConfig, # un-flattened param but shards, usable by other parallel schemes. 18 | ) 19 | 20 | from torch.distributed._shard.checkpoint import ( 21 | FileSystemReader, 22 | FileSystemWriter, 23 | save_state_dict, 24 | load_state_dict, 25 | ) 26 | from torch.distributed.checkpoint.default_planner import ( 27 | DefaultSavePlanner, 28 | DefaultLoadPlanner, 29 | ) 30 | 31 | 32 | from torch.distributed.checkpoint.state_dict import get_model_state_dict, StateDictOptions 33 | from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType 34 | import torch.distributed._shard.checkpoint as dist_cp 35 | import torch.distributed as dist 36 | 37 | 38 | def get_date_of_run(): 39 | """create date and time for file save uniqueness 40 | example: 2022-05-07-08:31:12_PM' 41 | """ 42 | date_of_run = datetime.now().strftime("%Y-%m-%d-%I:%M:%S_%p") 43 | print(f"--> current date and time of run = {date_of_run}") 44 | return date_of_run 45 | 46 | 47 | # create singleton saving policies to avoid making over and over 48 | fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) 49 | 50 | 51 | def load_model_sharded(model, rank, cfg): 52 | # torch.manual_seed(103) 53 | folder_name = ( 54 | cfg.dist_checkpoint_root_folder 55 | + "/" 56 | + cfg.dist_checkpoint_folder 57 | + "-" 58 | + cfg.model_nickname 59 | ) 60 | 61 | load_dir = Path.cwd() / folder_name 62 | 63 | if not load_dir.exists(): 64 | if rank == 0: 65 | print(f"No sharded_state_dict checkpoint directory found...skipping") 66 | return 67 | if rank == 0: 68 | print(f"loading model from model path: {load_dir} ") 69 | reader = FileSystemReader(load_dir) 70 | 71 | with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): 72 | checkpoint = {"model": model.state_dict()} 73 | if rank == 0: 74 | ck = checkpoint.keys() 75 | print(f" checkpoint key len = {len(ck)} and \n keys = {ck}") 76 | 77 | dist_cp.load_state_dict( 78 | state_dict=checkpoint, 79 | storage_reader=reader, 80 | ) 81 | if rank == 0: 82 | print(f"checkpoint after load_state_dict()") 83 | ck = checkpoint.keys() 84 | print(f" checkpoint key len = {len(ck)} and \n keys = {ck}") 85 | model.load_state_dict(checkpoint["model"]) 86 | if rank == 0: 87 | print(f"Sharded state checkpoint loaded from {load_dir}") 88 | 89 | 90 | def save_model_and_optimizer_sharded(model, rank, cfg,optim=None): 91 | """save model and optimizer via sharded_state_dict to save_dir""" 92 | 93 | folder_name = ( 94 | cfg.dist_checkpoint_root_folder 95 | + "/" 96 | + cfg.dist_checkpoint_folder 97 | + "-" 98 | + cfg.model_nickname 99 | ) 100 | 101 | save_dir = Path.cwd() / folder_name 102 | if rank == 0: 103 | print(f"Saving model to {save_dir}") 104 | 105 | distributed_writer = dist_cp.FileSystemWriter( 106 | save_dir, 107 | ) 108 | t0 = time.perf_counter() 109 | 110 | with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): 111 | 112 | state_dict = {"model": model.state_dict()} 113 | if optim is not None: 114 | state_dict["optim"] = FSDP.optim_state_dict(model, optim) 115 | 116 | dist_cp.save_state_dict( 117 | state_dict=state_dict, 118 | storage_writer=distributed_writer, 119 | planner=DefaultSavePlanner(), 120 | 121 | ) 122 | dist.barrier() 123 | t1 = time.perf_counter() 124 | if rank == 0: 125 | print(f"Sharded state checkpoint saved to {save_dir}") 126 | print( 127 | f"Checkpoint Time = {t1-t0:.4f}\n" 128 | ) 129 | def save_fsdp_model_checkpoint_full( 130 | model, 131 | optimizer, 132 | rank, 133 | cfg, 134 | epoch=1, 135 | ): 136 | """saving model via rank0 cpu streaming and full_state_dict""" 137 | 138 | with FSDP.state_dict_type( 139 | model, StateDictType.FULL_STATE_DICT, fullstate_save_policy 140 | ): 141 | cpu_state = model.state_dict() 142 | 143 | print(f"saving process: rank {rank} done w model state_dict\n") 144 | 145 | 146 | if rank == 0: 147 | print(f"--> saving model ...") 148 | # create save path 149 | folder_name = ( 150 | cfg.dist_checkpoint_root_folder 151 | + "/" 152 | + cfg.dist_checkpoint_folder 153 | + "-" 154 | + cfg.model_nickname 155 | ) 156 | save_dir = Path.cwd() / folder_name 157 | save_dir.mkdir(parents=True, exist_ok=True) 158 | save_name = cfg.model_name.replace("/","--") + "-" + str(epoch) + ".pt" 159 | save_full_path = str(save_dir) + "/" + save_name 160 | 161 | # save model 162 | torch.save(cpu_state, save_full_path) 163 | 164 | 165 | print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n") 166 | 167 | 168 | 169 | def load_model_checkpoint(model, rank, cfg): 170 | """load local checkpoint to rank0 cpu 171 | must be called * before * passing to FSDP""" 172 | 173 | if rank != 0: 174 | return 175 | 176 | # where is the checkpoint at... 177 | full_state_dict_model_path = ( 178 | Path.cwd() / cfg.checkpoint_folder / cfg.checkpoint_model_filename 179 | ) 180 | # is it present... 181 | if not full_state_dict_model_path.is_file(): 182 | print( 183 | f"model checkpoint {full_state_dict_model_path} not present. Returning..." 184 | ) 185 | return 186 | 187 | 188 | model_checkpoint = torch.load(full_state_dict_model_path) 189 | # integrate into loaded model 190 | model.load_state_dict(model_checkpoint) 191 | 192 | 193 | print(f"model checkpoint loaded to rank0 cpu") 194 | 195 | 196 | def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1): 197 | """save optimizer state via full state dict""" 198 | 199 | 200 | print(f"--> optim state call on rank {rank}\n") 201 | 202 | # pull all sharded optimizer states to rank0 cpu... 203 | 204 | optim_state = FSDP.full_optim_state_dict(model, optimizer) 205 | 206 | 207 | print(f"optim state dict ready on {rank} and len of {len(optim_state)}\n") 208 | 209 | if rank == 0: 210 | folder_name = ( 211 | cfg.dist_checkpoint_root_folder 212 | + "/" 213 | + cfg.dist_checkpoint_folder 214 | + "-" 215 | + cfg.model_nickname 216 | ) 217 | save_dir = Path.cwd() / folder_name 218 | save_dir.mkdir(parents=True, exist_ok=True) 219 | 220 | opt_save_name = ( 221 | "optimizer" + "-" + cfg.model_nickname + "-" + str(epoch) + ".pt" 222 | ) 223 | opt_save_full_path = save_dir / opt_save_name 224 | 225 | print(f"--> saving optimizer state...") 226 | print(f"Optimizer save full path = {opt_save_full_path}") 227 | 228 | torch.save(optim_state, opt_save_full_path) 229 | 230 | print(f"--> saved {opt_save_full_path} to disk") 231 | 232 | 233 | def load_optimizer_checkpoint(model, optimizer_checkpoint_path, rank): 234 | """load an fsdp optimizer full_state checkpoint using scatter method 235 | this ensures only rank 0 loads the optimizer state dict and scatters to other ranks 236 | """ 237 | 238 | 239 | if not optimizer_checkpoint_path.is_file(): 240 | print( 241 | f"warning - optimizer checkpoint not present {optimizer_checkpoint_path}. Returning. " 242 | ) 243 | return 244 | 245 | full_osd = None 246 | 247 | if rank == 0: 248 | full_osd = torch.load(optimizer_checkpoint_path) 249 | 250 | # called from all ranks, though only rank0 has a valid param for full_osd 251 | sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, model) 252 | 253 | print(f"optimizer shard loaded on rank {rank}") 254 | 255 | def load_sharded_model_single_gpu(model,model_path): 256 | 257 | reader = FileSystemReader(model_path) 258 | 259 | state_dict = { 260 | "model": model.state_dict() 261 | } 262 | 263 | dist_cp.load_state_dict( 264 | state_dict=state_dict, 265 | storage_reader= FileSystemReader(model_path), 266 | no_dist=True, 267 | ) 268 | 269 | model.load_state_dict(state_dict["model"]) 270 | 271 | print(f"Sharded state checkpoint loaded from {model_path}") 272 | return model 273 | 274 | def save_peft_checkpoint(model, model_path): 275 | """save_pretrained peft model""" 276 | 277 | options = StateDictOptions(full_state_dict=True, cpu_offload=True) 278 | 279 | if isinstance(model, FSDP): 280 | state_dict = get_model_state_dict(model, options=options) 281 | model.save_pretrained(model_path, state_dict=state_dict) 282 | else: 283 | model.save_pretrained(model_path) 284 | 285 | def save_peft_checkpoint_checkpointing( 286 | model, 287 | checkpoint_path, 288 | optimizer = None, 289 | scheduler = None, 290 | scaler = None, 291 | epoch = None, 292 | best_val_loss = None, 293 | train_prep = None, 294 | train_loss = None, 295 | val_prep = None, 296 | val_loss = None, 297 | train_step_perplexity = None, 298 | train_step_loss = None, 299 | val_step_loss = None, 300 | val_step_perplexity = None, 301 | test_step_loss = None, 302 | test_step_perplexity = None, 303 | epoch_times = None, 304 | checkpoint_times = None, 305 | total_train_steps = None, 306 | max_steps_reached = None, 307 | sampler_state = None, 308 | ): 309 | 310 | """ 311 | Save_pretrained peft model checkpoint and training status at the end of every epoch. 312 | When the training resumes, it automatically detects the existence of checkpoint directory 313 | and resumes from the last checkpoint. 314 | """ 315 | 316 | Path(checkpoint_path).mkdir(parents=True, exist_ok=True) 317 | if isinstance(model, FSDP): 318 | options = StateDictOptions(full_state_dict=True, cpu_offload=True) 319 | state_dict = get_model_state_dict(model, options=options) 320 | model.save_pretrained(checkpoint_path, state_dict=state_dict) 321 | else: 322 | model.save_pretrained(checkpoint_path) 323 | 324 | if optimizer is not None: 325 | torch.save(optimizer.state_dict(), os.path.join(checkpoint_path, "optimizer.pt")) 326 | if scheduler is not None: 327 | torch.save(scheduler.state_dict(), os.path.join(checkpoint_path, "scheduler.pt")) 328 | if scaler is not None: 329 | torch.save(scaler.state_dict(), os.path.join(checkpoint_path, "grad_scaler.pt")) 330 | if sampler_state is not None: 331 | torch.save(sampler_state, os.path.join(checkpoint_path, "sampler_state.pt")) 332 | 333 | rng_state = { 334 | "torch": torch.get_rng_state(), 335 | "cuda": torch.cuda.get_rng_state_all(), # for all GPUs 336 | "python": random.getstate() 337 | } 338 | torch.save(rng_state, os.path.join(checkpoint_path, "rng_state.pth")) 339 | 340 | metadata = { 341 | "epoch": epoch, 342 | "best_val_loss": float(best_val_loss) if best_val_loss is not None else None, 343 | "train_prep": train_prep, 344 | "train_loss": train_loss, 345 | "val_prep": val_prep, 346 | "val_loss": val_loss, 347 | "train_step_perplexity": train_step_perplexity, 348 | "train_step_loss": train_step_loss, 349 | "val_step_loss": val_step_loss, 350 | "val_step_perplexity": val_step_perplexity, 351 | "test_step_loss": test_step_loss, 352 | "test_step_perplexity": test_step_perplexity, 353 | "epoch_times": epoch_times, 354 | "checkpoint_times": checkpoint_times, 355 | "total_train_steps": total_train_steps, 356 | "max_steps_reached": max_steps_reached, 357 | } 358 | with open(os.path.join(checkpoint_path, "metadata.json"), "w") as f: 359 | try: 360 | json.dump(metadata, f, indent=4) 361 | except: 362 | import pdb; pdb.set_trace() 363 | 364 | print(f"[save_peft_checkpoint]: PEFT checkpoint saved to {checkpoint_path}") 365 | 366 | 367 | 368 | def save_model_checkpoint(model, output_dir): 369 | """save model when not peft and on single device""" 370 | 371 | output_file = Path(output_dir) / "model.pt" 372 | 373 | state_dict = model.state_dict() 374 | 375 | torch.save(state_dict, output_file) 376 | 377 | -------------------------------------------------------------------------------- /scripts/data_generation/prepare_finetuning_data.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import argparse 3 | import os 4 | import pathlib 5 | import random 6 | from enum import Enum 7 | from typing import List, Optional 8 | 9 | import pandas as pd 10 | import numpy as np 11 | from tqdm import tqdm 12 | 13 | from subpop.survey.config import SteeringPromptType 14 | from subpop.utils.survey_utils import generate_mcq, list_normalize 15 | from subpop.utils.random_utils import set_random_seed 16 | 17 | REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent 18 | 19 | 20 | def code_subgroup(attribute: str, subgroup: str) -> str: 21 | """ Code subgroup for generation of steering prompt. """ 22 | if attribute == 'CITIZEN': 23 | subgroup_coded = 'Yes' if subgroup == 'a US Citizen' else 'No' 24 | elif attribute == 'MARITAL': 25 | subgroup_coded = ( 26 | 'Never been married' if subgroup == 'Unmarried and have never been married' 27 | else subgroup 28 | ) 29 | elif attribute == 'POLPARTY': 30 | subgroup_coded = 'Other' if subgroup == 'Something else' else subgroup 31 | else: 32 | subgroup_coded = subgroup 33 | return subgroup_coded 34 | 35 | 36 | def prepare_data( 37 | survey_file_path: pathlib.Path, 38 | steering_prompts_file_path: pathlib.Path, 39 | steering_demographics_file_path: pathlib.Path, 40 | steering_prompt_type: SteeringPromptType = SteeringPromptType.QA, 41 | train_ratio: float = 0.9, 42 | val_ratio: float = 0.1, 43 | test_ratio: float = 0.0, 44 | test_wave: Optional[List[int]] = [], 45 | ) -> pd.DataFrame: 46 | """Prepare data for training and evaluation of fine-tuned language model. 47 | Args: 48 | survey_file_path (pathlib.Path): Path to survey file 49 | steering_prompts_file_path (pathlib.Path): Path to steering prompts file 50 | steering_demographics_file_path (pathlib.Path): Path to steering demographics file 51 | steering_prompt_type (SteeringPromptType): Type of steering prompt 52 | train, val, test_ratio (float): Ratios for train, validation and test splits 53 | test_wave (List[int]): List of wave numbers dedicated for test 54 | Returns: 55 | pd.DataFrame: Dataframe containing input prompts, output tokens and output distribution 56 | Note: 57 | output_dist does not contain refusal option, instead saving normalized distribution without refusal option 58 | """ 59 | 60 | assert train_ratio >= 0 and val_ratio >= 0 and test_ratio >= 0, "Ratios should be non-negative." 61 | assert sum([train_ratio, val_ratio, test_ratio]) == 1, "Ratios should sum up to 1." 62 | 63 | full_survey_df: pd.DataFrame = pd.read_csv(survey_file_path) 64 | steering_prompts_df: pd.DataFrame = pd.read_json(steering_prompts_file_path) 65 | steering_demographics_df: pd.DataFrame = pd.read_csv(steering_demographics_file_path) 66 | 67 | full_survey_df["responses"] = full_survey_df["responses"].apply(ast.literal_eval) 68 | full_survey_df["ordinal"] = full_survey_df["ordinal"].apply(ast.literal_eval) 69 | full_survey_df["options"] = full_survey_df["options"].apply(ast.literal_eval) 70 | 71 | steering_prompts_df["options"] = steering_prompts_df["options"].apply(ast.literal_eval) 72 | steering_demographics_df["group"] = steering_demographics_df["group"].apply(ast.literal_eval) 73 | steering_demographics_df = steering_demographics_df.explode("group") 74 | 75 | unique_qkeys: List[str] = full_survey_df["qkey"].unique() 76 | test_dedicated_qkeys: List[str] = [ 77 | qkey for qkey in unique_qkeys if any(f"_W{wave}" in qkey for wave in test_wave) 78 | ] # qkeys belonging to particular waves that are dedicated for test 79 | remaining_qkeys: List[str] = [ 80 | qkey for qkey in unique_qkeys if qkey not in test_dedicated_qkeys 81 | ] # qkeys that are not dedicated for test 82 | 83 | # shuffle keys and split into three datasets 84 | np.random.shuffle(remaining_qkeys) 85 | num_train_questions = int(round(train_ratio * len(remaining_qkeys))) 86 | num_val_questions = int(round(val_ratio * len(remaining_qkeys))) 87 | train_questions = remaining_qkeys[:num_train_questions] 88 | val_questions = remaining_qkeys[num_train_questions:num_train_questions+num_val_questions] 89 | test_questions = remaining_qkeys[num_train_questions+num_val_questions:] + test_dedicated_qkeys 90 | 91 | train_survey_df = full_survey_df[full_survey_df["qkey"].isin(train_questions)] 92 | val_survey_df = full_survey_df[full_survey_df["qkey"].isin(val_questions)] 93 | test_survey_df = full_survey_df[full_survey_df["qkey"].isin(test_questions)] 94 | 95 | train_data_list: List[pd.DataFrame] = [] 96 | val_data_list: List[pd.DataFrame] = [] 97 | test_data_list: List[pd.DataFrame] = [] 98 | 99 | # iterate over each (attribute, subgroup) and (train, val, test) split 100 | # to generate three lists of dataframes, {train, val, test}_data_list 101 | for _, subgroup_row in tqdm(steering_demographics_df.iterrows()): 102 | attribute: str = subgroup_row["attribute"] 103 | subgroup: str = subgroup_row["group"] 104 | subgroup_coded = code_subgroup(attribute, subgroup) 105 | 106 | for i, survey_df in enumerate([train_survey_df, val_survey_df, test_survey_df]): 107 | survey_subgroup_df: pd.DataFrame = survey_df[ 108 | (survey_df["attribute"] == attribute) 109 | & (survey_df["group"] == subgroup_coded) 110 | ].reset_index(drop=True) 111 | data_df: pd.DataFrame = pd.DataFrame( 112 | columns=["qkey", "input_prompt", "output_token", "output_dist"] 113 | ) 114 | 115 | for steering_prompt_type_str in steering_prompt_type.value: 116 | 117 | # steering prompt generation 118 | steering_prompt: str = steering_prompts_df[ 119 | steering_prompts_df.attribute == attribute 120 | ][steering_prompt_type_str].values[0] 121 | steering_options: list = steering_prompts_df[ 122 | steering_prompts_df.attribute == attribute 123 | ]["options"].values[0] 124 | 125 | assert ( 126 | subgroup in steering_options 127 | ), f"Subgroup {subgroup} not found in steering options" 128 | 129 | if steering_prompt_type_str == SteeringPromptType.QA.value[0]: 130 | # QA steering prompt generation 131 | steering_prompt = generate_mcq( 132 | question_body=steering_prompt, options=steering_options 133 | ) 134 | idx = steering_options.index(subgroup) 135 | steering_prompt += f" {chr(ord('A') + idx)}. {subgroup}\n\n" 136 | steering_prompt += "Answer the following question keeping in mind your previous answers.\n" 137 | 138 | else: 139 | # BIO and PORTRAY steering prompt: does not require mcq generation 140 | steering_prompt = ".\n".join(steering_prompt.split(". ")) 141 | steering_prompt += f" {subgroup}.\n\n" 142 | 143 | # survey prompt generation 144 | survey_prompt_series: pd.core.series.Series = survey_subgroup_df.apply( 145 | lambda row: generate_mcq( 146 | question_body=row.question, 147 | options=row.options, 148 | add_answer_forcing=True, 149 | ), 150 | axis=1, 151 | ) 152 | 153 | # concatenation of steering and survey prompts 154 | if survey_prompt_series.empty: 155 | continue 156 | data_df["input_prompt"] = steering_prompt + survey_prompt_series 157 | 158 | # augmentation of one-hot responses (ablation study for distribution modeling) 159 | response_dist_with_refusal: pd.core.series.Series = ( 160 | survey_subgroup_df.apply( 161 | lambda row: row.responses + [row.refusal_rate], axis=1 162 | ) 163 | ) 164 | response_samples: pd.core.series.Series = response_dist_with_refusal.apply( 165 | lambda x: random.choices(range(len(x)), weights=x, k=100) 166 | ) 167 | data_df["output_token"] = response_samples.apply( 168 | lambda x: [f" {chr(ord('A') + i)}" for i in x] 169 | ) 170 | data_df["output_dist"] = response_dist_with_refusal 171 | data_df["qkey"] = survey_subgroup_df["qkey"] 172 | data_df["attribute"] = attribute 173 | data_df["group"] = subgroup 174 | data_df["ordinal"] = survey_subgroup_df["ordinal"] 175 | 176 | if i == 0: 177 | train_data_list.append(data_df.copy()) 178 | elif i == 1: 179 | val_data_list.append(data_df.copy()) 180 | else: 181 | test_data_list.append(data_df.copy()) 182 | 183 | train_data_df = ( 184 | pd.concat(train_data_list).reset_index(drop=True) 185 | if train_data_list else pd.DataFrame() 186 | ) 187 | val_data_df = ( 188 | pd.concat(val_data_list).reset_index(drop=True) 189 | if val_data_list else pd.DataFrame() 190 | ) 191 | test_data_df = ( 192 | pd.concat(test_data_list).reset_index(drop=True) 193 | if test_data_list else pd.DataFrame() 194 | ) 195 | return train_data_df, val_data_df, test_data_df 196 | 197 | 198 | def get_args_datagen(): 199 | parser = argparse.ArgumentParser( 200 | description="Data Generation for Finetuning and Evaluation" 201 | ) 202 | parser.add_argument( 203 | "--dataset", 204 | type=str, default="subpop-train", 205 | help="Dataset name", 206 | ) 207 | parser.add_argument( 208 | "--steer_prompts_file_path", 209 | type=str, default=REPO_ROOT / "data" / "subpopulation_metadata" / "steering_prompts.json", 210 | help="Steer prompts file path", 211 | ) 212 | parser.add_argument( 213 | "--steer_demographics_file_path", 214 | type=str, default=REPO_ROOT / "data" / "subpopulation_metadata" / "demographics_22.csv", 215 | help="Steer demographics file path", 216 | ) 217 | parser.add_argument( 218 | "--train_ratio", 219 | type=float, default=0.9, 220 | help="Train split ratio" 221 | ) 222 | parser.add_argument( 223 | "--val_ratio", 224 | type=float, default=0.1, 225 | help="Validation split ratio" 226 | ) 227 | parser.add_argument( 228 | "--test_ratio", 229 | type=float, default=0.0, 230 | help="Test split ratio" 231 | ) 232 | parser.add_argument( 233 | "--test_wave", 234 | type=int, nargs="+", default=[], 235 | help="Wave numbers dedicated for test. Used when wants to spare a particular wave for test." 236 | ) 237 | parser.add_argument( 238 | "--seed", 239 | type=int, default=42, 240 | help="Random seed" 241 | ) 242 | parser.add_argument( 243 | "--no_shuffle", 244 | action="store_true", 245 | help="Flag to not shuffle data. Used when wants to keep the order of data." 246 | ) 247 | return parser.parse_args() 248 | 249 | 250 | if __name__ == "__main__": 251 | 252 | args = get_args_datagen() 253 | 254 | dataset_name = args.dataset 255 | output_dir = REPO_ROOT / "data" / dataset_name / "processed" 256 | response_distribution_file_path = ( 257 | REPO_ROOT / "data" / dataset_name / "processed" / f"{dataset_name}.csv" 258 | ) 259 | steer_prompts_file_path = args.steer_prompts_file_path 260 | steer_demographics_file_path = args.steer_demographics_file_path 261 | 262 | train_ratio = args.train_ratio 263 | val_ratio = args.val_ratio 264 | test_ratio = args.test_ratio 265 | test_wave = args.test_wave 266 | 267 | seed = args.seed 268 | no_shuffle = args.no_shuffle 269 | 270 | if not pathlib.Path(output_dir).exists(): 271 | os.makedirs(output_dir) 272 | 273 | """For each steering prompt type, generate a train / validation / test split.""" 274 | for steer_type in SteeringPromptType: 275 | set_random_seed(seed) # set the same seed for each steering prompt type. 276 | train_df, val_df, test_df = prepare_data( 277 | steering_prompt_type=steer_type, 278 | survey_file_path=response_distribution_file_path, 279 | steering_prompts_file_path=steer_prompts_file_path, 280 | steering_demographics_file_path=steer_demographics_file_path, 281 | train_ratio=train_ratio, 282 | val_ratio=val_ratio, 283 | test_ratio=test_ratio, 284 | test_wave=test_wave, 285 | ) 286 | if not no_shuffle: 287 | train_df = train_df.sample(frac=1, random_state=seed).reset_index(drop=True) 288 | val_df = val_df.sample(frac=1, random_state=seed).reset_index(drop=True) 289 | test_df = test_df.sample(frac=1, random_state=seed).reset_index(drop=True) 290 | train_df.to_csv(os.path.join(output_dir, f"opnqa_{steer_type.name}_train.csv"),index=False) 291 | val_df.to_csv(os.path.join(output_dir, f"opnqa_{steer_type.name}_val.csv"),index=False) 292 | test_df.to_csv(os.path.join(output_dir, f"opnqa_{steer_type.name}_test.csv"),index=False) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Language Model Fine-Tuning on Scaled Survey Data for Predicting Distributions of Public Opinions 2 | 3 | 4 | [![Arxiv](https://img.shields.io/badge/arXiv-2502.16761-B31B1B.svg)][#arxiv-paper-package] 5 | [![Github License](https://img.shields.io/badge/License-BSD_3--Clause-blue.svg)](https://opensource.org/licenses/BSD-3-Clause) 6 | 7 | 8 | [#license-gh-package]: https://lbesson.mit-license.org/ 9 | [#arxiv-paper-package]: https://arxiv.org/abs/2502.16761 10 | 11 | 12 |

13 | Read the paper 14 |

15 |

16 | Thumbnail 17 |

18 | 19 | **Can LLMs assist public opinion survey designs by predicting responses?** 20 | 21 | We fine-tune LLMs on our new large-scale survey response dataset, **SubPOP**, which reduces the distributional gap between human-LLM predictions by up to 46% 📊 22 | For more details, please check out our [paper](https://arxiv.org/abs/2502.16761). 23 | 24 | --- 25 | 26 | ## Installation 27 | 28 | To install the required packages, you can create a conda environment, clone, and install the dependencies as: 29 | ```bash 30 | conda create -n subpop python=3.10 -y 31 | conda activate subpop 32 | 33 | git clone git@github.com:JosephJeesungSuh/subpop.git 34 | cd subpop 35 | pip install -r requirements.txt 36 | pip install -e . 37 | ``` 38 | 39 | --- 40 | 41 | ## Basic Run 42 | To reproduce the fine-tuning and evaluations in the paper, run the following three sections in sequence: 43 | 44 | (1) Prepare dataset 45 | 46 | (2) Fine-tune the base model (you can skip this step with the provided model weights) 47 | 48 | (3) Run inference and measure the response probability distribution 49 | 50 | ## Prepare Dataset 51 | 52 | ### Step 0. Download Dataset 53 | 54 | We offer two options for obtaining the dataset: 55 | 56 | **Option 1**. Download Preprocessed Subpopulation-Level Response Distribution Data 57 | 58 | You can directly download the preprocessed SubPOP dataset from the [HuggingFace dataset repository](https://huggingface.co/datasets/jjssuh/subpop). 59 | This dataset includes two files: SubPOP-*Train* and SubPOP-*Eval* in a `.jsonl` format. 60 | Before downloading, you have to agree to the dataset's terms of use. 61 | After downloading these files, place them under `data/subpop-train/processed/` and `data/subpop-eval/processed/`, respectively. 62 | Then proceed directly to [Step 3: Generate Fine-Tuning Dataset](#step-3-generate-fine-tuning-dataset). 63 | 64 | **Option 2**. Curate Dataset from Raw Survey Responses 65 | 66 | For greater transparency and to facilitate further research involving raw survey data preprocessing, we provide a data curation pipeline. 67 | If you prefer to curate the dataset yourself, you'll first need to obtain the original survey response data: 68 | 69 | For SubPOP-*Train*, please visit American Trends Panel wave 61-132 70 | from [Pew Research](https://www.pewresearch.org/american-trends-panel-datasets/) 71 | and place all .sav files (ex. `ATP W132.sav`) under the `data/subpop-train/` directory. 72 | 73 | For SubPOP-*Eval*, please visit 2022 General Social Survey 74 | from [NORC](https://gss.norc.org/us/en/gss/get-the-data/stata.html) 75 | and place the .dta file (`GSS2022.dta`) under `data/subpop-eval`. 76 | 77 | For OpinionQA, the group-level survey response result is provided by 78 | [OpinionQA official repository](https://github.com/tatsu-lab/opinions_qa), 79 | and we adopt this survey dataset by locating it at `data/opinionqa/processed/opinionqa.csv`. 80 | 81 | ### Step 1. Refine Question Text 82 | 83 | As a first step, you can refine the question text with the following command. 84 | Processed question text will be placed at `data/subpop-train/processed/refined_qkey_dict.json`. 85 | You need to first register your OpenAI API key to the environment: `export OPENAI_API_KEY="sk-XXX"` 86 | 87 | ```bash 88 | python scripts/data_generation/refine_question.py 89 | ``` 90 | 91 | For SubPOP-*Eval* we provide a refined version at `data/subpop-eval/processed/refined_qkey_dict.json`. 92 | For OpinionQA, we provide a refined version at `data/opinionqa/processed/refined_qkey_dict.json` 93 | which was developed by [OpinionQA](https://github.com/tatsu-lab/opinions_qa). 94 | 95 | ### Step 2. Obtain Response Distribution 96 | 97 | American Trends Panel and General Social Survey provide responses of anonymized individuals. 98 | You can obtain the response distribution per each subpopulation by aggregating individual responses 99 | with the command: 100 | 101 | ```bash 102 | python scripts/data_generation/generate_distribution.py 103 | --dataset {DATASET_NAME} 104 | --n_workers {NUM_WORKERS} 105 | --demographics_data_path {PATH_TO_SUBPOPULATION_METADATA} 106 | ``` 107 | 108 | - `dataset`: Dataset name. Use `subpop-train` or `subpop-eval`. 109 | - `n_workers`: (Optional) Number of spawned processors. Default to 1. 110 | - `demographics_data_path`: (Optional) Path to the metadata of subpopulations. Default to `data/subpopulation_metadata/demographics_22.csv`. 111 | 112 | Running the script with a specified dataset name will result in 22 subpopulations response distribution 113 | located at `data/{DATASET_NAME}/processed/{DATASET_NAME}.csv`. 114 | 115 | ### Step 3. Generate Fine-Tuning Dataset 116 | 117 | This step converts the response distribution data from the previous step into a ready-to-go fine-tuning dataset. 118 | You can run the following command: 119 | 120 | ```bash 121 | python scripts/data_generation/prepare_finetuning_data.py 122 | --dataset {DATASET_NAME} 123 | --steer_prompts_file_path {PATH_TO_STEERING_PROMPT} 124 | --steer_demographics_file_path {PATH_TO_SUBPOPULATION_METADATA} 125 | --train_ratio {TRAIN_RATIO} 126 | --val_ratio {VALIDATION_RATIO} 127 | --test_ratio {TEST_RATIO} 128 | ``` 129 | 130 | - `dataset` : Dataset name. Use `subpop-train`, `subpop-eval`, or `opinionqa`. 131 | - `steer_prompts_file_path` : (Optional) Path to the metadata of steering prompts. Default to `data/subpopulation_metadata/steering_prompts.json`. 132 | - `steer_demographics_file_path` : (Optional) Path to the metadata of subpopulations. Default to `data/subpopulation_metadata/demographics_22.csv`. 133 | - `train_ratio` : Portion for train. Used value is 0.9 for SubPOP-*Train*, 0.0 for SubPOP-*Eval* and OpinionQA. 134 | - `val_ratio` : Portion for validation. Used value is 0.1 for SubPOP-*Train*, 0.0 for SubPOP-*Eval* and OpinionQA. 135 | - `test_ratio` : Portion for test. Used value is 0.0 for SubPOP-*Train*, 1.0 for SubPOP-*Eval* and OpinionQA. 136 | 137 | Running the script with a specified dataset name will result in fine-tuning data 138 | located at `data/{DATASET_NAME}/processed/opnqa_{QA,BIO,PORTRAY,ALL}_{train,val,test}.csv`. 139 | You can transfer the generated files `data/subpop-train/processed/opnqa_QA_{train,val,test}.csv` to `train/datasets/subpop-train` directory and move on to the next step. 140 | 141 | --- 142 | 143 | ## Fine-tune the Base Model 144 | 145 | For fine-tuning the base model to predict opinion response distribution, 146 | we build on [llama-cookbook (formerly llama-recipes)](https://github.com/meta-llama/llama-cookbook). 147 | The following command takes a train file (generated in the 'Prepare Dataset' section) and a base language model as an input, 148 | and trains a LoRA module. 149 | 150 | To use trained LoRA modules, check [checkpoints](#model-checkpoints). 151 | 152 | ```bash 153 | export HF_TOKEN=${HF_TOKEN} 154 | export WANDB_API_KEY=${WANDB_API_KEY} # optional 155 | export TOKENIZERS_PARALLELISM=true # optional 156 | 157 | torchrun --nnodes=1 158 | --nproc-per-node=${NPROC_PER_NODE} 159 | --master_port=${MASTER_PORT} 160 | scripts/experiment/run_finetune.py 161 | --enable_fsdp 162 | --low_cpu_fsdp 163 | --fsdp_config.pure_bf16 164 | --use_peft=${USE_PEFT} 165 | --use_fast_kernels 166 | --checkpoint_type StateDictType.FULL_STATE_DICT 167 | --peft_method='lora' 168 | --use_fp16 169 | --mixed_precision 170 | --batch_size_training ${BATCH_SIZE_TRAINING} 171 | --val_batch_size ${BATCH_SIZE_VALIDATION} 172 | --gradient_accumulation_steps ${GRADIENT_ACCUMULATION_STEPS} 173 | --dist_checkpoint_root_folder ${DIST_CHECKPOINT_ROOT_FOLDER} 174 | --dist_checkpoint_folder ${DIST_CHECKPOINT_FOLDER} 175 | --batching_strategy='padding' 176 | --dataset ${DATASET} 177 | --output_dir DIR_TO_OUTPUT_MODEL 178 | --dataset_path ${DATASET_PATH} 179 | --steering_type ${STEERING_TYPE} 180 | --model_name ${MODEL_NAME_OR_PATH} 181 | --model_nickname ${MODEL_NICKNAME} 182 | --lr ${LR} 183 | --num_epochs ${NUM_EPOCHS} 184 | --weight_decay ${WEIGHT_DECAY} 185 | --loss_function_type ${LOSS_FUNCTION_TYPE} 186 | --which_scheduler ${WHICH_SCHEDULER} 187 | --warmup_ratio ${WARMUP_RATIO} 188 | --gamma ${GAMMA} 189 | --attribute ${ATTRIBUTE} 190 | --group ${GROUP} 191 | --lora_config.r ${LORA_RANK} 192 | --lora_config.lora_alpha ${LORA_ALPHA} 193 | --is_chat ${IS_CHAT} 194 | --name NAME_OF_WANDB_RUN 195 | --wandb_config.project NAME_OF_WANDB_PROJECT 196 | --wandb_config.entity NAME_OF_WANDB_ENTITY 197 | ``` 198 | 199 | where we provide an example of environment variables as follows: 200 | 201 | ```bash 202 | envs: 203 | MASTER_PORT: 29501 204 | NPROC_PER_NODE: 2 205 | USE_PEFT: True 206 | LR: 2e-4 207 | NUM_EPOCHS: 50 208 | WEIGHT_DECAY: 0 209 | LOSS_FUNCTION_TYPE: ce # ce or wd, depending on the training objective to use. 210 | WHICH_SCHEDULER: cosine # cosine or step, for linear warmup with cosine decay or StepLR 211 | WARMUP_RATIO: WARMUP_RATIO # used for cosine 212 | GAMMA: GAMMA # used for StepLR 213 | BATCH_SIZE_TRAINING: 128 214 | BATCH_SIZE_VALIDATION: 128 215 | GRADIENT_ACCUMULATION_STEPS: 1 216 | DATASET: opnqa_steering_dataset 217 | STEERING_TYPE: QA 218 | DATASET_PATH: subpop-train 219 | MODEL_NICKNAME: llama-2-7b-base 220 | MODEL_NAME_OR_PATH: meta-llama/Llama-2-7b-hf 221 | IS_CHAT: False # True if using chat model (ex. llama-2-7b-chat) 222 | ATTRIBUTE: None 223 | GROUP: None 224 | DIST_CHECKPOINT_ROOT_FOLDER: None # only used for full fine-tuning 225 | DIST_CHECKPOINT_FOLDER: None # only used for full fine-tuning 226 | LORA_RANK: 8 # used for PEFT 227 | LORA_ALPHA: 32 # used for PEFT 228 | HF_TOKEN: YOUR_HF_TOKEN 229 | WANDB_API_KEY: YOUR_WANDB_KEY 230 | ``` 231 | 232 | To launch a training job on the [Strong Compute](https://strongcompute.com/) instance, you can run the following code: 233 | ```bash 234 | isc_project_id = STRONG_COMPUTE_PROJECT_ID 235 | experiment_name = YOUR_EXPERIMENT_NAME 236 | gpus = 6 237 | compute_mode = "interruptible" 238 | output_path = "~/outputs/subpop-train" 239 | command = ''' 240 | export HF_TOKEN=$HF_TOKEN 241 | source ~/.subpop/bin/activate && cd ~/isc/llama-recipes/ && 242 | torchrun --nnodes=$NNODES --nproc-per-node=$N_PROC 243 | --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT --node_rank=$RANK 244 | finetuning.py ...''' # additional arguments 245 | ``` 246 | for more information, please refer to the Strong Compute's [ISC documentation](https://strong-compute.gitbook.io/developer-docs). 247 | 248 | ### Model Checkpoints 249 | 250 | We release LoRA checkpoints for four base models: 251 | [Llama-2-7B base](https://huggingface.co/jjssuh/llama-2-7b-subpop), 252 | [Llama-2-13B base](https://huggingface.co/jjssuh/llama-2-13b-subpop), 253 | [Mistral-7B-v0.1 base](https://huggingface.co/jjssuh/mistral-7b-v0.1-subpop), 254 | and [Llama-3-70B base](https://huggingface.co/jjssuh/llama-3-70b-subpop). 255 | Please note that the base models are pretrained base models not instruction fine-tuned. 256 | 257 | --- 258 | 259 | ## Run Inference and Measure Response Distribution 260 | 261 | For running inference and measuring log-probability, we support vLLM offline batched inference. 262 | For more information, refer to vLLM documentation (https://docs.vllm.ai/en/latest/getting_started/quickstart.html). 263 | The following command takes a test file (generated in the 'Prepare Dataset' section) 264 | and a (fine-tuned) LLM as an input, and generates response distribution along with the Wasserstein distance. 265 | 266 | ```bash 267 | python scripts/experiment/run_inference.py 268 | --input_paths {PATH_TO_TEST_FILE} 269 | --output_dir {DIR_TO_SAVE_OUTPUT} 270 | --base_model_name_or_path {MODEL_NAME_OR_LOCAL_PATH_TO_MODEL} 271 | --tp_size {TENSOR_PARALLEL_REPLICAS} 272 | --lora_path {LORA_NAME_OR_LOCAL_PATH_TO_LORA_MODULES} 273 | --lora_name {CUSTOM_NAME_FOR_OUTPUT_NAMING} 274 | ``` 275 | 276 | - input_paths: Path or a list of paths to the input file(s). For example, `data/subpop-eval/processed/opnqa_QA_test.csv`. 277 | - output_dir: Output file's directory. 278 | - base_model_name_or_path: Base model name. For example, `meta-llama/Llama-2-7b-hf`. 279 | - tp_size: (Optional) The number of GPUs to use for tensor parallelism. 280 | - lora_path: (Optional) Path or a list of paths to the LoRA module(s). For example, `jjssuh/llama-2-7b-subpop`. To run inference with a base model, don't use this argument. 281 | - lora_name: (Optional) A name that a user chooses to call each LoRA module. For example, `llama-2-7b-subpop-FT`. 282 | 283 | After inference, you can refer to `scripts/experiment/analyze_inference_result.ipynb` 284 | for calculation of the Wasserstein distance and reproduce the final result. 285 | 286 | --- 287 | 288 | ## Contact 289 | 290 | For any questions or issues about the paper and implementation, please open an issue or contact josephsuh@berkeley.edu. 291 | 292 | ## Citation 293 | 294 | ``` 295 | @article{suh2025language, 296 | title={Language Model Fine-Tuning on Scaled Survey Data for Predicting Distributions of Public Opinions}, 297 | author={Suh, Joseph and Jahanparast, Erfan and Moon, Suhong and Kang, Minwoo and Chang, Serina}, 298 | journal={arXiv preprint arXiv:2502.16761}, 299 | year={2025} 300 | } 301 | ``` 302 | -------------------------------------------------------------------------------- /data/opinionqa/opinionqa_parsed-qkeys.json: -------------------------------------------------------------------------------- 1 | { 2 | "W26": [ 3 | "CARRYGUN_W26", 4 | "GROWUPGUN2A_W26", 5 | "GROWUPGUN2B_W26", 6 | "GROWUPGUN2C_W26", 7 | "GROWUPGUN4_W26", 8 | "GROWUPGUN6_W26", 9 | "GROWUPGUN7_W26", 10 | "GROWUPVIOL_W26", 11 | "GUNACCESS_W26", 12 | "GUNACTIVITYA_W26", 13 | "GUNACTIVITYB_W26", 14 | "GUNCONTRIBA_W26", 15 | "GUNCONTRIBC_W26", 16 | "GUNCONTRIBE_W26", 17 | "GUNCONTRIBF_W26", 18 | "GUNFRIEND_W26", 19 | "GUNIDENTITY_W26", 20 | "GUNLOADED1_W26", 21 | "GUNLOADED2_W26", 22 | "GUNLOCKED1_W26", 23 | "GUNLOCKED2_W26", 24 | "GUNRESPKIDSD_W26", 25 | "GUNRESPKIDSE_W26", 26 | "GUNRESPNOKIDSA_W26", 27 | "GUNRESPNOKIDSB_W26", 28 | "GUNRESPNOKIDSC_W26", 29 | "GUNRESPNOKIDSD_W26", 30 | "GUNRESPNOKIDSE_W26", 31 | "GUNRESPNOKIDSF_W26", 32 | "GUNSAFETYKIDS_W26", 33 | "GUNSAFE_W26", 34 | "GUNTYPEOWNC_W26", 35 | "IMPREASONGUN_W26", 36 | "NOCARRYGUN_W26", 37 | "REASONGUNA_W26", 38 | "REASONGUNB_W26", 39 | "REASONGUNC_W26", 40 | "REASONGUND_W26", 41 | "REASONGUNE_W26", 42 | "SHOOTFREQ_W26", 43 | "WORRYB_W26", 44 | "WORRYD_W26", 45 | "WORRYE_W26", 46 | "WORRYF_W26" 47 | ], 48 | "W29": [ 49 | "BOYSF1B_W29", 50 | "FAVORS_CPS_W29", 51 | "FEM2_W29", 52 | "GIRLSF2C_W29", 53 | "GIRLSF2D_W29", 54 | "GOPDIRCT_W29", 55 | "HOOD_NHISB_W29", 56 | "LOCALELECT_W29", 57 | "MAN1A_W29", 58 | "MAN1B_W29", 59 | "MAN1C_W29", 60 | "MAN1D_W29", 61 | "MAN1E_W29", 62 | "MASC2_W29", 63 | "S12_W29", 64 | "S13_W29", 65 | "SEENFEM_W29", 66 | "SEENMASC_W29", 67 | "TALK_CPS_W29", 68 | "WORRYBILL_W29" 69 | ], 70 | "W32": [ 71 | "CLASS_W32", 72 | "COMMIMPA_W32", 73 | "COMMIMPB_W32", 74 | "COMMIMPC_W32", 75 | "COMMIMPF_W32", 76 | "COMMIMPG_W32", 77 | "COMMIMPH_W32", 78 | "COMMYRS_W32", 79 | "ETHNCMAJ_W32", 80 | "GAYMARR2_W32", 81 | "HARASS2F1_W32", 82 | "IMMIMPACT_W32", 83 | "MOVERURAL_W32", 84 | "MOVESUBURB_W32", 85 | "MOVEURBAN_W32", 86 | "NEIGHINTERA_W32", 87 | "NEIGHINTERB_W32", 88 | "NEIGHKIDS_W32", 89 | "NEIGHSAMEA_W32", 90 | "SATLIFED_W32", 91 | "SOCTRUST2_W32", 92 | "VALUERURAL_W32", 93 | "WHADVANT_W32", 94 | "WILLMOVE_W32" 95 | ], 96 | "W34": [ 97 | "EAT1_W34", 98 | "EAT2_W34", 99 | "EAT5B_W34", 100 | "EAT5C_W34", 101 | "EAT5D_W34", 102 | "EVOBIOA_W34", 103 | "EVOBIOB_W34", 104 | "EVOPERS3_W34", 105 | "EVOTHREE_W34", 106 | "FUD22_W34", 107 | "FUD24_W34", 108 | "FUD33B_W34", 109 | "FUD35_W34", 110 | "FUD37A_W34", 111 | "FUD37B_W34", 112 | "FUD37C_W34" 113 | ], 114 | "W36": [ 115 | "BETTERBIZ1F2D_W36", 116 | "BETTERBIZ1F2E_W36", 117 | "BETTERBIZ1F2H_W36", 118 | "BETTERBIZ2F2C_W36", 119 | "BETTERBIZ2F2E_W36", 120 | "BETTERBIZ2F2F_W36", 121 | "BETTERPOL1F1B_W36", 122 | "BETTERPOL1F1E_W36", 123 | "BETTERPOL1F1I_W36", 124 | "BETTERPOL2F1A_W36", 125 | "BETTERPOL2F1C_W36", 126 | "BETTERPOL2F1D_W36", 127 | "BETTERPOL2F1F_W36", 128 | "E5_W36", 129 | "ESSENBIZF2E_W36", 130 | "ESSENBIZF2J_W36", 131 | "ESSENPOLF1A_W36", 132 | "ESSENPOLF1B_W36", 133 | "ESSENPOLF1D_W36", 134 | "ESSENPOLF1G_W36", 135 | "ESSENPOLF1H_W36", 136 | "HIGHEDWRNGB_W36", 137 | "HIGHEDWRNGC_W36", 138 | "IMPROVE1_W36", 139 | "IMPROVE2_W36", 140 | "IMPROVE3_W36", 141 | "MOREWMN1F2_W36", 142 | "POLCHF1_W36", 143 | "TRAITBIZMF2C_W36", 144 | "TRAITBIZMF2D_W36", 145 | "TRAITBIZMF2E_W36", 146 | "TRAITBIZMF2G_W36", 147 | "TRAITBIZWF2C_W36", 148 | "TRAITBIZWF2D_W36", 149 | "TRAITBIZWF2E_W36", 150 | "TRAITBIZWF2G_W36", 151 | "TRAITPOLMF1B_W36", 152 | "TRAITPOLMF1C_W36", 153 | "TRAITPOLMF1D_W36", 154 | "TRAITPOLMF1E_W36", 155 | "TRAITPOLMF1F_W36", 156 | "TRAITPOLMF1G_W36", 157 | "TRAITPOLWF1A_W36", 158 | "TRAITPOLWF1C_W36", 159 | "TRAITPOLWF1D_W36", 160 | "TRAITPOLWF1E_W36", 161 | "TRAITPOLWF1F_W36", 162 | "TRAITPOLWF1G_W36", 163 | "WHYNOTBIZF2A_W36", 164 | "WHYNOTBIZF2B_W36", 165 | "WHYNOTBIZF2D_W36", 166 | "WHYNOTBIZF2F_W36", 167 | "WHYNOTBIZF2G_W36", 168 | "WHYNOTBIZF2H_W36", 169 | "WHYNOTBIZF2J_W36", 170 | "WHYNOTBIZF2M_W36", 171 | "WHYNOTBIZF2N_W36", 172 | "WHYNOTBIZF2O_W36", 173 | "WHYNOTPOLF1A_W36", 174 | "WHYNOTPOLF1B_W36", 175 | "WHYNOTPOLF1C_W36", 176 | "WHYNOTPOLF1E_W36", 177 | "WHYNOTPOLF1G_W36", 178 | "WHYNOTPOLF1I_W36", 179 | "WHYNOTPOLF1J_W36", 180 | "WHYNOTPOLF1K_W36", 181 | "WHYNOTPOLF1L_W36", 182 | "WMNPRZ1_W36" 183 | ], 184 | "W41": [ 185 | "AUTOLKLY_W41", 186 | "AUTOWKPLC_W41", 187 | "ELDFINANCEF1_W41", 188 | "ELDFINANCEF2_W41", 189 | "ETHNCMAJMOD_W41", 190 | "FTRWORRYc_W41", 191 | "FTRWORRYe_W41", 192 | "FTRWORRYf_W41", 193 | "FUTRCLASSb_W41", 194 | "FUTRCLASSc_W41", 195 | "FUTR_DIV_W41", 196 | "GOVPRIOb_W41", 197 | "GOVPRIOc_W41", 198 | "GOVPRIOe_W41", 199 | "GOVPRIOfF1_W41", 200 | "GOVPRIOgF1_W41", 201 | "GOVPRIOhF1_W41", 202 | "GOVPRIOiF1_W41", 203 | "GOVPRIOjF1_W41", 204 | "GOVPRIOkF2_W41", 205 | "GOVPRIOlF2_W41", 206 | "GOVPRIOmF2_W41", 207 | "GOVPRIOnF2_W41", 208 | "GOVPRIOoF2_W41", 209 | "HAPPEN2b_W41", 210 | "HAPPEN2f_W41", 211 | "HARASS1F1b_W41", 212 | "HARASS1F1d_W41", 213 | "HARASS1NOWRKF2d_W41", 214 | "HARASS3NOWRKF2_W41", 215 | "INTRMAR_W41", 216 | "LEGALIMG_W41", 217 | "SOLVPROBa_W41", 218 | "SOLVPROBb_W41", 219 | "SOLVPROBc_W41", 220 | "SOLVPROBdF1_W41", 221 | "SOLVPROBf_W41", 222 | "SOLVPROBg_W41", 223 | "SOLVPROBh_W41", 224 | "WRKTRN1F1_W41", 225 | "WRKTRN1F2_W41" 226 | ], 227 | "W42": [ 228 | "CONFb_W42", 229 | "CONFc_W42", 230 | "CONFe_W42", 231 | "CONFg_W42", 232 | "PQ4_F2Aa_W42", 233 | "PQ4_F2Ac_W42", 234 | "PQ4_F2Ae_W42", 235 | "PQ4_F2Bb_W42", 236 | "PQ4_F2Bc_W42", 237 | "PQ4_F2Bd_W42", 238 | "PQ4_F2Cc_W42", 239 | "PQ8_F2B_W42", 240 | "RQ1_F1B_W42", 241 | "RQ1_F1C_W42", 242 | "RQ2_F1A_W42", 243 | "RQ4_F1Ab_W42", 244 | "RQ4_F1Ae_W42", 245 | "RQ4_F1Ba_W42", 246 | "RQ4_F1Bc_W42", 247 | "RQ4_F1Bd_W42", 248 | "RQ4_F1Be_W42", 249 | "RQ4_F1Ca_W42", 250 | "RQ4_F1Cb_W42", 251 | "RQ4_F1Cd_W42", 252 | "RQ5_F1B_W42", 253 | "RQ5_F1C_W42" 254 | ], 255 | "W43": [ 256 | "IDIMPORT_W43", 257 | "RACESURV13_W43", 258 | "RACESURV14_W43", 259 | "RACESURV15a_W43", 260 | "RACESURV15b_W43", 261 | "RACESURV17_W43", 262 | "RACESURV18b_W43", 263 | "RACESURV18f_W43", 264 | "RACESURV19a_W43", 265 | "RACESURV19c_W43", 266 | "RACESURV19e_W43", 267 | "RACESURV1b_W43", 268 | "RACESURV21_W43", 269 | "RACESURV27_W43", 270 | "RACESURV28a_W43", 271 | "RACESURV28b_W43", 272 | "RACESURV28c_W43", 273 | "RACESURV28d_W43", 274 | "RACESURV28e_W43", 275 | "RACESURV28f_W43", 276 | "RACESURV28g_W43", 277 | "RACESURV29d_W43", 278 | "RACESURV2_W43", 279 | "RACESURV34a_W43", 280 | "RACESURV34b_W43", 281 | "RACESURV34c_W43", 282 | "RACESURV34d_W43", 283 | "RACESURV34e_W43", 284 | "RACESURV38_W43", 285 | "RACESURV40_W43", 286 | "RACESURV41_W43", 287 | "RACESURV45_W43", 288 | "RACESURV47a_W43", 289 | "RACESURV47b_W43", 290 | "RACESURV47d_W43", 291 | "RACESURV47e_W43", 292 | "RACESURV47f_W43", 293 | "RACESURV48_W43", 294 | "RACESURV51_W43", 295 | "RACESURV52_W43", 296 | "RACESURV5a_W43", 297 | "RACESURV5b_W43", 298 | "RACESURV5d_W43", 299 | "RACESURV5e_W43", 300 | "RACESURV5f_W43", 301 | "RACESURV5g_W43", 302 | "RACESURV5h_W43", 303 | "RACESURV5i_W43", 304 | "RACESURV5j_W43", 305 | "RACESURV5l_W43", 306 | "RACESURV9_W43" 307 | ], 308 | "W45": [ 309 | "INFOCHALa_W45", 310 | "INFOCREATEa_W45", 311 | "MADEUPOFT_W45", 312 | "MADEUPSHAREWHY_W45", 313 | "MADEUPTOPICb_W45", 314 | "MADEUPTOPICc_W45", 315 | "MADEUPTOPICd_W45", 316 | "NEWSPREFV2_W45", 317 | "NEWSPROBd_W45", 318 | "NEWSPROBe_W45", 319 | "NEWS_PLATFORMg_W45", 320 | "SMLIKESb_W45", 321 | "SMLIKESf_W45" 322 | ], 323 | "W49": [ 324 | "BENEFITGOV_W49", 325 | "CONCERNCO_W49", 326 | "CONCERNGRPa_W49", 327 | "CONCERNGRPc_W49", 328 | "CONTROLCO_W49", 329 | "CONTROLGRPc_W49", 330 | "DATAUSEd_W49", 331 | "FACE3c_W49", 332 | "GOVREGV1_W49", 333 | "HOMEASSIST2_W49", 334 | "HOMEASSIST3_W49", 335 | "HOMEASSIST4_W49", 336 | "PP5e_W49", 337 | "PWMAN2_W49", 338 | "SMARTAPP_W49", 339 | "TRACKCO1a_W49", 340 | "TRACKCO1b_W49", 341 | "TRACKGOV1a_W49", 342 | "TRACKGOV1b_W49" 343 | ], 344 | "W50": [ 345 | "COHABDUR_W50", 346 | "DNA2b_W50", 347 | "DNA5_W50", 348 | "E5MOD_W50", 349 | "FAMSURV10c_W50", 350 | "FAMSURV16_W50", 351 | "FAMSURV17_W50", 352 | "FAMSURV1_W50", 353 | "FAMSURV20_W50", 354 | "FAMSURV23b_W50", 355 | "FAMSURV23c_W50", 356 | "FAMSURV23d_W50", 357 | "FAMSURV23e_W50", 358 | "FAMSURV23f_W50", 359 | "FAMSURV23g_W50", 360 | "FAMSURV26a_W50", 361 | "FAMSURV26b_W50", 362 | "FAMSURV26c_W50", 363 | "FAMSURV26d_W50", 364 | "FAMSURV29_W50", 365 | "FAMSURV2Ma_W50", 366 | "FAMSURV2Mc_W50", 367 | "FAMSURV2Wa_W50", 368 | "FAMSURV2Wc_W50", 369 | "FAMSURV32e_W50", 370 | "FAMSURV39_W50", 371 | "FAMSURV3_W50", 372 | "FAMSURV40_W50", 373 | "FAMSURV43_W50", 374 | "FAMSURV44_W50", 375 | "FAMSURV6_W50", 376 | "FAMSURV9c_W50", 377 | "FATHER_W50", 378 | "HAVEKIDS1_W50", 379 | "MARRDUR_W50", 380 | "MARRYPREF1_W50", 381 | "MARRYPREF2_W50", 382 | "MOTHER_W50", 383 | "PAR1_W50", 384 | "PAR2_W50", 385 | "ROMRELDUR_W50", 386 | "ROMRELSER_W50", 387 | "SATLIFEc_W50" 388 | ], 389 | "W54": [ 390 | "ECIMPg_W54", 391 | "ECON1_W54", 392 | "ECON3_d_W54", 393 | "ECON4_a_W54", 394 | "ECON4_b_W54", 395 | "ECON4_e_W54", 396 | "ECON4_f_W54", 397 | "ECON4_g_W54", 398 | "ECON5_b_W54", 399 | "ECON5_c_W54", 400 | "ECON5_d_W54", 401 | "ECON5_e_W54", 402 | "ECON5_f_W54", 403 | "ECON5_g_W54", 404 | "ECON5_h_W54", 405 | "ECON5_i_W54", 406 | "ECON5_j_W54", 407 | "ECON5_k_W54", 408 | "FIN_SIT_W54", 409 | "GOVPRIORITYb_W54", 410 | "GOVPRIORITYc_W54", 411 | "GOVPRIORITYd_W54", 412 | "GOVPRIORITYf_W54", 413 | "INEQ11_W54", 414 | "INEQ1_W54", 415 | "INEQ4_a_W54", 416 | "INEQ5_a_W54", 417 | "INEQ5_d_W54", 418 | "INEQ5_e_W54", 419 | "INEQ5_f_W54", 420 | "INEQ5_h_W54", 421 | "INEQ5_i_W54", 422 | "INEQ5_k_W54", 423 | "INEQ5_l_W54", 424 | "INEQ5_m_W54", 425 | "INEQ8_a_W54", 426 | "INEQ8_b_W54", 427 | "INEQ8_c_W54", 428 | "INEQ8_d_W54", 429 | "INEQ8_f_W54", 430 | "INEQ8_g_W54", 431 | "INEQ8_h_W54", 432 | "INEQ8_i_W54", 433 | "INEQ8_j_W54", 434 | "JOBTRAIN_W54", 435 | "WORRY2a_W54", 436 | "WORRY2b_W54", 437 | "WORRY2c_W54", 438 | "WORRY2d_W54", 439 | "WORRY2e_W54" 440 | ], 441 | "W82": [ 442 | "GAP21Q10_W82", 443 | "GAP21Q13_b_W82", 444 | "GAP21Q15_a_W82", 445 | "GAP21Q15_b_W82", 446 | "GAP21Q15_c_W82", 447 | "GAP21Q15_d_W82", 448 | "GAP21Q15_e_W82", 449 | "GAP21Q15_f_W82", 450 | "GAP21Q17_W82", 451 | "GAP21Q19_a_W82", 452 | "GAP21Q19_b_W82", 453 | "GAP21Q19_c_W82", 454 | "GAP21Q19_d_W82", 455 | "GAP21Q19_e_W82", 456 | "GAP21Q21_a_W82", 457 | "GAP21Q21_e_W82", 458 | "GAP21Q23_W82", 459 | "GAP21Q24_W82", 460 | "GAP21Q25_W82", 461 | "GAP21Q26_a_W82", 462 | "GAP21Q26_b_W82", 463 | "GAP21Q26_c_W82", 464 | "GAP21Q27_W82", 465 | "GAP21Q28_W82", 466 | "GAP21Q31_W82", 467 | "GAP21Q32_W82", 468 | "GAP21Q33_c_W82", 469 | "GAP21Q33_g_W82", 470 | "GAP21Q33_i_W82", 471 | "GAP21Q33_j_W82", 472 | "GAP21Q33_m_W82", 473 | "GAP21Q33_n_W82", 474 | "GAP21Q33_q_W82", 475 | "GAP21Q33_r_W82", 476 | "GAP21Q33_s_W82", 477 | "GAP21Q34_a_W82", 478 | "GAP21Q34_b_W82", 479 | "GAP21Q34_c_W82", 480 | "GAP21Q34_d_W82", 481 | "GAP21Q34_e_W82", 482 | "GAP21Q34_f_W82", 483 | "GAP21Q38_a_W82", 484 | "GAP21Q38_b_W82", 485 | "GAP21Q38_c_W82", 486 | "GAP21Q41_W82", 487 | "GAP21Q43_g_W82", 488 | "GAP21Q47_W82", 489 | "GAP21Q4_a_W82", 490 | "GAP21Q4_b_W82", 491 | "GAP21Q4_c_W82", 492 | "GAP21Q4_d_W82", 493 | "GAP21Q4_e_W82", 494 | "GAP21Q4_f_W82", 495 | "GAP21Q7_a_W82", 496 | "GAP21Q7_b_W82", 497 | "GAP21Q7_d_W82" 498 | ], 499 | "W92": [ 500 | "BILLION_W92", 501 | "CANDEXP_W92", 502 | "ELECT_CONF3_PRVFR_W92", 503 | "ELECT_CONF3_PRVSUP_W92", 504 | "ELECT_IMPT3_PRVFR_W92", 505 | "FREECOLL_W92", 506 | "GODMORALIMP_W92", 507 | "LEGALIMMIGAMT_W92", 508 | "POLINTOL2_a_W92", 509 | "PROG_RNEED_W92", 510 | "RACESURV52MOD_W92", 511 | "REPRSNTDEM_W92", 512 | "REPRSNTREP_W92", 513 | "SOCIETY_GUNS_W92", 514 | "SOCIETY_RELG_W92", 515 | "SOCIETY_RHIST_W92", 516 | "SOCIETY_SSM_W92", 517 | "SOCIETY_TRANS_W92", 518 | "SOCIETY_WHT_W92", 519 | "UNIMMIGCOMM_W92", 520 | "USEXCEPT_W92", 521 | "USMILSIZ_W92", 522 | "WHADVANT_W92" 523 | ] 524 | } -------------------------------------------------------------------------------- /scripts/data_generation/generate_distribution.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import argparse 3 | import itertools 4 | import json 5 | import os 6 | import pathlib 7 | from concurrent.futures import ProcessPoolExecutor, as_completed 8 | from typing import List, Dict, Tuple, Any, Union, Optional 9 | 10 | import pyreadstat 11 | import pandas as pd 12 | import numpy as np 13 | 14 | from subpop.utils.survey_utils import list_normalize 15 | from subpop.utils.surveydata_utils import ( 16 | ActualSurveyData, 17 | REMOVED_WAVES, 18 | PROHIBITED_WAVES, 19 | PROHIBITED_QKEY_PREFIXES, 20 | GSS_ATTRIBUTE_TO_VARIABLE_MAP, 21 | GSS_VALUE_TO_LABEL_MAP, 22 | ) 23 | 24 | """ 25 | explanations of constants: 26 | PROHIBITED_WAVES: 27 | waves included in OpinionQA dataset, not used for SubPOP-train 28 | REMOVED_WAVES: 29 | waves missing at least one of demographic attribute information (checked by manual inspection) 30 | PROHIBITED_QKEY_PREFIXES: 31 | qkeys that are parsed as ask-all questions, but are actually demographic / ideology label 32 | as we give demographic / ideology label as steering prompts, these questions are not used. 33 | GSS_ATTRIBUTE_TO_VARIABLE_MAP: 34 | mapping from attribute (ex. race or ethnicity) to variable name in GSS dataset (ex. raceacs) 35 | GSS_VALUE_TO_LABEL_MAP: 36 | mapping from variable name (ex. less than high school) to label (ex. 0,1,2,3,4,...,11) 37 | """ 38 | 39 | REPO_ROOT = pathlib.Path(__file__).resolve().parents[2] 40 | 41 | 42 | def process_qkey( 43 | qkey: str, 44 | refined_qkey_dict: Dict[str, Dict[str, str]], 45 | error_qkeys_list: List[str], 46 | attribute_group_pair: List[Union[Tuple[str, str], Tuple[List[str], List[str]]]], 47 | atp_waves_list: List[str], 48 | surveydata_dir: str, 49 | ): 50 | """ 51 | Process a single qkey to generate a dataframe with necessary information 52 | List of necessary information: 53 | qkey (str) : question identifier 54 | attribute (str) : demographic / ideology attribute (ex. Age) 55 | group (str) : group (ex. 18-29) 56 | responses (str) : response distribution list, saved as string 57 | refusal_rate (float) : refusal rate 58 | ordinal (str) : ordinality information, saved as string (ex. [1.0, 2.0]) 59 | question (str) : question text, refined by LLM 60 | options (str) : options list, saved as string (ex. ['Agree', 'Disagree', 'Refused]) 61 | Args: 62 | qkey : question identifier to process 63 | refined_qkey_dict : dict. of original (from raw data) and refined question (by LLM) 64 | error_qkeys_list : list of qkeys with error (missing raw data, parsed incorrectly, etc.) 65 | attribute_group_pair : list of tuples with attribute and group information 66 | atp_waves_list : list of waves that the qkeys may belong to 67 | surveydata_dir : directory path to the survey data 68 | Returns: 69 | data_to_append (List[Dict[str, Any]]) : list of dictionaries with necessary info. 70 | """ 71 | 72 | print(f"--> process_qkey: Working on {qkey}") 73 | surveydata = ActualSurveyData( 74 | wave_list=[int(qkey.split("W")[-1])], 75 | bank_qkeys=set(), 76 | query_qkeys=set(), 77 | refined_qbody_data=refined_qkey_dict, 78 | data_dir=surveydata_dir, 79 | ) 80 | 81 | question = refined_qkey_dict[qkey]["refined_qbody"] 82 | options = list(surveydata.fetch_options(qkey).values()) 83 | options = [option.strip() for option in options] 84 | question_data = [] 85 | for attribute, group in attribute_group_pair: 86 | if len(attribute) == 1: 87 | attribute = attribute[0] 88 | group = group[0] 89 | try: 90 | responses = list( 91 | surveydata.fetch_response_distribution( 92 | qkey, attribute, group, remove_refusal=False 93 | ).values() 94 | ) 95 | if responses is None: 96 | continue 97 | refusal_rate = responses[-1] 98 | responses = list_normalize(responses[:-1]) 99 | ordinal = [1.0] * len(responses) 100 | question_data.append( 101 | { 102 | "qkey": qkey, 103 | "attribute": str(attribute), 104 | "group": str(group), 105 | "responses": str(responses), 106 | "refusal_rate": refusal_rate, 107 | "ordinal": str(ordinal), 108 | "question": question, 109 | "options": str(options), 110 | } 111 | ) 112 | except Exception as e: 113 | print(f"--> process_qkey: error on {qkey} with {attribute}, {group}: {e}") 114 | continue 115 | return question_data 116 | 117 | 118 | def generate_combined_pairs(loaded_pair, n_combination): 119 | """Generate a list of combined (joint) demographics according to n_combination""" 120 | attribute_to_groups = {} 121 | for attr, group in loaded_pair: 122 | if attr not in attribute_to_groups: 123 | attribute_to_groups[attr] = [] 124 | attribute_to_groups[attr].append(group) 125 | attribute_combinations = list( 126 | itertools.combinations(attribute_to_groups.keys(), n_combination) 127 | ) 128 | combined_pairs = [] 129 | for attributes in attribute_combinations: 130 | group_combinations = itertools.product( 131 | *[attribute_to_groups[attr] for attr in attributes] 132 | ) 133 | for groups in group_combinations: 134 | combined_pairs.append((list(attributes), list(groups))) 135 | return combined_pairs 136 | 137 | 138 | def generate_distribution_gss( 139 | qkey: str, 140 | option_list: List[str], 141 | attribute: str, 142 | group: str, 143 | surveydata: pd.DataFrame, 144 | meta: Any, 145 | ) -> Optional[Dict[str, float]]: 146 | """ 147 | Get a response distribution for a given GSS question and subpopualtion. 148 | Args: 149 | qkey : question identifier string 150 | attribute : demographic / ideology attribute (ex. Age) 151 | group : group (ex. 18-29) 152 | surveydata : survey data to fetch response distribution 153 | Returns: 154 | list: response distribution list 155 | """ 156 | 157 | # get the response and weight for the entire population 158 | try: 159 | surveydata_q = surveydata[qkey].values 160 | surveydata_min_label = np.nanmin(surveydata_q) 161 | except Exception as e: 162 | print(f"--> generate_distribution_gss : {qkey} data does not exist.") 163 | return None 164 | weight = surveydata["wtssnrps"].values 165 | 166 | # get the response and weight for the subpopulation 167 | attribute_coded = GSS_ATTRIBUTE_TO_VARIABLE_MAP[attribute] 168 | group_coded = GSS_VALUE_TO_LABEL_MAP[attribute][group] 169 | if attribute == "RACE": 170 | # due to how the race/ethnicity is coded in GSS, special handling required 171 | subpop_index = np.array([], dtype=int) 172 | for group_code in group_coded: 173 | subpop_index = np.append( 174 | subpop_index, 175 | np.where(surveydata[attribute_coded + str(group_code)] == 1)[0], 176 | ) 177 | subpop_index = np.unique(subpop_index) 178 | else: 179 | subpop_index = np.where(np.isin(surveydata[attribute_coded], group_coded))[0] 180 | surveydata_q = surveydata_q[subpop_index] 181 | weight = weight[subpop_index] 182 | 183 | # aggregrate individual response to get a response distribution 184 | resp_dist = [0.0 for _ in range(len(option_list))] 185 | for resp_choice, resp_weight in zip(surveydata_q, weight): 186 | if isinstance(resp_choice, int) and not np.isnan(resp_weight): 187 | resp_dist[resp_choice - surveydata_min_label] += resp_weight 188 | sum_weights = sum(resp_dist) 189 | resp_dist = [resp / sum_weights for resp in resp_dist] 190 | return {option_list[i]: resp_dist[i] for i in range(len(option_list))} 191 | 192 | 193 | def generate_distribution_subpop_eval(args): 194 | 195 | # load the refined question body and option list dictionary 196 | with open(args.refined_qkey_dict_path, "r") as f: 197 | refined_qkey_dict = json.load(f) 198 | 199 | # load the survey response data. 200 | # Note: this survey data has to be downloaded from the official website! 201 | surveydata, meta = pyreadstat.read_dta( 202 | REPO_ROOT / "data" / "subpop-eval" / "GSS2022.dta" 203 | ) 204 | 205 | # generate a list of (attribute, group) tuples 206 | attribute_group_pair: List[Tuple[str, str]] = [] 207 | demographics_data = pd.read_csv(os.path.join(args.demographics_data_path)) 208 | for idx, row in demographics_data.iterrows(): 209 | attribute = row["attribute"] 210 | group_list = ast.literal_eval(row["group"]) 211 | for group in group_list: 212 | attribute_group_pair.append((attribute, group)) 213 | attribute_group_pair = generate_combined_pairs( 214 | attribute_group_pair, args.n_combination 215 | ) 216 | del demographics_data 217 | 218 | # for each question and subpopulation, generate a response distribution 219 | question_data = [] 220 | for qkey in refined_qkey_dict.keys(): 221 | question = refined_qkey_dict[qkey]["refined_qbody"] 222 | options = refined_qkey_dict[qkey]["option_list"] 223 | ordinal = refined_qkey_dict[qkey]["ordinal"] 224 | 225 | for attribute, group in attribute_group_pair: 226 | attribute = attribute[0] if len(attribute) == 1 else attribute 227 | group = group[0] if len(group) == 1 else group 228 | responses = generate_distribution_gss( 229 | qkey=qkey, 230 | option_list=options, 231 | attribute=attribute, 232 | group=group, 233 | surveydata=surveydata, 234 | meta=meta, 235 | ) 236 | refusal_rate = responses.get("Refused", 0.0) 237 | responses = list_normalize(list(responses.values())[:-1]) 238 | question_data.append( 239 | { 240 | "qkey": qkey, 241 | "attribute": str(attribute), 242 | "group": str(group), 243 | "responses": str(responses), 244 | "refusal_rate": refusal_rate, 245 | "ordinal": str(ordinal), 246 | "question": question, 247 | "options": str(options), 248 | } 249 | ) 250 | 251 | surveydata_df = pd.DataFrame(question_data) 252 | surveydata_df.to_csv(args.output_path, index=False) 253 | 254 | 255 | def generate_distribution_subpop_train(args): 256 | """Generate response distribution for each subpopulation and quesetion in SubPOP-train""" 257 | # load the list of qkeys with error during question text refining step 258 | if os.path.exists(args.error_qkeys_list_path): 259 | with open(args.error_qkeys_list_path, "r") as f: 260 | error_qkeys_list = json.load(f) 261 | else: 262 | error_qkeys_list = [] 263 | 264 | # load the refined question body dictionary 265 | with open(args.refined_qkey_dict_path, "r") as f: 266 | refined_qkey_dict = json.load(f) 267 | refined_qkey_dict = { 268 | qkey: refined_qkey_dict[qkey] 269 | for qkey in refined_qkey_dict 270 | if ( 271 | (qkey not in error_qkeys_list) 272 | and (qkey.split("_W")[0].lower() not in PROHIBITED_QKEY_PREFIXES) 273 | and (int(qkey.split("_W")[-1]) not in PROHIBITED_WAVES + REMOVED_WAVES) 274 | ) 275 | } 276 | 277 | # generate a list of (attribute, group) tuples 278 | attribute_group_pair: List[Tuple[str, str]] = [] 279 | demographics_data = pd.read_csv(os.path.join(args.demographics_data_path)) 280 | for idx, row in demographics_data.iterrows(): 281 | attribute = row["attribute"] 282 | group_list = ast.literal_eval(row["group"]) 283 | for group in group_list: 284 | attribute_group_pair.append((attribute, group)) 285 | attribute_group_pair = generate_combined_pairs( 286 | attribute_group_pair, args.n_combination 287 | ) 288 | del demographics_data 289 | 290 | # process each qkey to generate a list of dictionaries with necessary information 291 | surveydata_list = [] 292 | with ProcessPoolExecutor(max_workers=args.n_workers) as executor: 293 | futures = { 294 | executor.submit( 295 | process_qkey, 296 | qkey=qkey, 297 | refined_qkey_dict={qkey: refined_qkey_dict[qkey]}, 298 | error_qkeys_list=error_qkeys_list, 299 | attribute_group_pair=attribute_group_pair, 300 | atp_waves_list=[int(qkey.split("W")[-1])], 301 | surveydata_dir=pathlib.Path(args.refined_qkey_dict_path).parent.parent, 302 | ): qkey for qkey in refined_qkey_dict.keys() 303 | } 304 | for future in as_completed(futures): 305 | try: 306 | result = future.result() 307 | if result: 308 | surveydata_list.extend(result) 309 | except Exception as e: 310 | print(f"--> main: error processing qkey {futures[future]}: {e}") 311 | 312 | surveydata_df = pd.DataFrame(surveydata_list) 313 | surveydata_df.to_csv(args.output_path, index=False) 314 | 315 | 316 | def cli_args(): 317 | parser = argparse.ArgumentParser() 318 | parser.add_argument( 319 | "--dataset", 320 | type=str, 321 | default="subpop-train", 322 | help="dataset name", 323 | ) 324 | parser.add_argument( 325 | "--n_workers", 326 | type=int, 327 | default=1, 328 | help="Number of workers to use for multiprocessing", 329 | ) 330 | parser.add_argument( 331 | "--n_combination", 332 | type=int, 333 | default=1, 334 | help="Number of group combinations to consider (default to 1)", 335 | ) 336 | parser.add_argument( 337 | "--demographics_data_path", 338 | type=str, 339 | default=REPO_ROOT / "data" / "subpopulation_metadata" / "demographics_22.csv", 340 | help="Path to the demographics data", 341 | ) 342 | return parser.parse_args() 343 | 344 | 345 | if __name__ == "__main__": 346 | args = cli_args() 347 | dataset_name = args.dataset 348 | args.output_path = REPO_ROOT / "data" / dataset_name / "processed" / f"{dataset_name}.csv" 349 | args.error_qkeys_list_path = REPO_ROOT / "data" / dataset_name / "processed" / "error_qkeys_list.json" 350 | args.refined_qkey_dict_path = REPO_ROOT / "data" / dataset_name / "processed" / "refined_qkey_dict.json" 351 | 352 | if dataset_name == "subpop-train": 353 | generate_distribution_subpop_train(args) 354 | elif dataset_name == "subpop-eval": 355 | generate_distribution_subpop_eval(args) 356 | elif dataset_name == 'opinionqa': 357 | raise ValueError( 358 | f"--> main: dataset {args.dataset} is provided by OpinionQA." 359 | " Please refer to https://github.com/tatsu-lab/opinions_qa." 360 | ) 361 | else: 362 | raise ValueError(f"--> main: invalid dataset name {args.dataset}") 363 | --------------------------------------------------------------------------------