├── simplex-diffusion-main ├── sdlm │ ├── .lock │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ ├── split_data.py │ │ ├── gpt_eval_split.py │ │ ├── postprocessors.py │ │ ├── process_data.py │ │ ├── preprocessors.py │ │ ├── data_collator.py │ │ └── data_utils.py │ ├── metrics │ │ ├── __init__.py │ │ ├── repetition.py │ │ ├── metrics.py │ │ └── perplexity.py │ ├── inference │ │ ├── __init__.py │ │ └── inference_utils.py │ ├── models │ │ ├── gpt2 │ │ │ ├── __init__.py │ │ │ └── configuration_gpt2.py │ │ ├── h3 │ │ │ ├── __init__.py │ │ │ ├── ops │ │ │ │ ├── __init__.py │ │ │ │ ├── vandermonde.py │ │ │ │ ├── toeplitz.py │ │ │ │ ├── fftconv.py │ │ │ │ └── krylov.py │ │ │ ├── ssm │ │ │ │ ├── __init__.py │ │ │ │ ├── ssm_utils.py │ │ │ │ ├── ss_kernel_shift.py │ │ │ │ ├── dplr.py │ │ │ │ ├── ss_kernel.py │ │ │ │ ├── h3.py │ │ │ │ ├── hippo.py │ │ │ │ └── ss_kernel_diag.py │ │ │ ├── utils │ │ │ │ └── utils.py │ │ │ └── configuration_h3.py │ │ ├── longformer │ │ │ ├── __init__.py │ │ │ └── configuration_longformer.py │ │ ├── roberta │ │ │ ├── __init__.py │ │ │ └── configuration_roberta.py │ │ ├── xlm_roberta │ │ │ ├── __init__.py │ │ │ ├── configuration_xlm_roberta.py │ │ │ └── modeling_xlm_roberta.py │ │ ├── __init__.py │ │ └── utils.py │ ├── schedulers │ │ ├── __init__.py │ │ └── scheduling_simplex_ddpm.py │ ├── utils.py │ └── run_mlm.py ├── .gitmodules ├── configs │ └── models │ │ └── cnndm │ │ ├── roberta.json │ │ ├── roberta_train_extend.json │ │ ├── roberta_train.json │ │ └── roberta_control.json ├── commands_for_sum │ └── run_sum.sh ├── .gitignore └── env.yml ├── LICENSE └── README.md /simplex-diffusion-main/sdlm/.lock: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/inference/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/models/gpt2/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/models/h3/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/models/h3/ops/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/models/h3/ssm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/models/longformer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/models/roberta/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/models/xlm_roberta/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | from .scheduling_simplex_ddpm import SimplexDDPMScheduler 2 | -------------------------------------------------------------------------------- /simplex-diffusion-main/.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "flash-attention"] 2 | path = flash-attention 3 | url = git@github.com:HazyResearch/flash-attention.git 4 | -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .roberta.configuration_roberta import RobertaDiffusionConfig 2 | from .roberta.modeling_roberta import RobertaForDiffusionLM 3 | from .utils import load_model 4 | from .xlm_roberta.configuration_xlm_roberta import XLMRobertaDiffusionConfig 5 | from .xlm_roberta.modeling_xlm_roberta import XLMRobertaForDiffusionLM 6 | 7 | __all__ = ( 8 | "RobertaDiffusionConfig", 9 | "RobertaForDiffusionLM", 10 | "XLMRobertaDiffusionConfig", 11 | "XLMRobertaForDiffusionLM", 12 | "load_model", 13 | ) 14 | -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/data/split_data.py: -------------------------------------------------------------------------------- 1 | from datasets import DatasetDict, load_from_disk 2 | 3 | tokenized_data_path = ( 4 | "/home/lily/jt856/documents/simplex-diffusion/processed_data/openwebtext_50" 5 | ) 6 | output_dir = ( 7 | "/home/lily/jt856/documents/simplex-diffusion/processed_data/openwebtext_50_split" 8 | ) 9 | seed = 42 10 | validation_split_ratio = 0.001 11 | 12 | tokenized_datasets = load_from_disk(tokenized_data_path) 13 | train_testvalid = tokenized_datasets["train"].train_test_split( 14 | test_size=validation_split_ratio, shuffle=True, seed=seed 15 | ) 16 | tokenized_datasets = DatasetDict( 17 | {"train": train_testvalid["train"], "validation": train_testvalid["test"]} 18 | ) 19 | tokenized_datasets.save_to_disk(output_dir) 20 | -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/data/gpt_eval_split.py: -------------------------------------------------------------------------------- 1 | from datasets import DatasetDict, load_from_disk 2 | 3 | tokenized_data_path = ( 4 | "/home/lily/jt856/documents/simplex-diffusion/processed_data/openwebtext_256_split" 5 | ) 6 | output_dir = "/home/lily/jt856/documents/simplex-diffusion/processed_data/openwebtext_256_split_gpt_eval" 7 | seed = 42 8 | tokenized_datasets = load_from_disk(tokenized_data_path) 9 | validation_split_ratio = 0.1414827391058291 10 | train_testvalid = tokenized_datasets["validation"].train_test_split( 11 | test_size=validation_split_ratio, shuffle=True, seed=seed 12 | ) 13 | tokenized_datasets = DatasetDict( 14 | {"train": tokenized_datasets["train"], "validation": train_testvalid["test"]} 15 | ) 16 | tokenized_datasets.save_to_disk(output_dir) 17 | -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/models/h3/ssm/ssm_utils.py: -------------------------------------------------------------------------------- 1 | # TD: [2023-01-05]: Extracted the OptimModule class from 2 | # https://github.com/HazyResearch/state-spaces/blob/06dbbdfd0876501a7f12bf3262121badbc7658af/src/models/sequence/ss/kernel.py 3 | 4 | import torch.nn as nn 5 | 6 | 7 | class OptimModule(nn.Module): 8 | """Interface for Module that allows registering buffers/parameters with configurable optimizer hyperparameters""" 9 | 10 | def register(self, name, tensor, lr=None): 11 | """Register a tensor with a configurable learning rate and 0 weight decay""" 12 | 13 | if lr == 0.0: 14 | self.register_buffer(name, tensor) 15 | else: 16 | self.register_parameter(name, nn.Parameter(tensor)) 17 | 18 | optim = {"weight_decay": 0.0} 19 | if lr is not None: 20 | optim["lr"] = lr 21 | setattr(getattr(self, name), "_optim", optim) 22 | -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/models/xlm_roberta/configuration_xlm_roberta.py: -------------------------------------------------------------------------------- 1 | """Adapted XLM Roberta configuration for diffusion models.""" 2 | 3 | from typing import Optional 4 | 5 | from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig 6 | 7 | 8 | class XLMRobertaDiffusionConfig(XLMRobertaConfig): 9 | def __init__( 10 | self, 11 | self_condition: Optional[str] = None, 12 | self_condition_zeros_after_softmax: bool = False, 13 | deepmind_conditional: bool = False, 14 | classifier_free_simplex_inputs: bool = False, 15 | self_condition_mlp_projection=False, 16 | **kwargs 17 | ): 18 | super().__init__(**kwargs) 19 | self.self_condition = self_condition 20 | self.self_condition_zeros_after_softmax = self_condition_zeros_after_softmax 21 | self.deepmind_conditional = deepmind_conditional 22 | self.classifier_free_simplex_inputs = classifier_free_simplex_inputs 23 | self.self_condition_mlp_projection = self_condition_mlp_projection 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Yuhan Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/models/gpt2/configuration_gpt2.py: -------------------------------------------------------------------------------- 1 | """Adapted GPT-2 configuration for diffusion models.""" 2 | 3 | from typing import Optional 4 | 5 | from transformers.models.gpt2.configuration_gpt2 import GPT2Config 6 | 7 | 8 | class RobertaDiffusionConfig(GPT2Config): 9 | def __init__( 10 | self, 11 | self_condition: Optional[str] = None, 12 | self_condition_zeros_after_softmax: bool = False, 13 | deepmind_conditional: bool = False, 14 | classifier_free_simplex_inputs: bool = False, 15 | classifier_free_uncond_input: str = "empty_token", 16 | self_condition_mlp_projection=False, 17 | self_condition_mix_before_weights=False, 18 | self_condition_mix_logits_before_weights=False, 19 | empty_token_be_mask=False, 20 | **kwargs, 21 | ): 22 | super().__init__(**kwargs) 23 | self.self_condition = self_condition 24 | self.self_condition_zeros_after_softmax = self_condition_zeros_after_softmax 25 | self.deepmind_conditional = deepmind_conditional 26 | self.classifier_free_simplex_inputs = classifier_free_simplex_inputs 27 | self.classifier_free_uncond_input = classifier_free_uncond_input 28 | self.self_condition_mlp_projection = self_condition_mlp_projection 29 | self.self_condition_mix_before_weights = self_condition_mix_before_weights 30 | self.self_condition_mix_logits_before_weights = ( 31 | self_condition_mix_logits_before_weights 32 | ) 33 | self.empty_token_be_mask = empty_token_be_mask 34 | -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/models/roberta/configuration_roberta.py: -------------------------------------------------------------------------------- 1 | """Adapted Roberta configuration for diffusion models.""" 2 | 3 | from typing import Optional 4 | 5 | from transformers.models.roberta.configuration_roberta import RobertaConfig 6 | 7 | 8 | class RobertaDiffusionConfig(RobertaConfig): 9 | def __init__( 10 | self, 11 | self_condition: Optional[str] = None, 12 | self_condition_zeros_after_softmax: bool = False, 13 | deepmind_conditional: bool = False, 14 | classifier_free_simplex_inputs: bool = False, 15 | classifier_free_uncond_input: str = "empty_token", 16 | self_condition_mlp_projection=False, 17 | self_condition_mix_before_weights=False, 18 | self_condition_mix_logits_before_weights=False, 19 | empty_token_be_mask=False, 20 | **kwargs, 21 | ): 22 | super().__init__(**kwargs) 23 | self.self_condition = self_condition 24 | self.self_condition_zeros_after_softmax = self_condition_zeros_after_softmax 25 | self.deepmind_conditional = deepmind_conditional 26 | self.classifier_free_simplex_inputs = classifier_free_simplex_inputs 27 | self.classifier_free_uncond_input = classifier_free_uncond_input 28 | self.self_condition_mlp_projection = self_condition_mlp_projection 29 | self.self_condition_mix_before_weights = self_condition_mix_before_weights 30 | self.self_condition_mix_logits_before_weights = ( 31 | self_condition_mix_logits_before_weights 32 | ) 33 | self.empty_token_be_mask = empty_token_be_mask 34 | -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/models/longformer/configuration_longformer.py: -------------------------------------------------------------------------------- 1 | """Adapted Roberta configuration for diffusion models.""" 2 | 3 | from typing import Optional 4 | 5 | from transformers.models.longformer.configuration_longformer import LongformerConfig 6 | 7 | 8 | class LongformerDiffusionConfig(LongformerConfig): 9 | def __init__( 10 | self, 11 | self_condition: Optional[str] = None, 12 | self_condition_zeros_after_softmax: bool = False, 13 | deepmind_conditional: bool = False, 14 | classifier_free_simplex_inputs: bool = False, 15 | classifier_free_uncond_input: str = "empty_token", 16 | self_condition_mlp_projection=False, 17 | self_condition_mix_before_weights=False, 18 | self_condition_mix_logits_before_weights=False, 19 | empty_token_be_mask=False, 20 | **kwargs, 21 | ): 22 | super().__init__(**kwargs) 23 | self.self_condition = self_condition 24 | self.self_condition_zeros_after_softmax = self_condition_zeros_after_softmax 25 | self.deepmind_conditional = deepmind_conditional 26 | self.classifier_free_simplex_inputs = classifier_free_simplex_inputs 27 | self.classifier_free_uncond_input = classifier_free_uncond_input 28 | self.self_condition_mlp_projection = self_condition_mlp_projection 29 | self.self_condition_mix_before_weights = self_condition_mix_before_weights 30 | self.self_condition_mix_logits_before_weights = ( 31 | self_condition_mix_logits_before_weights 32 | ) 33 | self.empty_token_be_mask = empty_token_be_mask 34 | -------------------------------------------------------------------------------- /simplex-diffusion-main/configs/models/cnndm/roberta.json: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "/home2/yhliu/simplex-diffusion-main/raw/cnndm/roberta100/out-125000", 3 | "model_name_or_path": "/home2/yhliu/simplex-diffusion-main/raw/cnndm/roberta100/checkpoint-125000", 4 | "dataset_name":"/home2/yhliu/controlpoli/ssd-lm/cnn_dm_500_raw", 5 | "cache_dir":"/home2/yhliu", 6 | "preprocessing_num_workers":16, 7 | "overwrite_cache":false, 8 | "per_device_train_batch_size": 8, 9 | "per_device_eval_batch_size": 16, 10 | "do_train": false, 11 | "do_eval": false, 12 | "do_predict":true, 13 | "evaluation_strategy": "no", 14 | "eval_steps": 1000, 15 | "report_to": "tensorboard", 16 | "overwrite_output_dir": false, 17 | "max_seq_length": 512, 18 | "max_target_length":100, 19 | "max_source_length":412, 20 | "val_max_target_length":100, 21 | "skip_special_tokens":true, 22 | "max_eval_samples": 48, 23 | "max_predict_samples": 500, 24 | "simplex_value": 5.0, 25 | "num_diffusion_steps": 5000, 26 | "num_inference_diffusion_steps": 1000, 27 | "lr_scheduler_type": "linear", 28 | "learning_rate": 1e-4, 29 | "pad_to_max_length": true, 30 | "beta_schedule": "squaredcos_improved_ddpm", 31 | "weight_decay": 0.0, 32 | "top_p": 0.95, 33 | "max_steps": 200000, 34 | "gradient_accumulation_steps": 1, 35 | "warmup_steps": 2000, 36 | "logging_steps": 50, 37 | "save_steps": 1000, 38 | "conditional_generation": "ul2", 39 | "save_total_limit": 1, 40 | "tokenized_data_path": "/home2/yhliu/simplex-diffusion-main/raw/cnndm/roberta", 41 | "metric_for_best_model": "pred_texts_from_logits_masked_rouge1" 42 | 43 | 44 | } 45 | -------------------------------------------------------------------------------- /simplex-diffusion-main/configs/models/cnndm/roberta_train_extend.json: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "/home2/yhliu/simplex-diffusion-main/raw/cnndm/roberta_extend", 3 | "model_name_or_path": "roberta-base", 4 | "dataset_name":"cnn_dailymail", 5 | "dataset_config_name":"3.0.0", 6 | "cache_dir":"/home2/yhliu", 7 | "preprocessing_num_workers":16, 8 | "overwrite_cache":false, 9 | "per_device_train_batch_size": 8, 10 | "per_device_eval_batch_size": 8, 11 | "do_train": true, 12 | "do_eval": true, 13 | "do_predict":false, 14 | "evaluation_strategy": "steps", 15 | "eval_steps": 1000, 16 | "max_seq_length": 1024, 17 | "max_target_length":120, 18 | "max_source_length":904, 19 | "val_max_target_length":50, 20 | "skip_special_tokens":true, 21 | "max_eval_samples": 48, 22 | "max_predict_samples": 48, 23 | "simplex_value": 5.0, 24 | "num_diffusion_steps": 5000, 25 | "num_inference_diffusion_steps": 1000, 26 | "lr_scheduler_type": "linear", 27 | "learning_rate": 3e-5, 28 | "pad_to_max_length": true, 29 | "beta_schedule": "squaredcos_improved_ddpm", 30 | "weight_decay": 0.0, 31 | "top_p": 0.99, 32 | "max_steps": 200000, 33 | "gradient_accumulation_steps": 4, 34 | "warmup_steps": 2000, 35 | "logging_steps": 50, 36 | "save_steps": 1000, 37 | "conditional_generation": "ul2", 38 | "save_total_limit": 3, 39 | "tokenized_data_path": "/home2/yhliu/simplex-diffusion-main/raw/cnndm/roberta", 40 | "report_to":"tensorboard", 41 | "load_best_model_at_end": true, 42 | "overwrite_output_dir":false, 43 | "metric_for_best_model": "pred_texts_from_logits_masked_rouge1", 44 | "self_condition":"logits_mean" 45 | 46 | } 47 | -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/models/h3/utils/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from pytorch_lightning.utilities import rank_zero_only 4 | 5 | 6 | # Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging 7 | class LoggingContext: 8 | def __init__(self, logger, level=None, handler=None, close=True): 9 | self.logger = logger 10 | self.level = level 11 | self.handler = handler 12 | self.close = close 13 | 14 | def __enter__(self): 15 | if self.level is not None: 16 | self.old_level = self.logger.level 17 | self.logger.setLevel(self.level) 18 | if self.handler: 19 | self.logger.addHandler(self.handler) 20 | 21 | def __exit__(self, et, ev, tb): 22 | if self.level is not None: 23 | self.logger.setLevel(self.old_level) 24 | if self.handler: 25 | self.logger.removeHandler(self.handler) 26 | if self.handler and self.close: 27 | self.handler.close() 28 | # implicit return of None => don't swallow exceptions 29 | 30 | 31 | def get_logger(name=__name__) -> logging.Logger: 32 | """Initializes multi-GPU-friendly python logger.""" 33 | 34 | logger = logging.getLogger(name) 35 | 36 | # this ensures all logging levels get marked with the rank zero decorator 37 | # otherwise logs would get multiplied for each GPU process in multi-GPU setup 38 | for level in ( 39 | "debug", 40 | "info", 41 | "warning", 42 | "error", 43 | "exception", 44 | "fatal", 45 | "critical", 46 | ): 47 | setattr(logger, level, rank_zero_only(getattr(logger, level))) 48 | 49 | return logger 50 | -------------------------------------------------------------------------------- /simplex-diffusion-main/configs/models/cnndm/roberta_train.json: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "/home2/yhliu/simplex-diffusion-main/raw/cnndm/roberta100", 3 | "model_name_or_path": "roberta-base", 4 | "dataset_name":"cnn_dailymail", 5 | "dataset_config_name":"3.0.0", 6 | "cache_dir":"/home2/yhliu", 7 | "preprocessing_num_workers":16, 8 | "overwrite_cache":false, 9 | "per_device_train_batch_size": 8, 10 | "per_device_eval_batch_size": 8, 11 | "do_train": true, 12 | "do_eval": true, 13 | "do_predict":false, 14 | "evaluation_strategy": "steps", 15 | "eval_steps": 1000, 16 | "max_seq_length": 512, 17 | "max_target_length":100, 18 | "max_source_length":412, 19 | "val_max_target_length":50, 20 | "skip_special_tokens":true, 21 | "max_eval_samples": 48, 22 | "max_predict_samples": 48, 23 | "simplex_value": 5.0, 24 | "num_diffusion_steps": 5000, 25 | "num_inference_diffusion_steps": 1000, 26 | "lr_scheduler_type": "linear", 27 | "learning_rate": 3e-5, 28 | "pad_to_max_length": true, 29 | "beta_schedule": "squaredcos_improved_ddpm", 30 | "weight_decay": 0.0, 31 | "top_p": 0.99, 32 | "max_steps": 200000, 33 | "gradient_accumulation_steps": 4, 34 | "warmup_steps": 2000, 35 | "logging_steps": 50, 36 | "save_steps": 1000, 37 | "conditional_generation": "ul2", 38 | "save_total_limit": 3, 39 | "tokenized_data_path": "/home2/yhliu/simplex-diffusion-main/raw/cnndm/roberta", 40 | "report_to":"tensorboard", 41 | "load_best_model_at_end": true, 42 | "overwrite_output_dir":false, 43 | "metric_for_best_model": "pred_texts_from_logits_masked_rouge1", 44 | "self_condition_mix_before_weight":true, 45 | "self_condition":"logits_mean" 46 | 47 | } 48 | -------------------------------------------------------------------------------- /simplex-diffusion-main/configs/models/cnndm/roberta_control.json: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "/home2/yhliu/simplex-diffusion-main/raw/cnndm/biased/", 3 | "model_name_or_path": "/home2/yhliu/simplex-diffusion-main/raw/cnndm/roberta100/checkpoint-126000", 4 | "dataset_name":"/home2/yhliu/simplex-diffusion-main/biased_datasets/right", 5 | "cache_dir":"/home2/yhliu", 6 | "preprocessing_num_workers":16, 7 | "overwrite_cache":false, 8 | "per_device_train_batch_size": 8, 9 | "per_device_eval_batch_size": 8, 10 | "do_train": false, 11 | "do_eval": false, 12 | "do_predict":true, 13 | "evaluation_strategy": "no", 14 | "eval_steps": 1000, 15 | "report_to": "tensorboard", 16 | "overwrite_output_dir": false, 17 | "max_seq_length": 512, 18 | "max_target_length":100, 19 | "max_source_length":412, 20 | "val_max_target_length":100, 21 | "skip_special_tokens":true, 22 | "max_eval_samples": 48, 23 | "max_predict_samples": 500, 24 | "simplex_value": 5.0, 25 | "num_diffusion_steps": 5000, 26 | "num_inference_diffusion_steps": 1000, 27 | "lr_scheduler_type": "linear", 28 | "learning_rate": 1e-4, 29 | "pad_to_max_length": true, 30 | "beta_schedule": "squaredcos_improved_ddpm", 31 | "weight_decay": 0.0, 32 | "top_p": 0.95, 33 | "max_steps": 200000, 34 | "gradient_accumulation_steps": 1, 35 | "warmup_steps": 2000, 36 | "logging_steps": 50, 37 | "save_steps": 1000, 38 | "conditional_generation": "ul2", 39 | "save_total_limit": 1, 40 | "tokenized_data_path": "/home2/yhliu/simplex-diffusion-main/raw/cnndm/roberta", 41 | "metric_for_best_model": "pred_texts_from_logits_masked_rouge1", 42 | "if_control": true, 43 | "ctr_model_name":"/home2/yhliu/controlpoli/ssd-lm/controlling for political bias/POLITICS/POLITICS_model", 44 | "ctr_opt_label_idx":2 45 | 46 | 47 | 48 | } 49 | -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/metrics/repetition.py: -------------------------------------------------------------------------------- 1 | """Computes the repetition metric. Adapted from: https://raw.githubusercontent.com/ari-holtzman/degen/master/metrics/repetition.py""" 2 | import pdb 3 | 4 | 5 | def repetition(tokenized_texts, tokenizer): 6 | """ 7 | Args: 8 | tokenized_texts: (List[List[int]]) generated input tokenized texts. 9 | 10 | Computes the repetition metric https://arxiv.org/pdf/1904.09751.pdf showing how each 11 | example is repeating itself, specifically the phrase the generation is repeating 12 | and how many times it is repeated. 13 | """ 14 | SEP = tokenizer.encode(tokenizer.bos_token)[0] 15 | repetition_stats = [] 16 | max_n = 90 17 | num_examples = len(tokenized_texts) 18 | n_repeated_examples = 0 19 | for tokenized_text in tokenized_texts: 20 | if tokenized_text[-1] == SEP: 21 | tokenized_text.pop(-1) 22 | rev_gen = list(reversed(tokenized_text)) 23 | last_n_repeats = [0] * max_n 24 | for n in range(1, max_n + 1): 25 | n_repeat = 1 26 | while ( 27 | len(rev_gen[n * n_repeat : n * (n_repeat + 1)]) == n 28 | and rev_gen[n * n_repeat : n * (n_repeat + 1)] == rev_gen[:n] 29 | ): 30 | n_repeat += 1 31 | last_n_repeats[n - 1] = n_repeat 32 | max_repeated_n = max(range(max_n), key=lambda x: last_n_repeats[x]) 33 | if last_n_repeats[max_repeated_n] > 1 and (max_repeated_n + 1 >= 3 or last_n_repeats[max_repeated_n] > 50): 34 | repetition_stats.append( 35 | { 36 | "repeated_phrase": list(reversed(rev_gen[: max_repeated_n + 1])), 37 | "repeated_times": last_n_repeats[max_repeated_n], 38 | "repeated_phrase_length": max_repeated_n + 1, 39 | } 40 | ) 41 | n_repeated_examples += 1 42 | else: 43 | repetition_stats.append({}) 44 | 45 | return {"repetition": n_repeated_examples * 1.0 / num_examples} # , "repetition_stats": repetition_stats} 46 | -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/data/postprocessors.py: -------------------------------------------------------------------------------- 1 | import nltk # Here to have a nice missing dependency error message early on 2 | from transformers.utils import is_offline_mode 3 | from filelock import FileLock 4 | 5 | try: 6 | nltk.data.find("tokenizers/punkt") 7 | except (LookupError, OSError): 8 | if is_offline_mode(): 9 | raise LookupError("Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files") 10 | with FileLock(".lock") as lock: 11 | nltk.download("punkt", quiet=True) 12 | 13 | def string_to_float(string, default=-1.): 14 | """Converts string to float, using default when conversion not possible.""" 15 | try: 16 | return float(string) 17 | except ValueError: 18 | return default 19 | 20 | 21 | def string_to_int(string, default=-1): 22 | """Converts string to int, using default when conversion not possible.""" 23 | try: 24 | return int(string) 25 | except ValueError: 26 | return default 27 | 28 | 29 | def get_post_processor(task): 30 | """Returns post processor required to apply on the predictions/targets 31 | before computing metrics for each task.""" 32 | if task == "stsb": 33 | return string_to_float 34 | elif task in ["qqp", "cola", "mrpc"]: 35 | return string_to_int 36 | else: 37 | return None 38 | 39 | 40 | def postprocess_text_for_metric(metric, preds, labels=None, sources=None): 41 | if metric == "sari": 42 | assert sources is not None 43 | preds = [pred.strip() for pred in preds] 44 | labels = [label.strip() for label in labels] 45 | sources = [source.strip() for source in sources] 46 | return preds, labels, sources 47 | elif metric == "rouge": 48 | preds = [pred.strip() for pred in preds] 49 | labels = [label.strip() for label in labels] 50 | # rougeLSum expects newline after each sentence 51 | preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] 52 | labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] 53 | return preds, labels 54 | elif metric == "bleu": 55 | preds = [pred.strip() for pred in preds] 56 | labels = [[label.strip()] for label in labels] 57 | return preds, labels 58 | elif metric in ["bertscore", "bertscore_them"]: 59 | preds = [pred.strip() for pred in preds] 60 | labels = [label.strip() for label in labels] 61 | return preds, labels 62 | elif metric in ["dist"]: 63 | preds = [pred.strip() for pred in preds] 64 | return preds 65 | else: 66 | raise NotImplementedError -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/models/h3/configuration_h3.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from transformers.models.gpt2 import GPT2Config 4 | 5 | 6 | class H3Config(GPT2Config): 7 | def __init__( 8 | self, 9 | # n_layer: int, 10 | # vocab_size: int, 11 | # max_position_embeddings=0, 12 | # d_inner: int = 3072, 13 | d_model: int = 1, 14 | n_head: int = 1, 15 | rotary_emb_dim: int = 0, 16 | attn_layer_idx=None, 17 | resid_dropout: float = 0.0, 18 | embed_dropout: float = 0.1, 19 | layer_norm_epsilon: float = 1e-5, 20 | initializer_cfg=None, 21 | fused_mlp=False, 22 | fused_dropout_add_ln=False, 23 | residual_in_fp32=False, 24 | pad_vocab_size_multiple: int = 1, 25 | **kwargs, 26 | ): 27 | super().__init__(**kwargs) 28 | # h3 29 | self.d_model = d_model 30 | self.d_inner = d_model * 4 31 | self.ssm_cfg = {"mode": "diag", "measure": "diag-lin"} 32 | self.attn_layer_idx = attn_layer_idx 33 | self.attn_cfg = {"num_heads": n_head, "causal": False} 34 | if rotary_emb_dim: 35 | self.attn_cfg["rotary_emb_dim"] = rotary_emb_dim 36 | self.resid_dropout = resid_dropout 37 | self.embed_dropout = embed_dropout 38 | self.layer_norm_epsilon = layer_norm_epsilon 39 | self.initializer_cfg = initializer_cfg 40 | self.fused_mlp = fused_mlp 41 | self.fused_dropout_add_ln = fused_dropout_add_ln 42 | self.residual_in_fp32 = residual_in_fp32 43 | self.pad_vocab_size_multiple = pad_vocab_size_multiple 44 | 45 | 46 | class H3DiffusionConfig(H3Config): 47 | def __init__( 48 | self, 49 | self_condition: Optional[str] = None, 50 | self_condition_zeros_after_softmax: bool = False, 51 | deepmind_conditional: bool = False, 52 | classifier_free_simplex_inputs: bool = False, 53 | classifier_free_uncond_input: str = "empty_token", 54 | self_condition_mlp_projection=False, 55 | self_condition_mix_before_weights=False, 56 | self_condition_mix_logits_before_weights=False, 57 | empty_token_be_mask=False, 58 | **kwargs, 59 | ): 60 | super().__init__(**kwargs) 61 | self.self_condition = self_condition 62 | self.self_condition_zeros_after_softmax = self_condition_zeros_after_softmax 63 | self.deepmind_conditional = deepmind_conditional 64 | self.classifier_free_simplex_inputs = classifier_free_simplex_inputs 65 | self.classifier_free_uncond_input = classifier_free_uncond_input 66 | self.self_condition_mlp_projection = self_condition_mlp_projection 67 | self.self_condition_mix_before_weights = self_condition_mix_before_weights 68 | self.self_condition_mix_logits_before_weights = ( 69 | self_condition_mix_logits_before_weights 70 | ) 71 | self.empty_token_be_mask = empty_token_be_mask 72 | # PAD 73 | self.vocab_size += 1 74 | -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/models/h3/ssm/ss_kernel_shift.py: -------------------------------------------------------------------------------- 1 | # TD: [2023-01-05]: Extracted the SSKernelDiag class from 2 | # https://github.com/HazyResearch/state-spaces/blob/06dbbdfd0876501a7f12bf3262121badbc7658af/src/models/sequence/ss/kernel.py 3 | # We make a small change to use the log_vandermonde CUDA code. 4 | 5 | """SSKernelDiag is the S4D kernel, a simpler algorithm for computing the kernel for the case of diagonal state matrices A. 6 | """ 7 | import math 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from einops import rearrange, repeat 13 | from opt_einsum import contract 14 | 15 | from .ssm_utils import OptimModule 16 | 17 | 18 | class SSKernelShift(OptimModule): 19 | def __init__(self, B, C, L=None, lr=None, **kwargs): 20 | """ 21 | B: (H, d), real 22 | C: (channel, H, d), real 23 | """ 24 | super().__init__() 25 | self.L = L 26 | self.N = B.size(-1) 27 | self.H = B.shape[0] 28 | 29 | # Register parameters 30 | if lr is None or isinstance(lr, float): 31 | lr_dict = {} 32 | else: 33 | lr_dict, lr = lr, None 34 | self.register("B", B, lr_dict.get("B", lr)) 35 | self.C = nn.Parameter(C) 36 | 37 | def forward(self, state=None, rate=1.0, L=None): 38 | if L is None: 39 | L = self.L 40 | # This class doesn't support variable length functionalities, since it's a discrete SSM 41 | assert rate == 1.0 and L is not None 42 | 43 | # Augment B with state 44 | B = self.B 45 | if state is not None: 46 | B = rearrange( 47 | torch.cat([rearrange(B, "h n -> 1 h n"), state], dim=-3), 48 | "bp1 h n -> bp1 1 h n", 49 | ) # (1 + B, 1, H, N) 50 | B_f = torch.fft.rfft(B, n=2 * self.N) 51 | C_f = torch.fft.rfft(self.C, n=2 * self.N) 52 | k = torch.fft.irfft(B_f.conj() * C_f, n=2 * self.N)[..., : min(self.N, L)] 53 | # If self.N < L, need to pad with zeros to reach length L 54 | if self.N < L: 55 | k = F.pad(k, (0, L - self.N)) 56 | k = k.float() # Otherwise it could be dtype half 57 | if state is not None: 58 | k, k_state = k[0], k[1:] 59 | else: 60 | k_state = None 61 | return k, k_state 62 | 63 | def _setup_step(self): 64 | # Just here to conform to the interface, eventually we should refactor out 65 | pass 66 | 67 | def default_state(self, *batch_shape): 68 | return torch.zeros( 69 | *batch_shape, self.H, self.N, dtype=self.C.dtype, device=self.C.device 70 | ) 71 | 72 | def step(self, u, state): 73 | """u: (B, H), state: (B, H, N)""" 74 | next_state = F.pad(state, (1, -1)) + contract("h n, b h -> b h n", self.B, u) 75 | y = contract("c h n, b h n -> b c h", self.C, next_state) 76 | return y, next_state 77 | 78 | def forward_state(self, u, state): 79 | """u: (B, H, L), state: (B, H, N)""" 80 | L = u.shape[-1] 81 | B_f = torch.fft.rfft(self.B, n=2 * self.N) 82 | u_f = torch.fft.rfft( 83 | u[..., -self.N :].flip(-1).to(dtype=self.B.dtype), n=2 * self.N 84 | ) 85 | v = torch.fft.irfft(B_f * u_f, n=2 * self.N)[..., : self.N] 86 | if L < self.N: 87 | next_state = F.pad(state, (L, -L)) + v 88 | else: 89 | next_state = v 90 | return next_state 91 | -------------------------------------------------------------------------------- /simplex-diffusion-main/commands_for_sum/run_sum.sh: -------------------------------------------------------------------------------- 1 | #finetune summarization task from the checkpoint of roberta-based tuned with ul2-variable on the length=256. 2 | #cd "/data3/whr/lyh/ControllingPoliticalBias/simplex-diffusion-main/" 3 | cd "path/to/simplex-diffusion-main" 4 | model_name="roberta-base" 5 | model_path="model/" 6 | learning_rate=3e-5 7 | max_steps=200000 #120000 8 | CUDA_VISIBLE_DEVICES=1 9 | datasetname="Sampled_Datasets/cnn_dm_500_raw" 10 | dataset_config_name="3.0.0" 11 | cache_dir="./" 12 | preprocessing_num_workers=16 13 | overwrite_cache=false 14 | per_device_train_batch_size=8 15 | per_device_eval_batch_size=16 16 | do_train=false 17 | do_eval=false 18 | do_predict=true 19 | evaluation_strategy="no" 20 | eval_steps=1000 21 | report_to="tensorboard" 22 | overwrite_output_dir=false 23 | max_seq_length=512 24 | max_target_length=100 25 | max_source_length=412 26 | val_max_target_length=100 27 | skip_special_tokens=true 28 | max_eval_samples=100 29 | max_predict_samples=500 30 | simplex_value=5 31 | num_diffusion_steps=1000 32 | lr_scheduler_type="linear" 33 | pad_to_max_length=true 34 | beta_schedule="squaredcos_improved_ddpm" 35 | weight_decay=0.0 36 | warmup_steps=2000 37 | max_steps=200000 38 | gradient_accumulation_steps=1 39 | logging_steps=50 40 | save_steps=20000 41 | conditional_generation="ul2" 42 | save_total_limit=1 43 | tokenized_data_path="raw/xsum/roberta" 44 | metric_for_best_model="rouge1" 45 | if_control=true 46 | ctr_model_name='POLITICS_model' 47 | ctr_opt_label_idx=0 48 | output_dir="out/" 49 | decode_ctr_lr=1000 50 | 51 | 52 | CUDA_VISIBLE_DEVICES='0' python sdlm/run_summarization.py \ 53 | --model_name_or_path ${model_path} \ 54 | --output_dir ${output_dir} \ 55 | --dataset_name ${datasetname} \ 56 | --dataset_config_name ${dataset_config_name} \ 57 | --cache_dir ${cache_dir} \ 58 | --preprocessing_num_workers ${preprocessing_num_workers} \ 59 | --overwrite_cache ${overwrite_cache} \ 60 | --per_device_train_batch_size ${per_device_train_batch_size} \ 61 | --per_device_eval_batch_size ${per_device_eval_batch_size} \ 62 | --do_train ${do_train} \ 63 | --do_eval ${do_eval} \ 64 | --do_predict ${do_predict} \ 65 | --evaluation_strategy ${evaluation_strategy} \ 66 | --eval_steps ${eval_steps} \ 67 | --report_to ${report_to} \ 68 | --overwrite_output_dir ${overwrite_output_dir} \ 69 | --max_seq_length ${max_seq_length} \ 70 | --max_target_length ${max_target_length} \ 71 | --max_source_length ${max_source_length} \ 72 | --val_max_target_length ${val_max_target_length} \ 73 | --skip_special_tokens ${skip_special_tokens} \ 74 | --max_eval_samples ${max_eval_samples} \ 75 | --max_predict_samples ${max_predict_samples} \ 76 | --simplex_value ${simplex_value} \ 77 | --num_diffusion_steps ${num_diffusion_steps} \ 78 | --lr_scheduler_type ${lr_scheduler_type} \ 79 | --pad_to_max_length ${pad_to_max_length} \ 80 | --beta_schedule ${beta_schedule} \ 81 | --weight_decay ${weight_decay} \ 82 | --warmup_steps ${warmup_steps} \ 83 | --max_steps ${max_steps} \ 84 | --gradient_accumulation_steps ${gradient_accumulation_steps} \ 85 | --logging_steps ${logging_steps} \ 86 | --save_steps ${save_steps} \ 87 | --conditional_generation ${conditional_generation} \ 88 | --save_total_limit ${save_total_limit} \ 89 | --tokenized_data_path ${tokenized_data_path} \ 90 | --metric_for_best_model ${metric_for_best_model} \ 91 | --if_control ${if_control} \ 92 | --ctr_model_name ${ctr_model_name} \ 93 | --ctr_opt_label_idx ${ctr_opt_label_idx} \ 94 | --decode_ctr_lr ${decode_ctr_lr} \ 95 | --self_condition "logits_mean" \ 96 | --self_condition_mix_before_weights true 97 | 98 | -------------------------------------------------------------------------------- /simplex-diffusion-main/.gitignore: -------------------------------------------------------------------------------- 1 | # custom 2 | out/ 3 | model/ 4 | POLITICS_model/ 5 | Sampled_Datasets/ 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | cover/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | .pybuilder/ 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | # For a library or package, you might want to ignore these files since the code is 93 | # intended to run in multiple environments; otherwise, check them in: 94 | # .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 111 | __pypackages__/ 112 | 113 | # Celery stuff 114 | celerybeat-schedule 115 | celerybeat.pid 116 | 117 | # SageMath parsed files 118 | *.sage.py 119 | 120 | # Environments 121 | .env 122 | .venv 123 | env/ 124 | venv/ 125 | ENV/ 126 | env.bak/ 127 | venv.bak/ 128 | 129 | # Spyder project settings 130 | .spyderproject 131 | .spyproject 132 | 133 | # Rope project settings 134 | .ropeproject 135 | 136 | # mkdocs documentation 137 | /site 138 | 139 | # mypy 140 | .mypy_cache/ 141 | .dmypy.json 142 | dmypy.json 143 | 144 | # Pyre type checker 145 | .pyre/ 146 | 147 | # pytype static type analyzer 148 | .pytype/ 149 | 150 | # Cython debug symbols 151 | cython_debug/ 152 | 153 | # PyCharm 154 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 155 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 156 | # and can be added to the global gitignore or merged into this file. For a more nuclear 157 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 158 | #.idea/ 159 | -------------------------------------------------------------------------------- /simplex-diffusion-main/env.yml: -------------------------------------------------------------------------------- 1 | name: preserve 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - pip=23.1.2 7 | - pip: 8 | - absl-py==1.4.0 9 | - accelerate==0.16.0 10 | - aiohttp==3.8.4 11 | - aiosignal==1.3.1 12 | - antlr4-python3-runtime==4.9.3 13 | - appdirs==1.4.4 14 | - async-timeout==4.0.2 15 | - attrs==22.2.0 16 | - beautifulsoup4==4.12.2 17 | - bert-score==0.3.13 18 | - cachetools==5.3.0 19 | - certifi==2022.12.7 20 | - cfgv==3.3.1 21 | - charset-normalizer==3.0.1 22 | - click==8.1.3 23 | - colorama==0.4.6 24 | - consoleprinter==95 25 | - contourpy==1.0.7 26 | - cycler==0.11.0 27 | - datasets==2.10.0 28 | - diffusers==0.13.1 29 | - dill==0.3.6 30 | - distlib==0.3.6 31 | - docker-pycreds==0.4.0 32 | - einops==0.6.0 33 | - et-xmlfile==1.1.0 34 | - evaluate==0.4.0 35 | - faiss-cpu==1.7.3 36 | - filelock==3.9.0 37 | - fonttools==4.38.0 38 | - frozenlist==1.3.3 39 | - fsspec==2023.1.0 40 | - future==0.18.3 41 | - gdown==4.7.1 42 | - gitdb==4.0.10 43 | - gitpython==3.1.32 44 | - google-auth==2.16.1 45 | - google-auth-oauthlib==0.4.6 46 | - grpcio==1.51.3 47 | - huggingface-hub==0.12.1 48 | - identify==2.5.18 49 | - idna==3.4 50 | - importlib-metadata==6.0.0 51 | - importlib-resources==5.12.0 52 | - joblib==1.2.0 53 | - keopscore==2.1.1 54 | - kiwisolver==1.4.4 55 | - lightning-utilities==0.7.1 56 | - lxml==4.9.2 57 | - markdown==3.4.1 58 | - markupsafe==2.1.2 59 | - matplotlib==3.7.0 60 | - mauve-text==0.3.0 61 | - multidict==6.0.4 62 | - multiprocess==0.70.14 63 | - nltk==3.8.1 64 | - nodeenv==1.7.0 65 | - numpy==1.24.2 66 | - oauthlib==3.2.2 67 | - omegaconf==2.3.0 68 | - openai==0.28.0 69 | - openpyxl==3.1.2 70 | - opt-einsum==3.3.0 71 | - packaging==23.0 72 | - pandas==1.5.3 73 | - pathtools==0.1.2 74 | - pillow==9.4.0 75 | - platformdirs==3.0.0 76 | - portalocker==2.7.0 77 | - pre-commit==3.0.4 78 | - protobuf==3.20.0 79 | - psutil==5.9.4 80 | - pyarrow==11.0.0 81 | - pyasn1==0.4.8 82 | - pyasn1-modules==0.2.8 83 | - pybind11==2.10.3 84 | - pykeops==2.1.1 85 | - pyparsing==3.0.9 86 | - pysocks==1.7.1 87 | - pytorch-lightning==1.9.3 88 | - pytz==2022.7.1 89 | - pyyaml==6.0 90 | - regex==2022.10.31 91 | - requests==2.28.2 92 | - requests-oauthlib==1.3.1 93 | - responses==0.18.0 94 | - rouge-score==0.1.2 95 | - rsa==4.9 96 | - sacrebleu==2.3.1 97 | - scikit-learn==1.2.1 98 | - scipy==1.10.1 99 | - seaborn==0.12.2 100 | - sentencepiece==0.1.99 101 | - sentry-sdk==1.28.1 102 | - setproctitle==1.3.2 103 | - smmap==5.0.0 104 | - soupsieve==2.4.1 105 | - tabulate==0.9.0 106 | - tenacity==8.2.3 107 | - tensorboard==2.12.0 108 | - tensorboard-data-server==0.7.0 109 | - tensorboard-plugin-wit==1.8.1 110 | - termcolor==2.3.0 111 | - terminaltables==3.1.10 112 | - threadpoolctl==3.1.0 113 | - tiktoken==0.4.0 114 | - tokenizers==0.13.2 115 | - torch==1.13.0 116 | - torchaudio==0.13.0 117 | - torchmetrics==1.0.1 118 | - torchvision==0.14.0 119 | - tqdm==4.64.1 120 | - transformers==4.27.1 121 | - typing-extensions==4.5.0 122 | - tzdata==2023.3 123 | - ujson==5.8.0 124 | - urllib3==1.26.14 125 | - virtualenv==20.19.0 126 | - wandb==0.15.7 127 | - werkzeug==2.2.3 128 | - xxhash==3.2.0 129 | - yarl==1.8.2 130 | - zipp==3.14.0 131 | 132 | -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/data/process_data.py: -------------------------------------------------------------------------------- 1 | """Tokenize the dataset and saves the output.""" 2 | import logging 3 | import os 4 | import sys 5 | 6 | import datasets 7 | from accelerate import Accelerator 8 | from datasets import DatasetDict, load_dataset 9 | from transformers import AutoTokenizer, HfArgumentParser, set_seed 10 | from transformers.utils.versions import require_version 11 | import sys 12 | sys.path.append("/home2/yhliu/simplex-diffusion-main/sdlm") 13 | from arguments import DataTrainingArguments, ModelArguments, TrainingArguments 14 | from data.data_utils import tokenize_data 15 | 16 | require_version("datasets>=1.8.0") 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | def main(): 21 | parser = HfArgumentParser( 22 | (ModelArguments, DataTrainingArguments, TrainingArguments) 23 | ) 24 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 25 | # If we pass only one argument to the script and it's the path to a json file, 26 | # let's parse it to get our arguments. 27 | model_args, data_args, training_args = parser.parse_json_file( 28 | json_file=os.path.abspath(sys.argv[1]) 29 | ) 30 | else: 31 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 32 | 33 | accelerator = Accelerator() 34 | 35 | # Setup logging 36 | logging.basicConfig( 37 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 38 | datefmt="%m/%d/%Y %H:%M:%S", 39 | handlers=[logging.StreamHandler(sys.stdout)], 40 | ) 41 | 42 | log_level = training_args.get_process_log_level() 43 | logger.setLevel(log_level) 44 | datasets.utils.logging.set_verbosity(log_level) 45 | 46 | # Log on each process the small summary: 47 | logger.warning( 48 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 49 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 50 | ) 51 | # Set the verbosity to info of the Transformers logger (on main process only): 52 | logger.info(f"Training/evaluation parameters {training_args}") 53 | 54 | # Set seed before initializing model. 55 | set_seed(training_args.seed) 56 | 57 | assert data_args.dataset_name is not None 58 | # Downloading and loading a dataset from the hub. 59 | raw_datasets = datasets.DatasetDict() 60 | raw_datasets["train"] = load_dataset( 61 | data_args.dataset_name, 62 | data_args.dataset_config_name, 63 | cache_dir=model_args.cache_dir, 64 | use_auth_token=True if model_args.use_auth_token else None, 65 | split="train", 66 | verification_mode=data_args.verification_mode, 67 | ) 68 | 69 | tokenizer_kwargs = { 70 | "cache_dir": model_args.cache_dir, 71 | "use_fast": model_args.use_fast_tokenizer, 72 | "revision": model_args.model_revision, 73 | "use_auth_token": True if model_args.use_auth_token else None, 74 | } 75 | if model_args.model_name_or_path: 76 | tokenizer = AutoTokenizer.from_pretrained( 77 | model_args.model_name_or_path, **tokenizer_kwargs 78 | ) 79 | else: 80 | raise ValueError( 81 | "You are instantiating a new tokenizer from scratch. This is not supported by this script." 82 | "You can do it from another script, save it, and load it from here, using --tokenizer_name." 83 | ) 84 | 85 | tokenized_datasets = tokenize_data(data_args, tokenizer, raw_datasets, accelerator) 86 | 87 | train_testvalid = tokenized_datasets["train"].train_test_split( 88 | test_size=data_args.validation_split_ratio, 89 | shuffle=True, 90 | seed=training_args.seed, 91 | ) 92 | tokenized_datasets = DatasetDict( 93 | {"train": train_testvalid["train"], "validation": train_testvalid["test"]} 94 | ) 95 | 96 | with training_args.main_process_first(): 97 | tokenized_datasets.save_to_disk(training_args.output_dir) 98 | logger.info(f"The processed data are written in {training_args.output_dir}") 99 | 100 | 101 | if __name__ == "__main__": 102 | main() 103 | -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/models/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer 2 | 3 | 4 | from .longformer.configuration_longformer import LongformerDiffusionConfig 5 | from .longformer.modeling_longformer import LongformerForDiffusionLM 6 | from .roberta.configuration_roberta import RobertaDiffusionConfig 7 | from .roberta.modeling_roberta import RobertaForDiffusionLM 8 | from .xlm_roberta.configuration_xlm_roberta import XLMRobertaDiffusionConfig 9 | from .xlm_roberta.modeling_xlm_roberta import XLMRobertaForDiffusionLM 10 | try: 11 | from .h3.configuration_h3 import H3DiffusionConfig 12 | from .h3.modeling_h3 import H3ForDiffusionLM 13 | except ModuleNotFoundError: 14 | H3ForDiffusionLM, H3DiffusionConfig = None, None 15 | pass # probably due to no flash attention, which is fine 16 | 17 | def model_config_helper(model_name_or_path): 18 | if "roberta" in model_name_or_path: 19 | return RobertaDiffusionConfig, RobertaForDiffusionLM 20 | if "longformer" in model_name_or_path: 21 | return LongformerDiffusionConfig, LongformerForDiffusionLM 22 | if "gpt2" in model_name_or_path: 23 | return H3DiffusionConfig, H3ForDiffusionLM 24 | if "xlm" in model_name_or_path: 25 | return XLMRobertaDiffusionConfig, XLMRobertaForDiffusionLM 26 | raise ValueError 27 | 28 | 29 | def load_model(model_args, diffusion_args, logger): 30 | config_kwargs = { 31 | "cache_dir": model_args.cache_dir, 32 | "revision": model_args.model_revision, 33 | "use_auth_token": True if model_args.use_auth_token else None, 34 | } 35 | cfg_cls, model_cls = model_config_helper(model_args.model_name_or_path) 36 | config = cfg_cls.from_pretrained( 37 | model_args.model_name_or_path, 38 | self_condition=diffusion_args.self_condition, 39 | self_condition_zeros_after_softmax=diffusion_args.self_condition_zeros_after_softmax, 40 | deepmind_conditional=diffusion_args.deepmind_conditional, 41 | classifier_free_simplex_inputs=diffusion_args.classifier_free_simplex_inputs, 42 | classifier_free_uncond_input=diffusion_args.classifier_free_uncond_input, 43 | self_condition_mlp_projection=diffusion_args.self_condition_mlp_projection, 44 | self_condition_mix_before_weights=diffusion_args.self_condition_mix_before_weights, 45 | self_condition_mix_logits_before_weights=diffusion_args.self_condition_mix_logits_before_weights, 46 | empty_token_be_mask=diffusion_args.empty_token_be_mask, 47 | d_model=model_args.d_model, 48 | n_head=model_args.n_head, 49 | attn_layer_idx=model_args.attn_layer_idx, 50 | attention_window=model_args.attention_window, 51 | **config_kwargs, 52 | ) 53 | tokenizer_kwargs = { 54 | "cache_dir": model_args.cache_dir, 55 | "use_fast": model_args.use_fast_tokenizer, 56 | "revision": model_args.model_revision, 57 | "use_auth_token": True if model_args.use_auth_token else None, 58 | } 59 | if model_args.tokenizer_name: 60 | tokenizer = AutoTokenizer.from_pretrained( 61 | model_args.tokenizer_name, **tokenizer_kwargs 62 | ) 63 | elif model_args.model_name_or_path: 64 | tokenizer = AutoTokenizer.from_pretrained( 65 | model_args.model_name_or_path, **tokenizer_kwargs 66 | ) 67 | else: 68 | raise ValueError( 69 | "You are instantiating a new tokenizer from scratch. This is not supported by this script." 70 | "You can do it from another script, save it, and load it from here, using --tokenizer_name." 71 | ) 72 | if not tokenizer.pad_token_id: 73 | tokenizer.add_special_tokens({"pad_token": "[PAD]"}) 74 | 75 | if model_args.model_name_or_path: 76 | model = model_cls.from_pretrained( 77 | model_args.model_name_or_path, 78 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 79 | config=config, 80 | cache_dir=model_args.cache_dir, 81 | revision=model_args.model_revision, 82 | use_auth_token=True if model_args.use_auth_token else None, 83 | ) 84 | else: 85 | logger.info("Training new model from scratch") 86 | model = model_cls.from_config(config) 87 | 88 | # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch 89 | # on a small vocab and want a smaller embedding size, remove this test. 90 | vocab_size = model.get_input_embeddings().weight.shape[0] 91 | if len(tokenizer) > vocab_size: 92 | model.resize_token_embeddings(len(tokenizer)) 93 | 94 | return tokenizer, model 95 | -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/models/h3/ssm/dplr.py: -------------------------------------------------------------------------------- 1 | # Copied from https://github.com/HazyResearch/state-spaces/blob/06dbbdfd0876501a7f12bf3262121badbc7658af/src/models/sequence/ss/dplr.py 2 | 3 | """Initializations of structured state space models""" 4 | import math 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from einops import rearrange, repeat 10 | 11 | from . import hippo 12 | 13 | 14 | def dplr( 15 | scaling="linear", 16 | N=64, 17 | rank=1, 18 | H=1, 19 | dtype=torch.float, 20 | real_scale=1.0, 21 | imag_scale=1.0, 22 | random_real=False, 23 | random_imag=False, 24 | normalize=False, 25 | diagonal=True, 26 | random_B=False, 27 | ): 28 | assert dtype == torch.float or dtype == torch.double 29 | dtype = torch.cfloat if dtype == torch.float else torch.cdouble 30 | 31 | pi = torch.tensor(math.pi) 32 | if random_real: 33 | real_part = torch.rand(H, N // 2) 34 | else: 35 | real_part = 0.5 * torch.ones(H, N // 2) 36 | if random_imag: 37 | imag_part = N // 2 * torch.rand(H, N // 2) 38 | else: 39 | imag_part = repeat(torch.arange(N // 2), "n -> h n", h=H) 40 | 41 | real_part = real_scale * real_part 42 | if scaling == "random": 43 | imag_part = torch.randn(H, N // 2) 44 | elif scaling == "real": 45 | imag_part = 0 * imag_part 46 | real_part = 1 + repeat(torch.arange(N // 2), "n -> h n", h=H) 47 | elif scaling in ["linear", "lin"]: 48 | imag_part = pi * imag_part 49 | elif scaling in [ 50 | "inverse", 51 | "inv", 52 | ]: # Based on asymptotics of the default HiPPO matrix 53 | imag_part = 1 / pi * N * (N / (1 + 2 * imag_part) - 1) 54 | elif scaling in ["inverse2", "inv2"]: 55 | imag_part = 1 / pi * N * (N / (1 + imag_part) - 1) 56 | elif scaling in ["quadratic", "quad"]: 57 | imag_part = 1 / pi * (1 + 2 * imag_part) ** 2 58 | elif scaling in ["legs", "hippo"]: 59 | w, _, _, _ = hippo.nplr("legsd", N) 60 | imag_part = w.imag 61 | 62 | else: 63 | raise NotImplementedError 64 | imag_part = imag_scale * imag_part 65 | w = -real_part + 1j * imag_part 66 | 67 | # Initialize B 68 | if random_B: 69 | B = torch.randn(H, N // 2, dtype=dtype) 70 | else: 71 | B = torch.ones(H, N // 2, dtype=dtype) 72 | 73 | if normalize: 74 | norm = ( 75 | -B / w 76 | ) # (H, N) # Result if you integrate the kernel with constant 1 function 77 | zeta = 2 * torch.sum( 78 | torch.abs(norm) ** 2, dim=-1, keepdim=True 79 | ) # Variance with a random C vector 80 | B = B / zeta**0.5 81 | 82 | P = torch.randn(rank, H, N // 2, dtype=dtype) 83 | if diagonal: 84 | P = P * 0.0 85 | V = torch.eye(N, dtype=dtype)[:, : N // 2] # Only used in testing 86 | V = repeat(V, "n m -> h n m", h=H) 87 | 88 | return w, P, B, V 89 | 90 | 91 | def ssm(measure, N, R, H, **ssm_args): 92 | """Dispatcher to create single SSM initialization 93 | N: state size 94 | R: rank (for DPLR parameterization) 95 | H: number of independent SSM copies 96 | """ 97 | 98 | if measure == "dplr": 99 | w, P, B, V = dplr(N=N, rank=R, H=H, **ssm_args) 100 | elif measure.startswith("diag"): 101 | args = measure.split("-") 102 | assert args[0] == "diag" and len(args) > 1 103 | scaling = args[1] 104 | w, P, B, V = dplr(scaling=scaling, N=N, rank=R, H=H, diagonal=True, **ssm_args) 105 | else: 106 | w, P, B, V = hippo.nplr(measure, N, R, **ssm_args) 107 | w = repeat(w, "n -> s n", s=H) 108 | P = repeat(P, "r n -> r s n", s=H) 109 | B = repeat(B, "n -> s n", s=H) 110 | V = repeat(V, "n m -> s n m", s=H) 111 | return w, P, B, V 112 | 113 | 114 | combinations = { 115 | "hippo": ["legs", "fourier"], 116 | "diag": ["diag-inv", "diag-lin"], 117 | "all": ["legs", "fourier", "diag-inv", "diag-lin"], 118 | } 119 | 120 | 121 | def combination(measures, N, R, S, **ssm_args): 122 | if isinstance(measures, str): 123 | measures = combinations[measures] if measures in combinations else [measures] 124 | 125 | assert ( 126 | S % len(measures) == 0 127 | ), f"{S} independent trainable SSM copies must be multiple of {len(measures)} different measures" 128 | w, P, B, V = zip( 129 | *[ssm(measure, N, R, S // len(measures), **ssm_args) for measure in measures] 130 | ) 131 | w = torch.cat(w, dim=0) # (S N) 132 | P = torch.cat(P, dim=1) # (R S N) 133 | B = torch.cat(B, dim=0) # (S N) 134 | V = torch.cat(V, dim=0) # (S N N) 135 | return w, P, B, V 136 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Code for paper "P$^3$SUM: Preserving Author’s Perspective in News Summarizationwith Diffusion Language Models" accepted at NAACL@2024! 2 | 3 | ## Acknowledgement 4 | The diffusion part heavily relies on [TESS](https://github.com/allenai/tess-diffusion) and [SSD-LM](https://github.com/xhan77/ssd-lm). We extend our sincere gratitude to the authors for generously sharing the code in advance and for their exceptional work. Please check their remarkable papers as well: 5 | 1. [TESS: Text-to-Text Self-Conditioned Simplex Diffusion](https://arxiv.org/abs/2305.08379) 6 | 2. [SSD-LM: Semi-autoregressive Simplex-based Diffusion Language Model for Text Generation and Modular Control](https://arxiv.org/abs/2210.17432) 7 | 8 | ## Installation 9 | 10 | ```sh 11 | conda create -n preserve python=3.8.5 12 | conda activate preserve 13 | conda env update --file env.yml 14 | ``` 15 | ## Run 16 | First fill the path and other hyperparameters in ```simplex-diffusion-main/commands_for_sum/run_sum.sh```: 17 | ```sh 18 | cd "path/to/simplex-diffusion-main" 19 | model_name="roberta-base" 20 | model_path="model/" #path to the finetuned diffusion model for summarization task 21 | learning_rate=3e-5 22 | max_steps=200000 #120000 23 | CUDA_VISIBLE_DEVICES=1 24 | datasetname="Sampled_Datasets/cnn_dm_500_raw" #path to the input datasets 25 | dataset_config_name="3.0.0" 26 | cache_dir="./" 27 | preprocessing_num_workers=16 28 | overwrite_cache=false 29 | per_device_train_batch_size=8 30 | per_device_eval_batch_size=16 31 | do_train=false #whether train the model 32 | do_eval=false 33 | do_predict=true #wether test the model(generate summaries) 34 | evaluation_strategy="no" 35 | eval_steps=1000 36 | report_to="tensorboard" 37 | overwrite_output_dir=false 38 | max_seq_length=512 39 | max_target_length=100 # recommend to be close to the avg. length of gold summary 40 | max_source_length=412 #num of tokens in news context 41 | val_max_target_length=100 42 | skip_special_tokens=true 43 | max_eval_samples=100 44 | max_predict_samples=500 45 | simplex_value=5 46 | num_diffusion_steps=1000 47 | lr_scheduler_type="linear" 48 | pad_to_max_length=true 49 | beta_schedule="squaredcos_improved_ddpm" 50 | weight_decay=0.0 51 | warmup_steps=2000 52 | max_steps=200000 53 | gradient_accumulation_steps=1 54 | logging_steps=50 55 | save_steps=20000 56 | conditional_generation="ul2" 57 | save_total_limit=1 58 | tokenized_data_path="raw/xsum/roberta" 59 | metric_for_best_model="rouge1" 60 | if_control=true 61 | ctr_model_name='POLITICS_model' #path to the off-the-shelf classifier 62 | #ctr_model_name=None 63 | ctr_opt_label_idx=0 #political leaning of the input news context 64 | output_dir="out/" #path to save generated summaries 65 | decode_ctr_lr=1000 66 | ``` 67 | ## Resources 68 | We have made the [Sampled_Datasets](https://drive.google.com/file/d/1qIYjVl9wI-BYYO9C1XRgOsgEyFm-Xd7o/view?usp=sharing) available online for your convenience. For the off-the-shelf classifier, we recommend reaching out to the authors of the [POLITICS paper](https://arxiv.org/abs/2205.00619). If you need the weights of the diffusion model for summarization task, please feel free to email Yuhan Liu at lyh6560@stu.xjtu.edu.cn. 69 | 70 | ## Cite 71 | ``` 72 | @inproceedings{liu-etal-2024-p3sum, 73 | title = "{P}$^3${S}um: Preserving Author{'}s Perspective in News Summarization with Diffusion Language Models", 74 | author = "Liu, Yuhan and 75 | Feng, Shangbin and 76 | Han, Xiaochuang and 77 | Balachandran, Vidhisha and 78 | Park, Chan Young and 79 | Kumar, Sachin and 80 | Tsvetkov, Yulia", 81 | editor = "Duh, Kevin and 82 | Gomez, Helena and 83 | Bethard, Steven", 84 | booktitle = "Proceedings of the 2024 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (Volume 1: Long Papers)", 85 | month = jun, 86 | year = "2024", 87 | address = "Mexico City, Mexico", 88 | publisher = "Association for Computational Linguistics", 89 | url = "https://aclanthology.org/2024.naacl-long.119", 90 | pages = "2154--2173", 91 | abstract = "In this work, we take a first step towards designing summarization systems that are faithful to the author{'}s intent, not only the semantic content of the article. Focusing on a case study of preserving political perspectives in news summarization, we find that existing approaches alter the political opinions and stances of news articles in more than 50{\%} of summaries, misrepresenting the intent and perspectives of the news authors. We thus propose P$^3$Sum, a diffusion model-based summarization approach controlled by political perspective classifiers. In P$^3$Sum, the political leaning of a generated summary is iteratively evaluated at each decoding step, and any drift from the article{'}s original stance incurs a loss back-propagated to the embedding layers, steering the political stance of the summary at inference time. Extensive experiments on three news summarization datasets demonstrate that P$^3$Sum outperforms state-of-the-art summarization systems and large language models by up to 13.7{\%} in terms of the success rate of stance preservation, with competitive performance on standard metrics of summarization quality. Our findings present a first analysis of preservation of pragmatic features in summarization, highlight the lacunae in existing summarization models{---}that even state-of-the-art models often struggle to preserve author{'}s intents{---}and develop new summarization systems that are more faithful to author{'}s perspectives.", 92 | } 93 | 94 | ``` -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/utils.py: -------------------------------------------------------------------------------- 1 | """Defines the utilities used during the training/infernece of diffusion language models.""" 2 | import torch.nn.functional as F 3 | import os 4 | import re 5 | import pdb 6 | from pathlib import Path 7 | from transformers.utils import logging 8 | import shutil 9 | import numpy as np 10 | from typing import Callable, Iterable, List 11 | import torch 12 | 13 | logger = logging.get_logger(__name__) 14 | 15 | 16 | def join_texts(prefixes, sentences): 17 | """Joins prefixes to setences.""" 18 | return [f"{prefix}{sentence}" for prefix, sentence in zip(prefixes, sentences)] 19 | 20 | 21 | def convert_to_simplex(token_ids, simplex_value, vocab_size): 22 | return 2 * simplex_value * F.one_hot(token_ids, vocab_size) - simplex_value 23 | 24 | 25 | def scale(inputs, scale_value): 26 | return inputs / scale_value 27 | 28 | 29 | def get_last_checkpoint(folder, prefix_checkpoint_dir="step"): 30 | re_checkpoint = re.compile(r"^" + prefix_checkpoint_dir + r"\_(\d+)$") 31 | content = os.listdir(folder) 32 | checkpoints = [ 33 | path for path in content if re_checkpoint.search(path) is not None and os.path.isdir(os.path.join(folder, path)) 34 | ] 35 | if len(checkpoints) == 0: 36 | return 37 | return os.path.join(folder, max(checkpoints, key=lambda x: int(re_checkpoint.search(x).groups()[0]))) 38 | 39 | 40 | def remove_checkpoints(output_dir, checkpoint_prefix="step"): 41 | checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}_*") if os.path.isdir(x)] 42 | for checkpoint in checkpoints: 43 | logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") 44 | shutil.rmtree(checkpoint) 45 | 46 | 47 | def get_norm_stats(model): 48 | # Gradient norm of word embeddings and lm_head. 49 | input_embed_grad_norm = 0 50 | if model.roberta.embeddings.word_embeddings.weight.grad is not None: 51 | input_embed_grad_norm = model.roberta.embeddings.word_embeddings.weight.grad.detach().data.norm(2).item() 52 | 53 | output_embed_grad_norm = 0.0 54 | if model.lm_head.decoder.weight.grad is not None: 55 | output_embed_grad_norm = model.lm_head.decoder.weight.grad.detach().data.norm(2).item() 56 | 57 | """ 58 | total_grad_norm = 0.0 59 | for p in model.parameters(): 60 | grad_norm = 0.0 61 | if p.grad is not None: 62 | grad_norm = p.grad.detach().data.norm(2).item() 63 | total_grad_norm += grad_norm ** 2 64 | total_grad_norm = total_grad_norm ** 0.5 65 | 66 | # Norms of word embeddings and lm_head. 67 | input_embed_norm = model.roberta.embeddings.word_embeddings.weight.detach().data.norm(2).item() 68 | output_embed_norm = model.lm_head.decoder.weight.detach().data.norm(2).item() 69 | total_param_norm = 0.0 70 | for p in model.parameters(): 71 | param_norm = p.detach().data.norm(2) 72 | total_param_norm += param_norm.item() ** 2 73 | total_param_norm = total_param_norm ** 0.5 74 | """ 75 | return { 76 | "input_embed_grad_norm": input_embed_grad_norm, 77 | "output_embed_grad_norm": output_embed_grad_norm, 78 | # "total_grad_norm": total_grad_norm, 79 | # "input_embed_norm": input_embed_norm, 80 | # "output_embed_norm": output_embed_norm, 81 | # "total_param_norm": total_param_norm 82 | } 83 | 84 | 85 | def self_condition_preds(self_condition, logits, logits_projection=None): 86 | if self_condition in ["logits", "logits_addition", "logits_mean", "logits_max", "logits_multiply"]: 87 | previous_pred = logits.detach() 88 | elif self_condition in ["logits_with_projection", "logits_with_projection_addition"]: 89 | previous_pred = logits_projection(logits.detach()) 90 | else: 91 | assert NotImplementedError(f"{self_condition} is not implemented.") 92 | return previous_pred 93 | 94 | def mix_values_based_on_self_condition(self_condition_type, value_1, value_2): 95 | if self_condition_type in ["logits_with_projection_addition", "logits_addition"]: 96 | mixed_values = value_1 + value_2 97 | elif self_condition_type == "logits_mean": 98 | mixed_values = (value_1 + value_2) / 2.0 99 | elif self_condition_type == "logits_max": 100 | mixed_values = torch.max(value_1, value_2) 101 | elif self_condition_type == "logits_multiply": 102 | mixed_values = value_1 * value_2 103 | else: 104 | assert NotImplementedError 105 | return mixed_values 106 | 107 | def round_stsb_target(label): 108 | """STSB maps two sentences to a floating point number between 1 and 5 109 | representing their semantic similarity. Since we are treating all tasks as 110 | text-to-text tasks we need to convert this floating point number to a string. 111 | The vast majority of the similarity score labels in STSB are in the set 112 | [0, 0.2, 0.4, ..., 4.8, 5.0]. So, we first round the number to the closest 113 | entry in this set, and then we convert the result to a string (literally e.g. 114 | "3.4"). This converts STSB roughly into a 26-class classification dataset. 115 | Args: 116 | label: original label. 117 | Returns: 118 | A preprocessed label. 119 | """ 120 | return np.round((label * 5) / 5, decimals=1) 121 | 122 | 123 | def lmap(f: Callable, x: Iterable) -> List: 124 | """list(map(f, x))""" 125 | return list(map(f, x)) 126 | 127 | 128 | def pad_data(data_list, tokenizer): 129 | return tokenizer.pad({"input_ids": data_list}, padding=True)["input_ids"] 130 | -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/metrics/metrics.py: -------------------------------------------------------------------------------- 1 | """Implements the metrics for evaluation of the diffusion models.""" 2 | from mauve import compute_mauve 3 | from nltk.util import ngrams 4 | import numpy as np 5 | from collections import Counter 6 | from scipy import stats 7 | import operator 8 | import math 9 | import scipy 10 | import sklearn 11 | import pdb 12 | 13 | MAX_TEXT_LENGTH = 256 14 | 15 | 16 | def mauve(predictions, references, featurize_model_name="gpt2-large", length=MAX_TEXT_LENGTH): 17 | """Computes MAUVE scores between two lists of generated text and reference text. 18 | Args: 19 | predictions (list of str) of predictions. 20 | reference (list of str) of references. 21 | """ 22 | results = compute_mauve( 23 | p_text=references, # human-text. 24 | q_text=predictions, # machine-text. 25 | max_text_length=length, 26 | featurize_model_name=featurize_model_name, 27 | verbose=False, 28 | # These are the tricks to make `mauve` run faster if #examples > 5K. 29 | # See https://github.com/krishnap25/mauve#best-practices-for-mauve 30 | # num_buckets=500 if len(predictions) > 5000 else "auto", 31 | # kmeans_num_redo=1, 32 | ) 33 | return {"muave": results.mauve} 34 | 35 | 36 | def distinct_n_grams(texts): 37 | """Computes the average distinct n-grams of the generated texts. 38 | Args: 39 | texts (list of str): representing the generated texts. 40 | """ 41 | dist_1, dist_2, dist_3, dist_4 = [], [], [], [] 42 | for text in texts: 43 | total_words = len(text.split()) 44 | unigrams = set(ngrams(text.split(), 1)) 45 | bigrams = set(ngrams(text.split(), 2)) 46 | trigrams = set(ngrams(text.split(), 3)) 47 | fourgrams = set(ngrams(text.split(), 4)) 48 | if total_words == 0: 49 | dist_1.append(0) 50 | dist_2.append(0) 51 | dist_3.append(0) 52 | dist_4.append(0) 53 | else: 54 | dist_1.append(len(unigrams) / total_words) 55 | dist_2.append(len(bigrams) / total_words) 56 | dist_3.append(len(trigrams) / total_words) 57 | dist_4.append(len(fourgrams) / total_words) 58 | return {"dist-1": np.nanmean(dist_1), "dist-2": np.nanmean(dist_2), "dist-3": np.nanmean(dist_3), "dist-4": np.nanmean(dist_4)} 59 | 60 | 61 | def zipf(tokenized_texts, N=5000): 62 | """Computes the Zipf coefficient. 63 | 64 | Args: 65 | tokenized_texts (List[List[int]]) tokenized texts. 66 | Adapted from https://github.com/ari-holtzman/degen/blob/master/metrics/zipf.py 67 | """ 68 | cnt = Counter() 69 | for tokenized_text in tokenized_texts: 70 | cnt.update(tokenized_text) 71 | 72 | xs = np.arange(1, min(len(cnt), N) + 1) 73 | ys = np.array(sorted(cnt.values(), key=operator.neg)[:N]) 74 | a, b, r, p, std = stats.linregress(np.log(xs), np.log(ys)) 75 | # Note that zipf_minus_a is the reported number. 76 | return {"zipf_minus_a": -a, "zipf_minus_r": -r, "zipf_p": p} 77 | 78 | 79 | def accuracy(predictions, targets) -> dict: 80 | """Computes the average accuracy.""" 81 | return {"accuracy": 100 * ((np.array(predictions) == np.array(targets)).mean())} 82 | 83 | 84 | def pearson_corrcoef(predictions, targets) -> dict: 85 | """Computes Pearson correlation coefficient.""" 86 | pearson_corrcoef = 100 * scipy.stats.pearsonr(targets, predictions)[0] 87 | 88 | # Note that if all the predictions will be the same, spearman 89 | # correlation is nan, to gaurad against this, we check the output 90 | # and return 0 in this case. 91 | if math.isnan(pearson_corrcoef): 92 | pearson_corrcoef = 0 93 | return {"pearson": pearson_corrcoef} 94 | 95 | 96 | def spearman_corrcoef(predictions, targets) -> dict: 97 | """Computes Spearman correlation coefficient.""" 98 | spearman_corrcoef = 100 * scipy.stats.spearmanr(targets, predictions)[0] 99 | 100 | # Note that if all the predictions will be the same, spearman 101 | # correlation is nan, to gaurad against this, we check the output 102 | # and return 0 in this case. 103 | if math.isnan(spearman_corrcoef): 104 | spearman_corrcoef = 0 105 | return {"spearmanr": spearman_corrcoef} 106 | 107 | 108 | def f1_score_with_invalid(predictions, targets) -> dict: 109 | """Computes F1 score, with any prediction != 0 or 1 is counted as incorrect. 110 | Args: 111 | targets: list of targets, either 0 or 1 112 | predictions: list of predictions, any integer value 113 | Returns: 114 | F1 score, where any prediction != 0 or 1 is counted as wrong. 115 | """ 116 | targets, predictions = np.asarray(targets), np.asarray(predictions) 117 | # Get indices of invalid predictions. 118 | invalid_idx_mask = np.logical_and(predictions != 0, predictions != 1) 119 | # For any prediction != 0 or 1, we set the prediction to the opposite of its corresponding target. 120 | predictions[invalid_idx_mask] = 1 - targets[invalid_idx_mask] 121 | return {"f1": 100 * sklearn.metrics.f1_score(targets, predictions)} 122 | 123 | 124 | # TODO: maybe gaurd against invalid values https://stackoverflow.com/questions/56865344/how-do-i-calculate-the-matthews-correlation-coefficient-in-tensorflow 125 | def matthews_corrcoef(predictions, targets) -> dict: 126 | """Computes the Matthews correlation coefficient.""" 127 | return {"matthews_correlation": 100 * sklearn.metrics.matthews_corrcoef(targets, predictions)} 128 | 129 | 130 | def get_glue_metrics(task): 131 | GLUE_TASKS_TO_METRICS = { 132 | "mrpc": [f1_score_with_invalid, accuracy], 133 | "cola": [matthews_corrcoef], 134 | "sst2": [accuracy], 135 | "stsb": [pearson_corrcoef, spearman_corrcoef], 136 | "qqp": [f1_score_with_invalid, accuracy], 137 | "mnli": [accuracy], 138 | "qnli": [accuracy], 139 | "rte": [accuracy], 140 | "wnli": [accuracy], 141 | } 142 | return GLUE_TASKS_TO_METRICS[task] 143 | -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/models/h3/ops/vandermonde.py: -------------------------------------------------------------------------------- 1 | # TD [2023-01-05]: Copied from https://github.com/HazyResearch/state-spaces/blob/06dbbdfd0876501a7f12bf3262121badbc7658af/src/models/functional/vandermonde.py 2 | # We add the interface to the log vandermonde CUDA code 3 | 4 | """pykeops implementations of the Vandermonde matrix multiplication kernel used in the S4D kernel.""" 5 | import math 6 | import os 7 | 8 | import torch 9 | from einops import rearrange, repeat 10 | from opt_einsum import contract 11 | 12 | try: 13 | import pykeops 14 | from pykeops.torch import Genred, LazyTensor 15 | except: 16 | pass 17 | 18 | try: 19 | from cauchy_mult import vand_log_mult_sym_bwd, vand_log_mult_sym_fwd 20 | except: 21 | vand_log_mult_sym_fwd, vand_log_mult_sym_bwd = None, None 22 | 23 | 24 | _conj = lambda x: torch.cat([x, x.conj()], dim=-1) 25 | 26 | 27 | def _broadcast_dims(*tensors): 28 | max_dim = max([len(tensor.shape) for tensor in tensors]) 29 | tensors = [ 30 | tensor.view((1,) * (max_dim - len(tensor.shape)) + tensor.shape) 31 | for tensor in tensors 32 | ] 33 | return tensors 34 | 35 | 36 | def _c2r(x): 37 | return torch.view_as_real(x) 38 | 39 | 40 | def _r2c(x): 41 | return torch.view_as_complex(x) 42 | 43 | 44 | def vandermonde_naive(v, x, L, conj=True): 45 | """ 46 | v: (..., N) 47 | x: (..., N) 48 | returns: (..., L) \sum v x^l 49 | """ 50 | if conj: 51 | x = _conj(x) 52 | v = _conj(v) 53 | vandermonde_matrix = x.unsqueeze(-1) ** torch.arange(L).to(x) # (... N L) 54 | vandermonde_prod = torch.sum( 55 | v.unsqueeze(-1) * vandermonde_matrix, dim=-2 56 | ) # (... L) 57 | return vandermonde_prod 58 | 59 | 60 | def log_vandermonde_naive(v, x, L, conj=True): 61 | """ 62 | v: (..., N) 63 | x: (..., N) 64 | returns: (..., L) \sum v x^l 65 | """ 66 | vandermonde_matrix = torch.exp(x.unsqueeze(-1) * torch.arange(L).to(x)) # (... N L) 67 | vandermonde_prod = contract( 68 | "... n, ... n l -> ... l", v, vandermonde_matrix 69 | ) # (... L) 70 | if conj: 71 | return 2 * vandermonde_prod.real 72 | else: 73 | return vandermonde_prod 74 | 75 | 76 | def log_vandermonde_lazy(v, x, L, conj=True): 77 | if conj: 78 | v = _conj(v) 79 | x = _conj(x) 80 | l = torch.arange(L).to(x) 81 | v, x, l = _broadcast_dims(v, x, l) 82 | v_l = LazyTensor(rearrange(v, "... N -> ... N 1 1")) 83 | x_l = LazyTensor(rearrange(x, "... N -> ... N 1 1")) 84 | l_l = LazyTensor(rearrange(l, "... L -> ... 1 L 1")) 85 | # exp 86 | vand = (x_l * l_l).exp() 87 | s = (v_l * vand).sum(dim=len(v_l.shape) - 2) 88 | return s.squeeze(-1) 89 | 90 | 91 | def log_vandermonde(v, x, L, conj=True): 92 | expr = "ComplexMult(v, ComplexExp(ComplexMult(x, l)))" 93 | vandermonde_mult = Genred( 94 | expr, 95 | [ 96 | "v = Vj(2)", 97 | "x = Vj(2)", 98 | "l = Vi(2)", 99 | ], 100 | reduction_op="Sum", 101 | axis=1, 102 | ) 103 | 104 | l = torch.arange(L).to(x) 105 | v, x, l = _broadcast_dims(v, x, l) 106 | v = _c2r(v) 107 | x = _c2r(x) 108 | l = _c2r(l) 109 | 110 | r = vandermonde_mult(v, x, l, backend="GPU") 111 | if conj: 112 | return 2 * _r2c(r).real 113 | else: 114 | return _r2c(r) 115 | 116 | 117 | def log_vandermonde_transpose_naive(u, v, x, L): 118 | vandermonde_matrix = torch.exp(x.unsqueeze(-1) * torch.arange(L).to(x)) # (... N L) 119 | vandermonde_prod = contract( 120 | "... l, ... n, ... n l -> ... n", u.to(x), v.to(x), vandermonde_matrix 121 | ) # (... L) 122 | return vandermonde_prod 123 | 124 | 125 | def log_vandermonde_transpose(u, v, x, L): 126 | """ 127 | u: ... H L 128 | v: ... H N 129 | x: ... H N 130 | Returns: ... H N 131 | 132 | V = Vandermonde(a, L) : (H N L) 133 | contract_L(V * u * v) 134 | """ 135 | expr = "ComplexMult(ComplexMult(v, u), ComplexExp(ComplexMult(x, l)))" 136 | vandermonde_mult = Genred( 137 | expr, 138 | [ 139 | "u = Vj(2)", 140 | "v = Vi(2)", 141 | "x = Vi(2)", 142 | "l = Vj(2)", 143 | ], 144 | reduction_op="Sum", 145 | axis=1, 146 | ) 147 | 148 | l = torch.arange(L).to(x) 149 | u, v, x, l = _broadcast_dims(u, v, x, l) 150 | u = _c2r(u) 151 | v = _c2r(v) 152 | x = _c2r(x) 153 | l = _c2r(l) 154 | 155 | r = vandermonde_mult(u, v, x, l, backend="GPU") 156 | return _r2c(r) 157 | 158 | 159 | def _log_vandermonde_matmul(x, L): 160 | vandermonde_matrix = torch.exp(x.unsqueeze(-1) * torch.arange(L).to(x)) # (... N L) 161 | return vandermonde_matrix 162 | 163 | 164 | def log_vandermonde_matmul(v, K): 165 | prod = contract("...n, ...nl -> ...l", v, K) 166 | return 2 * prod.real 167 | 168 | 169 | class LogVandMultiplySymmetric(torch.autograd.Function): 170 | @staticmethod 171 | def forward(ctx, v, x, L): 172 | batch, N = v.shape 173 | supported_N_values = [1 << log_n for log_n in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]] 174 | if not N in supported_N_values: 175 | raise NotImplementedError(f"Only support N values in {supported_N_values}") 176 | max_L_value = 32 * 1024 * 64 * 1024 177 | if L > max_L_value: 178 | raise NotImplementedError(f"Only support L values <= {max_L_value}") 179 | if not v.is_cuda and x.is_cuda: 180 | raise NotImplementedError(f"Only support CUDA tensors") 181 | ctx.save_for_backward(v, x) 182 | return vand_log_mult_sym_fwd(v, x, L) 183 | 184 | @staticmethod 185 | def backward(ctx, dout): 186 | v, x = ctx.saved_tensors 187 | dv, dx = vand_log_mult_sym_bwd(v, x, dout) 188 | return dv, dx, None 189 | 190 | 191 | if vand_log_mult_sym_fwd and vand_log_mult_sym_bwd is not None: 192 | log_vandermonde_fast = LogVandMultiplySymmetric.apply 193 | else: 194 | log_vandermonde_fast = None 195 | -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/models/h3/ops/toeplitz.py: -------------------------------------------------------------------------------- 1 | # Downloaded from https://github.com/HazyResearch/state-spaces/blob/06dbbdfd0876501a7f12bf3262121badbc7658af/src/models/functional/toeplitz.py 2 | """ Utilities for computing convolutions. 3 | There are 3 equivalent views: 4 | 1. causal convolution 5 | 2. multiplication of (lower) triangular Toeplitz matrices 6 | 3. polynomial multiplication (mod x^N) 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | def construct_toeplitz(v, f=0.0): 15 | """Explicit construction of Krylov matrix [v A @ v A^2 @ v ... A^{n-1} @ v] 16 | where A = Z_f. This uses vectorized indexing and cumprod so it's much 17 | faster than using the Krylov function. 18 | Parameters: 19 | v: the starting vector of size n or (rank, n). 20 | f: real number 21 | Returns: 22 | K: Krylov matrix of size (n, n) or (rank, n, n). 23 | """ 24 | n = v.shape[-1] 25 | a = torch.arange(n, device=v.device) 26 | b = -a 27 | indices = a[:, None] + b[None] 28 | K = v[..., indices] 29 | K[..., indices < 0] *= f 30 | return K 31 | 32 | 33 | def triangular_toeplitz_multiply_(u, v, sum=None): 34 | n = u.shape[-1] 35 | u_expand = F.pad(u, (0, n)) 36 | v_expand = F.pad(v, (0, n)) 37 | u_f = torch.fft.rfft(u_expand, n=2 * n, dim=-1) 38 | v_f = torch.fft.rfft(v_expand, n=2 * n, dim=-1) 39 | uv_f = u_f * v_f 40 | if sum is not None: 41 | uv_f = uv_f.sum(dim=sum) 42 | output = torch.fft.irfft(uv_f, n=2 * n, dim=-1)[..., :n] 43 | return output 44 | 45 | 46 | def triangular_toeplitz_multiply_padded_(u, v): 47 | """Same as triangular_toeplitz_multiply but inputs and output assume to be 0-padded already.""" 48 | n = u.shape[-1] 49 | assert n % 2 == 0 50 | u_f = torch.fft.rfft(u, n=n, dim=-1) 51 | v_f = torch.fft.rfft(v, n=n, dim=-1) 52 | uv_f = u_f * v_f 53 | output = torch.fft.irfft(uv_f, n=n, dim=-1) 54 | output[..., n:] = 0 55 | return output 56 | 57 | 58 | class TriangularToeplitzMult(torch.autograd.Function): 59 | @staticmethod 60 | def forward(ctx, u, v): 61 | ctx.save_for_backward(u, v) 62 | return triangular_toeplitz_multiply_(u, v) 63 | 64 | @staticmethod 65 | def backward(ctx, grad): 66 | u, v = ctx.saved_tensors 67 | d_u = triangular_toeplitz_multiply_(grad.flip(-1), v).flip(-1) 68 | d_v = triangular_toeplitz_multiply_(grad.flip(-1), u).flip(-1) 69 | return d_u, d_v 70 | 71 | 72 | class TriangularToeplitzMultFast(torch.autograd.Function): 73 | @staticmethod 74 | def forward(ctx, u, v): 75 | n = u.shape[-1] 76 | u_expand = F.pad(u, (0, n)) 77 | v_expand = F.pad(v, (0, n)) 78 | u_f = torch.fft.rfft(u_expand, n=2 * n, dim=-1) 79 | v_f = torch.fft.rfft(v_expand, n=2 * n, dim=-1) 80 | 81 | ctx.save_for_backward(u_f, v_f) 82 | 83 | uv_f = u_f * v_f 84 | output = torch.fft.irfft(uv_f, n=2 * n, dim=-1)[..., :n] 85 | return output 86 | 87 | @staticmethod 88 | def backward(ctx, grad): 89 | u_f, v_f = ctx.saved_tensors 90 | n = grad.shape[-1] 91 | g_expand = F.pad(grad.flip(-1), (0, n)) 92 | g_f = torch.fft.rfft(g_expand, n=2 * n, dim=-1) 93 | gu_f = g_f * u_f 94 | gv_f = g_f * v_f 95 | d_u = torch.fft.irfft(gv_f, n=2 * n, dim=-1)[..., :n] 96 | d_v = torch.fft.irfft(gu_f, n=2 * n, dim=-1)[..., :n] 97 | d_u = d_u.flip(-1) 98 | d_v = d_v.flip(-1) 99 | return d_u, d_v 100 | 101 | 102 | class TriangularToeplitzMultPadded(torch.autograd.Function): 103 | @staticmethod 104 | def forward(ctx, u, v): 105 | ctx.save_for_backward(u, v) 106 | output = triangular_toeplitz_multiply_(u, v) 107 | return output 108 | 109 | @staticmethod 110 | def backward(ctx, grad): 111 | u, v = ctx.saved_tensors 112 | d_u = triangular_toeplitz_multiply_padded_(grad.flip(-1), v).flip(-1) 113 | d_v = triangular_toeplitz_multiply_padded_(grad.flip(-1), u).flip(-1) 114 | return d_u, d_v 115 | 116 | 117 | class TriangularToeplitzMultPaddedFast(torch.autograd.Function): 118 | """Trade off speed (20-25% faster) for more memory (20-25%)""" 119 | 120 | @staticmethod 121 | def forward(ctx, u, v): 122 | n = u.shape[-1] 123 | u_f = torch.fft.rfft(u, n=n, dim=-1) 124 | v_f = torch.fft.rfft(v, n=n, dim=-1) 125 | 126 | ctx.save_for_backward(u_f, v_f) 127 | 128 | uv_f = u_f * v_f 129 | output = torch.fft.irfft(uv_f, n=n, dim=-1) 130 | output[..., n // 2 :].zero_() 131 | return output 132 | 133 | @staticmethod 134 | def backward(ctx, grad): 135 | u_f, v_f = ctx.saved_tensors 136 | n = grad.shape[-1] 137 | g_expand = F.pad(grad[..., : n // 2].flip(-1), (0, n // 2)) 138 | g_f = torch.fft.rfft(g_expand, n=n, dim=-1) 139 | gu_f = g_f * u_f 140 | gv_f = g_f * v_f 141 | d_u = torch.fft.irfft(gv_f, n=n, dim=-1) 142 | d_v = torch.fft.irfft(gu_f, n=n, dim=-1) 143 | d_u[..., n // 2 :].zero_() 144 | d_v[..., n // 2 :].zero_() 145 | d_u[..., : n // 2] = d_u[..., : n // 2].flip(-1) # TODO 146 | d_v[..., : n // 2] = d_v[..., : n // 2].flip(-1) # TODO 147 | return d_u, d_v 148 | 149 | 150 | # triangular_toeplitz_multiply = triangular_toeplitz_multiply_ 151 | triangular_toeplitz_multiply = TriangularToeplitzMult.apply 152 | triangular_toeplitz_multiply_fast = TriangularToeplitzMultFast.apply 153 | triangular_toeplitz_multiply_padded = TriangularToeplitzMultPadded.apply 154 | triangular_toeplitz_multiply_padded_fast = TriangularToeplitzMultPaddedFast.apply 155 | 156 | 157 | def causal_convolution(u, v, fast=True, pad=False): 158 | if not pad and not fast: 159 | return triangular_toeplitz_multiply(u, v) 160 | if not pad and fast: 161 | return triangular_toeplitz_multiply_fast(u, v) 162 | if pad and not fast: 163 | return triangular_toeplitz_multiply_padded(u, v) 164 | if pad and fast: 165 | return triangular_toeplitz_multiply_padded_fast(u, v) 166 | -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/models/h3/ops/fftconv.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from einops import rearrange 6 | from fftconv import fftconv_bwd, fftconv_fwd 7 | 8 | 9 | @torch.jit.script 10 | def _mul_sum(y, q): 11 | return (y * q).sum(dim=1) 12 | 13 | 14 | # reference convolution with residual connection 15 | def fftconv_ref(u, k, D, dropout_mask, gelu=True, k_rev=None): 16 | seqlen = u.shape[-1] 17 | fft_size = 2 * seqlen 18 | k_f = torch.fft.rfft(k, n=fft_size) / fft_size 19 | if k_rev is not None: 20 | k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size 21 | k_f = k_f + k_rev_f.conj() 22 | u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size) 23 | y = torch.fft.irfft(u_f * k_f, n=fft_size, norm="forward")[..., :seqlen] 24 | out = y + u * D.unsqueeze(-1) 25 | if gelu: 26 | out = F.gelu(out) 27 | if dropout_mask is not None: 28 | return (out * rearrange(dropout_mask, "b H -> b H 1")).to(dtype=u.dtype) 29 | else: 30 | return out.to(dtype=u.dtype) 31 | 32 | 33 | # reference H3 forward pass 34 | def fftconv_h3_ref(k, ssm_kernel, D, q, v, head_dim=1, ssm_kernel_rev=None): 35 | seqlen = k.shape[-1] 36 | fft_size = 2 * seqlen 37 | kv = rearrange(k, "b (h d1) l -> b d1 1 h l", d1=head_dim) * rearrange( 38 | v, "b (h d2) l -> b 1 d2 h l", d2=head_dim 39 | ) # b d1 d2 h l 40 | kv_f = torch.fft.rfft(kv.to(dtype=ssm_kernel.dtype), n=fft_size) / fft_size 41 | ssm_kernel_f = torch.fft.rfft(ssm_kernel, n=fft_size) # h L+1 42 | if ssm_kernel_rev is not None: 43 | ssm_kernel_rev_f = torch.fft.rfft(ssm_kernel_rev, n=fft_size) # h L+1 44 | ssm_kernel_f = ssm_kernel_f + ssm_kernel_rev_f.conj() 45 | y = torch.fft.irfft(kv_f * ssm_kernel_f, n=fft_size, norm="forward")[ 46 | ..., :seqlen 47 | ] # b d1 d2 h l 48 | out = y + kv * D.unsqueeze(-1) # b d1 d2 h l 49 | q = rearrange(q, "b (h d1) l -> b d1 1 h l", d1=head_dim) 50 | if head_dim > 1: 51 | out = _mul_sum(out, q) 52 | return rearrange(out, "b d2 h l -> b (h d2) l").to(dtype=k.dtype) 53 | else: 54 | return rearrange(out * q, "b 1 1 h l -> b h l").to(dtype=k.dtype) 55 | 56 | 57 | class FFTConvFunc(torch.autograd.Function): 58 | @staticmethod 59 | def forward( 60 | ctx, 61 | u, 62 | k, 63 | D, 64 | dropout_mask=None, 65 | gelu=True, 66 | force_fp16_output=False, 67 | output_hbl_layout=False, 68 | v=None, 69 | head_dim=1, 70 | q=None, 71 | fftfp16=False, 72 | k_rev=None, 73 | ): 74 | seqlen = u.shape[-1] 75 | fft_size = max(2 * 2 ** int(math.ceil(math.log2(seqlen))), 16) 76 | k_f = torch.fft.rfft(k, n=fft_size) 77 | if k_rev is not None: 78 | k_f = k_f + torch.fft.rfft(k_rev, n=fft_size).conj() 79 | if u.stride(-1) != 1: 80 | u = u.contiguous() 81 | k_f = k_f.contiguous() 82 | D = D.contiguous() 83 | if v is not None and v.stride(-1) != 1: 84 | v = v.contiguous() 85 | if q is not None and q.stride(-1) != 1: 86 | q = q.contiguous() 87 | if dropout_mask is not None: 88 | dropout_mask = dropout_mask.contiguous() 89 | ctx.save_for_backward(u, k_f, D, dropout_mask, v, q) 90 | ctx.output_hbl_layout = output_hbl_layout 91 | ctx.head_dim = head_dim 92 | ctx.gelu = gelu 93 | ctx.fftfp16 = fftfp16 94 | ctx.has_k_rev = k_rev is not None 95 | out = fftconv_fwd( 96 | u, 97 | k_f, 98 | D, 99 | v, 100 | head_dim, 101 | q, 102 | dropout_mask, 103 | gelu, 104 | False, 105 | False, 106 | fft_size, 107 | force_fp16_output, 108 | output_hbl_layout, 109 | fftfp16, 110 | ) 111 | return out 112 | 113 | @staticmethod 114 | def backward(ctx, dout): 115 | if ctx.output_hbl_layout: 116 | dout = rearrange( 117 | rearrange(dout, "b h l -> h b l").contiguous(), "h b l -> b h l" 118 | ) 119 | else: 120 | dout = dout.contiguous() 121 | u, k_f, D, dropout_mask, v, q = ctx.saved_tensors 122 | seqlen = u.shape[-1] 123 | fft_size = max(2 * 2 ** int(math.ceil(math.log2(seqlen))), 16) 124 | du, dk_f, dD, dv, dq = fftconv_bwd( 125 | dout, 126 | u, 127 | k_f, 128 | D, 129 | v, 130 | ctx.head_dim, 131 | q, 132 | dropout_mask, 133 | ctx.gelu, 134 | False, 135 | False, 136 | fft_size, 137 | ctx.output_hbl_layout, 138 | ctx.fftfp16, 139 | ) 140 | dk = torch.fft.irfft(dk_f, n=fft_size, norm="forward")[..., :seqlen] 141 | dk_rev = ( 142 | None 143 | if not ctx.has_k_rev 144 | else torch.fft.irfft(dk_f.conj(), n=fft_size, norm="forward")[..., :seqlen] 145 | ) 146 | if v is not None: 147 | dv = dv.to( 148 | dtype=v.dtype 149 | ) # We do atomicAdd in fp32 so might need to convert to fp16 150 | return ( 151 | du, 152 | dk, 153 | dD, 154 | None, 155 | None, 156 | None, 157 | None, 158 | dv if v is not None else None, 159 | None, 160 | dq if q is not None else None, 161 | None, 162 | dk_rev, 163 | ) 164 | 165 | 166 | def fftconv_func( 167 | u, 168 | k, 169 | D, 170 | dropout_mask=None, 171 | gelu=True, 172 | force_fp16_output=False, 173 | output_hbl_layout=False, 174 | v=None, 175 | head_dim=1, 176 | q=None, 177 | fftfp16=False, 178 | k_rev=None, 179 | ): 180 | return FFTConvFunc.apply( 181 | u, 182 | k, 183 | D, 184 | dropout_mask, 185 | gelu, 186 | force_fp16_output, 187 | output_hbl_layout, 188 | v, 189 | head_dim, 190 | q, 191 | fftfp16, 192 | k_rev, 193 | ) 194 | -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/models/h3/ops/krylov.py: -------------------------------------------------------------------------------- 1 | # Downloaded from https://github.com/HazyResearch/state-spaces/blob/06dbbdfd0876501a7f12bf3262121badbc7658af/src/models/functional/krylov.py 2 | """ Compute a Krylov function efficiently. (S4 renames the Krylov function to a "state space kernel") 3 | A : (N, N) 4 | b : (N,) 5 | c : (N,) 6 | Return: [c^T A^i b for i in [L]] 7 | """ 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from einops import rearrange, repeat 12 | 13 | from .toeplitz import causal_convolution 14 | 15 | 16 | def krylov_sequential(L, A, b, c=None): 17 | """Constant matrix A 18 | A : (..., N, N) 19 | b : (..., N) 20 | c : (..., N) 21 | Returns 22 | if c: 23 | x : (..., L) 24 | x[i, l] = c[i] @ A^l @ b[i] 25 | else: 26 | x : (..., N, L) 27 | x[i, l] = A^l @ b[i] 28 | """ 29 | 30 | # Check which of dim b and c is smaller to save memory 31 | if c is not None and c.numel() < b.numel(): 32 | return krylov_sequential(L, A.transpose(-1, -2), c, b) 33 | 34 | b_ = b 35 | x = [] 36 | for _ in range(L): 37 | if c is not None: 38 | x_ = torch.sum( 39 | c * b_, dim=-1 40 | ) # (...) # could be faster with matmul or einsum? 41 | else: 42 | x_ = b_ 43 | x.append(x_) 44 | b_ = (A @ b_.unsqueeze(-1)).squeeze(-1) 45 | 46 | x = torch.stack(x, dim=-1) 47 | return x 48 | 49 | 50 | def krylov(L, A, b, c=None, return_power=False): 51 | """ 52 | Compute the Krylov matrix (b, Ab, A^2b, ...) using the squaring trick. 53 | If return_power=True, return A^{L-1} as well 54 | """ 55 | # TODO There is an edge case if L=1 where output doesn't get broadcasted, which might be an issue if caller is expecting broadcasting semantics... can deal with it if it arises 56 | 57 | x = b.unsqueeze(-1) # (..., N, 1) 58 | A_ = A 59 | 60 | AL = None 61 | if return_power: 62 | AL = torch.eye(A.shape[-1], dtype=A.dtype, device=A.device) 63 | _L = L - 1 64 | 65 | done = L == 1 66 | # loop invariant: _L represents how many indices left to compute 67 | while not done: 68 | if return_power: 69 | if _L % 2 == 1: 70 | AL = A_ @ AL 71 | _L //= 2 72 | 73 | # Save memory on last iteration 74 | l = x.shape[-1] 75 | if L - l <= l: 76 | done = True 77 | _x = x[..., : L - l] 78 | else: 79 | _x = x 80 | 81 | _x = A_ @ _x 82 | x = torch.cat( 83 | [x, _x], dim=-1 84 | ) # there might be a more efficient way of ordering axes 85 | if not done: 86 | A_ = A_ @ A_ 87 | 88 | assert x.shape[-1] == L 89 | 90 | if c is not None: 91 | x = torch.einsum("...nl, ...n -> ...l", x, c) 92 | x = x.contiguous() # WOW!! 93 | if return_power: 94 | return x, AL 95 | else: 96 | return x 97 | 98 | 99 | @torch.no_grad() 100 | def power(L, A, v=None): 101 | """Compute A^L and the scan sum_i A^i v_i 102 | A: (..., N, N) 103 | v: (..., N, L) 104 | """ 105 | 106 | I = torch.eye(A.shape[-1]).to(A) # , dtype=A.dtype, device=A.device) 107 | 108 | powers = [A] 109 | l = 1 110 | while True: 111 | if L % 2 == 1: 112 | I = powers[-1] @ I 113 | L //= 2 114 | if L == 0: 115 | break 116 | l *= 2 117 | if v is None: 118 | powers = [powers[-1] @ powers[-1]] 119 | else: 120 | powers.append(powers[-1] @ powers[-1]) 121 | 122 | if v is None: 123 | return I 124 | 125 | # Invariants: 126 | # powers[-1] := A^l 127 | # l := largest po2 at most L 128 | 129 | # Note that an alternative divide and conquer to compute the reduction is possible and can be embedded into the above loop without caching intermediate powers of A 130 | # We do this reverse divide-and-conquer for efficiency reasons: 131 | # 1) it involves fewer padding steps for non-po2 L 132 | # 2) it involves more contiguous arrays 133 | 134 | # Take care of edge case for non-po2 arrays 135 | # Note that this initial step is a no-op for the case of power of 2 (l == L) 136 | k = v.size(-1) - l 137 | v_ = powers.pop() @ v[..., l:] 138 | v = v[..., :l] 139 | v[..., :k] = v[..., :k] + v_ 140 | 141 | # Handle reduction for power of 2 142 | while v.size(-1) > 1: 143 | v = rearrange(v, "... (z l) -> ... z l", z=2) 144 | v = v[..., 0, :] + powers.pop() @ v[..., 1, :] 145 | return I, v.squeeze(-1) 146 | 147 | 148 | def krylov_toeplitz(L, A, b, c=None): 149 | """Specializes to lower triangular Toeplitz matrix A represented by its diagonals 150 | A : (..., N) 151 | b : (..., N) 152 | c : (..., N) 153 | Returns 154 | x : (..., N, L) 155 | x[i, l] = A^l @ b[i] 156 | """ 157 | x = b.unsqueeze(0) # (1, ..., N) 158 | A_ = A 159 | while x.shape[0] < L: 160 | xx = causal_convolution(A_, x) 161 | x = torch.cat( 162 | [x, xx], dim=0 163 | ) # there might be a more efficient way of ordering axes 164 | A_ = causal_convolution(A_, A_) 165 | x = x[:L, ...] # (L, ..., N) 166 | if c is not None: 167 | x = torch.einsum("l...n, ...n -> ...l", x, c) 168 | else: 169 | x = rearrange(x, "l ... n -> ... n l") 170 | x = x.contiguous() 171 | return x 172 | 173 | 174 | def krylov_toeplitz_(L, A, b, c=None): 175 | """Padded version of krylov_toeplitz that saves some fft's 176 | TODO currently not faster than original version, not sure why 177 | """ 178 | N = A.shape[-1] 179 | 180 | x = b.unsqueeze(0) # (1, ..., N) 181 | x = F.pad(x, (0, N)) 182 | A = F.pad(A, (0, N)) 183 | done = L == 1 184 | while not done: 185 | l = x.shape[0] 186 | # Save memory on last iteration 187 | if L - l <= l: 188 | done = True 189 | _x = x[: L - l] 190 | else: 191 | _x = x 192 | Af = torch.fft.rfft(A, n=2 * N, dim=-1) 193 | xf = torch.fft.rfft(_x, n=2 * N, dim=-1) 194 | xf_ = Af * xf 195 | x_ = torch.fft.irfft(xf_, n=2 * N, dim=-1) 196 | x_[..., N:] = 0 197 | x = torch.cat( 198 | [x, x_], dim=0 199 | ) # there might be a more efficient way of ordering axes 200 | if not done: 201 | A = torch.fft.irfft(Af * Af, n=2 * N, dim=-1) 202 | A[..., N:] = 0 203 | x = x[:L, ..., :N] # (L, ..., N) 204 | if c is not None: 205 | x = torch.einsum("l...n, ...n -> ...l", x, c) 206 | else: 207 | x = rearrange(x, "l ... n -> ... n l") 208 | x = x.contiguous() 209 | return x 210 | -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/metrics/perplexity.py: -------------------------------------------------------------------------------- 1 | """Perplexity Metric. This file is adapted from: https://huggingface.co/spaces/evaluate-measurement/perplexity/blob/main/perplexity.py""" 2 | 3 | import numpy as np 4 | import torch 5 | from torch.nn import CrossEntropyLoss 6 | from evaluate import logging 7 | import pdb 8 | 9 | 10 | def perplexity( 11 | texts, model, tokenizer, batch_size: int = 16, add_start_token: bool = True, max_length=None, only_return_loss=False 12 | ): 13 | """Perplexity (PPL) can be used for evaluating to what extent a dataset is similar to the distribution of text that 14 | a given model was trained on. It is defined as the exponentiated average negative log-likelihood of a sequence, 15 | calculated with exponent base `e`. 16 | 17 | For more information, see https://huggingface.co/docs/transformers/perplexity 18 | 19 | Args: 20 | texts (list of str): List of text strings. 21 | model: model used for calculating Perplexity 22 | NOTE: Perplexity can only be calculated for causal language models. 23 | This includes models such as gpt2, causal variations of bert, 24 | causal versions of t5, and more (the full list can be found 25 | in the AutoModelForCausalLM documentation here: 26 | https://huggingface.co/docs/transformers/master/en/model_doc/auto#transformers.AutoModelForCausalLM ) 27 | tokenizer: the corresponding tokenizer for the given model. 28 | batch_size (int): the batch size to run texts through the model. Defaults to 16. 29 | add_start_token (bool): whether to add the start token to the texts, 30 | so the perplexity can include the probability of the first word. Defaults to True. 31 | Returns: 32 | perplexity: dictionary containing the perplexity scores for the texts 33 | in the input list, as well as the mean perplexity. If one of the input texts is 34 | longer than the max input length of the model, then it is truncated to the 35 | max length for the perplexity computation. 36 | """ 37 | device = model.device 38 | # if batch_size > 1 (which generally leads to padding being required), and 39 | # if there is not an already assigned pad_token, assign an existing 40 | # special token to also be the padding token 41 | if tokenizer.pad_token is None and batch_size > 1: 42 | existing_special_tokens = list(tokenizer.special_tokens_map_extended.values()) 43 | # check that the model already has at least one special token defined 44 | assert ( 45 | len(existing_special_tokens) > 0 46 | ), "If batch_size > 1, model must have at least one special token to use for padding. Please use a different model or set batch_size=1." 47 | # assign one of the special tokens to also be the pad token 48 | tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]}) 49 | 50 | if add_start_token and max_length: 51 | # leave room for token to be added: 52 | assert ( 53 | tokenizer.bos_token is not None 54 | ), "Input model must already have a BOS token if using add_start_token=True. Please use a different model, or set add_start_token=False" 55 | max_tokenized_len = max_length - 1 56 | else: 57 | max_tokenized_len = max_length 58 | 59 | encodings = tokenizer( 60 | texts, 61 | add_special_tokens=False, 62 | padding=True, 63 | truncation=True if max_tokenized_len else False, 64 | max_length=max_tokenized_len, 65 | return_tensors="pt", 66 | return_attention_mask=True, 67 | ).to(device) 68 | encoded_texts = encodings["input_ids"] 69 | attn_masks = encodings["attention_mask"] 70 | 71 | # check that each input is long enough: 72 | if add_start_token: 73 | assert torch.all(torch.ge(attn_masks.sum(1), 1)), "Each input text must be at least one token long." 74 | else: 75 | assert torch.all( 76 | torch.ge(attn_masks.sum(1), 2) 77 | ), "When add_start_token=False, each input text must be at least two tokens long. Run with add_start_token=True if inputting strings of only one token, and remove all empty input strings." 78 | 79 | ppls = [] 80 | loss_fct = CrossEntropyLoss(reduction="none") 81 | if only_return_loss: 82 | all_losses, all_lengths = [], [] 83 | for start_index in logging.tqdm(range(0, len(encoded_texts), batch_size)): 84 | end_index = min(start_index + batch_size, len(encoded_texts)) 85 | encoded_batch = encoded_texts[start_index:end_index] 86 | attn_mask = attn_masks[start_index:end_index] 87 | 88 | if add_start_token: 89 | bos_tokens_tensor = torch.tensor([[tokenizer.bos_token_id]] * encoded_batch.size(dim=0)).to(device) 90 | encoded_batch = torch.cat([bos_tokens_tensor, encoded_batch], dim=1) 91 | attn_mask = torch.cat([torch.ones(bos_tokens_tensor.size(), dtype=torch.int64).to(device), attn_mask], dim=1) 92 | 93 | labels = encoded_batch 94 | 95 | with torch.no_grad(): 96 | out_logits = model(encoded_batch, attention_mask=attn_mask).logits 97 | 98 | shift_logits = out_logits[..., :-1, :].contiguous() 99 | shift_labels = labels[..., 1:].contiguous() 100 | shift_attention_mask_batch = attn_mask[..., 1:].contiguous() 101 | 102 | loss = (loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch).sum(1) 103 | lengths = shift_attention_mask_batch.sum(1) 104 | if only_return_loss: 105 | all_losses.append(loss) 106 | all_lengths.append(lengths) 107 | else: 108 | perplexity_batch = torch.exp(loss / lengths) 109 | ppls += perplexity_batch.tolist() 110 | 111 | if only_return_loss: 112 | return all_losses, all_lengths 113 | else: 114 | return {"perplexities": ppls, "mean_perplexity": np.mean(ppls)} 115 | 116 | 117 | def conditional_perplexity( 118 | texts, prefixes, model, tokenizer, batch_size: int = 16, add_start_token: bool = True, max_length=None 119 | ): 120 | """Computes the conditional perplexity for the case of prefix language modeling.""" 121 | full_texts = [f"{prefix}{text}" for prefix,text in zip(prefixes, texts)] 122 | loss, lengths = perplexity(full_texts, model, tokenizer, batch_size, add_start_token, max_length, only_return_loss=True) 123 | prefix_loss, prefix_lengths = perplexity( 124 | prefixes, model, tokenizer, batch_size, add_start_token, max_length, only_return_loss=True 125 | ) 126 | # Computing the perplexity over the whole examples. 127 | ppls = [] 128 | total_nlls = 0 129 | total_tokens = 0 130 | for i in range(len(loss)): 131 | perplexity_batch = torch.exp((loss[i] - prefix_loss[i]) / (lengths[i] - prefix_lengths[i])) 132 | ppls.extend(perplexity_batch.tolist()) 133 | total_nlls += torch.sum(loss[i] - prefix_loss[i]).item() 134 | total_tokens += torch.sum(lengths[i] - prefix_lengths[i]).item() 135 | return {"perplexities": ppls, "mean_perplexity": np.nanmean(ppls), "mean_perplexity_total": np.exp(total_nlls/total_tokens)} 136 | -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/data/preprocessors.py: -------------------------------------------------------------------------------- 1 | """Implements data preprocessings including the T5 preprocessing.""" 2 | import numpy as np 3 | import itertools 4 | import pdb 5 | import torch 6 | 7 | 8 | # TODO: here the max perhaps needs to be also the half-length. 9 | def gpt_span_mask(length, pad_length, use_half_length_as_prefix_size, eval_context_size): 10 | """Given the length and pad_length for an input generates a prefix (GPT-style) mask.""" 11 | # Start of the sequence is not masked, so we consider length-1. 12 | # TODO: we need an assert for length not be smaller than a value. 13 | if not use_half_length_as_prefix_size: 14 | # high should be higher than low, otherwise we set prefix_size=1. 15 | prefix_size = np.random.randint(low=1, high=int((length - 1) / 2)) if length >= 5 else 1 16 | else: 17 | # If eval_context_size is set, we consider it, otherwise we use half of the given length as 18 | # context. Note that since the start token is also masked, we deduct one from the given 19 | # context size. 20 | prefix_size = eval_context_size - 1 if eval_context_size is not None else int((length - 1) / 2) 21 | # The start token is not masked. 22 | return [False] + [False] * prefix_size + [True] * (length - prefix_size - 1) + [False] * pad_length 23 | 24 | 25 | def gpt_span_mask_batch(batch, use_half_length_as_prefix_size=False, eval_context_size=None): 26 | lengths = [len(feature["input_ids"]) for feature in batch] 27 | max_length = max(lengths) 28 | masks = [ 29 | gpt_span_mask(length, max_length - length, use_half_length_as_prefix_size, eval_context_size) for length in lengths 30 | ] 31 | return torch.tensor(masks) 32 | 33 | 34 | def t5_random_spans_mask(length, mask_ratio, mean_mask_span_length=3.0, rng=None, pad_length=None): 35 | """Noise mask consisting of random spans of mask tokens. 36 | 37 | The number of mask tokens and the number of mask spans and non-mask spans 38 | are determined deterministically as follows: 39 | num_mask_tokens = round(length * mask_ratio) 40 | num_nonmask_spans = num_mask_spans = round( 41 | num_mask_tokens / mean_mask_span_length) 42 | Spans alternate between non-mask and mask, beginning with non-mask. 43 | Subject to the above restrictions, all masks are equally likely. 44 | Note that this function do not mask start/end of sequence. 45 | Args: 46 | length: an int32 scalar (length of the incoming token sequence) 47 | mask_ratio: a float - approximate ratio of output mask (between 0 and 1). 48 | mean_mask_span_length: Average mask length. 49 | rng = a np.random.default_rng() instance or None 50 | Returns: 51 | a boolean list of shape [length] 52 | adapted from https://github.com/google-research/text-to-text-transfer-transformer/blob/master/t5/data/preprocessors.py#L2704 53 | and https://github.com/allenai/contrastive_pretraining/blob/95fe35d3257402c7df362c3e0f746a40d9fba8f0/cpt/data.py#L288 54 | """ 55 | # By default, we do not maks start and end of sequence. 56 | # TODO: we need to put assert for this! 57 | # NOTE: this only works if we use line_by_line which we do not. So I had to remove it. 58 | # length -= 2 59 | orig_length = length 60 | # Increase length to avoid degeneracy. 61 | length = max(length, 2) 62 | 63 | # Compute number of mask tokens and mask spans. 64 | num_mask_tokens = int(length * mask_ratio) 65 | # Avoid degeneracy by ensuring positive numbers of mask and nonmask tokens. 66 | num_mask_tokens = min(max(num_mask_tokens, 1), length - 1) 67 | num_mask_spans = int(num_mask_tokens / mean_mask_span_length) 68 | # Avoid degeneracy by ensuring positive number of mask spans. 69 | num_mask_spans = max(num_mask_spans, 1) 70 | num_nonmask_tokens = length - num_mask_tokens 71 | mask_span_lengths = _random_segmentation(num_mask_tokens, num_mask_spans, rng=rng) 72 | nonmask_span_lengths = _random_segmentation(num_nonmask_tokens, num_mask_spans, rng=rng) 73 | mask = list( 74 | itertools.chain.from_iterable( 75 | [[False] * nonmask_span_lengths[k] + [True] * mask_span_lengths[k] for k in range(num_mask_spans)] 76 | ) 77 | )[:orig_length] 78 | # Start and end of the sequence mask are set to False. Again since this is not line_by_line, we 79 | # remove this. 80 | # mask = [False] + mask + [False] 81 | if pad_length is not None: 82 | mask += [False for _ in range(pad_length)] 83 | return mask 84 | 85 | 86 | def t5_random_spans_mask_batch(batch, mask_ratio, mean_mask_span_length=3.0, rng=None): 87 | """Given not padded inputs, generates the T5 mask for each input.""" 88 | lengths = [len(feature["input_ids"]) for feature in batch] 89 | max_length = max(lengths) 90 | masks = [t5_random_spans_mask(length, mask_ratio, mean_mask_span_length, rng, max_length - length) for length in lengths] 91 | return torch.tensor(masks) 92 | 93 | 94 | def _random_segmentation(num_items, num_segments, rng=None): 95 | """Partition a sequence of items randomly into non-empty segments. 96 | Args: 97 | num_items: an integer scalar > 0 98 | num_segments: an integer scalar in [1, num_items] 99 | rng = a np.random.default_rng() instance or None 100 | Returns: 101 | a list with shape [num_segments] containing positive integers that add up to num_items. 102 | forked from: https://github.com/allenai/contrastive_pretraining/blob/95fe35d3257402c7df362c3e0f746a40d9fba8f0/cpt/data.py#L265 103 | """ 104 | first_in_segment = np.arange(num_items - 1) < num_segments - 1 105 | rng = rng or np.random.default_rng() 106 | rng.shuffle(first_in_segment) 107 | # The first position always starts a segment. 108 | # first_in_segment is boolean array for every position after the first that signals whether this location is the start of a new segment. 109 | segment_id = np.cumsum(first_in_segment) 110 | segment_length = [0] * num_segments 111 | segment_length[0] = 1 # first the first missing first in segment 112 | for k in range(num_items - 1): 113 | segment_length[segment_id[k]] += 1 114 | return segment_length 115 | 116 | 117 | def insert_extra_paddings(rng, token_ids, pad_token_id, padding_ratio): 118 | """Inserts padding tokens with the ratio of `padding_ratio` into the token_ids.""" 119 | # TODO: we need to assert to have start/end of sequence tokens. 120 | # We do not add the padding in the start and end of sequence. 121 | length = len(token_ids) - 2 122 | num_padding_tokens = int(length * padding_ratio) 123 | if num_padding_tokens == 0: 124 | # In this case, the rate of padding tokens was not enough to add extra tokens. 125 | return token_ids 126 | length = length + num_padding_tokens 127 | # We do not modify the start token. 128 | all_ids = np.arange(1, length + 1) 129 | # This is without shuffling. 130 | # original_ids = np.arange(1, length+1) 131 | rng = rng or np.random.default_rng() 132 | rng.shuffle(all_ids) 133 | # padding tokens positions. 134 | padding_ids = np.array(all_ids)[:num_padding_tokens] + 1 135 | token_ids_extended = [] 136 | current_id = 0 137 | for i in range(length + 2): 138 | if i not in padding_ids: 139 | token_ids_extended.append(pad_token_id) 140 | else: 141 | token_ids_extended.append(token_ids[current_id]) 142 | current_id += 1 143 | return token_ids_extended 144 | """ 145 | # Other tokens positions, we do not change the start and end of sequence tokens. 146 | other_tokens_ids = [0]+[x for x in original_ids if x not in padding_ids]+[length+1] 147 | # Considers the start and end of sequence tokens in the final length. 148 | token_ids_extended = np.full((length+2), pad_token_id, dtype=int) 149 | token_ids_extended[other_tokens_ids] = token_ids 150 | return token_ids_extended.tolist() 151 | """ 152 | -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/schedulers/scheduling_simplex_ddpm.py: -------------------------------------------------------------------------------- 1 | """DDPM scheduler for the simplex diffusion model.""" 2 | 3 | from diffusers import DDPMScheduler 4 | from diffusers.schedulers.scheduling_ddpm import DDPMSchedulerOutput 5 | from dataclasses import dataclass 6 | from typing import Union, Tuple, Optional 7 | import torch 8 | import numpy as np 9 | from diffusers.configuration_utils import register_to_config 10 | from diffusers.utils import BaseOutput 11 | import math 12 | import pdb 13 | 14 | 15 | @dataclass 16 | class SimplexDDPMSchedulerOutput(BaseOutput): 17 | """ 18 | Output class for the scheduler's step function output. 19 | Args: 20 | prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): 21 | Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the 22 | denoising loop. 23 | projected_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, vocab_size)`): 24 | The projected logits sample (x_{0}) based on the model output from the current timestep. 25 | """ 26 | 27 | prev_sample: torch.FloatTensor 28 | projected_logits: Optional[torch.FloatTensor] = None 29 | 30 | 31 | def betas_for_alpha_bar(num_diffusion_timesteps, device, max_beta=0.999, improved_ddpm=False): 32 | """ 33 | Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of 34 | (1-beta) over time from t = [0,1]. 35 | Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up 36 | to that part of the diffusion process. 37 | Args: 38 | num_diffusion_timesteps (`int`): the number of betas to produce. 39 | max_beta (`float`): the maximum beta to use; use values lower than 1 to 40 | prevent singularities. 41 | Returns: 42 | betas (`np.ndarray`): the betas used by the scheduler to step the model outputs 43 | """ 44 | 45 | def default_alpha_bar(time_step): 46 | return math.cos((time_step + 1e-4) / (1 + 1e-4) * math.pi / 2) ** 2 47 | 48 | if improved_ddpm: 49 | # Implements eqn. 17 in https://arxiv.org/pdf/2102.09672.pdf. 50 | alpha_bar = lambda x: (default_alpha_bar(x) / default_alpha_bar(0.0)) 51 | alphas_cumprod = [] 52 | else: 53 | alpha_bar = default_alpha_bar 54 | betas = [] 55 | for i in range(num_diffusion_timesteps): 56 | t1 = i / num_diffusion_timesteps 57 | t2 = (i + 1) / num_diffusion_timesteps 58 | alpha_bar_t1 = alpha_bar(t1) 59 | betas.append(min(1 - alpha_bar(t2) / alpha_bar_t1, max_beta)) 60 | if improved_ddpm: 61 | alphas_cumprod.append(alpha_bar_t1) 62 | # TODO(rabeeh): maybe this cause memory issue. 63 | betas = torch.tensor(betas, dtype=torch.float32, device=device) 64 | if improved_ddpm: 65 | return betas, torch.tensor(alphas_cumprod, dtype=torch.torch.float32, device=device) 66 | return betas 67 | 68 | 69 | class SimplexDDPMScheduler(DDPMScheduler): 70 | @register_to_config 71 | def __init__( 72 | self, 73 | device, 74 | simplex_value: float, 75 | num_train_timesteps: int = 1000, 76 | num_inference_timesteps: int = 1000, 77 | beta_start: float = 0.0001, 78 | beta_end: float = 0.02, 79 | beta_schedule: str = "linear", 80 | trained_betas: Optional[np.ndarray] = None, 81 | variance_type: str = "fixed_small", 82 | clip_sample: bool = False, 83 | ): 84 | if trained_betas is not None: 85 | self.betas = torch.from_numpy(trained_betas) 86 | elif beta_schedule == "linear": 87 | self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32, device=device) 88 | elif beta_schedule == "scaled_linear": 89 | # this schedule is very specific to the latent diffusion model. 90 | self.betas = ( 91 | torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32, device=device) 92 | ** 2 93 | ) 94 | elif beta_schedule == "squaredcos_cap_v2": 95 | # Glide cosine schedule 96 | self.betas = betas_for_alpha_bar(num_train_timesteps, device=device) 97 | elif beta_schedule == "squaredcos_improved_ddpm": 98 | self.betas, self.alphas_cumprod = betas_for_alpha_bar(num_train_timesteps, device=device, improved_ddpm=True) 99 | elif beta_schedule == "sigmoid": 100 | # GeoDiff sigmoid schedule 101 | betas = torch.linspace(-6, 6, num_train_timesteps, device=device) 102 | self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start 103 | else: 104 | raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") 105 | 106 | if beta_schedule == "squaredcos_improved_ddpm": 107 | self.alphas = None 108 | else: 109 | self.alphas = 1.0 - self.betas 110 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) 111 | 112 | self.one = torch.tensor(1.0, device=device) 113 | 114 | # standard deviation of the initial noise distribution 115 | self.init_noise_sigma = 1.0 116 | 117 | # setable values 118 | self.num_inference_steps = None 119 | # TODO(rabeeh): if memory issue, we can not add this to GPU and convert them iteratively. 120 | self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()).to(device=device) 121 | 122 | self.variance_type = variance_type 123 | 124 | def step( 125 | self, 126 | projected_logits: torch.FloatTensor, 127 | timestep: int, 128 | noise: torch.FloatTensor, 129 | generator=None, 130 | ) -> Union[DDPMSchedulerOutput, Tuple]: 131 | """ 132 | Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion 133 | process from the learned model outputs (most often the predicted noise). 134 | Args: 135 | projected_logits (`torch.FloatTensor`): projected logits from the diffusion model. 136 | timestep (`int`): current discrete timestep in the diffusion chain. 137 | noise (`torch.FloatTensor`): a random noise with simplex_value standard deviation. 138 | generator: random number generator. 139 | Returns: 140 | [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] resulted values. 141 | """ 142 | t = timestep 143 | 144 | # 1. compute alphas, betas 145 | alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one 146 | 147 | # 3. Clip "predicted x_0" 148 | if self.config.clip_sample: 149 | projected_logits = torch.clamp(projected_logits, -1, 1) 150 | 151 | # See algorithm 2 in Figure 3 in https://arxiv.org/pdf/2210.17432.pdf. 152 | predicted_logits_coeff = alpha_prod_t_prev ** (0.5) 153 | noise_coeff = (1 - alpha_prod_t_prev) ** (0.5) 154 | pred_prev_sample = predicted_logits_coeff * projected_logits + noise_coeff * noise 155 | 156 | return SimplexDDPMSchedulerOutput(prev_sample=pred_prev_sample, projected_logits=projected_logits) 157 | 158 | def add_noise( 159 | self, 160 | original_samples: torch.FloatTensor, 161 | noise: torch.FloatTensor, 162 | timesteps: torch.IntTensor, 163 | ) -> torch.FloatTensor: 164 | # Make sure alphas_cumprod and timestep have same device and dtype as original_samples 165 | # self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) 166 | # timesteps = timesteps.to(original_samples.device) 167 | 168 | alphas_cumprod_timesteps = self.alphas_cumprod[timesteps].view(-1, 1, 1) 169 | sqrt_alpha_prod = alphas_cumprod_timesteps**0.5 170 | sqrt_one_minus_alpha_prod = (1 - alphas_cumprod_timesteps) ** 0.5 171 | noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise 172 | return noisy_samples 173 | -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/models/h3/ssm/ss_kernel.py: -------------------------------------------------------------------------------- 1 | # TD: [2023-01-05]: Extracted the SSKernel class from 2 | # https://github.com/HazyResearch/state-spaces/blob/06dbbdfd0876501a7f12bf3262121badbc7658af/src/models/sequence/ss/kernel.py 3 | # We add option to use the shift kernel, and remove the option of SSKernelNPLR 4 | 5 | """SSM convolution kernels. 6 | SSKernel wraps different kernels with common options and handles the initialization. 7 | """ 8 | 9 | import math 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from einops import rearrange, repeat 15 | from opt_einsum import contract 16 | 17 | from ..ops.krylov import power 18 | from ..utils.utils import get_logger 19 | from .dplr import combination 20 | from .ss_kernel_diag import EMAKernel, SSKernelDiag 21 | from .ss_kernel_shift import SSKernelShift 22 | 23 | log = get_logger(__name__) 24 | 25 | 26 | _conj = lambda x: torch.cat([x, x.conj()], dim=-1) 27 | 28 | 29 | class SSKernel(nn.Module): 30 | """Wrapper around SSKernel parameterizations. 31 | 32 | The SSKernel is expected to support the interface 33 | forward() 34 | default_state() 35 | _setup_step() 36 | step() 37 | """ 38 | 39 | def __init__( 40 | self, 41 | H, 42 | N=64, 43 | L=None, 44 | measure="diag-lin", 45 | rank=1, 46 | channels=1, 47 | dt_min=0.001, 48 | dt_max=0.1, 49 | deterministic=False, 50 | lr=None, 51 | mode="diag", 52 | n_ssm=None, 53 | verbose=False, 54 | measure_args={}, 55 | **kernel_args, 56 | ): 57 | """State Space Kernel which computes the convolution kernel $\\bar{K}$ 58 | 59 | H: Number of independent SSM copies; controls the size of the model. Also called d_model in the config. 60 | N: State size (dimensionality of parameters A, B, C). Also called d_state in the config. Generally shouldn't need to be adjusted and doens't affect speed much. 61 | L: Maximum length of convolution kernel, if known. Should work in the majority of cases even if not known. 62 | measure: Options for initialization of (A, B). For NPLR mode, recommendations are "legs", "fout", "hippo" (combination of both). For Diag mode, recommendations are "diag-inv", "diag-lin", "diag-legs", and "diag" (combination of diag-inv and diag-lin) 63 | rank: Rank of low-rank correction for NPLR mode. Needs to be increased for measure "legt" 64 | channels: C channels turns the SSM from a 1-dim to C-dim map; can think of it having C separate "heads" per SSM. This was partly a feature to make it easier to implement bidirectionality; it is recommended to set channels=1 and adjust H to control parameters instead 65 | dt_min, dt_max: min and max values for the step size dt (\Delta) 66 | mode: Which kernel algorithm to use. 'nplr' is the full S4 model; 'diag' is the simpler S4D; 'slow' is a dense version for testing 67 | n_ssm: Number of independent trainable (A, B) SSMs, e.g. n_ssm=1 means all A/B parameters are tied across the H different instantiations of C. n_ssm=None means all H SSMs are completely independent. Generally, changing this option can save parameters but doesn't affect performance or speed much. This parameter must divide H 68 | lr: Passing in a number (e.g. 0.001) sets attributes of SSM parameers (A, B, dt). A custom optimizer hook is needed to configure the optimizer to set the learning rates appropriately for these parameters. 69 | """ 70 | super().__init__() 71 | self.N = N 72 | self.H = H 73 | dtype, cdtype = torch.float, torch.cfloat 74 | self.channels = channels 75 | self.n_ssm = n_ssm if n_ssm is not None else H 76 | self.mode = mode 77 | self.verbose = verbose 78 | self.kernel_args = kernel_args 79 | 80 | # Generate dt 81 | if deterministic: 82 | log_dt = torch.exp(torch.linspace(math.log(dt_min), math.log(dt_max), H)) 83 | else: 84 | log_dt = torch.rand(self.H, dtype=dtype) * ( 85 | math.log(dt_max) - math.log(dt_min) 86 | ) + math.log(dt_min) 87 | 88 | # Compute the preprocessed representation 89 | if mode == "ema": 90 | self.kernel = EMAKernel(H, N=N, channels=channels, **kernel_args) 91 | else: 92 | w, P, B, V = combination(measure, self.N, rank, self.n_ssm, **measure_args) 93 | 94 | # Broadcast C to have H channels 95 | if deterministic: 96 | C = torch.zeros(channels, self.n_ssm, self.N, dtype=cdtype) 97 | C[:, :, :1] = 1.0 98 | C = contract("hmn, chn -> chm", V.conj().transpose(-1, -2), C) # V^* C 99 | C = ( 100 | repeat(C, "c t n -> c (v t) n", v=self.n_ssm // C.size(-2)) 101 | .clone() 102 | .contiguous() 103 | ) 104 | else: 105 | C = torch.randn(channels, self.H, self.N // 2, dtype=cdtype) 106 | 107 | # Broadcast other parameters to have n_ssm copies 108 | assert ( 109 | self.n_ssm % B.size(-2) == 0 110 | and self.n_ssm % P.size(-2) == 0 111 | and self.n_ssm % w.size(-2) == 0 112 | ) 113 | # Broadcast tensors to n_ssm copies 114 | # These will be the parameters, so make sure tensors are materialized and contiguous 115 | B = ( 116 | repeat(B, "t n -> (v t) n", v=self.n_ssm // B.size(-2)) 117 | .clone() 118 | .contiguous() 119 | ) 120 | P = ( 121 | repeat(P, "r t n -> r (v t) n", v=self.n_ssm // P.size(-2)) 122 | .clone() 123 | .contiguous() 124 | ) 125 | w = ( 126 | repeat(w, "t n -> (v t) n", v=self.n_ssm // w.size(-2)) 127 | .clone() 128 | .contiguous() 129 | ) 130 | 131 | if mode == "diag": 132 | if not measure.startswith("diag"): 133 | log.warning( 134 | "Diagonal kernel (S4D) activated but initialization is not intended for S4D. Set `measure` to 'diag-lin', 'diag-inv', or 'diag-legs' for the main variants, or 'diag' for a combination of S4D-Lin and S4D-Inv." 135 | ) 136 | C = C * repeat(B, "t n -> (v t) n", v=H // self.n_ssm) 137 | self.kernel = SSKernelDiag( 138 | w, 139 | B, 140 | C, 141 | log_dt, 142 | L=L, 143 | lr=lr, 144 | **kernel_args, 145 | ) 146 | elif mode == "shift": 147 | # Initializing B to be e_1 148 | B = torch.zeros(self.H, self.N) 149 | B[..., 0] = 1.0 150 | # Match torch.Conv1d init 151 | C = torch.randn(self.H, self.channels, self.N) 152 | nn.init.kaiming_uniform_(C, a=math.sqrt(5)) 153 | C = rearrange(C, "h c n -> c h n") 154 | self.kernel = SSKernelShift(B, C, L=L, lr=lr, **kernel_args) 155 | else: 156 | raise NotImplementedError(f"{mode=} is not valid") 157 | 158 | def forward(self, state=None, L=None, rate=None): 159 | return self.kernel(state=state, L=L, rate=rate) 160 | 161 | @torch.no_grad() 162 | def forward_state(self, u, state): 163 | """Forward the state through a sequence, i.e. computes the state after passing chunk through SSM 164 | 165 | state: (B, H, N) 166 | u: (B, H, L) 167 | 168 | Returns: (B, H, N) 169 | """ 170 | 171 | if hasattr(self.kernel, "forward_state"): 172 | return self.kernel.forward_state(u, state) 173 | 174 | dA, dB = self.kernel._setup_state() # Construct dA, dB matrices 175 | # dA, dB = self.kernel.dA, self.kernel.dB # (H N N) (H N) 176 | 177 | conj = state.size(-1) != dA.size(-1) 178 | if conj: 179 | state = _conj(state) 180 | 181 | v = contract( 182 | "h n, b h l -> b h n l", dB, u.flip(-1) 183 | ) # dB.unsqueeze(-1) * u.flip(-1).unsqueeze(-2) 184 | AL, v = power(u.size(-1), dA, v) 185 | next_state = contract("h m n, b h n -> b h m", AL, state) 186 | next_state = next_state + v 187 | 188 | if conj: 189 | next_state = next_state[..., : next_state.size(-1) // 2] 190 | return next_state 191 | 192 | def _setup_step(self, **kwargs): 193 | # This method is intended to be private so that setting up an S4 module with 194 | # ``` 195 | # if hasattr(module, 'setup_step'): module.setup_step() 196 | # ``` 197 | # will not trigger this method multiple times 198 | self.kernel._setup_step(**kwargs) 199 | 200 | def step(self, u, state, **kwargs): 201 | y, state = self.kernel.step(u, state, **kwargs) 202 | return y, state 203 | 204 | def default_state(self, *args, **kwargs): 205 | return self.kernel.default_state(*args, **kwargs) 206 | -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/inference/inference_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | import pdb 5 | from utils import convert_to_simplex, join_texts 6 | from metrics.perplexity import perplexity, conditional_perplexity 7 | from metrics.metrics import distinct_n_grams, mauve, zipf 8 | from metrics.repetition import repetition 9 | 10 | 11 | def sample_logits(sampling_type, logits, top_p, temperature): 12 | # top-p (nucleus) sampling. 13 | if sampling_type == "top_p": 14 | logits = logits / temperature 15 | probs = F.softmax(logits, dim=-1) 16 | if top_p is not None: 17 | sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True) 18 | cumsum_probs = torch.cumsum(sorted_probs, dim=-1) 19 | 20 | # Remove tokens with cumulative probability above the threshold. 21 | sorted_indices_to_keep = cumsum_probs < top_p 22 | 23 | # Shift the indices to the right to keep also the first token below the threshold. 24 | sorted_indices_to_keep[..., 1:] = sorted_indices_to_keep[..., :-1].clone() 25 | sorted_indices_to_keep[..., 0] = 1 26 | 27 | indices_to_keep = sorted_indices_to_keep.scatter(dim=2, index=sorted_indices, src=sorted_indices_to_keep) 28 | filtered_logits = logits.masked_fill(indices_to_keep == 0, -float("Inf")) 29 | 30 | # sample from the filtered distribution. 31 | token_ids = torch.distributions.categorical.Categorical(logits=filtered_logits).sample() 32 | else: 33 | token_ids=torch.argmax(probs, dim=-1) 34 | else: 35 | assert NotImplementedError 36 | return token_ids 37 | 38 | 39 | def remove_first_occurrence(string, char): 40 | # We do not strip as we need the spaces as well. 41 | if char in string: 42 | idx = string.index(char) 43 | string = string[idx + len(char) :] 44 | return string 45 | 46 | 47 | def keep_till_first_occurrence(string, chars): 48 | """Given a list of characters, trim the text after the first occurance between them.""" 49 | idxs = [string.index(char) for char in chars if char in string] 50 | if len(idxs): 51 | min_idx = np.min(idxs) 52 | string = string[:min_idx] 53 | return string 54 | 55 | 56 | def process_text(texts): 57 | # TODO(rabeeh): for now we only cover roberta case. 58 | texts = [keep_till_first_occurrence(text, [""]) for text in texts] 59 | texts = [remove_first_occurrence(text, "") for text in texts] 60 | return texts 61 | 62 | 63 | def split_into_masked_and_unmasked(token_ids, span_mask, return_masked=None): 64 | """Given an span_mask, splits the given token_ids into masked and unmasked parts. 65 | 66 | If return_masked is set, only returns the masked parts, if this is set to False, 67 | only returns the unmasked parts, and If set to None, returns both parts. 68 | """ 69 | 70 | def update_spans(span, masked, unmasked, mask): 71 | # TODO: this needs to be here for previous version of the codes. 72 | # span = torch.stack(span) 73 | masked.append(span) if mask else unmasked.append(span) 74 | 75 | masked = [] 76 | unmasked = [] 77 | prev_mask = span_mask[0] 78 | span = [] 79 | for _, (token_id, mask) in enumerate(zip(token_ids, span_mask)): 80 | if mask == prev_mask: 81 | span.append(token_id) 82 | else: 83 | # Adds the previous span. 84 | update_spans(span, masked, unmasked, prev_mask) 85 | prev_mask = mask 86 | span = [token_id] 87 | # Adds the last span. 88 | update_spans(span, masked, unmasked, prev_mask) 89 | 90 | if return_masked is None: 91 | return masked, unmasked 92 | 93 | return masked if return_masked else unmasked 94 | 95 | 96 | def concatenate_alternatively(longer, shorter, mark=""): 97 | """Given two lists of strings, concatenates them alternatively. 98 | 99 | We assume that the concatenated string should starts from elements in the longer 100 | list (which has one extra element). The shorter text can optionally be embraced with 101 | a `mark` text on both sides. 102 | """ 103 | concatenated_str = "" 104 | for l, s in zip(longer, shorter): 105 | concatenated_str += l + " " + mark + s + mark + " " 106 | if len(longer) == len(shorter) + 1: 107 | return concatenated_str + longer[-1] 108 | elif len(longer) == len(shorter): 109 | return concatenated_str[:-1] 110 | else: 111 | raise ValueError 112 | 113 | 114 | def aggregate_list(x): 115 | str = "" 116 | if len(x) == 0: 117 | return str 118 | for l in x: 119 | str += l + " " 120 | return str[:-1] 121 | 122 | 123 | def logits_projection(logits, sampling_type, top_p, simplex_value, temperature): 124 | # TODO(rabeeh): huggingface has different sampling, like constrastive one. 125 | # also there are more variant in diffusion-lm. 126 | token_ids = sample_logits(sampling_type, logits, top_p, temperature) 127 | return convert_to_simplex(token_ids, simplex_value, vocab_size=logits.shape[2]) 128 | 129 | 130 | def filter_empty(texts): 131 | """Filters empty texts and return the remained texts and the their indices.""" 132 | list_of_tuples = [(text, i) for i, text in enumerate(texts) if text != ""] 133 | if len(list_of_tuples) == 0: 134 | return [], [] 135 | non_empty_texts, remained_inds = list(zip(*list_of_tuples)) 136 | return list(non_empty_texts), list(remained_inds) 137 | 138 | 139 | def predict_conditional_generated(span_masks, input_ids, tokenizer, predicted_token_ids, prefix_name, skip_special_tokens): 140 | masked = list( 141 | map(lambda x, y: split_into_masked_and_unmasked(x, y, return_masked=True), predicted_token_ids, span_masks) 142 | ) 143 | unmasked = list(map(lambda x, y: split_into_masked_and_unmasked(x, y, return_masked=False), input_ids, span_masks)) 144 | pred_masked_texts = [tokenizer.batch_decode(x, skip_special_tokens=skip_special_tokens) for x in masked] 145 | pred_unmasked_texts = [tokenizer.batch_decode(x, skip_special_tokens=skip_special_tokens) for x in unmasked] 146 | pred_texts = list(map(lambda x, y: concatenate_alternatively(x, y), pred_unmasked_texts, pred_masked_texts)) 147 | pred_texts_marked = list( 148 | map(lambda x, y: concatenate_alternatively(x, y, mark="***"), pred_unmasked_texts, pred_masked_texts) 149 | ) 150 | aggregated_masked_texts = list(map(lambda x: aggregate_list(x), pred_masked_texts)) 151 | predicted_tokens = [np.array(item).tolist() for submasked in masked for item in submasked] 152 | return { 153 | # prefix_name: pred_texts, 154 | prefix_name + "_marked": pred_texts_marked, 155 | prefix_name + "_masked": aggregated_masked_texts, 156 | prefix_name + "_masked_tokens": predicted_tokens, 157 | } 158 | 159 | 160 | def evaluate_generation( 161 | results, 162 | data_args, 163 | causal_model, 164 | causal_tokenizer, 165 | is_conditional_generation, 166 | prefix_lm_eval=False, 167 | skip_special_tokens=True, 168 | eval_for_all_metrics=False, 169 | ): 170 | metrics = {} 171 | # In case of prefix_lm since the generated text is unified, we can evaluate only the masked parts. 172 | if prefix_lm_eval: 173 | gold_text_key = "gold_texts_masked" 174 | # In case of gpt2, we only have the key of `generated_texts_masked`. 175 | keys = ( 176 | ["generated_texts_masked"] 177 | if "generated_texts_masked" in results 178 | else ["pred_texts_from_simplex_masked", "pred_texts_from_logits_masked"] 179 | ) 180 | else: 181 | keys = ["pred_texts_from_simplex", "pred_texts_from_logits"] 182 | gold_text_key = "gold_texts" 183 | 184 | if is_conditional_generation: 185 | gold_texts = results[gold_text_key] 186 | if not skip_special_tokens: 187 | gold_texts = process_text(gold_texts) 188 | if "prefixes" in results: 189 | prefixes = results["prefixes"] 190 | else: 191 | prefixes = None 192 | 193 | for key in keys: 194 | key_metrics = {} 195 | texts = results[key] 196 | if not skip_special_tokens: 197 | texts = process_text(texts) 198 | 199 | non_empty_texts, remained_indices = filter_empty(texts) 200 | if len(non_empty_texts) == 0: 201 | continue 202 | 203 | # Perplexity measured by a causal model. 204 | if prefixes is None: 205 | key_metrics.update({"perplexity": perplexity(non_empty_texts, causal_model, causal_tokenizer)["mean_perplexity"]}) 206 | else: 207 | non_empty_prefixes = [prefix for i, prefix in enumerate(prefixes) if i in remained_indices ] 208 | perplexity_results = conditional_perplexity(non_empty_texts, non_empty_prefixes, causal_model, causal_tokenizer) 209 | key_metrics.update({"perplexity": perplexity_results["mean_perplexity"], "total_perplexity":perplexity_results["mean_perplexity_total"]}) 210 | 211 | # Dist-1,2,3 measurements. 212 | key_metrics.update(distinct_n_grams(texts)) 213 | 214 | # Metrics requiring the gold text. 215 | if is_conditional_generation and eval_for_all_metrics: 216 | # Note that we need to pass both context and predicted texts to this metric. 217 | # remained_gold_texts = [text for i, text in enumerate(gold_texts) if i in remained_indices] 218 | # remained_prefixes = [text for i, text in enumerate(prefixes) if i in remained_indices] 219 | texts_with_context = join_texts(prefixes, texts) 220 | gold_with_context = join_texts(prefixes, gold_texts) 221 | length = data_args.max_seq_length - data_args.truncation_length 222 | key_metrics.update(mauve(predictions=texts_with_context, references=gold_with_context, length=length)) 223 | 224 | if key + "_tokens" in results and eval_for_all_metrics: 225 | key_metrics.update(repetition(results[key + "_tokens"], causal_tokenizer)) 226 | key_metrics.update(zipf(results[key + "_tokens"])) 227 | 228 | # Adds the metrics. 229 | key_metrics = {f"{key}_{k}": v for k, v in key_metrics.items()} 230 | metrics.update(key_metrics) 231 | 232 | return metrics 233 | -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/run_mlm.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | import datasets 6 | import transformers 7 | from datasets import load_from_disk, Dataset 8 | from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed 9 | from transformers.trainer_callback import TrainerState 10 | from transformers.trainer_utils import get_last_checkpoint 11 | from transformers.utils import check_min_version 12 | from transformers.utils.versions import require_version 13 | 14 | from arguments import get_args 15 | from data.data_collator import SpanInfillingDataCollator 16 | from data.data_utils import load_data, tokenize_data_new 17 | from inference.inference_utils import evaluate_generation 18 | from models import load_model 19 | from schedulers import SimplexDDPMScheduler 20 | from trainer import DiffusionTrainer 21 | 22 | 23 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 24 | check_min_version("4.25.0") 25 | 26 | require_version( 27 | "datasets>=2.0.0", 28 | "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt", 29 | ) 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | 34 | def get_compute_metrics(data_args, training_args, model_args): 35 | # Causal language model. 36 | causal_model = AutoModelForCausalLM.from_pretrained( 37 | model_args.autoregressive_eval_model 38 | ) 39 | causal_model = causal_model.to(training_args.device) 40 | causal_tokenizer = AutoTokenizer.from_pretrained( 41 | model_args.autoregressive_eval_model 42 | ) 43 | 44 | is_conditional_generation = data_args.conditional_generation is not None 45 | prefix_lm_eval = data_args.conditional_generation in [ 46 | "prefix_lm", 47 | "ul2", 48 | "ul2_with_unconditional", 49 | "ul2_variable", 50 | ] 51 | compute_metrics = lambda results: evaluate_generation( # noqa: E731 52 | results, 53 | data_args, 54 | causal_model, 55 | causal_tokenizer, 56 | is_conditional_generation, 57 | prefix_lm_eval=prefix_lm_eval, 58 | skip_special_tokens=data_args.skip_special_tokens, 59 | eval_for_all_metrics=training_args.eval_for_all_metrics, 60 | ) 61 | return compute_metrics 62 | 63 | 64 | def main(): 65 | # parse args 66 | model_args, data_args, training_args, diffusion_args = get_args() 67 | 68 | # Setup logging 69 | logging.basicConfig( 70 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 71 | datefmt="%m/%d/%Y %H:%M:%S", 72 | handlers=[logging.StreamHandler(sys.stdout)], 73 | ) 74 | 75 | log_level = training_args.get_process_log_level() 76 | logger.setLevel(log_level) 77 | datasets.utils.logging.set_verbosity(log_level) 78 | transformers.utils.logging.set_verbosity(log_level) 79 | transformers.utils.logging.enable_default_handler() 80 | transformers.utils.logging.enable_explicit_format() 81 | 82 | # Log on each process the small summary: 83 | logger.warning( 84 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 85 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 86 | ) 87 | # Set the verbosity to info of the Transformers logger (on main process only): 88 | logger.info(f"Training/evaluation parameters {training_args}") 89 | 90 | # Detecting last checkpoint. 91 | last_checkpoint = None 92 | if ( 93 | os.path.isdir(training_args.output_dir) 94 | and training_args.do_train 95 | and not training_args.overwrite_output_dir 96 | ): 97 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 98 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 99 | raise ValueError( 100 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 101 | "Use --overwrite_output_dir to overcome." 102 | ) 103 | elif ( 104 | last_checkpoint is not None and training_args.resume_from_checkpoint is None 105 | ): 106 | logger.info( 107 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 108 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 109 | ) 110 | 111 | # Set seed before initializing model. 112 | set_seed(training_args.seed) 113 | 114 | # load model 115 | tokenizer, model = load_model(model_args, diffusion_args, logger) 116 | 117 | # init schedulers 118 | noise_scheduler = SimplexDDPMScheduler( 119 | num_train_timesteps=diffusion_args.num_diffusion_steps, 120 | beta_schedule=diffusion_args.beta_schedule, 121 | simplex_value=diffusion_args.simplex_value, 122 | clip_sample=diffusion_args.clip_sample, 123 | device=training_args.device, 124 | ) 125 | inference_noise_scheduler = SimplexDDPMScheduler( 126 | num_train_timesteps=diffusion_args.num_inference_diffusion_steps, 127 | beta_schedule=diffusion_args.beta_schedule, 128 | simplex_value=diffusion_args.simplex_value, 129 | clip_sample=diffusion_args.clip_sample, 130 | device=training_args.device, 131 | ) 132 | 133 | if data_args.tokenized_data_path: 134 | tokenized_datasets = load_from_disk(data_args.tokenized_data_path) 135 | else: 136 | raw_datasets = load_data(data_args, model_args) 137 | tokenized_datasets = tokenize_data_new( 138 | data_args, tokenizer, raw_datasets, training_args 139 | ) 140 | 141 | if training_args.do_train: 142 | if "train" not in tokenized_datasets: 143 | raise ValueError("--do_train requires a train dataset") 144 | train_dataset = tokenized_datasets["train"] 145 | if data_args.max_train_samples is not None: 146 | max_train_samples = min(len(train_dataset), data_args.max_train_samples) 147 | train_dataset = train_dataset.select(range(max_train_samples)) 148 | 149 | if training_args.do_eval: 150 | if "validation" not in tokenized_datasets: 151 | raise ValueError("--do_eval requires a validation dataset") 152 | eval_dataset = tokenized_datasets["validation"] 153 | # convert eval dataset to regular dataset 154 | if isinstance(eval_dataset, datasets.IterableDataset): 155 | def iterable_generator(): 156 | for x in eval_dataset: 157 | yield x 158 | eval_dataset = Dataset.from_generator(iterable_generator) 159 | if data_args.max_eval_samples is not None: 160 | max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) 161 | eval_dataset = eval_dataset.select(range(max_eval_samples)) 162 | 163 | def preprocess_logits_for_metrics(logits): 164 | return logits.argmax(dim=-1) 165 | 166 | # Data collator 167 | # TODO: fix lambda max_seq_length, extra_padding_ratio: 168 | pad_to_multiple_of_8 = ( 169 | data_args.line_by_line 170 | and training_args.fp16 171 | and not data_args.pad_to_max_length 172 | ) 173 | data_collator = lambda mode: SpanInfillingDataCollator( # noqa: E731 174 | mode=mode, 175 | data_args=data_args, 176 | tokenizer=tokenizer, 177 | max_length=data_args.max_seq_length, 178 | seed=training_args.seed, 179 | pad_to_multiple_of=8 if pad_to_multiple_of_8 else None, 180 | eval_context_size=data_args.eval_context_size, 181 | ) 182 | 183 | if training_args.do_eval: 184 | compute_metrics = get_compute_metrics(data_args, training_args, model_args) 185 | 186 | if data_args.shuffle: 187 | train_dataset = train_dataset.shuffle(seed=training_args.seed) 188 | 189 | # Initialize our Trainer 190 | trainer = DiffusionTrainer( 191 | model=model, 192 | args=training_args, 193 | train_dataset=train_dataset if training_args.do_train else None, 194 | eval_dataset=eval_dataset if training_args.do_eval else None, 195 | tokenizer=tokenizer, 196 | data_collator=data_collator, 197 | compute_metrics=compute_metrics 198 | if training_args.do_eval and not training_args.without_compute_metrics 199 | else None, 200 | preprocess_logits_for_metrics=preprocess_logits_for_metrics 201 | if training_args.do_eval 202 | else None, 203 | noise_scheduler=noise_scheduler, 204 | diffusion_args=diffusion_args, 205 | data_args=data_args, 206 | inference_noise_scheduler=inference_noise_scheduler, 207 | ) 208 | 209 | # Training 210 | if training_args.do_train: 211 | checkpoint = None 212 | if training_args.resume_from_checkpoint is not None: 213 | checkpoint = training_args.resume_from_checkpoint 214 | elif last_checkpoint is not None: 215 | checkpoint = last_checkpoint 216 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 217 | trainer.save_model() # Saves the tokenizer too for easy upload 218 | metrics = train_result.metrics 219 | 220 | trainer.log_metrics("train", metrics) 221 | trainer.save_metrics("train", metrics) 222 | trainer.save_state() 223 | 224 | # Evaluation 225 | #model_args.model_name_or_path = 226 | if training_args.do_eval: 227 | if training_args.load_states_in_eval_from_model_path: 228 | trainer._load_from_checkpoint(model_args.model_name_or_path) 229 | trainer.state = TrainerState.load_from_json( 230 | os.path.join(model_args.model_name_or_path, "trainer_state.json") 231 | ) 232 | trainer._load_rng_state(model_args.model_name_or_path) 233 | 234 | # np.save("weights.npy", model.vocab_to_hidden_dim_embed.weight.data.numpy()) 235 | 236 | logger.info("*** Evaluate ***") 237 | metrics = trainer.evaluate() 238 | max_eval_samples = ( 239 | data_args.max_eval_samples 240 | if data_args.max_eval_samples is not None 241 | else len(eval_dataset) 242 | ) 243 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 244 | trainer.log_metrics("eval", metrics) 245 | trainer.save_metrics("eval", metrics) 246 | 247 | 248 | if __name__ == "__main__": 249 | main() 250 | -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/models/xlm_roberta/modeling_xlm_roberta.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn import CrossEntropyLoss 7 | from transformers.modeling_outputs import MaskedLMOutput 8 | from transformers.models.xlm_roberta.modeling_xlm_roberta import ( 9 | XLMRobertaLMHead, 10 | XLMRobertaModel, 11 | XLMRobertaPreTrainedModel, 12 | ) 13 | from transformers.utils import logging 14 | 15 | logger = logging.get_logger(__name__) 16 | 17 | 18 | class XLMRobertaForDiffusionLM(XLMRobertaPreTrainedModel): 19 | _keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"] 20 | _keys_to_ignore_on_load_missing = [ 21 | r"position_ids", 22 | r"lm_head.decoder.weight", 23 | r"lm_head.decoder.bias", 24 | ] 25 | _keys_to_ignore_on_load_unexpected = [r"pooler"] 26 | 27 | def __init__(self, config): 28 | super().__init__(config) 29 | 30 | if config.is_decoder: 31 | logger.warning( 32 | "If you want to use `XLMRobertaForMaskedLM` make sure `config.is_decoder=False` for " 33 | "bi-directional self-attention." 34 | ) 35 | 36 | self.roberta = XLMRobertaModel(config, add_pooling_layer=False) 37 | self.lm_head = XLMRobertaLMHead(config) 38 | 39 | # The LM head weights require special treatment only when they are tied with the word embeddings 40 | self.update_keys_to_ignore(config, ["lm_head.decoder.weight"]) 41 | 42 | self.vocab_to_hidden_dim_embed = nn.Linear( 43 | config.vocab_size, config.hidden_size, bias=False 44 | ) 45 | self.timestep_embed = nn.Linear(1, config.hidden_size, bias=True) 46 | 47 | if self.config.self_condition is not None and self.config.deepmind_conditional: 48 | # In this case, this is self-conditioning with conditional generation as done in DeepMind paper. 49 | # See Figure 3 in https://arxiv.org/pdf/2211.15089.pdf. 50 | # Here we concat masked word embeddings, noisy embeddings, mask, and self-conditioning inputs 51 | # and project them to the hidden_size. 52 | self.project_to_hidden_size = nn.Linear( 53 | config.hidden_size * 4, config.hidden_size, bias=False 54 | ) 55 | elif ( 56 | self.config.self_condition is not None 57 | and not self.config.self_condition # noqa: E713 58 | in [ 59 | "logits_addition", 60 | "logits_with_projection_addition", 61 | ] 62 | ): 63 | self.project_to_hidden_size = nn.Linear( 64 | config.hidden_size * 2, config.hidden_size, bias=False 65 | ) 66 | 67 | # Initialize weights and apply final processing 68 | self.post_init() 69 | 70 | def post_init(self): 71 | super().post_init() 72 | self.vocab_to_hidden_dim_embed.weight.data = ( 73 | self.get_input_embeddings().weight.data.T 74 | ) 75 | 76 | def get_output_embeddings(self): 77 | return self.lm_head.decoder 78 | 79 | def set_output_embeddings(self, new_embeddings): 80 | self.lm_head.decoder = new_embeddings 81 | 82 | def forward( 83 | self, 84 | timesteps: torch.FloatTensor, 85 | simplex: torch.FloatTensor, 86 | input_ids: Optional[torch.LongTensor] = None, 87 | span_mask: Optional[torch.FloatTensor] = None, 88 | attention_mask: Optional[torch.FloatTensor] = None, 89 | token_type_ids: Optional[torch.LongTensor] = None, 90 | position_ids: Optional[torch.LongTensor] = None, 91 | head_mask: Optional[torch.FloatTensor] = None, 92 | inputs_embeds: Optional[torch.FloatTensor] = None, 93 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 94 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 95 | labels: Optional[torch.LongTensor] = None, 96 | output_attentions: Optional[bool] = None, 97 | output_hidden_states: Optional[bool] = None, 98 | return_dict: Optional[bool] = None, 99 | previous_pred: Optional[torch.FloatTensor] = None, 100 | classifier_free_guidance: bool = False, 101 | unconditional_simplex: torch.FloatTensor = None, 102 | ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: 103 | r""" 104 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 105 | Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., 106 | config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the 107 | loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` 108 | kwargs (`Dict[str, any]`, optional, defaults to *{}*): 109 | Used to hide legacy arguments that have been deprecated. 110 | """ 111 | return_dict = ( 112 | return_dict if return_dict is not None else self.config.use_return_dict 113 | ) 114 | inputs_probs = F.softmax(simplex, dim=-1) 115 | seq_length = inputs_probs.shape[1] 116 | inputs_embeds = self.vocab_to_hidden_dim_embed(inputs_probs) 117 | 118 | if classifier_free_guidance: 119 | unconditional_probs = F.softmax(unconditional_simplex, dim=-1) 120 | uncond_inputs_embeds = self.vocab_to_hidden_dim_embed(unconditional_probs) 121 | 122 | if self.config.self_condition is not None: 123 | if self.config.self_condition_zeros_after_softmax and previous_pred is None: 124 | previous_pred_probs = torch.zeros_like(simplex, device=simplex.device) 125 | else: 126 | if previous_pred is None: 127 | previous_pred = torch.zeros_like(simplex, device=simplex.device) 128 | previous_pred_probs = F.softmax(previous_pred, dim=-1) 129 | previous_pred = self.vocab_to_hidden_dim_embed(previous_pred_probs) 130 | if not self.config.deepmind_conditional: 131 | if self.config.self_condition in [ 132 | "logits_with_projection_addition", 133 | "logits_addition", 134 | ]: 135 | inputs_embeds = inputs_embeds + previous_pred 136 | elif self.config.self_condition in ["logits", "logits_with_projection"]: 137 | inputs_embeds = self.project_to_hidden_size( 138 | torch.cat([inputs_embeds, previous_pred], axis=-1) 139 | ) 140 | else: 141 | raise NotImplementedError 142 | 143 | if span_mask is not None: 144 | # Original word embeddings without noise. 145 | inputs_word_embeds = self.get_input_embeddings()(input_ids) 146 | 147 | if self.config.self_condition is not None and self.config.deepmind_conditional: 148 | inputs_embeds = torch.where( 149 | span_mask.unsqueeze(-1), inputs_embeds, torch.zeros_like(previous_pred) 150 | ) 151 | previous_pred = torch.where( 152 | span_mask.unsqueeze(-1), previous_pred, torch.zeros_like(previous_pred) 153 | ) 154 | inputs_word_embeds = torch.where( 155 | span_mask.unsqueeze(-1), 156 | torch.zeros_like(inputs_word_embeds), 157 | inputs_word_embeds, 158 | ) 159 | tiled_mask = span_mask.unsqueeze(-1).repeat(1, 1, self.config.hidden_size) 160 | inputs_embeds = self.project_to_hidden_size( 161 | torch.cat( 162 | [inputs_embeds, inputs_word_embeds, previous_pred, tiled_mask], 163 | axis=-1, 164 | ) 165 | ) 166 | 167 | # TODO: remove conversion. 168 | timesteps_embed = self.timestep_embed(timesteps.view(-1, 1).float()) 169 | inputs_embeds = inputs_embeds + timesteps_embed.unsqueeze(1).repeat( 170 | 1, seq_length, 1 171 | ) 172 | 173 | if span_mask is not None and not self.config.deepmind_conditional: 174 | # For the unmasked tokens, we only compute their original word embeddings. 175 | # Note that this also sets the self-conditioned inputs wich we are conditioning on 176 | # to their original word embeddings values. 177 | inputs_embeds = torch.where( 178 | span_mask.unsqueeze(-1), inputs_embeds, inputs_word_embeds 179 | ) 180 | # TODO: we need to fix classifier-free guidance for the case of deepmind_conditional. 181 | if classifier_free_guidance: 182 | inputs_embeds = torch.cat([uncond_inputs_embeds, inputs_embeds]) 183 | 184 | outputs = self.roberta( 185 | input_ids=None, 186 | attention_mask=None, # attention_mask, 187 | token_type_ids=token_type_ids, 188 | position_ids=position_ids, 189 | head_mask=head_mask, 190 | inputs_embeds=inputs_embeds, 191 | encoder_hidden_states=encoder_hidden_states, 192 | encoder_attention_mask=encoder_attention_mask, 193 | output_attentions=output_attentions, 194 | output_hidden_states=output_hidden_states, 195 | return_dict=return_dict, 196 | ) 197 | sequence_output = outputs[0] 198 | prediction_scores = self.lm_head(sequence_output) 199 | 200 | masked_lm_loss = None 201 | # In case of classifier-free guidance, since the number of output logits and input token ids do not match 202 | # we do not compute the loss. 203 | if input_ids is not None and not classifier_free_guidance: 204 | loss_fct = CrossEntropyLoss() 205 | labels = ( 206 | torch.where(span_mask, input_ids, -100) 207 | if span_mask is not None 208 | else input_ids 209 | ) 210 | masked_lm_loss = loss_fct( 211 | prediction_scores.view(-1, self.config.vocab_size), labels.view(-1) 212 | ) 213 | 214 | if not return_dict: 215 | output = (prediction_scores,) + outputs[2:] 216 | return ( 217 | ((masked_lm_loss,) + output) if masked_lm_loss is not None else output 218 | ) 219 | 220 | return MaskedLMOutput( 221 | loss=masked_lm_loss, 222 | logits=prediction_scores, 223 | hidden_states=outputs.last_hidden_state, 224 | attentions=outputs.attentions, 225 | ) 226 | -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/models/h3/ssm/h3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from einops import rearrange 5 | 6 | from .ss_kernel import SSKernel 7 | 8 | try: 9 | from ..ops.fftconv import fftconv_func 10 | except ImportError: 11 | fftconv_func = None 12 | 13 | 14 | @torch.jit.script 15 | def mul_sum(q, y): 16 | return (q * y).sum(dim=1) 17 | 18 | 19 | class H3(nn.Module): 20 | def __init__( 21 | self, 22 | d_model, 23 | d_state=64, 24 | l_max=None, 25 | head_dim=1, 26 | use_fast_fftconv=False, 27 | dropout=0.0, # Just to absorb the kwarg 28 | layer_idx=None, 29 | device=None, 30 | dtype=None, 31 | # SSM Kernel arguments 32 | **kernel_args, 33 | ): 34 | """ 35 | d_state: the dimension of the state, also denoted by N 36 | l_max: the maximum kernel length, also denoted by L. Set l_max=None to always use a global kernel 37 | 38 | See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. Relevant options that are worth considering and tuning include "mode" + "measure", "dt_min", "dt_max", "lr" 39 | 40 | Other options are all experimental and should not need to be configured 41 | """ 42 | factory_kwargs = {"device": device, "dtype": dtype} 43 | super().__init__() 44 | self.d_model = d_model 45 | self.head_dim = head_dim 46 | assert d_model % head_dim == 0 47 | self.H = d_model // head_dim 48 | self.N = d_state 49 | self.L = l_max 50 | self.layer_idx = layer_idx 51 | self.use_fast_fftconv = use_fast_fftconv 52 | if self.use_fast_fftconv: 53 | assert fftconv_func is not None, "Need to install fftconv" 54 | 55 | self.q_proj = nn.Linear(self.d_model, self.d_model, **factory_kwargs) 56 | self.k_proj = nn.Linear(self.d_model, self.d_model, **factory_kwargs) 57 | self.v_proj = nn.Linear(self.d_model, self.d_model, **factory_kwargs) 58 | 59 | # TODO: SSKernel doesn't take device argument yet 60 | self.ssm_k_kernel = SSKernel( 61 | self.d_model, 62 | N=d_state, 63 | L=self.L, 64 | mode="shift", 65 | lr=kernel_args.get("lr", None), 66 | ) 67 | self.ssm_k_D = nn.Parameter(torch.randn(self.d_model)) 68 | # S4D Kernel 69 | self.kernel = SSKernel(self.H, N=self.N, L=self.L, channels=1, **kernel_args) 70 | self.D = nn.Parameter(torch.randn(self.H, **factory_kwargs)) 71 | 72 | # Pointwise 73 | # position-wise output transform to mix features 74 | # Don't use FusedDense since the layout is H first 75 | self.output_linear = nn.Linear(self.d_model, self.d_model) 76 | 77 | def forward(self, u, inference_params=None): 78 | """ 79 | u: (B L H) 80 | 81 | Returns: same shape as u 82 | """ 83 | L_og = u.size(-2) 84 | if self.use_fast_fftconv and L_og % 2 != 0: 85 | u = F.pad(u, (0, 0, 0, 1)) 86 | L = u.size(-2) 87 | 88 | use_fast_fftconv = self.use_fast_fftconv and inference_params is None 89 | 90 | state_k, state = None, None 91 | if inference_params is not None: 92 | assert self.layer_idx is not None 93 | if self.layer_idx not in inference_params.key_value_memory_dict: 94 | batch_shape = (u.shape[0] * self.head_dim * self.head_dim,) 95 | state_k = self.ssm_k_kernel.default_state(*batch_shape) 96 | state = self.kernel.default_state(*batch_shape) 97 | inference_params.key_value_memory_dict[self.layer_idx] = ( 98 | state_k, 99 | state, 100 | ) 101 | else: 102 | state_k, state = inference_params.key_value_memory_dict[self.layer_idx] 103 | if inference_params.sequence_len_offset == 0: 104 | self.ssm_k_kernel._setup_step() 105 | self.kernel._setup_step() 106 | 107 | if inference_params is not None and inference_params.sequence_len_offset > 0: 108 | y, next_state_k, next_state = self.step(u, state_k, state) 109 | inference_params.key_value_memory_dict[self.layer_idx][0].copy_( 110 | next_state_k 111 | ) 112 | inference_params.key_value_memory_dict[self.layer_idx][1].copy_(next_state) 113 | return y 114 | 115 | # Compute SS Kernel 116 | L_kernel = L if self.L is None else min(L, self.L) 117 | ssm_kernel, k_state = self.kernel( 118 | L=L_kernel, state=state, rate=1.0 119 | ) # (C H L) (B C H L) 120 | ssm_kernel = rearrange(ssm_kernel, "1 h l -> h l") 121 | 122 | u = rearrange(u, "b l h -> (b l) h") 123 | dtype = ( 124 | self.q_proj.weight.dtype 125 | if not torch.is_autocast_enabled() 126 | else torch.get_autocast_gpu_dtype() 127 | ) 128 | q = self.q_proj.weight @ u.T + self.q_proj.bias.to(dtype).unsqueeze(-1) 129 | k = self.k_proj.weight @ u.T + self.k_proj.bias.to(dtype).unsqueeze(-1) 130 | v = self.v_proj.weight @ u.T + self.v_proj.bias.to(dtype).unsqueeze(-1) 131 | q, k, v = [rearrange(x, "h (b l) -> b h l", l=L) for x in [q, k, v]] 132 | 133 | k_og = k 134 | ssm_k_kernel, _ = self.ssm_k_kernel( 135 | L=L_kernel, state=state_k, rate=1.0 136 | ) # (C H L) (B C H L) 137 | ssm_k_kernel = rearrange(ssm_k_kernel, "1 h l -> h l") 138 | if not use_fast_fftconv: 139 | fft_size = L_kernel + L 140 | ssm_k_kernel_f = torch.fft.rfft(ssm_k_kernel, n=fft_size) # (H 2L) 141 | k_f = torch.fft.rfft(k.to(ssm_kernel.dtype), n=fft_size) # (B H 2L) 142 | shift_k_out = torch.fft.irfft(ssm_k_kernel_f * k_f, n=fft_size)[..., :L] 143 | k = shift_k_out + rearrange(self.ssm_k_D, "h -> h 1") * k 144 | else: 145 | dropout_mask = None 146 | # No GeLU after the SSM 147 | # We want output_hbl=True so that k has the same layout as q and v for the next 148 | # fftconv 149 | k = fftconv_func( 150 | k, ssm_k_kernel, self.ssm_k_D, dropout_mask, False, False, True 151 | ) 152 | # This line below looks like it doesn't do anything, but it gets the stride right 153 | # for the case batch_size=1. In that case k has stride (L, L, 1), but q and v has 154 | # stride (H * L, L, 1). The two strides are equivalent because batch_size=1, but 155 | # the C++ code doesn't like that. 156 | k = rearrange(rearrange(k, "b h l -> h b l"), "h b l -> b h l") 157 | 158 | if not use_fast_fftconv: 159 | fft_size = L_kernel + L 160 | # kv = k * v 161 | kv = rearrange(k, "b (h d1) l -> b d1 1 h l", d1=self.head_dim) * rearrange( 162 | v, "b (h d2) l -> b 1 d2 h l", d2=self.head_dim 163 | ) # b d1 d2 h l 164 | kv_f = torch.fft.rfft(kv.to(dtype=ssm_kernel.dtype), n=fft_size) / fft_size 165 | ssm_kernel_f = torch.fft.rfft(ssm_kernel, n=fft_size) # h L+1 166 | y = torch.fft.irfft(kv_f * ssm_kernel_f, n=fft_size, norm="forward")[ 167 | ..., :L 168 | ] # b d1 d2 h l 169 | y = y + kv * self.D.unsqueeze(-1) # b d1 d2 h l 170 | q = rearrange(q, "b (h d1) l -> b d1 1 h l", d1=self.head_dim) 171 | # einsum is way slower than multiply and then sum. 172 | if self.head_dim > 1: 173 | y = mul_sum(y, q) 174 | y = rearrange(y, "b d h l -> b (d h) l") 175 | else: 176 | y = rearrange(y * q, "b 1 1 h l -> b h l") 177 | else: 178 | dropout_mask = None 179 | # No GeLU after the SSM 180 | # Set output_hbl_layout=True since we'll be doing a matmul right after 181 | y = fftconv_func( 182 | k, 183 | ssm_kernel, 184 | self.D, 185 | dropout_mask, 186 | False, 187 | torch.is_autocast_enabled(), 188 | True, 189 | v, 190 | self.head_dim, 191 | q, 192 | ) 193 | 194 | y = rearrange(y, "b h l -> b l h") 195 | 196 | if state is not None: 197 | assert inference_params is not None 198 | # TODO: This doesn't ever happen? 199 | # if inference_params.sequence_len_offset > 0: 200 | # y = y + k_state 201 | inference_params.key_value_memory_dict[self.layer_idx][0].copy_( 202 | self.ssm_k_kernel.forward_state(k_og, state_k) 203 | ) 204 | inference_params.key_value_memory_dict[self.layer_idx][1].copy_( 205 | self.kernel.forward_state( 206 | rearrange(kv, "b d1 d2 h l -> (b d1 d2) h l"), state 207 | ) 208 | ) 209 | 210 | # y could be in fp32 because of the SSMs 211 | if not torch.is_autocast_enabled(): 212 | y = y.to(dtype=self.output_linear.weight.dtype) 213 | y = self.output_linear(y) 214 | if L_og < L: 215 | y = y[:, :L_og, :] 216 | 217 | return y 218 | 219 | def step(self, u, state_k, state): 220 | q, k, v = self.q_proj(u), self.k_proj(u), self.v_proj(u) 221 | shift_k, next_state_k = self.ssm_k_kernel.step( 222 | rearrange(k, "b 1 h -> b h"), state_k 223 | ) 224 | k = shift_k + k * self.ssm_k_D 225 | # kv = k * v 226 | kv = rearrange(k, "b 1 (h d1) -> b d1 1 h", d1=self.head_dim) * rearrange( 227 | v, "b 1 (h d2) -> b 1 d2 h", d2=self.head_dim 228 | ) # b d1 d2 h 229 | y, next_state = self.kernel.step( 230 | rearrange(kv, "b d1 d2 h -> (b d1 d2) h"), state 231 | ) 232 | y = ( 233 | rearrange( 234 | y, "(b d1 d2) 1 h -> b d1 d2 h", d1=self.head_dim, d2=self.head_dim 235 | ) 236 | + kv * self.D 237 | ) 238 | q = rearrange(q, "b 1 (h d1) -> b d1 1 h", d1=self.head_dim) 239 | if self.head_dim > 1: 240 | y = mul_sum(y, q) 241 | y = rearrange(y, "b d h l -> b (d h) l") 242 | else: 243 | y = rearrange(y * q, "b 1 1 h -> b 1 h") 244 | # y could be in fp32 because of the SSMs 245 | if not torch.is_autocast_enabled(): 246 | y = y.to(dtype=self.output_linear.weight.dtype) 247 | return self.output_linear(y), next_state_k, next_state 248 | -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/models/h3/ssm/hippo.py: -------------------------------------------------------------------------------- 1 | # Copied from https://github.com/HazyResearch/state-spaces/blob/06dbbdfd0876501a7f12bf3262121badbc7658af/src/models/hippo/hippo.py 2 | 3 | """ Definitions of A and B matrices for various HiPPO operators. """ 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from einops import rearrange, repeat 10 | from opt_einsum import contract 11 | from scipy import special as ss 12 | 13 | 14 | def embed_c2r(A): 15 | A = rearrange(A, "... m n -> ... m () n ()") 16 | A = np.pad(A, ((0, 0), (0, 1), (0, 0), (0, 1))) + np.pad( 17 | A, ((0, 0), (1, 0), (0, 0), (1, 0)) 18 | ) 19 | return rearrange(A, "m x n y -> (m x) (n y)") 20 | 21 | 22 | # TODO take in 'torch' option to return torch instead of numpy, and converts the shape of B from (N, 1) to (N) 23 | def transition(measure, N, **measure_args): 24 | """A, B transition matrices for different measures 25 | measure: the type of measure 26 | legt - Legendre (translated) 27 | legs - Legendre (scaled) 28 | glagt - generalized Laguerre (translated) 29 | lagt, tlagt - previous versions of (tilted) Laguerre with slightly different normalization 30 | """ 31 | # Laguerre (translated) 32 | if measure == "lagt": 33 | b = measure_args.get("beta", 1.0) 34 | A = np.eye(N) / 2 - np.tril(np.ones((N, N))) 35 | B = b * np.ones((N, 1)) 36 | # Generalized Laguerre 37 | # alpha 0, beta small is most stable (limits to the 'lagt' measure) 38 | # alpha 0, beta 1 has transition matrix A = [lower triangular 1] 39 | elif measure == "glagt": 40 | alpha = measure_args.get("alpha", 0.0) 41 | beta = measure_args.get("beta", 0.01) 42 | A = -np.eye(N) * (1 + beta) / 2 - np.tril(np.ones((N, N)), -1) 43 | B = ss.binom(alpha + np.arange(N), np.arange(N))[:, None] 44 | 45 | L = np.exp( 46 | 0.5 * (ss.gammaln(np.arange(N) + alpha + 1) - ss.gammaln(np.arange(N) + 1)) 47 | ) 48 | A = (1.0 / L[:, None]) * A * L[None, :] 49 | B = ( 50 | (1.0 / L[:, None]) 51 | * B 52 | * np.exp(-0.5 * ss.gammaln(1 - alpha)) 53 | * beta ** ((1 - alpha) / 2) 54 | ) 55 | # Legendre (translated) 56 | elif measure == "legt": 57 | Q = np.arange(N, dtype=np.float64) 58 | R = (2 * Q + 1) ** 0.5 59 | j, i = np.meshgrid(Q, Q) 60 | A = R[:, None] * np.where(i < j, (-1.0) ** (i - j), 1) * R[None, :] 61 | B = R[:, None] 62 | A = -A 63 | 64 | # Halve again for timescale correctness 65 | A *= 0.5 66 | B *= 0.5 67 | # LMU: equivalent to LegT up to normalization 68 | elif measure == "lmu": 69 | Q = np.arange(N, dtype=np.float64) 70 | R = (2 * Q + 1)[:, None] # / theta 71 | j, i = np.meshgrid(Q, Q) 72 | A = np.where(i < j, -1, (-1.0) ** (i - j + 1)) * R 73 | B = (-1.0) ** Q[:, None] * R 74 | # Legendre (scaled) 75 | elif measure == "legs": 76 | q = np.arange(N, dtype=np.float64) 77 | col, row = np.meshgrid(q, q) 78 | r = 2 * q + 1 79 | M = -(np.where(row >= col, r, 0) - np.diag(q)) 80 | T = np.sqrt(np.diag(2 * q + 1)) 81 | A = T @ M @ np.linalg.inv(T) 82 | B = np.diag(T)[:, None] 83 | B = ( 84 | B.copy() 85 | ) # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B) 86 | elif measure == "legsd": 87 | q = np.arange(N, dtype=np.float64) 88 | col, row = np.meshgrid(q, q) 89 | r = 2 * q + 1 90 | M = -(np.where(row >= col, r, 0) - np.diag(q)) 91 | T = np.sqrt(np.diag(2 * q + 1)) 92 | A = T @ M @ np.linalg.inv(T) 93 | B = np.diag(T)[:, None] 94 | B = ( 95 | B.copy() 96 | ) # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B) 97 | A += 0.5 * B * B[None, :, 0] 98 | B = B / 2.0 99 | elif measure in ["fourier_diag", "foud"]: 100 | freqs = np.arange(N // 2) 101 | d = np.stack([freqs, np.zeros(N // 2)], axis=-1).reshape(-1)[:-1] 102 | A = 2 * np.pi * (-np.diag(d, 1) + np.diag(d, -1)) 103 | A = A - 0.5 * np.eye(N) 104 | B = np.zeros(N) 105 | B[0::2] = 2**0.5 106 | B[0] = 1 107 | B = B[:, None] 108 | elif measure in ["fourier", "fout"]: 109 | freqs = np.arange(N // 2) 110 | d = np.stack([np.zeros(N // 2), freqs], axis=-1).reshape(-1)[1:] 111 | A = np.pi * (-np.diag(d, 1) + np.diag(d, -1)) 112 | B = np.zeros(N) 113 | B[0::2] = 2**0.5 114 | B[0] = 1 115 | 116 | # Subtract off rank correction - this corresponds to the other endpoint u(t-1) in this case 117 | A = A - B[:, None] * B[None, :] 118 | B = B[:, None] 119 | elif measure == "fourier_decay": 120 | freqs = np.arange(N // 2) 121 | d = np.stack([np.zeros(N // 2), freqs], axis=-1).reshape(-1)[1:] 122 | A = np.pi * (-np.diag(d, 1) + np.diag(d, -1)) 123 | B = np.zeros(N) 124 | B[0::2] = 2**0.5 125 | B[0] = 1 126 | 127 | # Subtract off rank correction - this corresponds to the other endpoint u(t-1) in this case 128 | A = A - 0.5 * B[:, None] * B[None, :] 129 | B = 0.5 * B[:, None] 130 | elif measure == "fourier2": # Double everything: orthonormal on [0, 1] 131 | freqs = 2 * np.arange(N // 2) 132 | d = np.stack([np.zeros(N // 2), freqs], axis=-1).reshape(-1)[1:] 133 | A = np.pi * (-np.diag(d, 1) + np.diag(d, -1)) 134 | B = np.zeros(N) 135 | B[0::2] = 2**0.5 136 | B[0] = 1 137 | 138 | # Subtract off rank correction - this corresponds to the other endpoint u(t-1) in this case 139 | A = A - B[:, None] * B[None, :] * 2 140 | B = B[:, None] * 2 141 | elif measure == "random": 142 | A = np.random.randn(N, N) / N 143 | B = np.random.randn(N, 1) 144 | elif measure == "diagonal": 145 | A = -np.diag(np.exp(np.random.randn(N))) 146 | B = np.random.randn(N, 1) 147 | else: 148 | raise NotImplementedError 149 | 150 | return A, B 151 | 152 | 153 | def rank_correction(measure, N, rank=1, dtype=torch.float): 154 | """Return low-rank matrix L such that A + L is normal""" 155 | 156 | if measure == "legs": 157 | assert rank >= 1 158 | P = torch.sqrt(0.5 + torch.arange(N, dtype=dtype)).unsqueeze(0) # (1 N) 159 | elif measure == "legt": 160 | assert rank >= 2 161 | P = torch.sqrt(1 + 2 * torch.arange(N, dtype=dtype)) # (N) 162 | P0 = P.clone() 163 | P0[0::2] = 0.0 164 | P1 = P.clone() 165 | P1[1::2] = 0.0 166 | P = torch.stack([P0, P1], dim=0) # (2 N) 167 | P *= 2 ** ( 168 | -0.5 169 | ) # Halve the rank correct just like the original matrix was halved 170 | elif measure == "lagt": 171 | assert rank >= 1 172 | P = 0.5**0.5 * torch.ones(1, N, dtype=dtype) 173 | elif measure in ["fourier", "fout"]: 174 | P = torch.zeros(N) 175 | P[0::2] = 2**0.5 176 | P[0] = 1 177 | P = P.unsqueeze(0) 178 | elif measure == "fourier_decay": 179 | P = torch.zeros(N) 180 | P[0::2] = 2**0.5 181 | P[0] = 1 182 | P = P.unsqueeze(0) 183 | P = P / 2**0.5 184 | elif measure == "fourier2": 185 | P = torch.zeros(N) 186 | P[0::2] = 2**0.5 187 | P[0] = 1 188 | P = 2**0.5 * P.unsqueeze(0) 189 | elif measure in ["fourier_diag", "foud", "legsd"]: 190 | P = torch.zeros(1, N, dtype=dtype) 191 | else: 192 | raise NotImplementedError 193 | 194 | d = P.size(0) 195 | if rank > d: 196 | P = torch.cat([P, torch.zeros(rank - d, N, dtype=dtype)], dim=0) # (rank N) 197 | return P 198 | 199 | 200 | def initial_C(measure, N, dtype=torch.float): 201 | """Return C that captures the other endpoint in the HiPPO approximation""" 202 | 203 | if measure == "legt": 204 | C = (torch.arange(N, dtype=dtype) * 2 + 1) ** 0.5 * (-1) ** torch.arange(N) 205 | elif measure == "fourier": 206 | C = torch.zeros(N) 207 | C[0::2] = 2**0.5 208 | C[0] = 1 209 | else: 210 | C = torch.zeros(N, dtype=dtype) # (N) 211 | 212 | return C 213 | 214 | 215 | def nplr(measure, N, rank=1, dtype=torch.float, diagonalize_precision=True): 216 | """Return w, p, q, V, B such that 217 | (w - p q^*, B) is unitarily equivalent to the original HiPPO A, B by the matrix V 218 | i.e. A = V[w - p q^*]V^*, B = V B 219 | """ 220 | assert dtype == torch.float or dtype == torch.double 221 | cdtype = torch.cfloat if dtype == torch.float else torch.cdouble 222 | 223 | A, B = transition(measure, N) 224 | A = torch.as_tensor(A, dtype=dtype) # (N, N) 225 | B = torch.as_tensor(B, dtype=dtype)[:, 0] # (N,) 226 | 227 | P = rank_correction(measure, N, rank=rank, dtype=dtype) # (r N) 228 | AP = A + torch.sum(P.unsqueeze(-2) * P.unsqueeze(-1), dim=-3) 229 | 230 | # We require AP to be nearly skew-symmetric 231 | _A = AP + AP.transpose(-1, -2) 232 | if ( 233 | err := torch.sum((_A - _A[0, 0] * torch.eye(N)) ** 2) / N 234 | ) > 1e-5: # if not torch.allclose(_A - _A[0,0]*torch.eye(N), torch.zeros(N, N), atol=1e-5): 235 | print("WARNING: HiPPO matrix not skew symmetric", err) 236 | 237 | # Take advantage of identity + skew-symmetric form to calculate real and imaginary parts separately 238 | # Imaginary part can use eigh instead of eig 239 | w_re = torch.mean(torch.diagonal(AP), -1, keepdim=True) 240 | 241 | # Diagonalize in double precision 242 | if diagonalize_precision: 243 | AP = AP.to(torch.double) 244 | # w, V = torch.linalg.eig(AP) # (..., N) (..., N, N) 245 | w_im, V = torch.linalg.eigh(AP * -1j) # (..., N) (..., N, N) 246 | if diagonalize_precision: 247 | w_im, V = w_im.to(cdtype), V.to(cdtype) 248 | w = w_re + 1j * w_im 249 | # Check: V w V^{-1} = A 250 | # print("check", V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2)) 251 | 252 | # Only keep half of each conjugate pair 253 | _, idx = torch.sort(w.imag) 254 | w_sorted = w[idx] 255 | V_sorted = V[:, idx] 256 | 257 | # There is an edge case when eigenvalues can be 0, which requires some machinery to handle 258 | # We use a huge hack here: Assume only one pair is 0, and that it is the first row/column of A (only happens in Fourier case) 259 | V = V_sorted[:, : N // 2] 260 | w = w_sorted[: N // 2] 261 | assert w[-2].abs() > 1e-4, "Only 1 zero eigenvalue allowed in diagonal part of A" 262 | if w[-1].abs() < 1e-4: 263 | V[:, -1] = 0.0 264 | V[0, -1] = 2**-0.5 265 | V[1, -1] = 2**-0.5 * 1j 266 | 267 | _AP = V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2) 268 | if (err := torch.sum((2 * _AP.real - AP) ** 2) / N) > 1e-5: 269 | print( 270 | "Warning: Diagonalization of A matrix not numerically precise - error", err 271 | ) 272 | # print("check", V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2)) 273 | 274 | V_inv = V.conj().transpose(-1, -2) 275 | 276 | # C = initial_C(measure, N, dtype=dtype) 277 | B = contract("ij, j -> i", V_inv, B.to(V)) # V^* B 278 | # C = contract('ij, j -> i', V_inv, C.to(V)) # V^* C 279 | P = contract("ij, ...j -> ...i", V_inv, P.to(V)) # V^* P 280 | 281 | # return w, P, B, C, V 282 | return w, P, B, V 283 | -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/data/data_collator.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass 3 | from enum import Enum 4 | from random import choices 5 | from typing import Any, Dict, List, Optional, Union 6 | 7 | import numpy as np 8 | import torch 9 | from transformers.tokenization_utils_base import PreTrainedTokenizerBase 10 | from transformers.utils import PaddingStrategy 11 | 12 | from data.preprocessors import ( 13 | gpt_span_mask_batch, 14 | insert_extra_paddings, 15 | t5_random_spans_mask_batch, 16 | ) 17 | 18 | 19 | class Objective(Enum): 20 | # Prefix language modeling like GPT style pretraining. 21 | prefix = 1 22 | # T5 objective with a range of 2 to 5 tokens as the span length, which masks about 15% of input tokens. 23 | t5 = 2 24 | # Aggressive denoising where approximately 50% of the input sequence is masked. 25 | aggressive_t5 = 3 26 | # Unconditional generation case. 27 | unconditional = 4 28 | 29 | 30 | # TODO: automize this one. 31 | # TODO: these are for sequence length of 100, adapt for 200. 32 | OBJECTIVE_SETTINGS = { 33 | Objective.t5: [ 34 | {"mask_ratio": 0.15, "mean_mask_span_length": 8}, 35 | {"mask_ratio": 0.15, "mean_mask_span_length": 3}, 36 | ], 37 | Objective.aggressive_t5: [ 38 | {"mask_ratio": 0.5, "mean_mask_span_length": 8}, 39 | {"mask_ratio": 0.5, "mean_mask_span_length": 3}, 40 | {"mask_ratio": 0.5, "mean_mask_span_length": 48}, 41 | ], 42 | } 43 | 44 | 45 | @dataclass 46 | class SpanInfillingDataCollator: 47 | """ 48 | Data collator that will dynamically pad the inputs received. 49 | Args: 50 | tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]): 51 | The tokenizer used for encoding the data. 52 | padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): 53 | Select a strategy to pad the returned sequences (according to the model's padding side and padding index) 54 | among: 55 | - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single 56 | sequence is provided). 57 | - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum 58 | acceptable input length for the model if that argument is not provided. 59 | - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths). 60 | max_length (`int`, *optional*): 61 | Maximum length of the returned list and optionally padding length (see above). 62 | pad_to_multiple_of (`int`, *optional*): 63 | If set will pad the sequence to a multiple of the provided value. 64 | This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 65 | 7.5 (Volta). 66 | return_tensors (`str`): 67 | The type of Tensor to return. Allowable values are "np", "pt" and "tf". 68 | """ 69 | 70 | def __init__( 71 | self, 72 | mode, 73 | data_args, 74 | tokenizer: PreTrainedTokenizerBase, 75 | padding: Union[bool, str, PaddingStrategy] = True, 76 | max_length: Optional[int] = None, 77 | pad_to_multiple_of: Optional[int] = None, 78 | return_tensors: str = "pt", 79 | seed: int = 42, 80 | eval_context_size: int = None, 81 | ): 82 | self.tokenizer = tokenizer 83 | self.padding = padding 84 | self.max_length = max_length 85 | self.pad_to_multiple_of = pad_to_multiple_of 86 | self.return_tensors = return_tensors 87 | self.conditional_generation = data_args.conditional_generation 88 | self.extra_padding_ratio = data_args.extra_padding_ratio 89 | self.rng = np.random.default_rng(seed) 90 | self.eval_context_size = eval_context_size 91 | self.mode = mode 92 | if self.conditional_generation == "ul2_with_unconditional" and mode == "train": 93 | self.mask_generator = {} 94 | self.mask_generator[ 95 | Objective.t5 96 | ] = lambda batch, setting: t5_random_spans_mask_batch( 97 | batch, **setting, rng=self.rng 98 | ) 99 | self.mask_generator[ 100 | Objective.aggressive_t5 101 | ] = lambda batch, setting: t5_random_spans_mask_batch( 102 | batch, **setting, rng=self.rng 103 | ) 104 | self.mask_generator[Objective.prefix] = lambda batch: gpt_span_mask_batch( 105 | batch 106 | ) 107 | self.mask_generator[Objective.unconditional] = lambda batch: None 108 | elif self.conditional_generation == "span_infilling": 109 | self.mask_generator = lambda batch: t5_random_spans_mask_batch( 110 | batch, data_args.mask_ratio, data_args.mean_mask_span_length, self.rng 111 | ) 112 | elif self.conditional_generation == "prefix_lm": 113 | self.mask_generator = lambda batch: gpt_span_mask_batch( 114 | batch, 115 | use_half_length_as_prefix_size=(mode == "eval"), 116 | eval_context_size=eval_context_size, 117 | ) 118 | elif self.conditional_generation == "ul2" and mode == "train": 119 | self.mask_generator = {} 120 | self.mask_generator[ 121 | Objective.t5 122 | ] = lambda batch, setting: t5_random_spans_mask_batch( 123 | batch, **setting, rng=self.rng 124 | ) 125 | self.mask_generator[ 126 | Objective.aggressive_t5 127 | ] = lambda batch, setting: t5_random_spans_mask_batch( 128 | batch, **setting, rng=self.rng 129 | ) 130 | self.mask_generator[Objective.prefix] = lambda batch: gpt_span_mask_batch( 131 | batch 132 | ) 133 | elif self.conditional_generation == "ul2_variable" and mode == "train": 134 | self.mask_generator = {} 135 | self.mask_generator[ 136 | Objective.t5 137 | ] = lambda batch, mask_ratio, mean_mask_span_length: t5_random_spans_mask_batch( 138 | batch, 139 | mask_ratio=mask_ratio, 140 | mean_mask_span_length=mean_mask_span_length, 141 | rng=self.rng, 142 | ) 143 | self.mask_generator[Objective.prefix] = lambda batch: gpt_span_mask_batch( 144 | batch 145 | ) 146 | 147 | def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: 148 | 149 | if self.extra_padding_ratio: 150 | # Inserting random tokens uniformly, we do not modify start and end of 151 | # sequence tokens. 152 | for i in range(len(features)): 153 | features[i]["input_ids"] = insert_extra_paddings( 154 | self.rng, 155 | features[i]["input_ids"], 156 | self.tokenizer.pad_token_id, 157 | self.extra_padding_ratio, 158 | ) 159 | 160 | masks = {} 161 | if self.conditional_generation in ["span_infilling", "prefix_lm"]: 162 | masks = {"span_mask": self.mask_generator(features)} 163 | elif ( 164 | self.conditional_generation == "ul2_with_unconditional" 165 | and self.mode == "train" 166 | ): 167 | objectives = [ 168 | Objective.unconditional, 169 | Objective.t5, 170 | Objective.prefix, 171 | Objective.aggressive_t5, 172 | ] 173 | weights = [0.25, 0.25, 0.25, 0.25] 174 | objective = choices(objectives, weights)[0] 175 | if objective in [Objective.t5, Objective.aggressive_t5]: 176 | setting = choices(OBJECTIVE_SETTINGS[objective])[0] 177 | masks = {"span_mask": self.mask_generator[objective](features, setting)} 178 | else: 179 | masks = {"span_mask": self.mask_generator[objective](features)} 180 | elif self.conditional_generation == "ul2" and self.mode == "train": 181 | objectives = [Objective.t5, Objective.prefix, Objective.aggressive_t5] 182 | weights = [0.25, 0.25, 0.25] 183 | objective = choices(objectives, weights)[0] 184 | if objective in [Objective.t5, Objective.aggressive_t5]: 185 | setting = choices(OBJECTIVE_SETTINGS[objective])[0] 186 | masks = {"span_mask": self.mask_generator[objective](features, setting)} 187 | else: 188 | masks = {"span_mask": self.mask_generator[objective](features)} 189 | elif self.conditional_generation == "ul2_variable" and self.mode == "train": 190 | objectives = [Objective.t5, Objective.prefix] 191 | weights = [0.5, 0.5] 192 | objective = choices(objectives, weights)[0] 193 | if objective == objective.t5: 194 | # Here we assume the length is the same for all data in a batch. 195 | length = len(features[0]["input_ids"]) 196 | min_ratio = 1.0 / length 197 | mask_ratio = random.uniform(min_ratio, 0.5) 198 | mean_mask_span_length = int(random.uniform(1, mask_ratio * length)) 199 | masks = { 200 | "span_mask": self.mask_generator[objective]( 201 | features, mask_ratio, mean_mask_span_length 202 | ) 203 | } 204 | else: 205 | masks = {"span_mask": self.mask_generator[objective](features)} 206 | elif self.mode == "eval" and self.conditional_generation in [ 207 | "ul2", 208 | "ul2_with_unconditional", 209 | "ul2_variable", 210 | ]: 211 | masks = { 212 | "span_mask": gpt_span_mask_batch( 213 | features, 214 | use_half_length_as_prefix_size=True, 215 | eval_context_size=self.eval_context_size, 216 | ) 217 | } 218 | batch = self.tokenizer.pad( 219 | features, 220 | padding=self.padding, 221 | max_length=self.max_length, 222 | pad_to_multiple_of=self.pad_to_multiple_of, 223 | return_tensors=self.return_tensors, 224 | ) 225 | if "attention_mask" in batch: 226 | del batch["attention_mask"] 227 | return {**batch, **masks} 228 | 229 | 230 | @dataclass 231 | class DataCollatorForSeq2Seq: 232 | """ 233 | Data collator that will dynamically pad the inputs received, as well as the labels. 234 | Args: 235 | tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]): 236 | The tokenizer used for encoding the data. 237 | padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): 238 | Select a strategy to pad the returned sequences (according to the model's padding side and padding index) 239 | among: 240 | - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single sequence 241 | is provided). 242 | - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum 243 | acceptable input length for the model if that argument is not provided. 244 | - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different 245 | lengths). 246 | max_length (`int`, *optional*): 247 | Maximum length of the returned list and optionally padding length (see above). 248 | pad_to_multiple_of (`int`, *optional*): 249 | If set will pad the sequence to a multiple of the provided value. 250 | This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 251 | 7.5 (Volta). 252 | """ 253 | 254 | tokenizer: PreTrainedTokenizerBase 255 | padding: Union[bool, str, PaddingStrategy] = True 256 | max_length: Optional[int] = None 257 | pad_to_multiple_of: Optional[int] = None 258 | 259 | def __call__(self, features): 260 | input_ids = [feature["input_ids"] for feature in features] 261 | labels = [feature["labels"] for feature in features] 262 | input_target = [input + target for input, target in zip(input_ids, labels)] 263 | features = self.tokenizer.pad( 264 | {"input_ids": input_target}, 265 | padding=self.padding, 266 | max_length=self.max_length, 267 | pad_to_multiple_of=self.pad_to_multiple_of, 268 | return_tensors="pt", 269 | ) 270 | #yuhan 271 | features.pop('attention_mask') 272 | batch_length = features["input_ids"].shape[1] 273 | masks = [ 274 | len(input) * [False] + (batch_length - len(input)) * [True] 275 | for input in input_ids 276 | ] 277 | features["span_mask"] = torch.tensor(masks) 278 | return features 279 | -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/models/h3/ssm/ss_kernel_diag.py: -------------------------------------------------------------------------------- 1 | # TD: [2023-01-05]: Extracted the SSKernelDiag class from 2 | # https://github.com/HazyResearch/state-spaces/blob/06dbbdfd0876501a7f12bf3262121badbc7658af/src/models/sequence/ss/kernel.py 3 | # We make a small change to use the log_vandermonde CUDA code. 4 | 5 | """SSKernelDiag is the S4D kernel, a simpler algorithm for computing the kernel for the case of diagonal state matrices A. 6 | """ 7 | import math 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from einops import rearrange, repeat 13 | from opt_einsum import contract 14 | 15 | from ..utils.utils import get_logger 16 | from .ssm_utils import OptimModule 17 | 18 | log = get_logger(__name__) 19 | 20 | # This could be None if the CUDA import fails 21 | from ..ops.vandermonde import log_vandermonde_fast 22 | 23 | try: 24 | import pykeops 25 | 26 | from ..ops.vandermonde import log_vandermonde, log_vandermonde_transpose 27 | 28 | has_pykeops = True 29 | log.info("Pykeops installation found.") 30 | except ImportError: 31 | has_pykeops = False 32 | from ..ops.vandermonde import log_vandermonde_naive as log_vandermonde 33 | from ..ops.vandermonde import ( 34 | log_vandermonde_transpose_naive as log_vandermonde_transpose, 35 | ) 36 | 37 | log.warning( 38 | "Falling back on slow Vandermonde kernel. Install pykeops for improved memory efficiency." 39 | ) 40 | 41 | 42 | _c2r = torch.view_as_real 43 | _r2c = torch.view_as_complex 44 | 45 | if tuple(map(int, torch.__version__.split(".")[:2])) >= (1, 10): 46 | _resolve_conj = lambda x: x.conj().resolve_conj() 47 | else: 48 | _resolve_conj = lambda x: x.conj() 49 | 50 | 51 | class SSKernelDiag(OptimModule): 52 | """Version using (complex) diagonal state matrix (S4D)""" 53 | 54 | def __init__( 55 | self, 56 | A, 57 | B, 58 | C, 59 | log_dt, 60 | L=None, 61 | disc="bilinear", 62 | real_type="exp", 63 | lr=None, 64 | bandlimit=None, 65 | force_real=False, 66 | ): 67 | 68 | super().__init__() 69 | self.L = L 70 | self.disc = disc 71 | self.bandlimit = bandlimit 72 | self.real_type = real_type 73 | self.force_real = force_real 74 | 75 | # Rank of low-rank correction 76 | assert A.size(-1) == C.size(-1) 77 | self.H = log_dt.size(-1) 78 | self.N = A.size(-1) 79 | assert A.size(-2) == B.size(-2) # Number of independent SSMs trained 80 | assert self.H % A.size(-2) == 0 81 | self.n_ssm = A.size(-2) 82 | self.repeat = self.H // A.size(0) 83 | 84 | self.channels = C.shape[0] 85 | self.C = nn.Parameter(_c2r(_resolve_conj(C))) 86 | 87 | # Register parameters 88 | if lr is None or isinstance(lr, float): 89 | lr_dict = {} 90 | else: 91 | lr_dict, lr = lr, None 92 | 93 | self.register("log_dt", log_dt, lr_dict.get("dt", lr)) 94 | self.register("B", _c2r(B), lr_dict.get("B", lr)) 95 | self.register("inv_A_real", self._A_init(A.real), lr_dict.get("A", lr)) 96 | self.register("A_imag", A.imag, lr_dict.get("A", lr)) 97 | 98 | def _A_init(self, A_real): 99 | A_real = torch.clamp(A_real, max=-1e-4) 100 | if self.real_type == "none": 101 | return -A_real 102 | elif self.real_type == "exp": 103 | return torch.log(-A_real) # Some of the HiPPO methods have real part 0 104 | elif self.real_type == "relu": 105 | return -A_real 106 | elif self.real_type == "sigmoid": 107 | return torch.logit(-A_real) 108 | elif self.real_type == "softplus": 109 | return torch.log(torch.exp(-A_real) - 1) 110 | else: 111 | raise NotImplementedError 112 | 113 | def _A(self): 114 | # Get the internal A (diagonal) parameter 115 | if self.real_type == "none": 116 | A_real = -self.inv_A_real 117 | elif self.real_type == "exp": 118 | A_real = -torch.exp(self.inv_A_real) 119 | elif self.real_type == "relu": 120 | # JAX version seems to NaN if you alloA 0's, although this code Aas fine Aithout it 121 | A_real = -F.relu(self.inv_A_real) - 1e-4 122 | elif self.real_type == "sigmoid": 123 | A_real = -F.sigmoid(self.inv_A_real) 124 | elif self.real_type == "softplus": 125 | A_real = -F.softplus(self.inv_A_real) 126 | else: 127 | raise NotImplementedError 128 | A = A_real + 1j * self.A_imag 129 | return A 130 | 131 | def forward(self, L, state=None, rate=1.0, u=None): 132 | """ 133 | state: (B, H, N) initial state 134 | rate: sampling rate factor 135 | L: target length 136 | returns: 137 | (C, H, L) convolution kernel (generally C=1) 138 | (B, H, L) output from initial state 139 | """ 140 | 141 | dt = torch.exp(self.log_dt) * rate # (H) 142 | C = _r2c(self.C) # (C H N) 143 | A = self._A() # (H N) 144 | 145 | B = _r2c(self.B) 146 | B = repeat(B, "t n -> 1 (v t) n", v=self.repeat) 147 | 148 | # Force A to be real valued, so the whole kernel can be interpreted as a "multi-head EMA" 149 | if self.force_real: 150 | A = A.real + 0j 151 | 152 | if self.bandlimit is not None: 153 | freqs = dt[:, None] / rate * A.imag.abs() / (2 * math.pi) # (H, N) 154 | mask = torch.where(freqs < self.bandlimit * 0.5, 1, 0) 155 | C = C * mask 156 | 157 | # Incorporate dt into A 158 | A = repeat(A, "t n -> (v t) n", v=self.repeat) 159 | dtA = A * dt.unsqueeze(-1) # (H N) 160 | 161 | # Augment B with state 162 | if state is not None: 163 | s = state / dt.unsqueeze(-1) 164 | if self.disc == "bilinear": 165 | s = s * (1.0 + dtA / 2) 166 | elif self.disc == "zoh": 167 | s = s * dtA * dtA.exp() / (dtA.exp() - 1.0) 168 | B = torch.cat([s, B], dim=-3) # (1+B H N) 169 | 170 | C = (B[:, None, :, :] * C).view(-1, self.H, self.N) 171 | if self.disc == "zoh": 172 | # Power up 173 | C = C * (torch.exp(dtA) - 1.0) / A 174 | # TODO (TD): make it work for C.shape[0] > 1 175 | if log_vandermonde_fast is not None and C.shape[0] == 1: 176 | K = log_vandermonde_fast(C.squeeze(0), dtA, L).unsqueeze(0) # (H L) 177 | else: 178 | K = log_vandermonde(C, dtA, L) # (H L) 179 | elif self.disc == "bilinear": 180 | C = C * (1.0 - dtA / 2).reciprocal() * dt.unsqueeze(-1) # or * dtA / A 181 | dA = (1.0 + dtA / 2) / (1.0 - dtA / 2) 182 | if log_vandermonde_fast is not None: 183 | dA_log = repeat(dA.log(), "h d -> (c h) d", c=C.shape[0]) 184 | K = rearrange( 185 | log_vandermonde_fast(rearrange(C, "c h d -> (c h) d"), dA_log, L), 186 | "(c h) d -> c h d", 187 | c=C.shape[0], 188 | ) 189 | else: 190 | K = log_vandermonde(C, dA.log(), L) 191 | elif self.disc == "dss": 192 | # Implementation from DSS meant for case when real eigenvalues can be positive 193 | P = dtA.unsqueeze(-1) * torch.arange(L, device=C.device) # [H N L] 194 | A_gt_0 = A.real > 0 # [N] 195 | if A_gt_0.any(): 196 | with torch.no_grad(): 197 | P_max = dtA * (A_gt_0 * (L - 1)) # [H N] 198 | P = P - P_max.unsqueeze(-1) # [H N L] 199 | S = P.exp() # [H N L] 200 | 201 | dtA_neg = dtA * (1 - 2 * A_gt_0) # [H N] 202 | num = dtA_neg.exp() - 1 # [H N] 203 | den = (dtA_neg * L).exp() - 1 # [H N] 204 | 205 | # Inline reciprocal function for DSS logic 206 | x = den * A 207 | x_conj = _resolve_conj(x) 208 | r = x_conj / (x * x_conj + 1e-7) 209 | 210 | C = C * num * r # [C H N] 211 | K = contract("chn,hnl->chl", C, S).float() 212 | else: 213 | assert False, f"{self.disc} not supported" 214 | 215 | K = K.view(-1, self.channels, self.H, L) # (1+B C H L) 216 | if state is not None: 217 | K_state = K[:-1, :, :, :] # (B C H L) 218 | else: 219 | K_state = None 220 | K = K[-1, :, :, :] # (C H L) 221 | return K, K_state 222 | 223 | def _setup_step(self): 224 | # These methods are organized like this to be compatible with the NPLR kernel interface 225 | dt = torch.exp(self.log_dt) # (H) 226 | B = _r2c(self.B) # (H N) 227 | C = _r2c(self.C) # (C H N) 228 | self.dC = C 229 | A = self._A() # (H N) 230 | 231 | A = repeat(A, "t n -> (v t) n", v=self.repeat) 232 | B = repeat(B, "t n -> (v t) n", v=self.repeat) 233 | 234 | # Incorporate dt into A 235 | dtA = A * dt.unsqueeze(-1) # (H N) 236 | if self.disc == "zoh": 237 | self.dA = torch.exp(dtA) # (H N) 238 | self.dB = B * (torch.exp(dtA) - 1.0) / A # (C H N) 239 | elif self.disc == "bilinear": 240 | self.dA = (1.0 + dtA / 2) / (1.0 - dtA / 2) 241 | self.dB = ( 242 | B * (1.0 - dtA / 2).reciprocal() * dt.unsqueeze(-1) 243 | ) # or * dtA / A 244 | 245 | def default_state(self, *batch_shape): 246 | C = _r2c(self.C) 247 | state = torch.zeros( 248 | *batch_shape, self.H, self.N, dtype=C.dtype, device=C.device 249 | ) 250 | return state 251 | 252 | def step(self, u, state): 253 | next_state = contract("h n, b h n -> b h n", self.dA, state) + contract( 254 | "h n, b h -> b h n", self.dB, u 255 | ) 256 | y = contract("c h n, b h n -> b c h", self.dC, next_state) 257 | return 2 * y.real, next_state 258 | 259 | def forward_state(self, u, state): 260 | self._setup_step() 261 | AL = self.dA ** u.size(-1) 262 | u = u.flip(-1).to(self.dA).contiguous() # (B H L) 263 | v = log_vandermonde_transpose(u, self.dB, self.dA.log(), u.size(-1)) 264 | next_state = AL * state + v 265 | return next_state 266 | 267 | 268 | class EMAKernel(OptimModule): 269 | """Translation of Mega's MultiHeadEMA. 270 | This is a minimal implementation of the convolution kernel part of the module. 271 | This module, together with the main S4 block in src.models.sequence.ss.s4 272 | (which is really just a fft-conv wrapper around any convolution kernel, 273 | such as this one), should be exactly equivalent to using the original Mega 274 | EMA module in src.models.sequence.ss.ema. 275 | Two additional flags have been provided to resolve discrepencies in parameter 276 | count between S4(D) and EMA 277 | - `dt_tie` makes the shape of the step size \Delta (H, 1) instead of (H, N) 278 | - `efficient_bidirectional` ties the A/B/dt parameters for the conv kernels 279 | in both forwards and backwards directions. This should have exactly the same 280 | speed, slightly more parameter efficiency, and unchanged performance. 281 | """ 282 | 283 | def __init__( 284 | self, 285 | H, 286 | N=2, 287 | channels=1, 288 | l_max=None, 289 | dt_tie=False, 290 | efficient_bidirectional=False, 291 | ): 292 | super().__init__() 293 | 294 | self.H = H 295 | self.N = N 296 | self.channels = channels 297 | self.l_max = l_max 298 | self.scale = math.sqrt(1.0 / self.N) 299 | 300 | # Exactly match the parameter count of S4(D) when bididirectional is on 301 | self.efficient_bidirectional = efficient_bidirectional 302 | if self.efficient_bidirectional: 303 | H_C = H * channels 304 | else: 305 | H *= channels 306 | H_C = H 307 | 308 | self.delta = nn.Parameter(torch.Tensor(H, 1 if dt_tie else N, 1)) 309 | self.alpha = nn.Parameter(torch.Tensor(H, N, 1)) 310 | self.beta = nn.Parameter(torch.Tensor(H, N, 1)) 311 | self.gamma = nn.Parameter(torch.Tensor(H_C, N)) 312 | # self.omega = nn.Parameter(torch.Tensor(H)) # D skip connection handled by outside class 313 | 314 | self.reset_parameters() 315 | 316 | def reset_parameters(self): 317 | with torch.no_grad(): 318 | nn.init.normal_(self.delta, mean=0.0, std=0.2) 319 | nn.init.normal_(self.alpha, mean=0.0, std=0.2) 320 | # Mega comment: beta [1, -1, 1, -1, ...] seems more stable. 321 | val = torch.ones(self.N, 1) 322 | if self.N > 1: 323 | idx = torch.tensor(list(range(1, self.N, 2))) 324 | val.index_fill_(0, idx, -1.0) 325 | self.beta.normal_(mean=0.0, std=0.02).add_(val) 326 | nn.init.normal_(self.gamma, mean=0.0, std=1.0) 327 | # nn.init.normal_(self.omega, mean=0.0, std=1.0) 328 | 329 | def coeffs(self): # Same as discretize 330 | p = torch.sigmoid(self.delta) # (H N 1) 331 | alpha = torch.sigmoid(self.alpha) 332 | q = 1.0 - p * alpha 333 | return p, q 334 | 335 | def forward(self, L=None, state=None, rate=1.0): 336 | L = L if self.l_max is None else min(self.l_max, L) 337 | p, q = self.coeffs() # (H N 1) 338 | vander = torch.arange(L).to(p).view(1, 1, L) * torch.log(q) # (H N L) 339 | kernel = (p * self.beta) * torch.exp(vander) 340 | if self.efficient_bidirectional: 341 | C = rearrange(self.gamma * self.scale, "(c h) n -> c h n", c=self.channels) 342 | kernel = torch.einsum("dnl,cdn->cdl", kernel, C) 343 | # kernel = rearrange(kernel, 'c d l -> (c d) l') 344 | else: 345 | kernel = torch.einsum("dnl,dn->dl", kernel, self.gamma * self.scale) 346 | kernel = rearrange(kernel, "(c h) l -> c h l", c=self.channels) 347 | 348 | kernel = kernel[..., :L] 349 | # kernel = rearrange(kernel, '(c h) l -> c h l', c=self.channels) 350 | return kernel, None # k_state 351 | -------------------------------------------------------------------------------- /simplex-diffusion-main/sdlm/data/data_utils.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | from datasets import load_dataset,load_from_disk 3 | import logging 4 | import torch 5 | from datasets import DatasetDict, IterableDataset 6 | 7 | SMALL_GLUE_DATA = ["cola", "wnli", "rte", "mrpc", "stsb"] 8 | LARGE_GLUE_DATA = ["qnli", "qqp", "sst2"] 9 | 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def load_data(data_args, model_args): 15 | if data_args.dataset_name is not None: 16 | try: 17 | raw_datasets = load_dataset( 18 | data_args.dataset_name, 19 | data_args.dataset_config_name, 20 | cache_dir=model_args.cache_dir, 21 | use_auth_token=True if model_args.use_auth_token else None, 22 | streaming=data_args.streaming 23 | ) 24 | except: 25 | raw_datasets = load_from_disk(data_args.dataset_name) 26 | else: 27 | data_files = {} 28 | if data_args.train_file is not None: 29 | data_files["train"] = data_args.train_file 30 | if data_args.validation_file is not None: 31 | data_files["validation"] = data_args.validation_file 32 | extension = data_args.train_file.split(".")[-1] 33 | if extension == "txt": 34 | extension = "text" 35 | raw_datasets = load_dataset( 36 | extension, 37 | data_files=data_files, 38 | cache_dir=model_args.cache_dir, 39 | use_auth_token=True if model_args.use_auth_token else None, 40 | streaming=data_args.streaming 41 | ) 42 | return raw_datasets 43 | 44 | 45 | def tokenize_data_new(data_args, tokenizer, raw_datasets, training_args): 46 | # Preprocessing the datasets. 47 | # First we tokenize all the texts. 48 | if training_args.do_train: 49 | column_names = raw_datasets["train"].column_names 50 | else: 51 | column_names = raw_datasets["validation"].column_names 52 | if column_names is None: 53 | text_column_name = "text" 54 | else: 55 | text_column_name = "text" if "text" in column_names else column_names[0] 56 | 57 | # just want the text! 58 | raw_datasets = raw_datasets.select_columns([text_column_name]) 59 | 60 | if data_args.max_seq_length is None: 61 | max_seq_length = tokenizer.model_max_length 62 | if max_seq_length > 1024: 63 | logger.warning( 64 | f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " 65 | "Picking 1024 instead. You can change that default value by passing --max_seq_length xxx." 66 | ) 67 | max_seq_length = 1024 68 | else: 69 | if data_args.max_seq_length > tokenizer.model_max_length: 70 | logger.warning( 71 | f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" 72 | f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." 73 | ) 74 | max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) 75 | 76 | if data_args.line_by_line: 77 | # When using line_by_line, we just tokenize each nonempty line. 78 | padding = "max_length" if data_args.pad_to_max_length else False 79 | 80 | def tokenize_function(examples): 81 | # Remove empty lines 82 | examples[text_column_name] = [ 83 | line for line in examples[text_column_name] if len(line) > 0 and not line.isspace() 84 | ] 85 | return tokenizer( 86 | examples[text_column_name], 87 | padding=padding, 88 | truncation=True, 89 | max_length=max_seq_length, 90 | # We use this option because DataCollatorForLanguageModeling (see below) is more efficient when it 91 | # receives the `special_tokens_mask`. 92 | return_special_tokens_mask=True, 93 | ) 94 | 95 | with training_args.main_process_first(desc="dataset map tokenization"): 96 | tokenized_datasets = raw_datasets.map( 97 | tokenize_function, 98 | batched=True, 99 | num_proc=data_args.preprocessing_num_workers, 100 | remove_columns=[text_column_name], 101 | load_from_cache_file=not data_args.overwrite_cache, 102 | desc="Running tokenizer on dataset line_by_line", 103 | ) 104 | else: 105 | # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts. 106 | # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more 107 | # efficient when it receives the `special_tokens_mask`. 108 | def tokenize_function(examples): 109 | return tokenizer(examples[text_column_name], return_special_tokens_mask=True) 110 | 111 | with training_args.main_process_first(desc="dataset map tokenization"): 112 | if not data_args.streaming: 113 | tokenized_datasets = raw_datasets.map( 114 | tokenize_function, 115 | batched=True, 116 | num_proc=data_args.preprocessing_num_workers, 117 | remove_columns=column_names, 118 | load_from_cache_file=not data_args.overwrite_cache, 119 | desc="Running tokenizer on every text in dataset", 120 | ) 121 | else: 122 | tokenized_datasets = raw_datasets.map( 123 | tokenize_function, 124 | batched=True, 125 | remove_columns=[text_column_name], 126 | ) 127 | 128 | # Main data processing function that will concatenate all texts from our dataset and generate chunks of 129 | # max_seq_length. 130 | def group_texts(examples): 131 | # Concatenate all texts. 132 | concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} 133 | total_length = len(concatenated_examples[list(examples.keys())[0]]) 134 | # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can 135 | # customize this part to your needs. 136 | if total_length >= max_seq_length: 137 | total_length = (total_length // max_seq_length) * max_seq_length 138 | # Split by chunks of max_len. 139 | result = { 140 | k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)] 141 | for k, t in concatenated_examples.items() 142 | } 143 | return result 144 | 145 | # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a 146 | # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value 147 | # might be slower to preprocess. 148 | # 149 | # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: 150 | # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map 151 | 152 | with training_args.main_process_first(desc="grouping texts together"): 153 | if not data_args.streaming: 154 | tokenized_datasets = tokenized_datasets.map( 155 | group_texts, 156 | batched=True, 157 | num_proc=data_args.preprocessing_num_workers, 158 | load_from_cache_file=not data_args.overwrite_cache, 159 | desc=f"Grouping texts in chunks of {max_seq_length}", 160 | ) 161 | else: 162 | tokenized_datasets = tokenized_datasets.map( 163 | group_texts, 164 | batched=True, 165 | ) 166 | return tokenized_datasets 167 | 168 | 169 | # TODO: we need to remove this one and update process_data.py. 170 | def tokenize_data(data_args, tokenizer, raw_datasets, accelerator): 171 | if data_args.max_seq_length is None: 172 | max_seq_length = tokenizer.model_max_length 173 | if max_seq_length > 1024: 174 | logger.warning( 175 | f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " 176 | "Picking 1024 instead. You can change that default value by passing --max_seq_length xxx." 177 | ) 178 | max_seq_length = 1024 179 | else: 180 | if data_args.max_seq_length > tokenizer.model_max_length: 181 | logger.warning( 182 | f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" 183 | f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." 184 | ) 185 | max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) 186 | 187 | # Preprocessing the datasets. 188 | # First we tokenize all the texts. 189 | column_names = raw_datasets["train"].column_names 190 | text_column_name = "text" if "text" in column_names else column_names[0] 191 | 192 | if data_args.line_by_line: 193 | # When using line_by_line, we just tokenize each nonempty line. 194 | padding = "max_length" if data_args.pad_to_max_length else False 195 | 196 | def tokenize_function(examples): 197 | # Remove empty lines 198 | examples[text_column_name] = [ 199 | line for line in examples[text_column_name] if len(line) > 0 and not line.isspace() 200 | ] 201 | return tokenizer(examples[text_column_name], padding=padding, truncation=True, max_length=max_seq_length) 202 | 203 | with accelerator.main_process_first(): 204 | tokenized_datasets = raw_datasets.map( 205 | tokenize_function, 206 | batched=True, 207 | num_proc=data_args.preprocessing_num_workers, 208 | remove_columns=[text_column_name], 209 | load_from_cache_file=not data_args.overwrite_cache, 210 | desc="Running tokenizer on dataset line_by_line", 211 | ) 212 | else: 213 | # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts. 214 | def tokenize_function(examples): 215 | return tokenizer(examples[text_column_name]) 216 | 217 | with accelerator.main_process_first(): 218 | tokenized_datasets = raw_datasets.map( 219 | tokenize_function, 220 | batched=True, 221 | num_proc=data_args.preprocessing_num_workers, 222 | remove_columns=column_names, 223 | load_from_cache_file=not data_args.overwrite_cache, 224 | desc="Running tokenizer on every text in dataset", 225 | ) 226 | 227 | # Main data processing function that will concatenate all texts from our dataset and generate chunks of 228 | # max_seq_length. 229 | def group_texts(examples): 230 | # Concatenate all texts. 231 | concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} 232 | total_length = len(concatenated_examples[list(examples.keys())[0]]) 233 | # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can 234 | # customize this part to your needs. 235 | if total_length >= max_seq_length: 236 | total_length = (total_length // max_seq_length) * max_seq_length 237 | # Split by chunks of max_len. 238 | result = { 239 | k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)] 240 | for k, t in concatenated_examples.items() 241 | } 242 | return result 243 | 244 | # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a 245 | # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value 246 | # might be slower to preprocess. 247 | # 248 | # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: 249 | # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map 250 | 251 | with accelerator.main_process_first(): 252 | tokenized_datasets = tokenized_datasets.map( 253 | group_texts, 254 | batched=True, 255 | num_proc=data_args.preprocessing_num_workers, 256 | load_from_cache_file=not data_args.overwrite_cache, 257 | desc=f"Grouping texts in chunks of {max_seq_length}", 258 | ) 259 | return tokenized_datasets 260 | 261 | 262 | def split_data_to_train_validation(data_args, data, seed): 263 | total_size = len(data["train"]) 264 | validation_size = int(total_size * data_args.validation_split_ratio) 265 | train_size = total_size - validation_size 266 | 267 | # TODO(rabeeh): we need to do this for the other ones as well and think how to do it cleanly. 268 | if data_args.max_train_samples is not None: 269 | train_size = min(train_size, data_args.max_train_samples) 270 | if data_args.max_eval_samples is not None: 271 | validation_size = min(validation_size, data_args.max_eval_samples) 272 | 273 | remaining_size = total_size - train_size - validation_size 274 | train, validation, _ = torch.utils.data.random_split( 275 | data["train"], [train_size, validation_size, remaining_size], generator=torch.Generator().manual_seed(seed) 276 | ) 277 | data["train"], data["validation"] = train, validation 278 | assert len(data["train"]) == train_size and len(data["validation"]) == validation_size 279 | return data 280 | 281 | 282 | def split_glue(raw_datasets, dataset_name, seed): 283 | """Since glue test sets are not public, splits the data splits to form test sets. 284 | 285 | For large datasets (#samples > 10K), divides training set into 1K as validation and 286 | rest as train, using original validation as test. Otherwise, divides validation set 287 | to half (half for validation and half for test).""" 288 | if dataset_name == "mnli": 289 | raw_datasets = DatasetDict( 290 | { 291 | "test": raw_datasets["validation_matched"], 292 | "validation": raw_datasets["validation_mismatched"], 293 | "train": raw_datasets["train"], 294 | } 295 | ) 296 | elif dataset_name in SMALL_GLUE_DATA: 297 | # Splits the validation set into half for validation and half for test. 298 | splits = raw_datasets["validation"].train_test_split(test_size=0.5, shuffle=True, seed=seed) 299 | raw_datasets = DatasetDict({"validation": splits["train"], "test": splits["test"], "train": raw_datasets["train"]}) 300 | elif dataset_name in LARGE_GLUE_DATA: 301 | # Splits the training set into 1K as validation, rest as train. 302 | test_size = 1000 / len(raw_datasets["train"]) 303 | splits = raw_datasets["train"].train_test_split(test_size=test_size, shuffle=True, seed=seed) 304 | raw_datasets = DatasetDict( 305 | {"train": splits["train"], "validation": splits["test"], "test": raw_datasets["validation"]} 306 | ) 307 | else: 308 | raise NotImplementedError 309 | return raw_datasets 310 | --------------------------------------------------------------------------------