├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── RL4LMs_logo.png ├── requirements.txt ├── rl4lms ├── __init__.py ├── algorithms │ ├── a2c │ │ ├── __init__.py │ │ └── a2c.py │ ├── common │ │ ├── algo_utils.py │ │ └── maskable │ │ │ ├── __init__.py │ │ │ ├── buffers.py │ │ │ ├── callbacks.py │ │ │ ├── distributions.py │ │ │ ├── evaluation.py │ │ │ ├── logits_processor.py │ │ │ ├── policies.py │ │ │ └── utils.py │ ├── nlpo │ │ ├── __init__.py │ │ ├── nlpo.py │ │ └── policies.py │ ├── ppo │ │ ├── __init__.py │ │ └── ppo.py │ └── trpo │ │ ├── __init__.py │ │ ├── policies.py │ │ └── trpo.py ├── core_components │ ├── __init__.py │ ├── sampler.py │ └── sweep.py ├── data_pools │ ├── __init__.py │ ├── custom_text_generation_pools.py │ ├── task_utils │ │ └── totto │ │ │ ├── eval_utils │ │ │ ├── __init__.py │ │ │ ├── mosesdecoder │ │ │ │ └── scripts │ │ │ │ │ └── tokenizer │ │ │ │ │ └── detokenizer.perl │ │ │ ├── prepare_predictions_for_eval.py │ │ │ ├── prepare_references_for_eval.py │ │ │ ├── table_to_text_utils.py │ │ │ ├── totto_bleu_eval.sh │ │ │ ├── totto_parent_eval.py │ │ │ └── totto_parent_eval.sh │ │ │ └── preprocess_utils.py │ └── text_generation_pool.py └── envs │ ├── __init__.py │ ├── common │ ├── __init__.py │ ├── action_space.py │ ├── base_env.py │ ├── observation.py │ └── reward.py │ └── text_generation │ ├── __init__.py │ ├── alg_wrappers.py │ ├── caption_metrics │ ├── __init__.py │ ├── cider.py │ └── spice │ │ ├── Readme.txt │ │ ├── __init__.py │ │ ├── get_stanford_models.sh │ │ ├── lib │ │ ├── Meteor-1.5.jar │ │ ├── SceneGraphParser-1.0.jar │ │ ├── ejml-0.23.jar │ │ ├── fst-2.47.jar │ │ ├── guava-19.0.jar │ │ ├── hamcrest-core-1.3.jar │ │ ├── jackson-core-2.5.3.jar │ │ ├── javassist-3.19.0-GA.jar │ │ ├── json-simple-1.1.1.jar │ │ ├── junit-4.12.jar │ │ ├── lmdbjni-0.4.6.jar │ │ ├── lmdbjni-linux64-0.4.6.jar │ │ ├── lmdbjni-osx64-0.4.6.jar │ │ ├── lmdbjni-win64-0.4.6.jar │ │ ├── objenesis-2.4.jar │ │ ├── slf4j-api-1.7.12.jar │ │ └── slf4j-simple-1.7.21.jar │ │ ├── spice-1.0.jar │ │ └── spice.py │ ├── env.py │ ├── evaluation_utils.py │ ├── hf_generation_utils.py │ ├── kl_controllers.py │ ├── logging_utils.py │ ├── metric.py │ ├── observation.py │ ├── policy.py │ ├── policy │ ├── __init__.py │ ├── base_policy.py │ ├── causal_policy.py │ └── seq2seq_policy.py │ ├── post_processors.py │ ├── preference_reward.py │ ├── registry.py │ ├── reward.py │ ├── summ_metrics │ ├── __init__.py │ └── summa_c.py │ ├── test_datapool.py │ ├── test_metric.py │ ├── test_reward.py │ ├── training_utils.py │ ├── utils_supervised.py │ └── warm_start.py ├── scripts ├── crowdworking_templates │ ├── likert │ │ ├── IMDB_sentiment_completion.html │ │ ├── IMDB_sentiment_completion_example_input.csv │ │ ├── commongen.html │ │ ├── commongen_example_input.csv │ │ ├── daily_dialogue.html │ │ ├── daily_dialogue_example_input.csv │ │ ├── summarization.html │ │ ├── summarization_example_input.csv │ │ ├── totto.html │ │ └── totto_example_input.csv │ └── pairwise │ │ ├── commongen_pairwise.html │ │ └── commongen_pairwise_example_input.csv ├── reward-modeling │ ├── evaluate_intent_classifier.py │ └── train_intent_classifier.py └── training │ ├── task_configs │ ├── common_gen │ │ ├── t5_nlpo.yml │ │ ├── t5_nlpo_on_supervised.yml │ │ ├── t5_ppo.yml │ │ ├── t5_ppo_on_supervised.yml │ │ └── t5_supervised.yml │ ├── dialog │ │ ├── gpt2_nlpo.yml │ │ ├── gpt2_nlpo_on_supervised.yml │ │ ├── gpt2_ppo.yml │ │ ├── gpt2_ppo_on_supervised.yml │ │ └── gpt2_supervised.yml │ ├── imdb_text_continuation │ │ ├── gpt2_a2c.yml │ │ ├── gpt2_nlpo.yml │ │ ├── gpt2_nlpo_on_supervised.yml │ │ ├── gpt2_ppo.yml │ │ ├── gpt2_ppo_on_supervised.yml │ │ └── gpt2_supervised.yml │ ├── iwslt2017 │ │ ├── t5_nlpo.yml │ │ ├── t5_nlpo_on_supervised.yml │ │ ├── t5_ppo.yml │ │ ├── t5_ppo_on_supervised.yml │ │ └── t5_supervised.yml │ ├── narrative_qa │ │ ├── t5_nlpo.yml │ │ ├── t5_nlpo_on_supervised.yml │ │ ├── t5_ppo.yml │ │ └── t5_ppo_on_supervised.yml │ ├── summarization │ │ ├── t5_nlpo.yml │ │ ├── t5_nlpo_on_supervised.yml │ │ ├── t5_ppo.yml │ │ ├── t5_ppo_on_supervised.yml │ │ └── t5_supervised.yml │ ├── synthetic_generate_dates │ │ └── gpt2_ppo.yml │ ├── synthetic_generate_increasing_numbers │ │ ├── bart_ppo.yml │ │ ├── blendorbot_ppo.yml │ │ ├── gpt2_a2c.yml │ │ ├── gpt2_nlpo.yml │ │ ├── gpt2_ppo.yml │ │ ├── gpt2_trpo.yml │ │ ├── t5_nlpo.yml │ │ └── t5_ppo.yml │ └── totto │ │ ├── t5_nlpo.yml │ │ ├── t5_nlpo_on_supervised.yml │ │ ├── t5_ppo.yml │ │ ├── t5_ppo_on_supervised.yml │ │ └── t5_supervised.yml │ └── train_text_generation.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # vscode 132 | .vscode/ 133 | .history/ 134 | .idea/ 135 | cluster_utils/ 136 | **/wandb/ 137 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:1.11.0-cuda11.3-cudnn8-devel 2 | 3 | ENV LC_ALL=C.UTF-8 4 | ENV LANG=C.UTF-8 5 | 6 | RUN rm /etc/apt/sources.list.d/cuda.list 7 | RUN rm /etc/apt/sources.list.d/nvidia-ml.list 8 | 9 | # install git 10 | RUN apt-get update 11 | RUN apt-get install -y git 12 | RUN apt-get install -y wget 13 | RUN apt-get install unzip 14 | 15 | # install java 16 | RUN apt-get install -y openjdk-8-jdk 17 | RUN apt-get install -y openjdk-8-jre 18 | RUN update-alternatives --config java 19 | RUN update-alternatives --config javac 20 | 21 | WORKDIR /stage/ 22 | 23 | # Copy the files to /stage 24 | COPY setup.py ./ 25 | COPY requirements.txt ./ 26 | COPY rl4lms/ ./rl4lms 27 | COPY scripts/ ./scripts 28 | 29 | # other model downloads 30 | WORKDIR /stage/rl4lms/envs/text_generation/caption_metrics/spice 31 | RUN ./get_stanford_models.sh 32 | WORKDIR /stage/ 33 | 34 | # finally install the package (with dependencies) 35 | RUN pip install -e . 36 | 37 | # download external models (since it requires dependencies) 38 | RUN pip install markupsafe==2.0.1 39 | RUN python -c "import nltk; nltk.download('punkt')" 40 | RUN python -m spacy download en_core_web_sm 41 | 42 | 43 | -------------------------------------------------------------------------------- /RL4LMs_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/RL4LMs/97df0bd2f7406a906206c9610aea795fbf52884c/RL4LMs_logo.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | importlib-metadata<5.0 2 | spacy==3.0.6 3 | bert-score==0.3.11 4 | BLEURT @ git+https://github.com/google-research/bleurt.git@c6f2375c7c178e1480840cf27cb9e2af851394f9 5 | datasets==2.5.1 6 | gem-metrics @ git+https://github.com/GEM-benchmark/GEM-metrics.git@431a8174bd6b3637e8d6118bfad2983e39e99733 7 | gym==0.21.0 8 | jsonlines==3.0.0 9 | nltk==3.7 10 | pandas==1.3.5 11 | rich==12.0.0 12 | stable-baselines3==1.5.1a5 13 | torch==1.11.0 14 | torchvision==0.12.0 15 | tqdm==4.64.0 16 | transformers==4.18.0 17 | wandb==0.12.15 18 | jsonlines==3.0.0 19 | rouge_score==0.0.4 20 | sacrebleu==2.2.0 21 | py-rouge==1.1 22 | absl-py 23 | six -------------------------------------------------------------------------------- /rl4lms/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | aapd_data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data/AAPD" ) -------------------------------------------------------------------------------- /rl4lms/algorithms/a2c/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/RL4LMs/97df0bd2f7406a906206c9610aea795fbf52884c/rl4lms/algorithms/a2c/__init__.py -------------------------------------------------------------------------------- /rl4lms/algorithms/common/maskable/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/RL4LMs/97df0bd2f7406a906206c9610aea795fbf52884c/rl4lms/algorithms/common/maskable/__init__.py -------------------------------------------------------------------------------- /rl4lms/algorithms/common/maskable/callbacks.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | from stable_baselines3.common.callbacks import EvalCallback 5 | from stable_baselines3.common.vec_env import sync_envs_normalization 6 | 7 | from rl4lms.algorithms.common.maskable.evaluation import evaluate_policy 8 | 9 | 10 | class MaskableEvalCallback(EvalCallback): 11 | """ 12 | Callback for evaluating an agent. Supports invalid action masking. 13 | 14 | :param eval_env: The environment used for initialization 15 | :param callback_on_new_best: Callback to trigger 16 | when there is a new best model according to the ``mean_reward`` 17 | :param n_eval_episodes: The number of episodes to test the agent 18 | :param eval_freq: Evaluate the agent every eval_freq call of the callback. 19 | :param log_path: Path to a folder where the evaluations (``evaluations.npz``) 20 | will be saved. It will be updated at each evaluation. 21 | :param best_model_save_path: Path to a folder where the best model 22 | according to performance on the eval env will be saved. 23 | :param deterministic: Whether the evaluation should 24 | use a stochastic or deterministic actions. 25 | :param render: Whether to render or not the environment during evaluation 26 | :param verbose: 27 | :param warn: Passed to ``evaluate_policy`` (warns if ``eval_env`` has not been 28 | wrapped with a Monitor wrapper) 29 | :param use_masking: Whether or not to use invalid action masks during evaluation 30 | """ 31 | 32 | def __init__(self, *args, use_masking: bool = True, **kwargs): 33 | super().__init__(*args, **kwargs) 34 | self.use_masking = use_masking 35 | 36 | def _on_step(self) -> bool: 37 | if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0: 38 | # Sync training and eval env if there is VecNormalize 39 | sync_envs_normalization(self.training_env, self.eval_env) 40 | 41 | # Reset success rate buffer 42 | self._is_success_buffer = [] 43 | 44 | # Note that evaluate_policy() has been patched to support masking 45 | episode_rewards, episode_lengths = evaluate_policy( 46 | self.model, 47 | self.eval_env, 48 | n_eval_episodes=self.n_eval_episodes, 49 | render=self.render, 50 | deterministic=self.deterministic, 51 | return_episode_rewards=True, 52 | warn=self.warn, 53 | callback=self._log_success_callback, 54 | use_masking=self.use_masking, 55 | ) 56 | 57 | if self.log_path is not None: 58 | self.evaluations_timesteps.append(self.num_timesteps) 59 | self.evaluations_results.append(episode_rewards) 60 | self.evaluations_length.append(episode_lengths) 61 | 62 | kwargs = {} 63 | # Save success log if present 64 | if len(self._is_success_buffer) > 0: 65 | self.evaluations_successes.append(self._is_success_buffer) 66 | kwargs = dict(successes=self.evaluations_successes) 67 | 68 | np.savez( 69 | self.log_path, 70 | timesteps=self.evaluations_timesteps, 71 | results=self.evaluations_results, 72 | ep_lengths=self.evaluations_length, 73 | **kwargs, 74 | ) 75 | 76 | mean_reward, std_reward = np.mean( 77 | episode_rewards), np.std(episode_rewards) 78 | mean_ep_length, std_ep_length = np.mean( 79 | episode_lengths), np.std(episode_lengths) 80 | self.last_mean_reward = mean_reward 81 | 82 | if self.verbose > 0: 83 | print( 84 | f"Eval num_timesteps={self.num_timesteps}, " f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}") 85 | print( 86 | f"Episode length: {mean_ep_length:.2f} +/- {std_ep_length:.2f}") 87 | # Add to current Logger 88 | self.logger.record("eval/mean_reward", float(mean_reward)) 89 | self.logger.record("eval/mean_ep_length", mean_ep_length) 90 | 91 | if len(self._is_success_buffer) > 0: 92 | success_rate = np.mean(self._is_success_buffer) 93 | if self.verbose > 0: 94 | print(f"Success rate: {100 * success_rate:.2f}%") 95 | self.logger.record("eval/success_rate", success_rate) 96 | 97 | # Dump log so the evaluation results are printed with the correct timestep 98 | self.logger.record("time/total timesteps", 99 | self.num_timesteps, exclude="tensorboard") 100 | self.logger.dump(self.num_timesteps) 101 | 102 | if mean_reward > self.best_mean_reward: 103 | if self.verbose > 0: 104 | print("New best mean reward!") 105 | if self.best_model_save_path is not None: 106 | self.model.save(os.path.join( 107 | self.best_model_save_path, "best_model")) 108 | self.best_mean_reward = mean_reward 109 | # Trigger callback if needed 110 | if self.callback is not None: 111 | return self._on_event() 112 | 113 | return True 114 | -------------------------------------------------------------------------------- /rl4lms/algorithms/common/maskable/evaluation.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 3 | 4 | import gym 5 | import numpy as np 6 | from stable_baselines3.common.monitor import Monitor 7 | from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, VecMonitor, is_vecenv_wrapped 8 | 9 | from rl4lms.algorithms.common.maskable.utils import get_action_masks, is_masking_supported 10 | from rl4lms.algorithms.nlpo import NLPO 11 | 12 | 13 | def evaluate_policy( # noqa: C901 14 | model: NLPO, 15 | env: Union[gym.Env, VecEnv], 16 | n_eval_episodes: int = 10, 17 | deterministic: bool = True, 18 | render: bool = False, 19 | callback: Optional[Callable[[Dict[str, Any], Dict[str, Any]], None]] = None, 20 | reward_threshold: Optional[float] = None, 21 | return_episode_rewards: bool = False, 22 | warn: bool = True, 23 | use_masking: bool = True, 24 | ) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]: 25 | """ 26 | Runs policy for ``n_eval_episodes`` episodes and returns average reward. 27 | If a vector env is passed in, this divides the episodes to evaluate onto the 28 | different elements of the vector env. This static division of work is done to 29 | remove bias. See https://github.com/DLR-RM/stable-baselines3/issues/402 for more 30 | details and discussion. 31 | 32 | .. note:: 33 | If environment has not been wrapped with ``Monitor`` wrapper, reward and 34 | episode lengths are counted as it appears with ``env.step`` calls. If 35 | the environment contains wrappers that modify rewards or episode lengths 36 | (e.g. reward scaling, early episode reset), these will affect the evaluation 37 | results as well. You can avoid this by wrapping environment with ``Monitor`` 38 | wrapper before anything else. 39 | 40 | :param model: The RL agent you want to evaluate. 41 | :param env: The gym environment. In the case of a ``VecEnv`` 42 | this must contain only one environment. 43 | :param n_eval_episodes: Number of episode to evaluate the agent 44 | :param deterministic: Whether to use deterministic or stochastic actions 45 | :param render: Whether to render the environment or not 46 | :param callback: callback function to do additional checks, 47 | called after each step. Gets locals() and globals() passed as parameters. 48 | :param reward_threshold: Minimum expected reward per episode, 49 | this will raise an error if the performance is not met 50 | :param return_episode_rewards: If True, a list of rewards and episde lengths 51 | per episode will be returned instead of the mean. 52 | :param warn: If True (default), warns user about lack of a Monitor wrapper in the 53 | evaluation environment. 54 | :param use_masking: Whether or not to use invalid action masks during evaluation 55 | :return: Mean reward per episode, std of reward per episode. 56 | Returns ([float], [int]) when ``return_episode_rewards`` is True, first 57 | list containing per-episode rewards and second containing per-episode lengths 58 | (in number of steps). 59 | """ 60 | 61 | if use_masking and not is_masking_supported(env): 62 | raise ValueError( 63 | "Environment does not support action masking. Consider using ActionMasker wrapper") 64 | 65 | is_monitor_wrapped = False 66 | 67 | if not isinstance(env, VecEnv): 68 | env = DummyVecEnv([lambda: env]) 69 | 70 | is_monitor_wrapped = is_vecenv_wrapped( 71 | env, VecMonitor) or env.env_is_wrapped(Monitor)[0] 72 | 73 | if not is_monitor_wrapped and warn: 74 | warnings.warn( 75 | "Evaluation environment is not wrapped with a ``Monitor`` wrapper. " 76 | "This may result in reporting modified episode lengths and rewards, if other wrappers happen to modify these. " 77 | "Consider wrapping environment first with ``Monitor`` wrapper.", 78 | UserWarning, 79 | ) 80 | 81 | n_envs = env.num_envs 82 | episode_rewards = [] 83 | episode_lengths = [] 84 | 85 | episode_counts = np.zeros(n_envs, dtype="int") 86 | # Divides episodes among different sub environments in the vector as evenly as possible 87 | episode_count_targets = np.array( 88 | [(n_eval_episodes + i) // n_envs for i in range(n_envs)], dtype="int") 89 | 90 | current_rewards = np.zeros(n_envs) 91 | current_lengths = np.zeros(n_envs, dtype="int") 92 | observations = env.reset() 93 | states = None 94 | 95 | while (episode_counts < episode_count_targets).any(): 96 | if use_masking: 97 | action_masks = get_action_masks(env) 98 | actions, state = model.predict( 99 | observations, 100 | state=states, 101 | deterministic=deterministic, 102 | action_masks=action_masks, 103 | ) 104 | else: 105 | actions, states = model.predict( 106 | observations, state=states, deterministic=deterministic) 107 | observations, rewards, dones, infos = env.step(actions) 108 | current_rewards += rewards 109 | current_lengths += 1 110 | for i in range(n_envs): 111 | if episode_counts[i] < episode_count_targets[i]: 112 | 113 | # unpack values so that the callback can access the local variables 114 | reward = rewards[i] 115 | done = dones[i] 116 | info = infos[i] 117 | 118 | if callback is not None: 119 | callback(locals(), globals()) 120 | 121 | if dones[i]: 122 | if is_monitor_wrapped: 123 | # Atari wrapper can send a "done" signal when 124 | # the agent loses a life, but it does not correspond 125 | # to the true end of episode 126 | if "episode" in info.keys(): 127 | # Do not trust "done" with episode endings. 128 | # Monitor wrapper includes "episode" key in info if environment 129 | # has been wrapped with it. Use those rewards instead. 130 | episode_rewards.append(info["episode"]["r"]) 131 | episode_lengths.append(info["episode"]["l"]) 132 | # Only increment at the real end of an episode 133 | episode_counts[i] += 1 134 | else: 135 | episode_rewards.append(current_rewards[i]) 136 | episode_lengths.append(current_lengths[i]) 137 | episode_counts[i] += 1 138 | current_rewards[i] = 0 139 | current_lengths[i] = 0 140 | if states is not None: 141 | states[i] *= 0 142 | 143 | if render: 144 | env.render() 145 | 146 | mean_reward = np.mean(episode_rewards) 147 | std_reward = np.std(episode_rewards) 148 | if reward_threshold is not None: 149 | assert mean_reward > reward_threshold, "Mean reward below threshold: " f"{mean_reward:.2f} < {reward_threshold:.2f}" 150 | if return_episode_rewards: 151 | return episode_rewards, episode_lengths 152 | return mean_reward, std_reward 153 | -------------------------------------------------------------------------------- /rl4lms/algorithms/common/maskable/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from stable_baselines3.common.type_aliases import GymEnv 3 | from stable_baselines3.common.vec_env import VecEnv 4 | 5 | EXPECTED_METHOD_NAME = "action_masks" 6 | 7 | 8 | def get_action_masks(env: GymEnv) -> np.ndarray: 9 | """ 10 | Checks whether gym env exposes a method returning invalid action masks 11 | 12 | :param env: the Gym environment to get masks from 13 | :return: A numpy array of the masks 14 | """ 15 | 16 | if isinstance(env, VecEnv): 17 | return np.stack(env.env_method(EXPECTED_METHOD_NAME)) 18 | else: 19 | return getattr(env, EXPECTED_METHOD_NAME)() 20 | 21 | 22 | def is_masking_supported(env: GymEnv) -> bool: 23 | """ 24 | Checks whether gym env exposes a method returning invalid action masks 25 | 26 | :param env: the Gym environment to check 27 | :return: True if the method is found, False otherwise 28 | """ 29 | 30 | if isinstance(env, VecEnv): 31 | try: 32 | # TODO: add VecEnv.has_attr() 33 | env.get_attr(EXPECTED_METHOD_NAME) 34 | return True 35 | except AttributeError: 36 | return False 37 | else: 38 | return hasattr(env, EXPECTED_METHOD_NAME) 39 | -------------------------------------------------------------------------------- /rl4lms/algorithms/nlpo/__init__.py: -------------------------------------------------------------------------------- 1 | from rl4lms.algorithms.nlpo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy 2 | from rl4lms.algorithms.nlpo.nlpo import NLPO 3 | -------------------------------------------------------------------------------- /rl4lms/algorithms/nlpo/policies.py: -------------------------------------------------------------------------------- 1 | from rl4lms.algorithms.common.maskable.policies import ( 2 | MaskableActorCriticCnnPolicy, 3 | MaskableActorCriticPolicy, 4 | MaskableMultiInputActorCriticPolicy, 5 | ) 6 | 7 | MlpPolicy = MaskableActorCriticPolicy 8 | CnnPolicy = MaskableActorCriticCnnPolicy 9 | MultiInputPolicy = MaskableMultiInputActorCriticPolicy 10 | -------------------------------------------------------------------------------- /rl4lms/algorithms/ppo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/RL4LMs/97df0bd2f7406a906206c9610aea795fbf52884c/rl4lms/algorithms/ppo/__init__.py -------------------------------------------------------------------------------- /rl4lms/algorithms/trpo/__init__.py: -------------------------------------------------------------------------------- 1 | from rl4lms.algorithms.trpo.policies import * 2 | from rl4lms.algorithms.trpo.trpo import TRPO 3 | -------------------------------------------------------------------------------- /rl4lms/algorithms/trpo/policies.py: -------------------------------------------------------------------------------- 1 | # This file is here just to define MlpPolicy/CnnPolicy 2 | # that work for TRPO 3 | from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy 4 | 5 | MlpPolicy = ActorCriticPolicy 6 | CnnPolicy = ActorCriticCnnPolicy 7 | MultiInputPolicy = MultiInputActorCriticPolicy -------------------------------------------------------------------------------- /rl4lms/core_components/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/RL4LMs/97df0bd2f7406a906206c9610aea795fbf52884c/rl4lms/core_components/__init__.py -------------------------------------------------------------------------------- /rl4lms/core_components/sampler.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | from typing import Any, List 3 | import numpy as np 4 | 5 | 6 | class PrioritySampler: 7 | def __init__(self, max_size: int = None, priority_scale: float = 0.0): 8 | """ 9 | Creates a priority sampler 10 | 11 | Args: 12 | max_size (int): maximum size of the queue 13 | priority_scale (float): 0.0 is a pure uniform sampling, 1.0 is completely priority sampling 14 | """ 15 | self.max_size = max_size 16 | self.items = deque(maxlen=self.max_size) 17 | self.item_priorities = deque(maxlen=self.max_size) 18 | self.priority_scale = priority_scale 19 | 20 | def add(self, item: Any, priority: float): 21 | self.items.append(item) 22 | self.item_priorities.append(priority) 23 | 24 | def sample(self, size: int) -> List[Any]: 25 | min_sample_size = min(len(self.items), size) 26 | scaled_item_priorities = np.array( 27 | self.item_priorities) ** self.priority_scale 28 | sample_probs = scaled_item_priorities / np.sum(scaled_item_priorities) 29 | samples = np.random.choice( 30 | a=self.items, p=sample_probs, size=min_sample_size) 31 | return samples 32 | 33 | def update(self, item: Any, priority: float): 34 | index = self.items.index(item) 35 | del self.items[index] 36 | del self.item_priorities[index] 37 | self.add(item, priority) 38 | 39 | def get_all_samples(self) -> List[Any]: 40 | return self.items 41 | -------------------------------------------------------------------------------- /rl4lms/core_components/sweep.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any, List 2 | import json 3 | import hashlib 4 | from itertools import product 5 | 6 | 7 | def get_dict_obj(keys: List, values: List) -> Dict: 8 | dict = {} 9 | for key, value in zip(keys, values): 10 | dict[key] = value 11 | return dict 12 | 13 | 14 | def find_products(splits_by_keys: Dict) -> List[Dict]: 15 | values = list(splits_by_keys.values()) 16 | keys = list(splits_by_keys.keys()) 17 | if len(values) == 1: 18 | dict_objs = [get_dict_obj(keys, [value]) for value in values[0]] 19 | else: 20 | product_values = product(*values) 21 | dict_objs = [get_dict_obj(keys, value) for value in product_values] 22 | return dict_objs 23 | 24 | 25 | def to_expand(obj: Any) -> bool: 26 | expand = True if isinstance(obj, dict) and obj.get( 27 | "expand", False) else False 28 | return expand 29 | 30 | 31 | def split_config(obj: Dict) -> List[Dict]: 32 | """ 33 | Recursively splits the given object 34 | """ 35 | if not isinstance(obj, dict): 36 | return obj 37 | 38 | # it is a dict and further split 39 | splits_by_key = {} 40 | for key, child_obj in obj.items(): 41 | if to_expand(child_obj): 42 | all_splits = [] 43 | for item in child_obj["values"]: 44 | splits = split_config(item) 45 | if isinstance(splits, list): 46 | all_splits.extend(splits) 47 | else: 48 | all_splits.append(splits) 49 | splits_by_key[key] = all_splits 50 | 51 | elif isinstance(child_obj, dict): # anoter dict, which needs to be expanded 52 | splits_by_key[key] = split_config(child_obj) 53 | else: # others which need not be expanded 54 | splits_by_key[key] = [child_obj] 55 | 56 | # here, find cartesian 57 | configs = find_products(splits_by_key) 58 | 59 | return configs 60 | 61 | 62 | def dict_hash(dictionary: Dict[str, Any]) -> str: 63 | """MD5 hash of a dictionary.""" 64 | dhash = hashlib.md5() 65 | encoded = json.dumps(dictionary, sort_keys=True).encode() 66 | dhash.update(encoded) 67 | return dhash.hexdigest() 68 | 69 | 70 | if __name__ == "__main__": 71 | config = { 72 | "param_1": { 73 | "expand": True, 74 | "values": [1, 2] 75 | }, 76 | "param_3": { 77 | "param_3_2": { 78 | "expand": False, 79 | "values": [3, 4] 80 | }, 81 | "param_3_3": 5 82 | } 83 | } 84 | 85 | configs = split_config(config) 86 | print(f" Total configs found: {len(configs)}") 87 | for config in configs: 88 | print(config) 89 | print(dict_hash(config)) 90 | -------------------------------------------------------------------------------- /rl4lms/data_pools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/RL4LMs/97df0bd2f7406a906206c9610aea795fbf52884c/rl4lms/data_pools/__init__.py -------------------------------------------------------------------------------- /rl4lms/data_pools/task_utils/totto/eval_utils/__init__.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import os 3 | from tempfile import TemporaryDirectory 4 | import jsonlines 5 | from typing import List 6 | import json 7 | 8 | 9 | def compute_parent(predicted_texts: List[str], 10 | raw_tables: List[dict]): 11 | 12 | with TemporaryDirectory() as temp_dir: 13 | 14 | # write tables 15 | target_path = os.path.join(temp_dir, "samples.jsonl") 16 | with jsonlines.open(target_path, "w") as writer: 17 | for table in raw_tables: 18 | writer.write(table) 19 | 20 | # write gen texts 21 | prediction_path = os.path.join(temp_dir, "predictions.txt") 22 | with open(prediction_path, "w") as fp: 23 | predicted_texts = '\n'.join(predicted_texts) 24 | fp.write(predicted_texts) 25 | 26 | cmd = ['bash', 'totto_parent_eval.sh', 27 | '-p', prediction_path, 28 | '-t', target_path, 29 | '--output_dir', temp_dir, 30 | ] 31 | subprocess.check_call(cmd, 32 | cwd=os.path.dirname(os.path.abspath(__file__)), 33 | stdout=subprocess.DEVNULL) 34 | 35 | # read the results back 36 | with open(os.path.join(temp_dir, "parent_overall.json")) as fp: 37 | parent_overall_results = json.load(fp) 38 | 39 | with open(os.path.join(temp_dir, "parent_overlap.json")) as fp: 40 | parent_overlap_results = json.load(fp) 41 | 42 | with open(os.path.join(temp_dir, "parent_non_overlap.json")) as fp: 43 | parent_non_overlap_results = json.load(fp) 44 | 45 | return parent_overall_results, parent_overlap_results, parent_non_overlap_results 46 | 47 | 48 | def compute_bleu(predicted_texts: List[str], 49 | raw_tables: List[dict]): 50 | 51 | def _read_results(path): 52 | try: 53 | with open(path) as fp: 54 | score = json.load(fp)["score"]/100 55 | except: 56 | score = 0.0 57 | return score 58 | 59 | with TemporaryDirectory() as temp_dir: 60 | 61 | # write tables 62 | target_path = os.path.join(temp_dir, "samples.jsonl") 63 | with jsonlines.open(target_path, "w") as writer: 64 | for table in raw_tables: 65 | writer.write(table) 66 | 67 | # write gen texts 68 | prediction_path = os.path.join(temp_dir, "predictions.txt") 69 | with open(prediction_path, "w") as fp: 70 | predicted_texts = '\n'.join(predicted_texts) 71 | fp.write(predicted_texts) 72 | 73 | cmd = ['bash', 'totto_bleu_eval.sh', 74 | '-p', prediction_path, 75 | '-t', target_path, 76 | '--output_dir', temp_dir, 77 | ] 78 | subprocess.check_call(cmd, 79 | cwd=os.path.dirname(os.path.abspath(__file__)), 80 | stdout=subprocess.DEVNULL) 81 | 82 | # read the results back 83 | bleu_overall = _read_results( 84 | os.path.join(temp_dir, "bleu_overall.json")) 85 | bleu_overlap = _read_results( 86 | os.path.join(temp_dir, "bleu_overlap.json")) 87 | bleu_non_overlap = _read_results( 88 | os.path.join(temp_dir, "bleu_non_overlap.json")) 89 | return bleu_overall, bleu_overlap, bleu_non_overlap 90 | -------------------------------------------------------------------------------- /rl4lms/data_pools/task_utils/totto/eval_utils/prepare_predictions_for_eval.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Processes references for evaluation (except for tokenization).""" 16 | import json 17 | import os 18 | 19 | from absl import app 20 | from absl import flags 21 | import six 22 | 23 | FLAGS = flags.FLAGS 24 | 25 | flags.DEFINE_string("input_prediction_path", None, "Prediction txt file.") 26 | flags.DEFINE_string("input_target_path", None, "Target json file.") 27 | flags.DEFINE_string("output_dir", None, "Output directory.") 28 | 29 | 30 | def write_predictions(predictions, output_path): 31 | """Write predictions to file.""" 32 | with open(output_path, "w", encoding="utf-8") as f: 33 | for prediction in predictions: 34 | if not prediction: 35 | prediction = "" 36 | f.write(prediction.lower() + "\n") 37 | 38 | 39 | def main(_): 40 | input_prediction_path = FLAGS.input_prediction_path 41 | input_target_path = FLAGS.input_target_path 42 | output_dir = FLAGS.output_dir 43 | 44 | predictions = [] 45 | overlap_predictions = [] 46 | nonoverlap_predictions = [] 47 | with open(input_prediction_path, "r", encoding="utf-8") as input_file: 48 | for line in input_file: 49 | line = line.strip() 50 | predictions.append(line) 51 | 52 | json_examples = [] 53 | with open(input_target_path, "r", encoding="utf-8") as input_file: 54 | for line in input_file: 55 | line = six.ensure_text(line, "utf-8") 56 | json_example = json.loads(line) 57 | json_examples.append(json_example) 58 | 59 | assert len(predictions) == len(json_examples) 60 | for index, prediction in enumerate(predictions): 61 | json_example = json_examples[index] 62 | if json_example["overlap_subset"]: 63 | overlap_predictions.append(prediction) 64 | else: 65 | nonoverlap_predictions.append(prediction) 66 | 67 | print("Writing predictions.") 68 | all_output_path = os.path.join(output_dir, "predictions") 69 | overlap_output_path = os.path.join(output_dir, "overlap_predictions") 70 | nonoverlap_output_path = os.path.join(output_dir, "nonoverlap_predictions") 71 | write_predictions(predictions, all_output_path) 72 | write_predictions(overlap_predictions, overlap_output_path) 73 | write_predictions(nonoverlap_predictions, nonoverlap_output_path) 74 | 75 | 76 | if __name__ == "__main__": 77 | flags.mark_flags_as_required( 78 | ["input_prediction_path", "input_target_path", "output_dir"]) 79 | app.run(main) 80 | -------------------------------------------------------------------------------- /rl4lms/data_pools/task_utils/totto/eval_utils/table_to_text_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Utilities for tables-to-text.""" 16 | 17 | 18 | def get_highlighted_subtable(table, cell_indices): 19 | """Extract out the highlighted part of a table.""" 20 | highlighted_table = [] 21 | for (row_index, col_index) in cell_indices: 22 | cell = table[row_index][col_index] 23 | highlighted_table.append(cell) 24 | 25 | return highlighted_table 26 | 27 | 28 | def get_table_parent_format(table, table_page_title, table_section_title, 29 | table_section_text): 30 | """Convert table to format required by PARENT.""" 31 | table_parent_array = [] 32 | 33 | # Table values. 34 | for row in table: 35 | for cell in row: 36 | if cell["is_header"]: 37 | attribute = "header" 38 | else: 39 | attribute = "cell" 40 | value = cell["value"].strip() 41 | if value: 42 | value = value.replace("|", "-") 43 | entry = "%s|||%s" % (attribute, value) 44 | table_parent_array.append(entry) 45 | 46 | # Page title. 47 | if table_page_title: 48 | table_page_title = table_page_title.replace("|", "-") 49 | entry = "%s|||%s" % ("page_title", table_page_title) 50 | table_parent_array.append(entry) 51 | 52 | # Section title. 53 | if table_section_title: 54 | table_section_title = table_section_title.replace("|", "-") 55 | entry = "%s|||%s" % ("section_title", table_section_title) 56 | table_parent_array.append(entry) 57 | 58 | # Section text. 59 | if table_section_text: 60 | table_section_text = table_section_text.replace("|", "-") 61 | entry = "%s|||%s" % ("section_text", table_section_text) 62 | table_parent_array.append(entry) 63 | 64 | table_parent_str = "\t".join(table_parent_array) 65 | return table_parent_str 66 | 67 | 68 | def get_subtable_parent_format(subtable, table_page_title, table_section_title): 69 | """Convert subtable to PARENT format. Do not include section text.""" 70 | table_parent_array = [] 71 | # Table values. 72 | for cell in subtable: 73 | if cell["is_header"]: 74 | attribute = "header" 75 | else: 76 | attribute = "cell" 77 | value = cell["value"].strip() 78 | if value: 79 | value = value.replace("|", "-") 80 | entry = "%s|||%s" % (attribute, value) 81 | table_parent_array.append(entry) 82 | 83 | # Page title. 84 | if table_page_title: 85 | table_page_title = table_page_title.replace("|", "-") 86 | entry = "%s|||%s" % ("page_title", table_page_title) 87 | table_parent_array.append(entry) 88 | 89 | # Section title. 90 | if table_section_title: 91 | table_section_title = table_section_title.replace("|", "-") 92 | entry = "%s|||%s" % ("section_title", table_section_title) 93 | table_parent_array.append(entry) 94 | 95 | table_parent_str = "\t".join(table_parent_array) 96 | return table_parent_str 97 | -------------------------------------------------------------------------------- /rl4lms/data_pools/task_utils/totto/eval_utils/totto_bleu_eval.sh: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | #!/bin/bash 16 | 17 | # Prepare the needed variables. 18 | PREDICTION_PATH=unset 19 | TARGET_PATH=unset 20 | BLEURT_CKPT=unset 21 | OUTPUT_DIR="temp/" 22 | MODE="test" 23 | 24 | # Function to report 25 | usage() 26 | { 27 | echo "Usage: totto_bleu_eval.sh [ -p | --prediction_path PREDICTION/PATH.txt ] 28 | [ -t | --target_path TARGET/PATH.jsonl ] 29 | [ -b | --bleurt_ckpt BLEURT_CHECKPOINT/PATH ] 30 | [ -o | --output_dir ./dev/ ] 31 | [ -m | --mode dev/test ]" 32 | exit 2 33 | } 34 | 35 | # Parse the arguments and check for validity. 36 | PARSED_ARGUMENTS=$(getopt -a -n totto_eval -o p:t:b:o:m: --long prediction_path:,target_path:,bleurt_ckpt:,output_dir:,mode: -- "$@") 37 | VALID_ARGUMENTS=$? 38 | if [ "$VALID_ARGUMENTS" != "0" ]; then 39 | usage 40 | fi 41 | 42 | # echo "PARSED_ARGUMENTS is $PARSED_ARGUMENTS" 43 | # Sort the arguments into their respective variables. 44 | while : 45 | do 46 | case "$1" in 47 | -p | --prediction_path) PREDICTION_PATH="$2" ; shift 2 ;; 48 | -t | --target_path) TARGET_PATH="$2" ; shift 2 ;; 49 | -b | --bleurt_ckpt) BLEURT_CKPT="$2" ; shift 2 ;; 50 | -o | --output_dir) OUTPUT_DIR="$2" ; shift 2 ;; 51 | -m | --mode) MODE="$2" ; shift 2 ;; 52 | # -- denotes the end of arguments; break out of the while loop 53 | --) shift; break ;; 54 | *) shift; break ; 55 | esac 56 | done 57 | 58 | # Check the validity of the arguments (e.g., files exist and mode is valid). 59 | if [[ $PREDICTION_PATH == unset || $TARGET_PATH == unset ]] 60 | then 61 | echo "Prediction path and target path are required arguments." 62 | usage 63 | exit 2 64 | elif [[ !($MODE == "dev" || $MODE == "test") ]] 65 | then 66 | echo "Mode has to be dev or test." 67 | usage 68 | exit 2 69 | elif [[ !(-f $PREDICTION_PATH) ]] 70 | then 71 | echo "Your prediction path \"${PREDICTION_PATH}\" does not exist on your filesystem." 72 | usage 73 | exit 2 74 | elif [[ !(-f $TARGET_PATH) ]] 75 | then 76 | echo "Your target path \"${TARGET_PATH}\" does not exist on your filesystem." 77 | usage 78 | exit 2 79 | fi 80 | 81 | # Trim trailing slash (for concatenation ease later). 82 | OUTPUT_DIR=$(echo $OUTPUT_DIR | sed 's:/*$::') 83 | 84 | # All checks passed. Report the variables. 85 | echo "Running with the following variables:" 86 | echo "PREDICTION_PATH : $PREDICTION_PATH" 87 | echo "TARGET_PATH : $TARGET_PATH " 88 | echo "BLEURT_CKPT : $BLEURT_CKPT " 89 | echo "OUTPUT_DIR : $OUTPUT_DIR" 90 | echo "MODE : $MODE" 91 | 92 | if [ ! -d "${OUTPUT_DIR}" ]; then 93 | echo "Creating Output directory." 94 | mkdir "${OUTPUT_DIR}" 95 | fi 96 | 97 | # echo "Preparing references." 98 | python3 -m prepare_references_for_eval \ 99 | --input_path="${TARGET_PATH}" \ 100 | --output_dir="${OUTPUT_DIR}" \ 101 | --mode="${MODE}" 102 | ret=$? 103 | if [ $ret -ne 0 ]; then 104 | echo "Failed to run python script. Please ensure that all libraries are installed and that files are formatted correctly." 105 | exit 1 106 | fi 107 | 108 | echo "Preparing predictions." 109 | python3 -m prepare_predictions_for_eval \ 110 | --input_prediction_path="${PREDICTION_PATH}" \ 111 | --input_target_path="${TARGET_PATH}" \ 112 | --output_dir="${OUTPUT_DIR}" 113 | ret=$? 114 | if [ $ret -ne 0 ]; then 115 | echo "Failed to run python script. Please ensure that all libraries are installed and that files are formatted correctly." 116 | exit 1 117 | fi 118 | 119 | # Define all required files and detokenize. 120 | echo "Running detokenizers." 121 | declare -a StringArray=("predictions" "overlap_predictions" "nonoverlap_predictions" 122 | "references" "overlap_references" "nonoverlap_references" 123 | "references-multi0" "references-multi1" "references-multi2" 124 | "overlap_references-multi0" "overlap_references-multi1" "overlap_references-multi2" 125 | "nonoverlap_references-multi0" "nonoverlap_references-multi1" "nonoverlap_references-multi2" 126 | "tables_parent_precision_format" "tables_parent_recall_format" 127 | "overlap_tables_parent_precision_format" "overlap_tables_parent_recall_format" 128 | "nonoverlap_tables_parent_precision_format" "nonoverlap_tables_parent_recall_format" 129 | ) 130 | 131 | for filename in "${StringArray[@]}"; 132 | do 133 | mosesdecoder/scripts/tokenizer/detokenizer.perl -q -l en -threads 8 < "${OUTPUT_DIR}/${filename}" > "${OUTPUT_DIR}/detok_${filename}" 134 | done 135 | 136 | echo "======== EVALUATE OVERALL ========" 137 | 138 | # Compute BLEU scores using sacrebleu (https://github.com/mjpost/sacrebleu) 139 | echo "Computing BLEU (overall)" 140 | cat ${OUTPUT_DIR}/detok_predictions | sacrebleu ${OUTPUT_DIR}/detok_references-multi0 ${OUTPUT_DIR}/detok_references-multi1 ${OUTPUT_DIR}/detok_references-multi2 > "${OUTPUT_DIR}/bleu_overall.json" 141 | ret=$? 142 | if [ $ret -ne 0 ]; then 143 | echo "Failed to run eval script. You may have to install PERL packages using cpanm." 144 | exit 1 145 | fi 146 | 147 | echo "======== EVALUATE OVERLAP SUBSET ========" 148 | 149 | echo "Computing BLEU (overlap subset)" 150 | cat ${OUTPUT_DIR}/detok_overlap_predictions | sacrebleu ${OUTPUT_DIR}/detok_overlap_references-multi0 ${OUTPUT_DIR}/detok_overlap_references-multi1 ${OUTPUT_DIR}/detok_overlap_references-multi2 > "${OUTPUT_DIR}/bleu_overlap.json" 151 | 152 | echo "======== EVALUATE NON-OVERLAP SUBSET ========" 153 | 154 | echo "Computing BLEU (non-overlap subset)" 155 | cat ${OUTPUT_DIR}/detok_nonoverlap_predictions | sacrebleu ${OUTPUT_DIR}/detok_nonoverlap_references-multi0 ${OUTPUT_DIR}/detok_nonoverlap_references-multi1 ${OUTPUT_DIR}/detok_nonoverlap_references-multi2 > "${OUTPUT_DIR}/bleu_non_overlap.json" 156 | -------------------------------------------------------------------------------- /rl4lms/data_pools/task_utils/totto/eval_utils/totto_parent_eval.sh: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | #!/bin/bash 16 | 17 | # Prepare the needed variables. 18 | PREDICTION_PATH=unset 19 | TARGET_PATH=unset 20 | BLEURT_CKPT=unset 21 | OUTPUT_DIR="temp/" 22 | MODE="test" 23 | 24 | # Function to report 25 | usage() 26 | { 27 | echo "Usage: totto_parent_eval.sh [ -p | --prediction_path PREDICTION/PATH.txt ] 28 | [ -t | --target_path TARGET/PATH.jsonl ] 29 | [ -b | --bleurt_ckpt BLEURT_CHECKPOINT/PATH ] 30 | [ -o | --output_dir ./dev/ ] 31 | [ -m | --mode dev/test ]" 32 | exit 2 33 | } 34 | 35 | # Parse the arguments and check for validity. 36 | PARSED_ARGUMENTS=$(getopt -a -n totto_eval -o p:t:b:o:m: --long prediction_path:,target_path:,bleurt_ckpt:,output_dir:,mode: -- "$@") 37 | VALID_ARGUMENTS=$? 38 | if [ "$VALID_ARGUMENTS" != "0" ]; then 39 | usage 40 | fi 41 | 42 | # echo "PARSED_ARGUMENTS is $PARSED_ARGUMENTS" 43 | # Sort the arguments into their respective variables. 44 | while : 45 | do 46 | case "$1" in 47 | -p | --prediction_path) PREDICTION_PATH="$2" ; shift 2 ;; 48 | -t | --target_path) TARGET_PATH="$2" ; shift 2 ;; 49 | -b | --bleurt_ckpt) BLEURT_CKPT="$2" ; shift 2 ;; 50 | -o | --output_dir) OUTPUT_DIR="$2" ; shift 2 ;; 51 | -m | --mode) MODE="$2" ; shift 2 ;; 52 | # -- denotes the end of arguments; break out of the while loop 53 | --) shift; break ;; 54 | *) shift; break ; 55 | esac 56 | done 57 | 58 | # Check the validity of the arguments (e.g., files exist and mode is valid). 59 | if [[ $PREDICTION_PATH == unset || $TARGET_PATH == unset ]] 60 | then 61 | echo "Prediction path and target path are required arguments." 62 | usage 63 | exit 2 64 | elif [[ !($MODE == "dev" || $MODE == "test") ]] 65 | then 66 | echo "Mode has to be dev or test." 67 | usage 68 | exit 2 69 | elif [[ !(-f $PREDICTION_PATH) ]] 70 | then 71 | echo "Your prediction path \"${PREDICTION_PATH}\" does not exist on your filesystem." 72 | usage 73 | exit 2 74 | elif [[ !(-f $TARGET_PATH) ]] 75 | then 76 | echo "Your target path \"${TARGET_PATH}\" does not exist on your filesystem." 77 | usage 78 | exit 2 79 | fi 80 | 81 | # Trim trailing slash (for concatenation ease later). 82 | OUTPUT_DIR=$(echo $OUTPUT_DIR | sed 's:/*$::') 83 | 84 | # All checks passed. Report the variables. 85 | echo "Running with the following variables:" 86 | echo "PREDICTION_PATH : $PREDICTION_PATH" 87 | echo "TARGET_PATH : $TARGET_PATH " 88 | echo "BLEURT_CKPT : $BLEURT_CKPT " 89 | echo "OUTPUT_DIR : $OUTPUT_DIR" 90 | echo "MODE : $MODE" 91 | 92 | if [ ! -d "${OUTPUT_DIR}" ]; then 93 | echo "Creating Output directory." 94 | mkdir "${OUTPUT_DIR}" 95 | fi 96 | 97 | 98 | # echo "Preparing references." 99 | python3 -m prepare_references_for_eval \ 100 | --input_path="${TARGET_PATH}" \ 101 | --output_dir="${OUTPUT_DIR}" \ 102 | --mode="${MODE}" 103 | ret=$? 104 | if [ $ret -ne 0 ]; then 105 | echo "Failed to run python script. Please ensure that all libraries are installed and that files are formatted correctly." 106 | exit 1 107 | fi 108 | 109 | echo "Preparing predictions." 110 | python3 -m prepare_predictions_for_eval \ 111 | --input_prediction_path="${PREDICTION_PATH}" \ 112 | --input_target_path="${TARGET_PATH}" \ 113 | --output_dir="${OUTPUT_DIR}" 114 | ret=$? 115 | if [ $ret -ne 0 ]; then 116 | echo "Failed to run python script. Please ensure that all libraries are installed and that files are formatted correctly." 117 | exit 1 118 | fi 119 | 120 | # Define all required files and detokenize. 121 | echo "Running detokenizers." 122 | declare -a StringArray=("predictions" "overlap_predictions" "nonoverlap_predictions" 123 | "references" "overlap_references" "nonoverlap_references" 124 | "references-multi0" "references-multi1" "references-multi2" 125 | "overlap_references-multi0" "overlap_references-multi1" "overlap_references-multi2" 126 | "nonoverlap_references-multi0" "nonoverlap_references-multi1" "nonoverlap_references-multi2" 127 | "tables_parent_precision_format" "tables_parent_recall_format" 128 | "overlap_tables_parent_precision_format" "overlap_tables_parent_recall_format" 129 | "nonoverlap_tables_parent_precision_format" "nonoverlap_tables_parent_recall_format" 130 | ) 131 | 132 | for filename in "${StringArray[@]}"; 133 | do 134 | mosesdecoder/scripts/tokenizer/detokenizer.perl -q -l en -threads 8 < "${OUTPUT_DIR}/${filename}" > "${OUTPUT_DIR}/detok_${filename}" 135 | done 136 | 137 | echo "======== EVALUATE OVERALL ========" 138 | 139 | echo "Computing PARENT (overall)" 140 | python3 -m totto_parent_eval \ 141 | --reference_path="${OUTPUT_DIR}/detok_references-multi" \ 142 | --generation_path="${OUTPUT_DIR}/detok_predictions" \ 143 | --precision_table_path="${OUTPUT_DIR}/detok_tables_parent_precision_format" \ 144 | --recall_table_path="${OUTPUT_DIR}/detok_tables_parent_recall_format"\ 145 | --result_path="${OUTPUT_DIR}/parent_overall.json" 146 | 147 | echo "======== EVALUATE OVERLAP SUBSET ========" 148 | 149 | echo "Computing PARENT (overlap subset)" 150 | python3 -m totto_parent_eval \ 151 | --reference_path="${OUTPUT_DIR}/detok_overlap_references-multi" \ 152 | --generation_path="${OUTPUT_DIR}/detok_overlap_predictions" \ 153 | --precision_table_path="${OUTPUT_DIR}/detok_overlap_tables_parent_precision_format" \ 154 | --recall_table_path="${OUTPUT_DIR}/detok_overlap_tables_parent_recall_format"\ 155 | --result_path="${OUTPUT_DIR}/parent_overlap.json" 156 | 157 | echo "======== EVALUATE NON-OVERLAP SUBSET ========" 158 | 159 | echo "Computing PARENT (non-overlap subset)" 160 | python3 -m totto_parent_eval \ 161 | --reference_path="${OUTPUT_DIR}/detok_nonoverlap_references-multi" \ 162 | --generation_path="${OUTPUT_DIR}/detok_nonoverlap_predictions" \ 163 | --precision_table_path="${OUTPUT_DIR}/detok_nonoverlap_tables_parent_precision_format" \ 164 | --recall_table_path="${OUTPUT_DIR}/detok_nonoverlap_tables_parent_recall_format"\ 165 | --result_path="${OUTPUT_DIR}/parent_non_overlap.json" -------------------------------------------------------------------------------- /rl4lms/data_pools/task_utils/totto/preprocess_utils.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/google-research/language/blob/master/language/totto/baseline_preprocessing/preprocess_data_main.py 2 | # coding=utf-8 3 | # Copyright 2018 The Google AI Language Team Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Baseline preprocessing utilities.""" 17 | import copy 18 | 19 | 20 | def _add_adjusted_col_offsets(table): 21 | """Add adjusted column offsets to take into account multi-column cells.""" 22 | adjusted_table = [] 23 | for row in table: 24 | real_col_index = 0 25 | adjusted_row = [] 26 | for cell in row: 27 | adjusted_cell = copy.deepcopy(cell) 28 | adjusted_cell["adjusted_col_start"] = real_col_index 29 | adjusted_cell["adjusted_col_end"] = ( 30 | adjusted_cell["adjusted_col_start"] + adjusted_cell["column_span"]) 31 | real_col_index += adjusted_cell["column_span"] 32 | adjusted_row.append(adjusted_cell) 33 | adjusted_table.append(adjusted_row) 34 | return adjusted_table 35 | 36 | 37 | def _get_heuristic_row_headers(adjusted_table, row_index, col_index): 38 | """Heuristic to find row headers.""" 39 | row_headers = [] 40 | row = adjusted_table[row_index] 41 | for i in range(0, col_index): 42 | if row[i]["is_header"]: 43 | row_headers.append(row[i]) 44 | return row_headers 45 | 46 | 47 | def _get_heuristic_col_headers(adjusted_table, row_index, col_index): 48 | """Heuristic to find column headers.""" 49 | adjusted_cell = adjusted_table[row_index][col_index] 50 | adjusted_col_start = adjusted_cell["adjusted_col_start"] 51 | adjusted_col_end = adjusted_cell["adjusted_col_end"] 52 | col_headers = [] 53 | for r in range(0, row_index): 54 | row = adjusted_table[r] 55 | for cell in row: 56 | if (cell["adjusted_col_start"] < adjusted_col_end and 57 | cell["adjusted_col_end"] > adjusted_col_start): 58 | if cell["is_header"]: 59 | col_headers.append(cell) 60 | 61 | return col_headers 62 | 63 | 64 | def get_highlighted_subtable(table, cell_indices, with_heuristic_headers=False): 65 | """Extract out the highlighted part of a table.""" 66 | highlighted_table = [] 67 | 68 | adjusted_table = _add_adjusted_col_offsets(table) 69 | 70 | for (row_index, col_index) in cell_indices: 71 | cell = table[row_index][col_index] 72 | if with_heuristic_headers: 73 | row_headers = _get_heuristic_row_headers(adjusted_table, row_index, 74 | col_index) 75 | col_headers = _get_heuristic_col_headers(adjusted_table, row_index, 76 | col_index) 77 | else: 78 | row_headers = [] 79 | col_headers = [] 80 | 81 | highlighted_cell = { 82 | "cell": cell, 83 | "row_headers": row_headers, 84 | "col_headers": col_headers 85 | } 86 | highlighted_table.append(highlighted_cell) 87 | 88 | return highlighted_table 89 | 90 | 91 | def linearize_full_table(table, cell_indices, table_page_title, 92 | table_section_title): 93 | """Linearize full table with localized headers and return a string.""" 94 | table_str = "" 95 | if table_page_title: 96 | table_str += " " + table_page_title + " " 97 | if table_section_title: 98 | table_str += " " + table_section_title + " " 99 | 100 | table_str += " " 101 | adjusted_table = _add_adjusted_col_offsets(table) 102 | for r_index, row in enumerate(table): 103 | row_str = " " 104 | for c_index, col in enumerate(row): 105 | 106 | row_headers = _get_heuristic_row_headers(adjusted_table, r_index, c_index) 107 | col_headers = _get_heuristic_col_headers(adjusted_table, r_index, c_index) 108 | 109 | # Distinguish between highlighted and non-highlighted cells. 110 | if [r_index, c_index] in cell_indices: 111 | start_cell_marker = " " 112 | end_cell_marker = " " 113 | else: 114 | start_cell_marker = " " 115 | end_cell_marker = " " 116 | 117 | # The value of the cell. 118 | item_str = start_cell_marker + col["value"] + " " 119 | 120 | # All the column headers associated with this cell. 121 | for col_header in col_headers: 122 | item_str += " " + col_header["value"] + " " 123 | 124 | # All the row headers associated with this cell. 125 | for row_header in row_headers: 126 | item_str += " " + row_header["value"] + " " 127 | 128 | item_str += end_cell_marker 129 | row_str += item_str 130 | 131 | row_str += " " 132 | table_str += row_str 133 | 134 | table_str += "
" 135 | if cell_indices: 136 | assert "" in table_str 137 | return table_str 138 | 139 | 140 | def linearize_subtable(subtable, table_page_title, table_section_title): 141 | """Linearize the highlighted subtable and return a string of its contents.""" 142 | table_str = "" 143 | if table_page_title: 144 | table_str += " " + table_page_title + " " 145 | if table_section_title: 146 | table_str += " " + table_section_title + " " 147 | table_str += " " 148 | 149 | for item in subtable: 150 | cell = item["cell"] 151 | row_headers = item["row_headers"] 152 | col_headers = item["col_headers"] 153 | 154 | # The value of the cell. 155 | item_str = " " + cell["value"] + " " 156 | 157 | # All the column headers associated with this cell. 158 | for col_header in col_headers: 159 | item_str += " " + col_header["value"] + " " 160 | 161 | # All the row headers associated with this cell. 162 | for row_header in row_headers: 163 | item_str += " " + row_header["value"] + " " 164 | 165 | item_str += " " 166 | table_str += item_str 167 | 168 | table_str += "
" 169 | return table_str 170 | -------------------------------------------------------------------------------- /rl4lms/data_pools/text_generation_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | from abc import abstractclassmethod 3 | from dataclasses import dataclass 4 | from typing import Any, List, Dict 5 | 6 | 7 | @dataclass(init=True) 8 | class Sample: 9 | id: str 10 | prompt_or_input_text: str 11 | references: List[str] 12 | meta_data: Dict[str, Any] = None 13 | 14 | 15 | class TextGenPool: 16 | def __init__(self, samples: List[Sample]): 17 | self._samples = samples 18 | 19 | def __len__(self): 20 | return len(self._samples) 21 | 22 | def __getitem__(self, ix: int) -> Sample: 23 | if ix >= len(self): 24 | raise StopIteration 25 | sample = self._samples[ix] 26 | return sample, 1.0 27 | 28 | def sample(self) -> Sample: 29 | random_sample = random.choice(self._samples) 30 | return random_sample 31 | 32 | @abstractclassmethod 33 | def prepare(cls, **args) -> 'TextGenPool': 34 | """ 35 | A factory method to instantiate data pool 36 | """ 37 | raise NotImplementedError 38 | 39 | def split(self, split_ratios: List[float]) -> List['TextGenPool']: 40 | start_ix = 0 41 | pools = [] 42 | for ratio in split_ratios: 43 | count = int(len(self) * ratio) 44 | end_ix = start_ix + count 45 | pools.append(type(self)(self._samples[start_ix: end_ix])) 46 | start_ix = end_ix 47 | return pools 48 | -------------------------------------------------------------------------------- /rl4lms/envs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/RL4LMs/97df0bd2f7406a906206c9610aea795fbf52884c/rl4lms/envs/__init__.py -------------------------------------------------------------------------------- /rl4lms/envs/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/RL4LMs/97df0bd2f7406a906206c9610aea795fbf52884c/rl4lms/envs/common/__init__.py -------------------------------------------------------------------------------- /rl4lms/envs/common/action_space.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from gym.spaces.discrete import Discrete 3 | 4 | 5 | class ActionSpace(Discrete): 6 | def __init__(self, actions: List[str]): 7 | self.actions = actions 8 | self._ix_to_action = {ix: action for ix, action in enumerate(self.actions)} 9 | self._action_to_ix = {action: ix for ix, action in enumerate(self.actions)} 10 | super().__init__(len(self.actions)) 11 | 12 | def __post_init__(self): 13 | self._ix_to_action = {ix: action for ix, action in enumerate(self.actions)} 14 | self._action_to_ix = {action: ix for ix, action in enumerate(self.actions)} 15 | 16 | def action_to_ix(self, action: str) -> int: 17 | return self._action_to_ix[action] 18 | 19 | def ix_to_action(self, ix: int) -> str: 20 | return self._ix_to_action[ix] 21 | 22 | def size(self) -> int: 23 | return self.n 24 | 25 | def __repr__(self): 26 | return f"Discrete Action Space with {self.size()} actions: {self.actions}" 27 | -------------------------------------------------------------------------------- /rl4lms/envs/common/base_env.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Tuple, List, Union 3 | from rl4lms.envs.common.observation import BaseObservation, BaseObservationFeaturizer 4 | from rl4lms.envs.common.reward import RewardFunction 5 | from rl4lms.data_pools.base import Sample 6 | from rl4lms.envs.common.action_space import ActionSpace 7 | from gym import spaces 8 | import gym 9 | import numpy as np 10 | 11 | 12 | class BaseEnv(gym.Env): 13 | """ 14 | A base class for all the environments 15 | """ 16 | 17 | def __init__(self, max_steps: int, reward_function: RewardFunction, 18 | observation_featurizer: BaseObservationFeaturizer, return_obs_as_vector: bool = True): 19 | """ 20 | Args: 21 | max_steps (int): max steps for each episode 22 | reward_function (RewardFunction): reward function that computes scalar reward for each observation-action 23 | observation_featurizer (ObservationFeaturizer): a featurizer that vectorizes input and context of observation 24 | return_obs_vector (bool): return the observation as vector 25 | """ 26 | self.max_steps = max_steps 27 | self.reward_function = reward_function 28 | self.return_obs_as_vector = return_obs_as_vector 29 | self.set_featurizer(observation_featurizer) 30 | 31 | # Standard gym methods 32 | 33 | @abstractmethod 34 | def step(self, action: int) -> Tuple[Union[BaseObservation, np.array], int, bool, dict]: 35 | """ 36 | Takes a step with the given action and returns (next state, reward, done, info) 37 | """ 38 | raise NotImplementedError 39 | 40 | @abstractmethod 41 | def reset(self, sample: Sample = None) -> Union[BaseObservation, np.array]: 42 | """ 43 | Resets the episode and returns an observation 44 | """ 45 | raise NotImplementedError 46 | 47 | @abstractmethod 48 | def render(self): 49 | """ 50 | Renders the current state of the environment 51 | """ 52 | raise NotImplementedError 53 | 54 | @abstractmethod 55 | def close(self): 56 | raise NotImplementedError 57 | 58 | # Methods related to observation and action space infos 59 | 60 | def get_observation_dim(self) -> int: 61 | """ 62 | Gets the observation dimension 63 | """ 64 | return self.observation_featurizer.get_observation_dim() 65 | 66 | def get_action_space(self) -> ActionSpace: 67 | """ 68 | Lists all possible actions indices and its meaning 69 | 70 | Returns: 71 | ActionSpace -- an instance of action space 72 | """ 73 | return self.action_space 74 | 75 | # Additional methods for online learning and sampling 76 | 77 | @abstractmethod 78 | def add_sample(self, sample: Sample): 79 | """ 80 | Adds annotated sample for sampling/replaying 81 | """ 82 | raise NotImplementedError 83 | 84 | def get_samples(self) -> List[Sample]: 85 | """ 86 | Returns list of samples available in the environment 87 | 88 | Returns: 89 | List[Sample]: list of samples in the environment 90 | """ 91 | raise NotImplementedError 92 | 93 | def set_featurizer(self, observation_featurizer: BaseObservationFeaturizer): 94 | """ 95 | Sets the observation featurizer (can also change during run time) 96 | """ 97 | self.observation_featurizer = observation_featurizer 98 | if observation_featurizer is not None: 99 | self._set_spaces(observation_featurizer) 100 | 101 | def _set_spaces(self, observation_featurizer: BaseObservationFeaturizer): 102 | low = np.full(shape=(observation_featurizer.get_observation_dim(),), 103 | fill_value=-float('inf'), dtype=np.float32) 104 | high = np.full(shape=(observation_featurizer.get_observation_dim( 105 | ),), fill_value=float('inf'), dtype=np.float32) 106 | self.observation_space = spaces.Box(low, high, dtype=np.float32) 107 | -------------------------------------------------------------------------------- /rl4lms/envs/common/observation.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from abc import ABC, abstractmethod 3 | import torch 4 | 5 | 6 | @dataclass 7 | class BaseObservation: 8 | """ 9 | Placeholder for observation data class 10 | """ 11 | pass 12 | 13 | 14 | class BaseObservationFeaturizer(ABC): 15 | 16 | @abstractmethod 17 | def featurize(self, observation: BaseObservation) -> torch.Tensor: 18 | raise NotImplementedError 19 | 20 | def get_observation_dim(self) -> int: 21 | """ 22 | Returns the observation dim 23 | """ 24 | return self.get_input_dim() + self.get_context_dim() -------------------------------------------------------------------------------- /rl4lms/envs/common/reward.py: -------------------------------------------------------------------------------- 1 | 2 | from rl4lms.envs.common.observation import BaseObservation 3 | from abc import ABC, abstractclassmethod 4 | from typing import List 5 | 6 | 7 | class RewardFunction(ABC): 8 | @abstractclassmethod 9 | def __call__(self, observation: BaseObservation, action: str, targets: List[str]) -> float: 10 | """[summary] 11 | 12 | Args: 13 | observation (Observation): current observation at t 14 | action (str): current action at t 15 | targets (List[str]): targets of the current sample 16 | 17 | Returns: 18 | - a scalar reward 19 | """ 20 | raise NotImplementedError 21 | -------------------------------------------------------------------------------- /rl4lms/envs/text_generation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/RL4LMs/97df0bd2f7406a906206c9610aea795fbf52884c/rl4lms/envs/text_generation/__init__.py -------------------------------------------------------------------------------- /rl4lms/envs/text_generation/caption_metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/RL4LMs/97df0bd2f7406a906206c9610aea795fbf52884c/rl4lms/envs/text_generation/caption_metrics/__init__.py -------------------------------------------------------------------------------- /rl4lms/envs/text_generation/caption_metrics/spice/Readme.txt: -------------------------------------------------------------------------------- 1 | 2 | This zip file contains the pre-built SPICE-1.0.jar and all libraries required to run it, except for Stanford CoreNLP. 3 | 4 | Run $ ./get_stanford_models.sh or otherwise download the CoreNLP 3.6.0 code and models jar files into /lib. 5 | 6 | Instructions for using SPICE are found in spice-1.0.jar/README.md 7 | 8 | 9 | References: 10 | https://panderson.me/spice/ -------------------------------------------------------------------------------- /rl4lms/envs/text_generation/caption_metrics/spice/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/RL4LMs/97df0bd2f7406a906206c9610aea795fbf52884c/rl4lms/envs/text_generation/caption_metrics/spice/__init__.py -------------------------------------------------------------------------------- /rl4lms/envs/text_generation/caption_metrics/spice/get_stanford_models.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | # This script downloads the Stanford CoreNLP models. 3 | 4 | # Adapted from: 5 | # https://panderson.me/spice/ 6 | 7 | CORENLP=stanford-corenlp-full-2015-12-09 8 | SPICELIB=lib 9 | 10 | DIR="$( cd "$(dirname "$0")" ; pwd -P )" 11 | cd $DIR 12 | 13 | echo $(pwd) 14 | echo "$(dirname "$0")" 15 | 16 | echo "Downloading..." 17 | 18 | wget http://nlp.stanford.edu/software/$CORENLP.zip 19 | 20 | echo "Unzipping..." 21 | 22 | unzip $CORENLP.zip -d $SPICELIB/ 23 | mv $SPICELIB/$CORENLP/stanford-corenlp-3.6.0.jar $SPICELIB/ 24 | mv $SPICELIB/$CORENLP/stanford-corenlp-3.6.0-models.jar $SPICELIB/ 25 | rm -f stanford-corenlp-full-2015-12-09.zip 26 | 27 | echo "Done." 28 | -------------------------------------------------------------------------------- /rl4lms/envs/text_generation/caption_metrics/spice/lib/Meteor-1.5.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/RL4LMs/97df0bd2f7406a906206c9610aea795fbf52884c/rl4lms/envs/text_generation/caption_metrics/spice/lib/Meteor-1.5.jar -------------------------------------------------------------------------------- /rl4lms/envs/text_generation/caption_metrics/spice/lib/SceneGraphParser-1.0.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/RL4LMs/97df0bd2f7406a906206c9610aea795fbf52884c/rl4lms/envs/text_generation/caption_metrics/spice/lib/SceneGraphParser-1.0.jar -------------------------------------------------------------------------------- /rl4lms/envs/text_generation/caption_metrics/spice/lib/ejml-0.23.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/RL4LMs/97df0bd2f7406a906206c9610aea795fbf52884c/rl4lms/envs/text_generation/caption_metrics/spice/lib/ejml-0.23.jar -------------------------------------------------------------------------------- /rl4lms/envs/text_generation/caption_metrics/spice/lib/fst-2.47.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/RL4LMs/97df0bd2f7406a906206c9610aea795fbf52884c/rl4lms/envs/text_generation/caption_metrics/spice/lib/fst-2.47.jar -------------------------------------------------------------------------------- /rl4lms/envs/text_generation/caption_metrics/spice/lib/guava-19.0.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/RL4LMs/97df0bd2f7406a906206c9610aea795fbf52884c/rl4lms/envs/text_generation/caption_metrics/spice/lib/guava-19.0.jar -------------------------------------------------------------------------------- /rl4lms/envs/text_generation/caption_metrics/spice/lib/hamcrest-core-1.3.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/RL4LMs/97df0bd2f7406a906206c9610aea795fbf52884c/rl4lms/envs/text_generation/caption_metrics/spice/lib/hamcrest-core-1.3.jar -------------------------------------------------------------------------------- /rl4lms/envs/text_generation/caption_metrics/spice/lib/jackson-core-2.5.3.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/RL4LMs/97df0bd2f7406a906206c9610aea795fbf52884c/rl4lms/envs/text_generation/caption_metrics/spice/lib/jackson-core-2.5.3.jar -------------------------------------------------------------------------------- /rl4lms/envs/text_generation/caption_metrics/spice/lib/javassist-3.19.0-GA.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/RL4LMs/97df0bd2f7406a906206c9610aea795fbf52884c/rl4lms/envs/text_generation/caption_metrics/spice/lib/javassist-3.19.0-GA.jar -------------------------------------------------------------------------------- /rl4lms/envs/text_generation/caption_metrics/spice/lib/json-simple-1.1.1.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/RL4LMs/97df0bd2f7406a906206c9610aea795fbf52884c/rl4lms/envs/text_generation/caption_metrics/spice/lib/json-simple-1.1.1.jar -------------------------------------------------------------------------------- /rl4lms/envs/text_generation/caption_metrics/spice/lib/junit-4.12.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/RL4LMs/97df0bd2f7406a906206c9610aea795fbf52884c/rl4lms/envs/text_generation/caption_metrics/spice/lib/junit-4.12.jar -------------------------------------------------------------------------------- /rl4lms/envs/text_generation/caption_metrics/spice/lib/lmdbjni-0.4.6.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/RL4LMs/97df0bd2f7406a906206c9610aea795fbf52884c/rl4lms/envs/text_generation/caption_metrics/spice/lib/lmdbjni-0.4.6.jar -------------------------------------------------------------------------------- /rl4lms/envs/text_generation/caption_metrics/spice/lib/lmdbjni-linux64-0.4.6.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/RL4LMs/97df0bd2f7406a906206c9610aea795fbf52884c/rl4lms/envs/text_generation/caption_metrics/spice/lib/lmdbjni-linux64-0.4.6.jar -------------------------------------------------------------------------------- /rl4lms/envs/text_generation/caption_metrics/spice/lib/lmdbjni-osx64-0.4.6.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/RL4LMs/97df0bd2f7406a906206c9610aea795fbf52884c/rl4lms/envs/text_generation/caption_metrics/spice/lib/lmdbjni-osx64-0.4.6.jar -------------------------------------------------------------------------------- /rl4lms/envs/text_generation/caption_metrics/spice/lib/lmdbjni-win64-0.4.6.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/RL4LMs/97df0bd2f7406a906206c9610aea795fbf52884c/rl4lms/envs/text_generation/caption_metrics/spice/lib/lmdbjni-win64-0.4.6.jar -------------------------------------------------------------------------------- /rl4lms/envs/text_generation/caption_metrics/spice/lib/objenesis-2.4.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/RL4LMs/97df0bd2f7406a906206c9610aea795fbf52884c/rl4lms/envs/text_generation/caption_metrics/spice/lib/objenesis-2.4.jar -------------------------------------------------------------------------------- /rl4lms/envs/text_generation/caption_metrics/spice/lib/slf4j-api-1.7.12.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/RL4LMs/97df0bd2f7406a906206c9610aea795fbf52884c/rl4lms/envs/text_generation/caption_metrics/spice/lib/slf4j-api-1.7.12.jar -------------------------------------------------------------------------------- /rl4lms/envs/text_generation/caption_metrics/spice/lib/slf4j-simple-1.7.21.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/RL4LMs/97df0bd2f7406a906206c9610aea795fbf52884c/rl4lms/envs/text_generation/caption_metrics/spice/lib/slf4j-simple-1.7.21.jar -------------------------------------------------------------------------------- /rl4lms/envs/text_generation/caption_metrics/spice/spice-1.0.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/RL4LMs/97df0bd2f7406a906206c9610aea795fbf52884c/rl4lms/envs/text_generation/caption_metrics/spice/spice-1.0.jar -------------------------------------------------------------------------------- /rl4lms/envs/text_generation/caption_metrics/spice/spice.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from https://github.com/INK-USC/CommonGen/tree/master/evaluation/Traditional/eval_metrics/spice 3 | """ 4 | 5 | from __future__ import division 6 | import os 7 | import subprocess 8 | import json 9 | import numpy as np 10 | import tempfile 11 | import spacy 12 | 13 | # Assumes spice.jar is in the same directory as spice.py. Change as needed. 14 | SPICE_JAR = 'spice-1.0.jar' 15 | TEMP_DIR = 'tmp' 16 | CACHE_DIR = 'cache' 17 | 18 | 19 | class Spice: 20 | """ 21 | Main Class to compute the SPICE metric 22 | """ 23 | 24 | def __init__(self) -> None: 25 | self._nlp = spacy.load("en_core_web_sm") 26 | # keep only tagger 27 | for pipe in ["tok2vec", "parser", "ner", "attribute_ruler", "lemmatizer"]: 28 | self._nlp.remove_pipe(pipe) 29 | 30 | def float_convert(self, obj): 31 | try: 32 | return float(obj) 33 | except: 34 | return np.nan 35 | 36 | def tokenize(self, dict): 37 | for key in dict: 38 | new_sentence_list = [] 39 | for sentence in dict[key]: 40 | a = '' 41 | for token in self._nlp(str(sentence)): 42 | a += token.text 43 | a += ' ' 44 | new_sentence_list.append(a.rstrip()) 45 | dict[key] = new_sentence_list 46 | 47 | return dict 48 | 49 | def compute_score(self, gts, res): 50 | 51 | # tokenize 52 | gts = self.tokenize(gts) 53 | res = self.tokenize(res) 54 | 55 | assert(sorted(gts.keys()) == sorted(res.keys())) 56 | imgIds = sorted(gts.keys()) 57 | 58 | # Prepare temp input file for the SPICE scorer 59 | input_data = [] 60 | for id in imgIds: 61 | hypo = res[id] 62 | ref = gts[id] 63 | 64 | # Sanity check. 65 | assert(type(hypo) is list) 66 | assert(len(hypo) == 1) 67 | assert(type(ref) is list) 68 | assert(len(ref) >= 1) 69 | 70 | input_data.append({ 71 | "image_id": id, 72 | "test": hypo[0], 73 | "refs": ref 74 | }) 75 | 76 | cwd = os.path.dirname(os.path.abspath(__file__)) 77 | temp_dir = os.path.join(cwd, TEMP_DIR) 78 | if not os.path.exists(temp_dir): 79 | os.makedirs(temp_dir) 80 | in_file = tempfile.NamedTemporaryFile( 81 | mode="w", delete=False, dir=temp_dir) 82 | json.dump(input_data, in_file, indent=2) 83 | in_file.close() 84 | 85 | # Start job 86 | out_file = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir) 87 | out_file.close() 88 | cache_dir = os.path.join(cwd, CACHE_DIR) 89 | if not os.path.exists(cache_dir): 90 | os.makedirs(cache_dir) 91 | spice_cmd = ['java', '-jar', '-Xmx8G', SPICE_JAR, in_file.name, 92 | '-cache', cache_dir, 93 | '-out', out_file.name, 94 | '-subset', 95 | '-silent' 96 | ] 97 | subprocess.check_call(spice_cmd, 98 | cwd=os.path.dirname(os.path.abspath(__file__))) 99 | 100 | # Read and process results 101 | with open(out_file.name) as data_file: 102 | results = json.load(data_file) 103 | os.remove(in_file.name) 104 | os.remove(out_file.name) 105 | 106 | imgId_to_scores = {} 107 | spice_scores = [] 108 | individual_scores = {} 109 | for item in results: 110 | imgId_to_scores[item['image_id']] = item['scores'] 111 | spice_scores.append(self.float_convert(item['scores']['All']['f'])) 112 | individual_scores[item['image_id']] = self.float_convert( 113 | item['scores']['All']['f']) 114 | average_score = np.mean(np.array(spice_scores)) 115 | scores = [] 116 | for image_id in imgIds: 117 | # Convert none to NaN before saving scores over subcategories 118 | score_set = {} 119 | for category, score_tuple in imgId_to_scores[image_id].items(): 120 | score_set[category] = {k: self.float_convert( 121 | v) for k, v in score_tuple.items()} 122 | scores.append(score_set) 123 | 124 | return average_score, individual_scores 125 | 126 | def method(self): 127 | return "SPICE" 128 | 129 | 130 | if __name__ == "__main__": 131 | gts = {"cat#dog#boy": ["The dog is the boy's cat.", "The dog eats the cat of the boy."], 132 | "apple#tree#boy": ["A boy is picking apples from trees."]} 133 | res = {"cat#dog#boy": ["The dog is the boy's cat."], 134 | "apple#tree#boy": ["A boy is picking apples from trees and put them into bags."]} 135 | 136 | metric = Spice() 137 | print(metric.compute_score(gts, res)) 138 | -------------------------------------------------------------------------------- /rl4lms/envs/text_generation/env.py: -------------------------------------------------------------------------------- 1 | from cmath import inf 2 | from typing import Dict, Tuple, Optional, List 3 | 4 | import torch 5 | from gym import Env, spaces 6 | from gym.spaces.dict import Dict as DictSpace 7 | from gym.spaces.discrete import Discrete 8 | from rl4lms.data_pools.text_generation_pool import Sample 9 | from rl4lms.envs.text_generation.reward import BatchedRewardFunction, RewardFunction 10 | from rl4lms.envs.text_generation.observation import Observation 11 | from transformers import AutoTokenizer 12 | from rl4lms.core_components.sampler import PrioritySampler 13 | 14 | 15 | class TextGenEnv(Env): 16 | def __init__( 17 | self, 18 | tokenizer: AutoTokenizer, 19 | reward_function: RewardFunction, 20 | samples: Tuple[List[Sample], float], 21 | max_episode_length: int = 512, 22 | priority_scale: float = 0.0, 23 | max_prompt_length: Optional[int] = None, 24 | terminate_on_eos: bool = False, 25 | context_start_token: Optional[int] = None, 26 | prompt_truncation_side: str = "left", 27 | ): 28 | """ 29 | A generic RL environment to generate textual sequences. 30 | For eg: text generation, summarization, machine translation, text simplification 31 | Args: 32 | tokenizer (AutoTokenizer): pre-trained tokenizer 33 | reward_function (RewardFunction): reward functiom 34 | samples (Tuple[List[Sample], float]): list of samples 35 | max_episode_length (int, optional): Max steps to the model Defaults to 512. 36 | priority_scale (float, optional): weight for the priority sampler Defaults to 0.0. 37 | max_prompt_length (Optional[int], optional): maximum prompt length. Defaults to None. 38 | terminate_on_eos (bool, optional): whether to terminate on EOS. Defaults to False. 39 | context_start_token (bool, optional): start token for the context (For Encoder-Decoder models! ) 40 | prompt_truncation_side (str): truncation side for prompt text (Defaults to "left") 41 | """ 42 | self.tokenizer = tokenizer 43 | self.reward_function = reward_function 44 | self.max_steps = max_episode_length 45 | self._max_text_length = ( 46 | max_prompt_length if max_prompt_length else tokenizer.model_max_length 47 | ) 48 | self._terminate_on_eos = terminate_on_eos 49 | self._context_start_token = context_start_token 50 | self._prompt_truncation_side = prompt_truncation_side 51 | super().__init__() 52 | 53 | # set the observation and action space here 54 | self._vocab_size = tokenizer.vocab_size 55 | self.observation_space = DictSpace( 56 | { 57 | # we have to provide fixed sized inputs (padded) because sb3 support for DictObsersevation is limited 58 | # while creating rollout buffers, observations are concatenated for each key 59 | "prompt_or_input_encoded_pt": spaces.Box( 60 | low=0, high=self._vocab_size, shape=(self._max_text_length,) 61 | ), 62 | "prompt_or_input_attention_mask_pt": spaces.Box( 63 | low=0, high=1, shape=(self._max_text_length,) 64 | ), 65 | "context_encoded_pt": spaces.Box( 66 | low=0, high=self._vocab_size, shape=(self.max_steps,) 67 | ), 68 | "context_attention_mask_pt": spaces.Box( 69 | low=0, high=1, shape=(self.max_steps,) 70 | ), 71 | "input_encoded_pt": spaces.Box( 72 | low=0, 73 | high=self._vocab_size, 74 | shape=(self._max_text_length + self.max_steps,), 75 | ), 76 | "input_attention_mask_pt": spaces.Box( 77 | low=0, high=1, shape=(self._max_text_length + self.max_steps,) 78 | ), 79 | } 80 | ) 81 | self.action_space = Discrete(n=self._vocab_size) 82 | # see https://github.com/huggingface/transformers/issues/4875 : rounding up to nearest power of 2 for better GPU efficiency 83 | if 'mt5' in self.tokenizer.name_or_path: 84 | n = 250112 85 | self.action_space = Discrete(n=n) 86 | elif 't5' in self.tokenizer.name_or_path: 87 | n = 32128 88 | self.action_space = Discrete(n=n) 89 | self.sampler_for_replaying = PrioritySampler(priority_scale=priority_scale) 90 | for sample, weight in samples: 91 | self.sampler_for_replaying.add(sample, weight) 92 | 93 | # check the tokenizer and add padding tokens 94 | if self.tokenizer.pad_token is None: 95 | self.tokenizer.pad_token = self.tokenizer.eos_token 96 | self.tokenizer.padding_side = "left" # TBD: configure this 97 | self.tokenizer.truncation_side = "left" # TBD: configure this 98 | 99 | # init tracking variables 100 | self.__current_sample = None 101 | self.__current_obs = None 102 | self.__time_step = None 103 | 104 | def step(self, action: int) -> Tuple[Dict[str, torch.tensor], int, bool, dict]: 105 | self.__time_step += 1 106 | 107 | # previous obs 108 | previous_obs = self.__current_obs 109 | 110 | # just update the context tensor and gets the new observation 111 | self.__current_obs = self.__current_obs.update(action, self.tokenizer) 112 | 113 | # decide if the episode is finished or not 114 | done = (action == self.tokenizer.eos_token_id and self._terminate_on_eos) or ( 115 | self.__time_step == self.max_steps 116 | ) 117 | 118 | # compute reward 119 | if not isinstance(self.reward_function, BatchedRewardFunction): 120 | reward = ( 121 | None 122 | if self.reward_function is None 123 | else self.reward_function( 124 | previous_obs, 125 | action, 126 | self.__current_obs, 127 | done, 128 | self.__current_obs.meta_info, 129 | ) 130 | ) 131 | else: 132 | reward = -inf # will be overridden later 133 | 134 | # populate additional info 135 | info = { 136 | "output": self.__current_obs.context_text, 137 | "action_history": self.__current_obs.action_history, 138 | "reference_text": self.__current_obs.target_or_reference_texts, 139 | "prompt_text": self.__current_obs.prompt_or_input_text, 140 | "prev_output": previous_obs.context_text, 141 | "meta_info": previous_obs.meta_info, 142 | } 143 | 144 | return self.__current_obs.to_dict(), reward, done, info 145 | 146 | def reset(self, sample: Sample = None) -> Dict[str, torch.tensor]: 147 | """ 148 | Resets the environment and starts a new episode 149 | """ 150 | # gets a new sample if not provided 151 | if sample is None: 152 | sample = self.sampler_for_replaying.sample(size=1)[0] 153 | self.__current_sample = sample 154 | 155 | # init the observation 156 | self.__current_obs = Observation.init_from_sample( 157 | sample, 158 | self.tokenizer, 159 | self._max_text_length, 160 | self.max_steps, 161 | self._prompt_truncation_side, 162 | self._context_start_token, 163 | sample.meta_data, 164 | ) 165 | 166 | # start the time step counter 167 | self.__time_step = 0 168 | 169 | dict_observation = self.__current_obs.to_dict() 170 | return dict_observation 171 | 172 | def render(self): 173 | pass 174 | 175 | def close(self): 176 | pass 177 | 178 | def add_sample(self, sample: Sample, weight: int = 1.0): 179 | self.sampler_for_replaying.add(sample, weight) 180 | -------------------------------------------------------------------------------- /rl4lms/envs/text_generation/evaluation_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List 2 | 3 | from stable_baselines3.common.policies import BasePolicy 4 | from tqdm import tqdm 5 | from transformers import AutoTokenizer 6 | 7 | from rl4lms.data_pools.custom_text_generation_pools import Sample 8 | from rl4lms.envs.text_generation.logging_utils import Tracker 9 | from rl4lms.envs.text_generation.metric import BaseMetric 10 | 11 | 12 | def get_batch(samples: List[Sample], batch_size: int): 13 | current_ix = 0 14 | n_samples = len(samples) 15 | while current_ix < n_samples: 16 | current_batch = samples[current_ix : current_ix + batch_size] 17 | yield current_batch 18 | current_ix += batch_size 19 | 20 | 21 | def evaluate_on_samples( 22 | policy: BasePolicy, 23 | tokenizer: AutoTokenizer, 24 | samples: List[Sample], 25 | batch_size: int, 26 | max_prompt_length: int, 27 | metrics: List[BaseMetric], 28 | epoch: int, 29 | split_name: str, 30 | tracker: Tracker = None, 31 | dt_control_token: str = "", 32 | gen_kwargs: Dict[str, Any] = None, 33 | ): 34 | # generate text by batch 35 | all_generated_texts = [] 36 | all_ref_texts = [] 37 | all_prompt_texts = [] 38 | all_meta_infos = [] 39 | n_samples = len(samples) 40 | for batch in tqdm(list(get_batch(samples, batch_size)), desc="Evaluating"): 41 | batch_generated_texts = generate_text( 42 | policy, tokenizer, batch, max_prompt_length, dt_control_token, gen_kwargs 43 | ) 44 | batch_ref_texts = [sample.references for sample in batch] 45 | batch_prompt_texts = [sample.prompt_or_input_text for sample in batch] 46 | batch_meta_infos = [sample.meta_data for sample in batch] 47 | all_generated_texts.extend(batch_generated_texts) 48 | all_ref_texts.extend(batch_ref_texts) 49 | all_prompt_texts.extend(batch_prompt_texts) 50 | all_meta_infos.extend(batch_meta_infos) 51 | 52 | # compute metrics 53 | corpus_level_metrics = {} 54 | sample_scores_by_metric = {} 55 | if metrics is not None: 56 | for metric in metrics: 57 | metric_dict = metric.compute( 58 | all_prompt_texts, 59 | all_generated_texts, 60 | all_ref_texts, 61 | all_meta_infos, 62 | policy.get_language_model(), 63 | split_name, 64 | ) 65 | 66 | for metric_key, (sample_scores, corpus_score) in metric_dict.items(): 67 | if sample_scores is None: 68 | sample_scores = ["n/a"] * n_samples 69 | corpus_level_metrics[metric_key] = corpus_score 70 | sample_scores_by_metric[metric_key] = sample_scores 71 | 72 | # aggregate sample metric scores 73 | sample_predictions_dict = [] 74 | for ix, (sample, prompt_text, generated_text, ref_texts) in enumerate( 75 | zip(samples, all_prompt_texts, all_generated_texts, all_ref_texts) 76 | ): 77 | sample_prediction = { 78 | "split_name": split_name, 79 | "sample_id": sample.id, 80 | "prompt_text": prompt_text, 81 | "generated_text": generated_text, 82 | "ref_text": "".join( 83 | [ 84 | f"" + ref_text + f"" 85 | for ref_ix, ref_text in enumerate(ref_texts) 86 | ] 87 | ), 88 | } 89 | for metric_key, sample_scores in sample_scores_by_metric.items(): 90 | sample_prediction[metric_key] = sample_scores[ix] 91 | sample_predictions_dict.append(sample_prediction) 92 | 93 | if tracker is not None: 94 | # log the entire predictions 95 | tracker.log_predictions(epoch, split_name, sample_predictions_dict) 96 | # log the corpus level scores 97 | tracker.log_metrics(epoch, split_name, corpus_level_metrics) 98 | 99 | 100 | def generate_text( 101 | policy: BasePolicy, 102 | tokenizer: AutoTokenizer, 103 | samples: List[Sample], 104 | max_prompt_length: int, 105 | dt_control_token: str, 106 | gen_kwargs: Dict[str, Any], 107 | ): 108 | prompt_texts = [ 109 | dt_control_token + sample.prompt_or_input_text for sample in samples 110 | ] 111 | generated_texts = policy.generate( 112 | tokenizer, prompt_texts, max_prompt_length, gen_kwargs=gen_kwargs 113 | ).gen_texts 114 | return generated_texts 115 | -------------------------------------------------------------------------------- /rl4lms/envs/text_generation/kl_controllers.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Dict, Any 2 | import torch 3 | 4 | 5 | class KLController: 6 | def __init__(self, kl_coeff: float, target_kl: Optional[float] = None) -> None: 7 | self._kl_coeff = kl_coeff 8 | self._target_kl = target_kl 9 | 10 | def step(self, kl_div: torch.tensor): 11 | """ 12 | Adapts the KL coeff 13 | """ 14 | if self._target_kl is not None: 15 | diff_to_target = (kl_div - self._target_kl) / self._target_kl 16 | e_t = torch.clip(diff_to_target, -0.2, 0.2).item() 17 | self._kl_coeff = self._kl_coeff * (1 + 0.1 * e_t) 18 | 19 | @property 20 | def kl_coeff(self): 21 | return self._kl_coeff 22 | 23 | def get_state_dict(self) -> Dict[str, Any]: 24 | state = { 25 | "target_kl": self._target_kl, 26 | "current_kl_coeff": self._kl_coeff 27 | } 28 | return state 29 | 30 | def load_from_state_dict(self, state_dict: Dict[str, Any]): 31 | self._kl_coeff = state_dict["current_kl_coeff"] 32 | self._target_kl = state_dict["target_kl"] 33 | 34 | 35 | if __name__ == "__main__": 36 | contr = KLController(kl_coeff=0.1, target_kl=0.1) 37 | 38 | contr.step(torch.tensor(-0.2)) 39 | print(contr.kl_coeff) 40 | 41 | contr.step(torch.tensor(0.3)) 42 | print(contr.kl_coeff) 43 | 44 | contr.step(torch.tensor(0.4)) 45 | print(contr.kl_coeff) 46 | 47 | state_dict = contr.get_state_dict() 48 | print(state_dict) 49 | 50 | contr._target_kl = None 51 | contr._kl_coeff = None 52 | contr.load_from_state_dict(state_dict) 53 | assert contr._target_kl == state_dict["target_kl"] 54 | assert contr._kl_coeff == state_dict["current_kl_coeff"] 55 | -------------------------------------------------------------------------------- /rl4lms/envs/text_generation/logging_utils.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import Dict, Any, List 3 | import os 4 | import json 5 | import jsonlines 6 | import wandb 7 | import pandas as pd 8 | from transformers import AutoModel 9 | import logging 10 | import copy 11 | import random 12 | from rich.logging import RichHandler 13 | 14 | 15 | class Tracker: 16 | def __init__(self, 17 | base_path_to_store_results: str, 18 | run_config: Dict[str, Any], 19 | project_name: str, 20 | experiment_name: str, 21 | entity_name: str = None, 22 | wandb_log: bool = False, 23 | log_level: int = logging.DEBUG, 24 | ): 25 | self._log_level = log_level 26 | self._base_path_to_store_results = base_path_to_store_results 27 | self._config = run_config 28 | self._experiment_name = experiment_name 29 | self._project_name = project_name 30 | self._entity_name = entity_name 31 | self._wandb_log = wandb_log 32 | self._init() 33 | 34 | def _init(self): 35 | # create a folder 36 | self._run_path = os.path.join( 37 | self._base_path_to_store_results, 38 | self._project_name, 39 | self._experiment_name) 40 | os.makedirs(self._run_path, exist_ok=True) 41 | 42 | # store also the config into it 43 | config_path = os.path.join(self._run_path, "config.json") 44 | with open(config_path, "w") as fp: 45 | json.dump(self._config, fp) 46 | 47 | # init logger 48 | log_path = os.path.join(self._run_path, "log.txt") 49 | logging.basicConfig( 50 | level=self._log_level, 51 | format="%(asctime)s [%(levelname)s] %(message)s", 52 | handlers=[ 53 | logging.FileHandler(log_path), 54 | RichHandler() 55 | ] 56 | ) 57 | 58 | # init wandb 59 | if self._wandb_log: 60 | self._wandb_run = wandb.init( 61 | entity=self._entity_name, 62 | project=self._project_name, 63 | name=self._experiment_name, 64 | config=self._config 65 | ) 66 | 67 | def log_predictions(self, epoch: int, 68 | split_name: str, 69 | predictions: List[Dict]): 70 | # log them per epoch in a separate file as they can get huge 71 | prediction_file_at_epoch = os.path.join( 72 | self._run_path, f"epoch_{epoch}_{split_name}_split_predictions.json") 73 | with open(prediction_file_at_epoch, "w") as fp: 74 | json.dump(predictions, fp) 75 | 76 | # randomly display few predictions for logging 77 | predictions_ = copy.deepcopy(predictions) 78 | random.shuffle(predictions_) 79 | logging.info(f"Split {split_name} predictions") 80 | for pred in predictions_[:10]: 81 | logging.info(pred) 82 | 83 | # for wandb logging, we create a table consisting of predictions 84 | # we can create one table per split per epoch 85 | if self._wandb_log and len(predictions) > 0: 86 | 87 | def to_df(predictions): 88 | columns = predictions[0].keys() 89 | data_by_column = defaultdict(list) 90 | for item in predictions: 91 | for column in columns: 92 | data_by_column[column].append(item.get(column, "")) 93 | data_df = pd.DataFrame(data_by_column) 94 | return data_df 95 | 96 | predictions_as_df = to_df(predictions) 97 | predictions_table_at_epoch = wandb.Table(data=predictions_as_df) 98 | self._wandb_run.log({ 99 | f"{split_name}_predictions_at_epoch_{epoch}": predictions_table_at_epoch}) 100 | 101 | def log_metrics(self, epoch: int, 102 | split_name: str, 103 | metrics_dict: Dict[str, float]): 104 | # for each split, one file 105 | metric_file_per_split = os.path.join( 106 | self._run_path, f"{split_name}_split_metrics.jsonl") 107 | metrics_dict_ = { 108 | "epoch": epoch, 109 | "metrics": metrics_dict 110 | } 111 | with jsonlines.open(metric_file_per_split, "a") as writer: 112 | writer.write(metrics_dict_) 113 | 114 | # log to wandb 115 | if self._wandb_log: 116 | metric_dict_ = { 117 | f"{split_name}/{metric_key}": value for metric_key, value in metrics_dict.items()} 118 | metric_dict_["epoch"] = epoch 119 | wandb.log(metric_dict_) 120 | 121 | # logger 122 | logging.info(f"{split_name} metrics: {metrics_dict_}") 123 | 124 | def log_rollout_infos(self, rollout_info: Dict[str, float]): 125 | logging.info(f"Rollout Info: {rollout_info}") 126 | rollout_info_file = os.path.join( 127 | self._run_path, "rollout_info.jsonl") 128 | with jsonlines.open(rollout_info_file, mode="a") as writer: 129 | writer.write(rollout_info) 130 | 131 | # log to wandb 132 | if self._wandb_log: 133 | wandb.log(rollout_info) 134 | 135 | def log_training_infos(self, training_info: Dict[str, float]): 136 | logging.info(f"Training Info: {training_info}") 137 | training_info_file = os.path.join( 138 | self._run_path, "training_info.jsonl") 139 | with jsonlines.open(training_info_file, mode="a") as writer: 140 | writer.write(training_info) 141 | 142 | # log to wandb 143 | if self._wandb_log: 144 | wandb.log(training_info) 145 | 146 | def done(self): 147 | if self._wandb_log: 148 | wandb.finish() 149 | 150 | def save_auto_model(self, model: AutoModel): 151 | model_path = os.path.join(self._run_path, "model") 152 | model.save_pretrained(model_path) 153 | 154 | @property 155 | def checkpoint_base_path(self): 156 | return os.path.join(self._run_path, "checkpoints") 157 | 158 | def log_info(self, msg: str): 159 | logging.info(msg) 160 | 161 | 162 | if __name__ == "__main__": 163 | base_path = "/scratch/test_logs" 164 | run_config = { 165 | "param_1": 1, 166 | "param_2": 2 167 | } 168 | predictions = { 169 | "1": [{"sample_id": "1", "prompt_text": "Hello", "gen_text": "I am there"}, 170 | {"sample_id": "2", "prompt_text": "Hi", "gen_text": "there"}], 171 | "2": [{"sample_id": "1", "prompt_text": "Hello", "gen_text": "I am there"}, 172 | {"sample_id": "2", "prompt_text": "Hi", "gen_text": "there"}], 173 | "3": [{"sample_id": "1", "prompt_text": "Hello", "gen_text": "I am there"}, 174 | {"sample_id": "2", "prompt_text": "Hi", "gen_text": "there"}], 175 | } 176 | 177 | metrics = { 178 | "1": {"metric_1": 0.05, "metric_2": 0.1}, 179 | "2": {"metric_1": 0.06, "metric_2": 0.2}, 180 | "3": {"metric_1": 0.06, "metric_2": 0.3}, 181 | } 182 | 183 | rollout_infos = [ 184 | {"ep_len": 2, "ep_reward": 0.4}, 185 | {"ep_len": 3, "ep_reward": 0.5}, 186 | {"ep_len": 3, "ep_reward": 0.5}, 187 | ] 188 | 189 | tracker = Tracker(base_path, run_config, "Test run", True) 190 | tracker.log_predictions(1, "val", predictions["1"]) 191 | tracker.log_metrics(1, "val", metrics["1"]) 192 | tracker.log_predictions(2, "val", predictions["2"]) 193 | tracker.log_metrics(2, "val", metrics["2"]) 194 | tracker.log_predictions(3, "val", predictions["3"]) 195 | tracker.log_metrics(3, "val", metrics["3"]) 196 | tracker.log_rollout_infos(rollout_infos[0]) 197 | tracker.log_rollout_infos(rollout_infos[1]) 198 | tracker.log_rollout_infos(rollout_infos[2]) 199 | tracker.done() 200 | -------------------------------------------------------------------------------- /rl4lms/envs/text_generation/policy/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/RL4LMs/97df0bd2f7406a906206c9610aea795fbf52884c/rl4lms/envs/text_generation/policy/__init__.py -------------------------------------------------------------------------------- /rl4lms/envs/text_generation/post_processors.py: -------------------------------------------------------------------------------- 1 | from nltk.tokenize import sent_tokenize 2 | 3 | 4 | def three_sentence_summary(text): 5 | """ 6 | Returns first three sentences from the generated text 7 | """ 8 | return "\n".join(sent_tokenize(text)[:3]) 9 | -------------------------------------------------------------------------------- /rl4lms/envs/text_generation/summ_metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/RL4LMs/97df0bd2f7406a906206c9610aea795fbf52884c/rl4lms/envs/text_generation/summ_metrics/__init__.py -------------------------------------------------------------------------------- /rl4lms/envs/text_generation/test_datapool.py: -------------------------------------------------------------------------------- 1 | from rl4lms.data_pools.text_generation_pool import TextGenPool, Sample 2 | 3 | 4 | class TestTextGenPool(TextGenPool): 5 | @classmethod 6 | def prepare(cls, split: str, prompt: str, n_samples=100): 7 | samples = [Sample(id=ix, 8 | prompt_or_input_text=prompt, # a dummy prompt 9 | references=[] 10 | ) for ix in range(n_samples)] 11 | pool_instance = cls(samples) 12 | return pool_instance 13 | -------------------------------------------------------------------------------- /rl4lms/envs/text_generation/test_metric.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Any, Dict, List 3 | 4 | import numpy as np 5 | from rl4lms.envs.text_generation.metric import BaseMetric 6 | from rl4lms.envs.text_generation.test_reward import (RewardIncreasingNumbers, 7 | RewardSentencesWithDates) 8 | from transformers import PreTrainedModel 9 | 10 | 11 | class IncreasingNumbersinText(BaseMetric): 12 | def __init__(self, min_tokens: int) -> None: 13 | super().__init__() 14 | self._min_tokens = min_tokens 15 | 16 | def compute(self, prompt_texts: List[str], 17 | generated_texts: List[str], 18 | reference_texts: List[List[str]], 19 | meta_infos: List[Dict[str, Any]] = None, 20 | model: PreTrainedModel = None, 21 | split_name: str = None) -> Dict[str, float]: 22 | 23 | all_rewards = [] 24 | for gen_text in generated_texts: 25 | reward = RewardIncreasingNumbers.reward_increasing_numbers_in_text( 26 | gen_text, self._min_tokens) 27 | all_rewards.append(reward) 28 | 29 | metric_dict = { 30 | "synthetic/increasing_numbers_in_text": (all_rewards, np.mean(all_rewards)) 31 | } 32 | return metric_dict 33 | 34 | 35 | class DateInText(BaseMetric): 36 | def compute(self, prompt_texts: List[str], 37 | generated_texts: List[str], 38 | reference_texts: List[List[str]], 39 | meta_infos: List[Dict[str, Any]] = None, 40 | model: PreTrainedModel = None, 41 | split_name: str = None) -> Dict[str, float]: 42 | 43 | all_rewards = [] 44 | for gen_text in generated_texts: 45 | reward = RewardSentencesWithDates.date_in_text( 46 | gen_text) 47 | all_rewards.append(reward) 48 | metric_dict = { 49 | "synthetic/dates_in_text": (all_rewards, np.mean(all_rewards)) 50 | } 51 | return metric_dict 52 | -------------------------------------------------------------------------------- /rl4lms/envs/text_generation/test_reward.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Any, Dict 3 | 4 | from rl4lms.envs.text_generation.observation import Observation 5 | from rl4lms.envs.text_generation.reward import RewardFunction 6 | 7 | 8 | class RewardIncreasingNumbers(RewardFunction): 9 | def __init__(self, 10 | min_tokens: int) -> None: 11 | super().__init__() 12 | self.min_tokens = min_tokens 13 | 14 | @staticmethod 15 | def is_number(text): 16 | try: 17 | float(text) 18 | return True 19 | except ValueError: 20 | return False 21 | 22 | @staticmethod 23 | def reward_increasing_numbers_in_text(gen_text: str, 24 | min_tokens: int): 25 | gen_tokens = gen_text.split() 26 | number_tokens = [float(token) 27 | for token in gen_tokens if RewardIncreasingNumbers.is_number(token)] 28 | if len(number_tokens) > 0: 29 | # then we check how many numbers are in the sorted order 30 | sorted_count = 1 31 | previous_token = number_tokens[0] 32 | for token in number_tokens[1:]: 33 | if token > previous_token: 34 | sorted_count += 1 35 | previous_token = token 36 | else: 37 | break 38 | return ((sorted_count)/max(len(gen_tokens), (min_tokens/2))) 39 | return 0 40 | 41 | def __call__(self, prev_observation: Observation, 42 | action: int, 43 | current_observation: Observation, 44 | done: bool, 45 | meta_info: Dict[str, Any] = None) -> float: 46 | if done: 47 | gen_text = current_observation.context_text 48 | reward = RewardIncreasingNumbers.reward_increasing_numbers_in_text( 49 | gen_text, self.min_tokens) 50 | return reward 51 | return 0 52 | 53 | 54 | class RewardSentencesWithDates: 55 | 56 | def date_in_text(text: str): 57 | match = re.search(r'\d{4}-\d{2}-\d{2}', 58 | text) 59 | if match is not None: 60 | return 1 61 | else: 62 | return 0 63 | 64 | def __call__(self, prev_observation: Observation, 65 | action: int, 66 | current_observation: Observation, 67 | done: bool, 68 | meta_info: Dict[str, Any] = None) -> float: 69 | if done: 70 | return RewardSentencesWithDates.date_in_text(current_observation.context_text) 71 | return 0 72 | -------------------------------------------------------------------------------- /rl4lms/envs/text_generation/warm_start.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Dict 3 | 4 | import torch 5 | 6 | from rl4lms.envs.text_generation.logging_utils import Tracker 7 | from rl4lms.envs.text_generation.policy.base_policy import LMActorCriticPolicy 8 | 9 | ################## Policy Warm Start Mixins####################################### 10 | 11 | 12 | class ActorOnlyWarmStartMixin: 13 | def get_state_dict(self) -> Dict[str, Any]: 14 | state_dict = { 15 | "policy_model": self._policy_model.state_dict(), 16 | "optimizer": self.optimizer.state_dict() 17 | } 18 | return state_dict 19 | 20 | def load_from_dict(self, state_dict: dict = None): 21 | if state_dict is not None: 22 | self._policy_model.load_state_dict(state_dict["policy_model"]) 23 | self.optimizer.load_state_dict(state_dict["optimizer"]) 24 | 25 | 26 | class ActorCriticWarmStartMixin: 27 | def get_state_dict(self) -> Dict[str, Any]: 28 | state_dict = { 29 | "policy_model": self._policy_model.state_dict(), 30 | "value_model": self._value_model.state_dict(), 31 | "value_head": self._value_head.state_dict(), 32 | "optimizer": self.optimizer.state_dict() 33 | } 34 | return state_dict 35 | 36 | def load_from_dict(self, state_dict: dict = None): 37 | if state_dict is not None: 38 | self._policy_model.load_state_dict(state_dict["policy_model"]) 39 | self._value_model.load_state_dict(state_dict["value_model"]) 40 | self._value_head.load_state_dict(state_dict["value_head"]) 41 | self.optimizer.load_state_dict(state_dict["optimizer"]) 42 | 43 | 44 | class MaskableActorCriticWarmStartMixin: 45 | def get_state_dict(self) -> Dict[str, Any]: 46 | state_dict = { 47 | "policy_model": self._policy_model.state_dict(), 48 | "value_model": self._value_model.state_dict(), 49 | "value_head": self._value_head.state_dict(), 50 | "mask_model": self._mask_model.state_dict(), 51 | "optimizer": self.optimizer.state_dict() 52 | } 53 | return state_dict 54 | 55 | def load_from_dict(self, state_dict: dict = None): 56 | if state_dict is not None: 57 | self._policy_model.load_state_dict(state_dict["policy_model"]) 58 | self._value_model.load_state_dict(state_dict["value_model"]) 59 | self._value_head.load_state_dict(state_dict["value_head"]) 60 | self.optimizer.load_state_dict(state_dict["optimizer"]) 61 | 62 | 63 | ################## Algorithm Warm Start Mixins####################################### 64 | class OnPolicyWarmStartMixin: 65 | def get_state_dict(self) -> Dict[str, Any]: 66 | # just the kl controller state is sufficient for onpolicy algs 67 | state_dict = { 68 | "kl_controller_state": self._kl_controller.get_state_dict(), 69 | } 70 | return state_dict 71 | 72 | def load_from_dict(self, state_dict: Dict[str, Any]) -> Dict[str, Any]: 73 | if state_dict is not None: 74 | self._kl_controller.load_from_state_dict( 75 | state_dict["kl_controller_state"]) 76 | 77 | 78 | class OffPolicyWarmStartMixin: 79 | def get_state_dict(self) -> Dict[str, Any]: 80 | # TBD: just buffer is sufficient? or is there something else? 81 | state_dict = { 82 | "replay_buffer": self.replay_buffer.get_state_dict(), 83 | } 84 | return state_dict 85 | 86 | def load_from_dict(self, state_dict: Dict[str, Any]) -> Dict[str, Any]: 87 | if state_dict is not None: 88 | self.replay_buffer.load_from_state_dict( 89 | state_dict["replay_buffer"]) 90 | 91 | 92 | ################## Trainer Warm Start Mixins####################################### 93 | class TrainerWarmStartMixin: 94 | def _get_recent_ckpt_path(self, tracker: Tracker): 95 | try: 96 | checkpoints = os.listdir(tracker.checkpoint_base_path) 97 | except: 98 | os.makedirs(tracker.checkpoint_base_path) 99 | checkpoints = os.listdir(tracker.checkpoint_base_path) 100 | 101 | if len(checkpoints) == 0: 102 | return None, None 103 | 104 | sorted_ckpts = sorted(checkpoints, reverse=True, 105 | key=lambda ckpt: int(ckpt.split("_")[1])) 106 | recent_ckpt = sorted_ckpts[0] 107 | recent_ckpt_id = int(recent_ckpt.split("_")[1]) 108 | 109 | recent_ckpt_path = os.path.join( 110 | tracker.checkpoint_base_path, f"checkpoint_{recent_ckpt_id}") 111 | return recent_ckpt_path, recent_ckpt_id 112 | 113 | def load_trainer_state(self, tracker: Tracker): 114 | recent_ckpt_path, _ = self._get_recent_ckpt_path(tracker) 115 | state_dict = None 116 | try: 117 | if recent_ckpt_path is not None: 118 | state_dict = torch.load( 119 | recent_ckpt_path, map_location=torch.device("cuda")) 120 | tracker.log_info("Model checkpoint found - Warm starting") 121 | self._policy_state_dict = state_dict["policy_state"] 122 | self._alg_state_dict = state_dict["alg_state"] 123 | self._trainer_state = state_dict["trainer_state"] 124 | 125 | tracker.log_info( 126 | f"Loaded the current trainer state from: {self._trainer_state}") 127 | else: 128 | self._policy_state_dict = None 129 | self._alg_state_dict = None 130 | self._trainer_state = { 131 | "current_iter": 0, 132 | } 133 | except Exception as e: 134 | tracker.log_info(f"Exception while doing warm start {e}") 135 | tracker.log_info( 136 | f"Checkpoint may be corrupted...skipping warm start") 137 | self._policy_state_dict = None 138 | self._alg_state_dict = None 139 | self._trainer_state = { 140 | "current_iter": 0, 141 | } 142 | 143 | def save_trainer_state(self, tracker: Tracker, 144 | policy: LMActorCriticPolicy, 145 | trainer_state: Dict[str, Any]): 146 | full_state = { 147 | "alg_state": self._alg.get_state_dict(), 148 | "policy_state": policy.get_state_dict(), 149 | "trainer_state": trainer_state 150 | } 151 | _, recent_ckpt_id = self._get_recent_ckpt_path(tracker) 152 | 153 | # hot fix - just to save only the last checkpoint (overwrite) 154 | new_ckpt_id = 0 if recent_ckpt_id is None else recent_ckpt_id + 1 155 | new_ckpt_path = os.path.join( 156 | tracker.checkpoint_base_path, f"checkpoint_{new_ckpt_id}") 157 | torch.save(full_state, new_ckpt_path, pickle_protocol=4) 158 | -------------------------------------------------------------------------------- /scripts/crowdworking_templates/likert/IMDB_sentiment_completion_example_input.csv: -------------------------------------------------------------------------------- 1 | prompt,machine_completion,algorithm 2 | "After the already disappointing ""Final Conflict"", the series hits rock-bottom with this very weak fourth entry. At least the third film tried (unsuccessfully) to continue the story of Damien, while this one simply rehashes and copies ideas from the ""Omen"" (animals are"," replaced with ghosts) series after the events of ""Ego One"". But this may not hold up as a complete piece, as the fourth issue was not so successful, as the fourth film does not hold up as a complete film.

""A Matter of Life and Death"" has a brilliant cinematography by Jack Cardiff, a man who knew how to work wonders with a camera. Particularly impressive is the contrast from the monochromatic tones given to the scenes played in heaven, and the colored ones when the action comes back to earth. This was quite a coup, and well ahead of its times. The black and white sequence that involves the long staircase where Peter and the Conductor are chatting has to be one of the most amazing things on any film.

Much has been said in this forum about the film, so our comment will be about the great acting Powell and Pressberger got out of the large, distinguished cast, who responded magnificently to the directors' guidance.

David Niven, is Peter,"," a gifted actor, who gets a chance to play with many different actors in this film. David is an actor from China, who has traveled to the United States from China from time to time, where he had a chance to meet many different ...",generated_text 4 | "I would have enjoyed this movie slightly more had not been for Jason (Herb) Evers constant harping on experiment. Many early reviewers of The Seven Samurai accused Toshiro Mifune of overacting. Yet, as more and more critics viewed that film they saw it as being purposefully done. Jason Evers is obviously not Toshiro Mifune, and his overacting is exactly that.

Most of the actors in this B classic were rather"," different from each other. Most of the actors are in different walks of life, and not just friends, which is interesting considering many of the characters in this film would be friends. All in all, a great film to watch in any movie ...",generated_text 5 | "Like his early masterpiece ""The Elephant Man"" Lynch proves to his detractors that he can tell a straight, simple story without losing his artistic touch. This is a true story of an elderly retired man (expertly played by Richard Farnsworth) who decides to ride a tractor across a few states to pay a final visit to his estranged brother who now stands"," in awe of him. The film is filled with humor and suspense, without a dull moment. The film is well paced and well written. As a film, you will enjoy watching the people in the film.

The ...",generated_text 6 | "This movie is some of the worst crap I have ever seen. I literally got a sharp pain in my head while watching this movie. The CGI was awful, and the story was just a waste of ink. Dean Cain's character was Mr-Super-Intuitive-I-can-figure-out-anything, except he can't seem to work his own helicopter correctly. The biggest problem"," I had with ""Dancing Star"" was that so many scenes got stuck in between, and that was really annoying.

I just hope it was ever released on big-name TV, because everyone who commented on this ...",generated_text 7 | "Caught part of GEORGE 2 on TV recently, but couldn't get myself to watch it through to the end. Just awful! I can't even remember the plot. All I know is that George and Ursula were not the George and Ursula of the first movie, which was bad enough. There's a lot of scrambling around, but the direction and editing were so shoddy and choppy, it was like watching outtakes"," , and as a result, the movie sucked. When they tried to edit the first scene into a story and it didn't make any sense, it totally blew me away. You get that, you get that.

...",generated_text 8 | "Remember the early days of Pay Per View? I do, and i can almost remember the number you had to CALL to actually rent the movie on your t.v. As a kid we always wanted to rent playboy, but this meant actually calling someone from PPV and asking to rent it. And then you get the nerve to do it and your watching four hours of soft core no angle crap. Well the reason i bring that up is because this movie too was on ppv. And i remember almost every scene", so why was i talking to everyone but i didnt. When they had a playboy playgirl on their night out call you could practically hear her sob every time she caught a glimpse of her father. Because he never wanted her around. It ...,generated_text 9 | "If you weren't there, then unfortunately this movie will be beyond compassion for you. Which as I say is a shame because although some of the acting is amateurish, it is meant to be for realism. Let's face it--in real life, we don't say things in an exacting or perfect way, even when we mean to. In this sense, it works. This, however, does not apply to our ""known"" actors in this film, notably Jodie Foster (born a natural). The fact that the other 3 girls are not accomplished only adds to the story--Jodie plays the glue that struggles to keep their friendship close, even with the obvious feeling of fatality. Meaning that no matter how close friends are, eventually there are some people that just fade away, no matter how you try.

And therein is the core of the movie. It's not about partying, it's not about sexuality, but about these 4 girls and their final time as still young girls before they have to go the world alone.

If you have ever had a friendship like that in your life, you will feel this movie--it will mean a lot to you, no matter what era it is set in, or what era you grew up in. We all knew these girls in school, or at the very least knew of them."," I don't want to go through theirstories, but suffice it to say that one of the only people who is still alive is Jean Grey, whose character is a young girl from the 50's who grew up and became a very important woman ...",generated_text 10 | "This movie tells the tender tale of a demented scientist who, after his fiance is decapitated, goes around ogling strippers so that he"," can find her in the back of a truck. He meets the couple, and they both find a place in the basement of a house that they love. The couple finds a doctor, so they go down to the cellar to have a coffee ...",generated_text 11 | -------------------------------------------------------------------------------- /scripts/crowdworking_templates/likert/commongen_example_input.csv: -------------------------------------------------------------------------------- 1 | prompt,machine_completion,algorithm 2 | generate a sentence with: coral fish island., a few fish with corals on a tropical island,generated_text 3 | generate a sentence with: coral fish island., a turtle eating coral coral with fish on island.,generated_text 4 | generate a sentence with: bathroom light wall., A bathroom with light on the wall.,generated_text 5 | generate a sentence with: bathroom light wall., A bathroom with light on the wall and window blinds in a roomy bathroom.,generated_text 6 | generate a sentence with: couple love meadow., this couple are in love with nature and a meadow.,generated_text 7 | generate a sentence with: couple love meadow., a romantic couple is in love on the meadow.,generated_text 8 | generate a sentence with: bridge lock padlock., A lock padlock on an old wooden ferry bridge.,generated_text 9 | generate a sentence with: bridge lock padlock., A blue and white car underneath a bridge with padlocks and locks in.,generated_text 10 | generate a sentence with: colour riot summer., shades of shades of blue are riotous in the sunshine during summer.,generated_text 11 | -------------------------------------------------------------------------------- /scripts/crowdworking_templates/likert/daily_dialogue_example_input.csv: -------------------------------------------------------------------------------- 1 | prompt,machine_completion,algorithm 2 | "

Person 1: when do the playoffs start?

Person 2: in a few weeks. Are you planning on watching them?

Person 1: of course! I really love watching football games on TV.

Person 2: Me, too! Do you want to come over and watch the next game at my place? I'm planning on having a few people over to watch it together.

Person 1: sure, that would be great.

","

Person 2: it would be great, too! I'll watch it in my room when I'm finished with it

",zero_shot 3 | "

Person 1: When was this break-in?

Person 2: I believe it happened last night.

Person 1: Where was the break-in?

Person 2: It happened at my house.

Person 1: Was anything stolen?

","

Person 2: No.

",base_ppo 4 | "

Person 1: I'm sorry we won't. You can wait for two weeks. If you don't get a telephone call, it means that you weren't successful.

Person 2: Then do I have the chance to get this job?

Person 1: I'm sorry but I can't make the final decision myself, and I have to discuss it with other interviewers.

Person 2: I know. No matter what the result will be, I have learned a lot from our conversation.

Person 1: Your mentality is very good and that's great.

","

Person 2: Thanks for everything, but I'm still a little unsure whether to accept a job as you know it

",supervised 5 | "

Person 1: The house warming party.

Person 2: Whose house warming party is it?

Person 1: Tom and Bill Smith. They are both working now you know? And they've bought a new house.

Person 2: Oh, they have? I didn't know. I haven't seen Tom lately.

Person 1: It's out in the suburbs.

","

Person 2: But there is a lot to see. Tom usually just drives us to the car park or the neighborhood

",supervised 6 | "

Person 1: I'm going to buy a CD.

Person 2: Why not listen to music online?

Person 1: Then I will have to pay quite a lot to listen to music on line, isn't it?

Person 2: You can audition music on line, and you can audition it off line after downloading it, so you needn't pay for it.

Person 1: I have downloaded one, but it still can't be listened.

","

Person 2: If you can't stand listening to it online, I suppose you can still download the free stuff.

",supervised 7 | "

Person 1: Didn't I?

Person 2: You know you didn't.

Person 1: Well, I'm telling you now.

Person 2: Yes, but you might have told me before.

Person 1: I didn't think you'd be interested.

","

Person 2: What did I do?

",supervised_nlpo 8 | "

Person 1: Yes, we have a great variety of Chinese antiques.

Person 2: What do you have?

Person 1: Well, we have some painting and some porcelain. Which do you like better?

Person 2: I'd like porcelain.

Person 1: Good. How about the table set?

","

Person 2: I don't love table sets.

",base_ppo 9 | "

Person 1: Accessories? Do you carry polarizing filters?

Person 2: Polarizing filter um we should! I ' m sorry, ma ' am. It looks like we ' re sold out.

Person 1: No you ' re not! There are some right here!

Person 2: Oh, well, would you look at that! Polarizing filters.

Person 1: Thanks for your help, Ralph!

","

Person 2: No worries about it, there is no trouble!

",zero_shot 10 | "

Person 1: I called HP about my computer.

Person 2: What did they say?

Person 1: They said I need a new hard drive.

Person 2: That's too bad. How much is a new one?

Person 1: It's not too much, only about $ 85.

","

Person 2: What is a new hard drive?

",base_ppo 11 | -------------------------------------------------------------------------------- /scripts/crowdworking_templates/pairwise/commongen_pairwise_example_input.csv: -------------------------------------------------------------------------------- 1 | concept,choice_A,choice_B,choice_A_id,choice_B_id 2 | shirt street walk wear,a man is wearing a shirt as he walks down the street.,a man wears a shirt as he walks down the street.,supervised,pref_model 3 | arm forehead kiss neck wrap,a man wraps his arms around his neck and kisses his forehead.,"a man kisses his arm, forehead, and wraps his neck.",supervised,pref_model 4 | bike break ride take,a man takes a break from riding his bike.,a man is riding a bike and taking a break.,pref_model,supervised 5 | cut hold saw wood,a man holds a saw to cut wood.,a man is holding a saw to cut a piece of wood.,pref_model,supervised 6 | chair head shave sit,a man shaves his head in a chair and sits down.,a man sits in a chair and shaves his head.,pref_model,supervised 7 | ear hold phone talk toddler,a toddler is holding a phone in his ear and talking.,a toddler holds his ear up and talks to a phone.,supervised,pref_model 8 | catch dog mouth toy,a dog catches a toy in his mouth.,a dog is catching a toy in his mouth.,pref_model,supervised 9 | hill jump ramp ride,a man is jumping off a ramp and riding down a hill.,a man jumps off a ramp and rides down a hill.,supervised,pref_model 10 | feed fish hand tank,a hand feeds a fish in a tank.,a man is feeding a fish in his hand in a tank.,pref_model,supervised 11 | -------------------------------------------------------------------------------- /scripts/reward-modeling/evaluate_intent_classifier.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification 2 | from datasets import load_dataset, load_metric 3 | from transformers import AutoTokenizer 4 | import numpy as np 5 | from rl4lms.data_pools.custom_text_generation_pools import DailyDialog 6 | from datasets.arrow_dataset import Dataset 7 | import torch 8 | from sklearn.metrics import classification_report 9 | from tqdm import tqdm 10 | 11 | def get_batch(samples, batch_size: int): 12 | current_ix = 0 13 | n_samples = len(samples) 14 | while current_ix < n_samples: 15 | current_batch = samples[current_ix : current_ix + batch_size] 16 | yield current_batch 17 | current_ix += batch_size 18 | 19 | def get_dataset(datapool, label="intent"): 20 | # get the training data in text, label format 21 | texts = [] 22 | labels = [] 23 | for sample, _ in datapool: 24 | 25 | history = sample.prompt_or_input_text.split(DailyDialog.EOU_TOKEN) 26 | history = [utt for utt in history if utt != ""] 27 | last_utterance = history[-1] 28 | 29 | # just consider the utterance 30 | input_text = last_utterance + DailyDialog.EOU_TOKEN + sample.references[0] 31 | 32 | texts.append(input_text) 33 | labels.append(sample.meta_data[label][0]-1) 34 | 35 | print(np.unique(labels, return_counts=True)) 36 | 37 | dataset = Dataset.from_dict( 38 | { 39 | "text": texts, 40 | "labels": labels 41 | }, 42 | split="train" 43 | ) 44 | return dataset 45 | 46 | tokenizer = AutoTokenizer.from_pretrained("rajkumarrrk/roberta-daily-dialog-intent-classifier") 47 | model = AutoModelForSequenceClassification.from_pretrained("rajkumarrrk/roberta-daily-dialog-intent-classifier") 48 | 49 | # data pool 50 | train_dp = DailyDialog.prepare("train", 1) 51 | val_dp = DailyDialog.prepare("val", 1) 52 | 53 | # train and val dataset 54 | ds_train = get_dataset(train_dp, "intent") 55 | ds_test = get_dataset(val_dp, "intent") 56 | 57 | 58 | all_pred_labels = [] 59 | all_target_labels = [] 60 | batches = list(get_batch(ds_test, 10)) 61 | for batch in tqdm(batches): 62 | encoded = tokenizer( 63 | batch["text"], 64 | return_tensors="pt", 65 | truncation=True, 66 | padding=True) 67 | with torch.no_grad(): 68 | outputs = model(input_ids=encoded.input_ids, 69 | attention_mask=encoded.attention_mask) 70 | pred_labels = torch.argmax(outputs.logits, dim=1).tolist() 71 | all_pred_labels.extend(pred_labels) 72 | all_target_labels.extend(batch["labels"]) 73 | 74 | print(classification_report(all_target_labels, all_pred_labels)) -------------------------------------------------------------------------------- /scripts/reward-modeling/train_intent_classifier.py: -------------------------------------------------------------------------------- 1 | from transformers import TrainingArguments, Trainer, DataCollatorWithPadding 2 | from datasets import load_dataset, load_metric 3 | from transformers import AutoModelForSequenceClassification, AutoTokenizer 4 | import numpy as np 5 | from rl4lms.data_pools.custom_text_generation_pools import DailyDialog 6 | from datasets.arrow_dataset import Dataset 7 | 8 | def get_dataset(datapool, label="intent"): 9 | # get the training data in text, label format 10 | texts = [] 11 | labels = [] 12 | for sample, _ in datapool: 13 | 14 | history = sample.prompt_or_input_text.split(DailyDialog.EOU_TOKEN) 15 | history = [utt for utt in history if utt != ""] 16 | last_utterance = history[-1] 17 | 18 | # just consider the utterance 19 | input_text = last_utterance + DailyDialog.EOU_TOKEN + sample.references[0] 20 | 21 | texts.append(input_text) 22 | labels.append(sample.meta_data[label][0]-1) 23 | 24 | print(np.unique(labels, return_counts=True)) 25 | 26 | dataset = Dataset.from_dict( 27 | { 28 | "text": texts, 29 | "labels": labels 30 | }, 31 | split="train" 32 | ) 33 | return dataset 34 | 35 | 36 | def main(): 37 | 38 | # label 39 | label = "intent" 40 | 41 | # results folder 42 | results_folder = f"./results/{label}" 43 | 44 | # data pool 45 | train_dp = DailyDialog.prepare("train", 1) 46 | val_dp = DailyDialog.prepare("val", 1) 47 | 48 | # train and val dataset 49 | ds_train = get_dataset(train_dp, label) 50 | ds_test = get_dataset(val_dp, label) 51 | 52 | model_name = "cardiffnlp/twitter-roberta-base-emotion" 53 | model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=4) 54 | tokenizer = AutoTokenizer.from_pretrained(model_name) 55 | 56 | def tokenize(examples): 57 | outputs = tokenizer(examples['text'], truncation=True) 58 | return outputs 59 | 60 | tokenized_ds_train = ds_train.map(tokenize, batched=True) 61 | tokenized_ds_test = ds_test.map(tokenize, batched=True) 62 | 63 | def compute_metrics(eval_preds): 64 | metric = load_metric("accuracy") 65 | logits, labels = eval_preds 66 | predictions = np.argmax(logits, axis=-1) 67 | return metric.compute(predictions=predictions, references=labels) 68 | 69 | training_args = TrainingArguments(num_train_epochs=10, 70 | output_dir=results_folder, 71 | # push_to_hub=True, 72 | per_device_train_batch_size=8, 73 | per_device_eval_batch_size=64, 74 | evaluation_strategy="steps", 75 | save_strategy='steps', 76 | logging_steps=20, 77 | save_total_limit=1, 78 | save_steps=100, 79 | lr_scheduler_type="constant", 80 | learning_rate=1e-6, 81 | ) 82 | 83 | data_collator = DataCollatorWithPadding(tokenizer) 84 | 85 | trainer = Trainer(model=model, tokenizer=tokenizer, 86 | data_collator=data_collator, 87 | args=training_args, 88 | train_dataset=tokenized_ds_train, 89 | eval_dataset=tokenized_ds_test, 90 | compute_metrics=compute_metrics) 91 | 92 | trainer.train(resume_from_checkpoint=True) 93 | 94 | if __name__ == '__main__': 95 | main() 96 | 97 | -------------------------------------------------------------------------------- /scripts/training/task_configs/common_gen/t5_nlpo.yml: -------------------------------------------------------------------------------- 1 | tokenizer: 2 | model_name: t5-base 3 | padding_side: left 4 | truncation_side: left 5 | pad_token_as_eos_token: False 6 | 7 | reward_fn: 8 | id: meteor 9 | args: 10 | shaping_fn: "common_gen_repeat_penalty" 11 | 12 | 13 | datapool: 14 | id: commongen 15 | args: 16 | concept_end_token: '.' 17 | concept_separator_token: ' ' 18 | prefix: "generate a sentence with: " 19 | 20 | env: 21 | n_envs: 10 22 | args: 23 | max_prompt_length: 15 24 | max_episode_length: 20 25 | terminate_on_eos: True 26 | context_start_token: 0 27 | 28 | alg: 29 | id: nlpo 30 | args: 31 | n_steps: 128 32 | batch_size: 64 33 | verbose: 1 34 | learning_rate: 0.000002 35 | n_epochs: 5 36 | kl_div: 37 | coeff: 0.001 38 | target_kl: 2.0 39 | policy: 40 | id: maskable_seq2seq_lm_actor_critic_policy 41 | args: 42 | model_name: t5-base 43 | apply_model_parallel: True 44 | mask_type: "learned_top_p" 45 | top_mask: 0.9 46 | target_update_iterations: 20 47 | generation_kwargs: 48 | do_sample: True 49 | top_k: 50 50 | min_length: 10 51 | max_new_tokens: 20 52 | 53 | train_evaluation: 54 | eval_batch_size: 100 55 | n_iters: 100 56 | eval_every: 10 57 | save_every: 20 58 | metrics: 59 | - id: meteor 60 | args: {} 61 | - id: rouge 62 | - id: bleu 63 | args: {} 64 | - id: bert_score 65 | args: 66 | language: en 67 | # - id: bleurt 68 | # args: 69 | # config_name: bleurt-large-512 70 | - id: diversity 71 | args: {} 72 | # - id: summaCZS 73 | # args: 74 | # granularity: sentence 75 | # use_ent: True 76 | # use_con: False 77 | # - id: summaCConv 78 | # args: 79 | # granularity: sentence 80 | generation_kwargs: 81 | do_sample: True 82 | top_k: 50 83 | min_length: 10 84 | max_new_tokens: 20 85 | -------------------------------------------------------------------------------- /scripts/training/task_configs/common_gen/t5_nlpo_on_supervised.yml: -------------------------------------------------------------------------------- 1 | tokenizer: 2 | model_name: t5-base 3 | padding_side: left 4 | truncation_side: left 5 | pad_token_as_eos_token: False 6 | 7 | reward_fn: 8 | id: meteor 9 | # values: 10 | # #- id: rouge_combined 11 | # - id: meteor 12 | # # - id: rouge 13 | # # args: 14 | # # rouge_type: "rouge1" 15 | # # - id: spider 16 | # # args: 17 | # # spice_coeff: 0.0 18 | # # cider_coeff: 1.0 19 | # # - id: spider 20 | # # args: 21 | # # spice_coeff: 1.0 22 | # # cider_coeff: 0.0 23 | # # # - id: spider 24 | # # # args: 25 | # # # spice_coeff: 0.5 26 | # # # cider_coeff: 0.5 27 | 28 | 29 | datapool: 30 | id: commongen 31 | args: 32 | concept_end_token: '.' 33 | concept_separator_token: ' ' 34 | prefix: "generate a sentence with: " 35 | 36 | 37 | env: 38 | n_envs: 10 39 | args: 40 | max_prompt_length: 20 41 | max_episode_length: 20 42 | terminate_on_eos: True 43 | context_start_token: 0 44 | prompt_truncation_side: "right" 45 | 46 | 47 | alg: 48 | id: nlpo 49 | args: 50 | n_steps: 128 51 | batch_size: 64 52 | verbose: 1 53 | learning_rate: 0.000002 54 | n_epochs: 5 55 | ent_coef: 0.01 56 | kl_div: 57 | coeff: 0.01 58 | target_kl: 1.0 59 | policy: 60 | id: maskable_seq2seq_lm_actor_critic_policy 61 | args: 62 | model_name: rajkumarrrk/t5-common-gen 63 | apply_model_parallel: True 64 | prompt_truncation_side: "right" 65 | mask_type: "learned_top_p" 66 | top_mask: 0.9 67 | target_update_iterations: 20 68 | generation_kwargs: 69 | do_sample: True 70 | top_k: 50 71 | min_length: 5 72 | max_new_tokens: 20 73 | 74 | train_evaluation: 75 | eval_batch_size: 50 76 | n_iters: 100 77 | eval_every: 5 78 | save_every: 20 79 | metrics: 80 | - id: meteor 81 | args: {} 82 | - id: rouge 83 | - id: bleu 84 | args: {} 85 | - id: bert_score 86 | args: 87 | language: en 88 | - id: cider 89 | - id: spice 90 | - id: diversity 91 | args: {} 92 | generation_kwargs: 93 | num_beams: 5 94 | min_length: 5 95 | max_new_tokens: 20 96 | 97 | -------------------------------------------------------------------------------- /scripts/training/task_configs/common_gen/t5_ppo.yml: -------------------------------------------------------------------------------- 1 | tokenizer: 2 | model_name: t5-base 3 | padding_side: left 4 | truncation_side: left 5 | pad_token_as_eos_token: False 6 | 7 | reward_fn: 8 | id: meteor 9 | args: 10 | shaping_fn: "common_gen_repeat_penalty" 11 | # values: 12 | # - id: rouge_combined 13 | # args: 14 | # shaping_fn: "common_gen_repeat_penalty" 15 | # - id: meteor 16 | # args: 17 | # shaping_fn: "common_gen_repeat_penalty" 18 | # - id: rouge 19 | # args: 20 | # rouge_type: "rouge1" 21 | # shaping_fn: "common_gen_repeat_penalty" 22 | # - id: spider 23 | # args: 24 | # spice_coeff: 0.0 25 | # cider_coeff: 1.0 26 | # shaping_fn: "common_gen_repeat_penalty_batched" 27 | # - id: spider 28 | # args: 29 | # spice_coeff: 1.0 30 | # cider_coeff: 0.0 31 | # shaping_fn: "common_gen_repeat_penalty_batched" 32 | # - id: spider 33 | # args: 34 | # spice_coeff: 0.5 35 | # cider_coeff: 0.5 36 | # shaping_fn: "common_gen_repeat_penalty_batched" 37 | 38 | 39 | datapool: 40 | id: commongen 41 | args: 42 | concept_end_token: '.' 43 | concept_separator_token: ' ' 44 | prefix: "generate a sentence with: " 45 | 46 | 47 | env: 48 | n_envs: 10 49 | args: 50 | max_prompt_length: 20 51 | max_episode_length: 20 52 | terminate_on_eos: True 53 | context_start_token: 0 54 | prompt_truncation_side: "right" 55 | 56 | 57 | alg: 58 | id: ppo 59 | args: 60 | n_steps: 256 61 | batch_size: 64 62 | verbose: 1 63 | learning_rate: 0.000002 64 | n_epochs: 5 65 | ent_coef: 0.01 66 | kl_div: 67 | coeff: 0.001 68 | target_kl: 2.0 69 | policy: 70 | id: seq2seq_lm_actor_critic_policy 71 | args: 72 | model_name: t5-base 73 | apply_model_parallel: True 74 | prompt_truncation_side: "right" 75 | generation_kwargs: 76 | do_sample: True 77 | top_k: 0 78 | min_length: 5 79 | max_new_tokens: 20 80 | 81 | train_evaluation: 82 | eval_batch_size: 20 83 | n_iters: 200 84 | eval_every: 20 85 | save_every: 1 86 | metrics: 87 | - id: meteor 88 | args: {} 89 | - id: rouge 90 | - id: bleu 91 | args: {} 92 | - id: bert_score 93 | args: 94 | language: en 95 | - id: cider 96 | - id: spice 97 | - id: diversity 98 | args: {} 99 | generation_kwargs: 100 | num_beams: 5 101 | min_length: 5 102 | max_new_tokens: 20 103 | 104 | -------------------------------------------------------------------------------- /scripts/training/task_configs/common_gen/t5_ppo_on_supervised.yml: -------------------------------------------------------------------------------- 1 | tokenizer: 2 | model_name: t5-base 3 | padding_side: left 4 | truncation_side: left 5 | pad_token_as_eos_token: False 6 | 7 | reward_fn: 8 | id: common_gen_preference_model 9 | args: 10 | device: "cpu" 11 | batch_size: 5 12 | model_type: "11b" 13 | 14 | # values: 15 | # - id: rouge_combined 16 | # - id: meteor 17 | # - id: rouge 18 | # args: 19 | # rouge_type: "rouge1" 20 | # - id: spider 21 | # args: 22 | # spice_coeff: 0.0 23 | # cider_coeff: 1.0 24 | # - id: spider 25 | # args: 26 | # spice_coeff: 1.0 27 | # cider_coeff: 0.0 28 | # - id: spider 29 | # args: 30 | # spice_coeff: 0.5 31 | # cider_coeff: 0.5 32 | 33 | 34 | datapool: 35 | id: commongen 36 | args: 37 | concept_end_token: '.' 38 | concept_separator_token: ' ' 39 | prefix: "generate a sentence with: " 40 | 41 | 42 | env: 43 | n_envs: 10 44 | args: 45 | max_prompt_length: 20 46 | max_episode_length: 20 47 | terminate_on_eos: True 48 | context_start_token: 0 49 | prompt_truncation_side: "right" 50 | 51 | 52 | alg: 53 | id: ppo 54 | args: 55 | n_steps: 128 56 | batch_size: 64 57 | verbose: 1 58 | learning_rate: 0.000002 59 | n_epochs: 5 60 | ent_coef: 0.01 61 | kl_div: 62 | coeff: 0.01 63 | target_kl: 1.0 64 | policy: 65 | id: seq2seq_lm_actor_critic_policy 66 | args: 67 | model_name: rajkumarrrk/t5-common-gen 68 | apply_model_parallel: True 69 | prompt_truncation_side: "right" 70 | generation_kwargs: 71 | do_sample: True 72 | top_k: 50 73 | min_length: 5 74 | max_new_tokens: 20 75 | 76 | train_evaluation: 77 | eval_batch_size: 50 78 | n_iters: 100 79 | eval_every: 5 80 | save_every: 20 81 | metrics: 82 | - id: meteor 83 | args: {} 84 | - id: rouge 85 | - id: bleu 86 | args: {} 87 | - id: bert_score 88 | args: 89 | language: en 90 | - id: cider 91 | - id: spice 92 | - id: diversity 93 | args: {} 94 | generation_kwargs: 95 | num_beams: 5 96 | min_length: 5 97 | max_new_tokens: 20 98 | 99 | -------------------------------------------------------------------------------- /scripts/training/task_configs/common_gen/t5_supervised.yml: -------------------------------------------------------------------------------- 1 | tokenizer: 2 | model_name: t5-base 3 | padding_side: left 4 | truncation_side: right 5 | pad_token_as_eos_token: False 6 | 7 | datapool: 8 | id: commongen 9 | args: 10 | concept_end_token: '.' 11 | concept_separator_token: ' ' 12 | prefix: "generate a sentence with: " 13 | 14 | alg: 15 | id: supervised 16 | training_args: 17 | per_device_train_batch_size: 8 18 | logging_steps: 5000 19 | num_train_epochs: 4 20 | weight_decay: 0.01 21 | lr_scheduler_type: cosine 22 | learning_rate: 0.00001 23 | save_total_limit: 1 24 | model_type: seq2seq 25 | model_name: "t5-base" 26 | generation_kwargs: 27 | num_beams: 5 28 | min_length: 5 29 | max_new_tokens: 20 30 | post_processing_fn: null 31 | 32 | train_evaluation: 33 | eval_batch_size: 100 34 | metrics: 35 | - id: meteor 36 | args: {} 37 | - id: rouge 38 | - id: bleu 39 | args: {} 40 | - id: bert_score 41 | args: 42 | language: en 43 | - id: cider 44 | - id: spice 45 | - id: diversity 46 | args: {} 47 | 48 | -------------------------------------------------------------------------------- /scripts/training/task_configs/dialog/gpt2_nlpo.yml: -------------------------------------------------------------------------------- 1 | tokenizer: 2 | model_name: gpt2 3 | padding_side: left 4 | truncation_side: left 5 | pad_token_as_eos_token: True 6 | 7 | reward_fn: 8 | id: "intent_accuracy" 9 | args: 10 | intent_coeff: 0.75 11 | auto_coeff: 0.25 12 | 13 | datapool: 14 | id: "daily_dialog" 15 | args: 16 | context_size: 5 17 | 18 | env: 19 | n_envs: 10 20 | args: 21 | max_prompt_length: 128 22 | max_episode_length: 20 23 | terminate_on_eos: True 24 | 25 | alg: 26 | id: nlpo 27 | args: 28 | n_steps: 128 29 | batch_size: 64 30 | verbose: 1 31 | learning_rate: 0.000001 32 | n_epochs: 5 33 | 34 | kl_div: 35 | coeff: 0.2 36 | target_kl: 0.5 37 | policy: 38 | id: maskable_causal_lm_actor_critic_policy 39 | args: 40 | model_name: gpt2 41 | apply_model_parallel: True 42 | top_mask: 0.9 43 | min_tokens_to_keep: 100 44 | mask_type: 'learned_top_p' 45 | target_update_iterations: 20 46 | generation_kwargs: 47 | do_sample: True 48 | top_k: 20 49 | min_length: 2 50 | max_new_tokens: 20 51 | 52 | train_evaluation: 53 | eval_batch_size: 32 54 | n_iters: 100 55 | eval_every: 5 56 | save_every: 10 57 | metrics: 58 | - id: intent_accuracy 59 | - id: causal_perplexity 60 | args: 61 | tokenizer_id: gpt2 62 | stride: 128 63 | model_type: causal 64 | - id: diversity 65 | args: {} 66 | - id: meteor 67 | args: {} 68 | - id: rouge 69 | - id: bleu 70 | args: {} 71 | - id: bert_score 72 | args: 73 | language: en 74 | - id: sacre_bleu 75 | args: 76 | tokenize: "intl" 77 | generation_kwargs: 78 | do_sample: True 79 | top_k: 20 80 | min_length: 2 81 | max_new_tokens: 20 -------------------------------------------------------------------------------- /scripts/training/task_configs/dialog/gpt2_nlpo_on_supervised.yml: -------------------------------------------------------------------------------- 1 | tokenizer: 2 | model_name: gpt2 3 | padding_side: left 4 | truncation_side: left 5 | pad_token_as_eos_token: True 6 | 7 | reward_fn: 8 | id: "intent_accuracy" 9 | args: 10 | intent_coeff: 0.25 11 | auto_coeff: 0.75 12 | 13 | datapool: 14 | id: "daily_dialog" 15 | args: 16 | context_size: 5 17 | 18 | env: 19 | n_envs: 10 20 | args: 21 | max_prompt_length: 128 22 | max_episode_length: 20 23 | terminate_on_eos: True 24 | 25 | alg: 26 | id: nlpo 27 | args: 28 | n_steps: 128 29 | batch_size: 64 30 | verbose: 1 31 | learning_rate: 0.000001 32 | n_epochs: 5 33 | 34 | kl_div: 35 | coeff: 0.2 36 | target_kl: 0.05 37 | policy: 38 | id: maskable_causal_lm_actor_critic_policy 39 | args: 40 | model_name: rajkumarrrk/gpt2-fine-tuned-on-daily-dialog 41 | apply_model_parallel: True 42 | top_mask: 0.9 43 | min_tokens_to_keep: 100 44 | mask_type: 'learned_top_p' 45 | target_update_iterations: 20 46 | generation_kwargs: 47 | do_sample: True 48 | top_k: 20 49 | min_length: 2 50 | max_new_tokens: 20 51 | 52 | train_evaluation: 53 | eval_batch_size: 32 54 | n_iters: 50 55 | eval_every: 5 56 | save_every: 10 57 | metrics: 58 | - id: intent_accuracy 59 | - id: causal_perplexity 60 | args: 61 | tokenizer_id: gpt2 62 | stride: 128 63 | model_type: causal 64 | - id: diversity 65 | args: {} 66 | - id: meteor 67 | args: {} 68 | - id: rouge 69 | - id: bleu 70 | args: {} 71 | - id: bert_score 72 | args: 73 | language: en 74 | - id: sacre_bleu 75 | args: 76 | tokenize: "intl" 77 | generation_kwargs: 78 | do_sample: True 79 | top_k: 20 80 | min_length: 2 81 | max_new_tokens: 20 -------------------------------------------------------------------------------- /scripts/training/task_configs/dialog/gpt2_ppo.yml: -------------------------------------------------------------------------------- 1 | tokenizer: 2 | model_name: gpt2 3 | padding_side: left 4 | truncation_side: left 5 | pad_token_as_eos_token: True 6 | 7 | reward_fn: 8 | id: "intent_accuracy" 9 | args: 10 | intent_coeff: 0.75 11 | auto_coeff: 0.25 12 | 13 | 14 | datapool: 15 | id: "daily_dialog" 16 | args: 17 | context_size: 5 18 | 19 | env: 20 | n_envs: 10 21 | args: 22 | max_prompt_length: 128 23 | max_episode_length: 20 24 | terminate_on_eos: True 25 | 26 | alg: 27 | id: ppo 28 | args: 29 | n_steps: 128 30 | batch_size: 64 31 | verbose: 1 32 | learning_rate: 0.000001 33 | n_epochs: 5 34 | 35 | kl_div: 36 | coeff: 0.2 37 | target_kl: 0.5 38 | 39 | policy: 40 | id: causal_lm_actor_critic_policy 41 | args: 42 | model_name: gpt2 43 | apply_model_parallel: True 44 | generation_kwargs: 45 | do_sample: True 46 | top_k: 20 47 | min_length: 2 48 | max_new_tokens: 20 49 | 50 | train_evaluation: 51 | eval_batch_size: 32 52 | n_iters: 50 53 | eval_every: 5 54 | save_every: 10 55 | metrics: 56 | - id: intent_accuracy 57 | - id: causal_perplexity 58 | args: 59 | tokenizer_id: gpt2 60 | stride: 128 61 | model_type: causal 62 | - id: diversity 63 | args: {} 64 | - id: meteor 65 | args: {} 66 | - id: rouge 67 | - id: bleu 68 | args: {} 69 | - id: bert_score 70 | args: 71 | language: en 72 | - id: sacre_bleu 73 | args: 74 | tokenize: "intl" 75 | generation_kwargs: 76 | do_sample: True 77 | top_k: 20 78 | min_length: 2 79 | max_new_tokens: 20 -------------------------------------------------------------------------------- /scripts/training/task_configs/dialog/gpt2_ppo_on_supervised.yml: -------------------------------------------------------------------------------- 1 | tokenizer: 2 | model_name: gpt2 3 | padding_side: left 4 | truncation_side: left 5 | pad_token_as_eos_token: True 6 | 7 | reward_fn: 8 | id: "intent_accuracy" 9 | args: 10 | intent_coeff: 0.5 11 | auto_coeff: 0.5 12 | 13 | 14 | datapool: 15 | id: "daily_dialog" 16 | args: 17 | context_size: 5 18 | 19 | env: 20 | n_envs: 10 21 | args: 22 | max_prompt_length: 128 23 | max_episode_length: 20 24 | terminate_on_eos: True 25 | 26 | alg: 27 | id: ppo 28 | args: 29 | n_steps: 128 30 | batch_size: 64 31 | verbose: 1 32 | learning_rate: 0.000001 33 | n_epochs: 5 34 | 35 | kl_div: 36 | coeff: 0.2 37 | target_kl: 0.05 38 | 39 | policy: 40 | id: causal_lm_actor_critic_policy 41 | args: 42 | model_name: rajkumarrrk/gpt2-fine-tuned-on-daily-dialog 43 | apply_model_parallel: True 44 | generation_kwargs: 45 | do_sample: True 46 | top_k: 20 47 | min_length: 2 48 | max_new_tokens: 20 49 | 50 | train_evaluation: 51 | eval_batch_size: 32 52 | n_iters: 50 53 | eval_every: 5 54 | save_every: 10 55 | metrics: 56 | - id: intent_accuracy 57 | - id: causal_perplexity 58 | args: 59 | tokenizer_id: gpt2 60 | stride: 128 61 | model_type: causal 62 | - id: diversity 63 | args: {} 64 | - id: meteor 65 | args: {} 66 | - id: rouge 67 | - id: bleu 68 | args: {} 69 | - id: bert_score 70 | args: 71 | language: en 72 | - id: sacre_bleu 73 | args: 74 | tokenize: "intl" 75 | generation_kwargs: 76 | do_sample: True 77 | top_k: 20 78 | min_length: 2 79 | max_new_tokens: 20 -------------------------------------------------------------------------------- /scripts/training/task_configs/dialog/gpt2_supervised.yml: -------------------------------------------------------------------------------- 1 | tokenizer: 2 | model_name: gpt2 3 | padding_side: left 4 | truncation_side: left 5 | pad_token_as_eos_token: True 6 | max_length: 128 7 | 8 | datapool: 9 | id: "daily_dialog" 10 | args: 11 | context_size: 5 12 | 13 | alg: 14 | id: supervised 15 | training_args: 16 | per_device_train_batch_size: 32 17 | logging_steps: 300 18 | num_train_epochs: 20 19 | lr_scheduler_type: "constant" 20 | learning_rate: 0.00001 21 | save_total_limit: 1 22 | model_type: causal 23 | model_name: gpt2 24 | generation_kwargs: 25 | do_sample: True 26 | top_k: 20 27 | min_length: 2 28 | max_new_tokens: 20 29 | post_processing_fn: null 30 | 31 | train_evaluation: 32 | eval_batch_size: 256 33 | metrics: 34 | - id: intent_accuracy 35 | - id: causal_perplexity 36 | args: 37 | tokenizer_id: gpt2 38 | stride: 128 39 | model_type: causal 40 | - id: diversity 41 | args: {} 42 | - id: meteor 43 | args: {} 44 | - id: rouge 45 | - id: bleu 46 | args: {} 47 | - id: bert_score 48 | args: 49 | language: en 50 | - id: sacre_bleu 51 | args: 52 | tokenize: "intl" 53 | 54 | -------------------------------------------------------------------------------- /scripts/training/task_configs/imdb_text_continuation/gpt2_a2c.yml: -------------------------------------------------------------------------------- 1 | tokenizer: 2 | model_name: lvwerra/gpt2-imdb 3 | padding_side: left 4 | truncation_side: left 5 | pad_token_as_eos_token: True 6 | 7 | reward_fn: 8 | id: learned_reward 9 | args: 10 | model_name: lvwerra/distilbert-imdb 11 | label_ix: 1 12 | include_prompt_for_eval: True 13 | 14 | datapool: 15 | id: imdb 16 | args: {} 17 | 18 | env: 19 | n_envs: 2 20 | args: 21 | max_prompt_length: 64 22 | max_episode_length: 48 23 | terminate_on_eos: True 24 | 25 | alg: 26 | id: a2c 27 | args: 28 | n_steps: 48 29 | verbose: 1 30 | learning_rate: 0.000001 31 | 32 | kl_div: 33 | coeff: 0.1 34 | target_kl: 0.5 35 | policy: 36 | id: causal_lm_actor_critic_policy 37 | args: 38 | model_name: lvwerra/gpt2-imdb 39 | apply_model_parallel: True 40 | generation_kwargs: 41 | do_sample: True 42 | min_length: 48 43 | max_new_tokens: 48 44 | 45 | train_evaluation: 46 | eval_batch_size: 256 47 | n_iters: 100 48 | eval_every: 20 49 | save_every: 1 50 | metrics: 51 | - id: learned_reward 52 | args: 53 | model_name: lvwerra/distilbert-imdb 54 | label_ix: 1 55 | batch_size: 100 56 | - id: causal_perplexity 57 | args: 58 | tokenizer_id: gpt2 59 | stride: 512 60 | model_type: causal 61 | - id: diversity 62 | args: {} -------------------------------------------------------------------------------- /scripts/training/task_configs/imdb_text_continuation/gpt2_nlpo.yml: -------------------------------------------------------------------------------- 1 | tokenizer: 2 | model_name: lvwerra/gpt2-imdb 3 | padding_side: left 4 | truncation_side: left 5 | pad_token_as_eos_token: True 6 | 7 | reward_fn: 8 | id: learned_reward 9 | args: 10 | model_name: lvwerra/distilbert-imdb 11 | label_ix: 1 12 | include_prompt_for_eval: True 13 | 14 | datapool: 15 | id: imdb 16 | args: 17 | seed: 42 18 | 19 | env: 20 | n_envs: 10 21 | args: 22 | max_prompt_length: 64 23 | max_episode_length: 48 24 | terminate_on_eos: True 25 | 26 | alg: 27 | id: nlpo 28 | args: 29 | n_steps: 128 30 | batch_size: 64 31 | verbose: 1 32 | learning_rate: 0.000001 33 | n_epochs: 5 34 | 35 | kl_div: 36 | coeff: 0.1 37 | target_kl: 0.1 38 | 39 | policy: 40 | id: maskable_causal_lm_actor_critic_policy 41 | args: 42 | model_name: lvwerra/gpt2-imdb 43 | apply_model_parallel: True 44 | top_mask: 0.9 45 | min_tokens_to_keep: 100 46 | mask_type: 'learned_top_p' 47 | target_update_iterations: 5 48 | generation_kwargs: 49 | do_sample: True 50 | min_length: 48 51 | max_new_tokens: 48 52 | 53 | train_evaluation: 54 | eval_batch_size: 64 55 | n_iters: 50 56 | eval_every: 10 57 | save_every: 10 58 | metrics: 59 | - id: learned_reward 60 | args: 61 | model_name: lvwerra/distilbert-imdb 62 | label_ix: 1 63 | batch_size: 100 64 | - id: causal_perplexity 65 | args: 66 | tokenizer_id: gpt2 67 | stride: 512 68 | model_type: causal 69 | - id: diversity 70 | args: {} -------------------------------------------------------------------------------- /scripts/training/task_configs/imdb_text_continuation/gpt2_nlpo_on_supervised.yml: -------------------------------------------------------------------------------- 1 | tokenizer: 2 | model_name: lvwerra/gpt2-imdb 3 | padding_side: left 4 | truncation_side: left 5 | pad_token_as_eos_token: True 6 | 7 | reward_fn: 8 | id: learned_reward 9 | args: 10 | model_name: lvwerra/distilbert-imdb 11 | label_ix: 1 12 | include_prompt_for_eval: True 13 | 14 | datapool: 15 | id: imdb 16 | args: 17 | seed: 42 18 | 19 | env: 20 | n_envs: 10 21 | args: 22 | max_prompt_length: 64 23 | max_episode_length: 48 24 | terminate_on_eos: True 25 | 26 | alg: 27 | id: nlpo 28 | args: 29 | n_steps: 128 30 | batch_size: 64 31 | verbose: 1 32 | learning_rate: 0.000001 33 | n_epochs: 5 34 | 35 | kl_div: 36 | coeff: 0.1 37 | target_kl: 0.1 38 | 39 | policy: 40 | id: maskable_causal_lm_actor_critic_policy 41 | args: 42 | model_name: rajkumarrrk/gpt2-fine-tuned-on-imdb-positive-reviews 43 | apply_model_parallel: True 44 | top_mask: 0.9 45 | min_tokens_to_keep: 100 46 | mask_type: 'learned_top_p' 47 | target_update_iterations: 5 48 | generation_kwargs: 49 | do_sample: True 50 | min_length: 48 51 | max_new_tokens: 48 52 | 53 | train_evaluation: 54 | eval_batch_size: 256 55 | n_iters: 50 56 | eval_every: 10 57 | save_every: 10 58 | metrics: 59 | - id: learned_reward 60 | args: 61 | model_name: lvwerra/distilbert-imdb 62 | label_ix: 1 63 | batch_size: 100 64 | - id: causal_perplexity 65 | args: 66 | tokenizer_id: gpt2 67 | stride: 512 68 | model_type: causal 69 | - id: diversity 70 | args: {} -------------------------------------------------------------------------------- /scripts/training/task_configs/imdb_text_continuation/gpt2_ppo.yml: -------------------------------------------------------------------------------- 1 | tokenizer: 2 | model_name: lvwerra/gpt2-imdb 3 | padding_side: left 4 | truncation_side: left 5 | pad_token_as_eos_token: True 6 | 7 | reward_fn: 8 | id: learned_reward 9 | args: 10 | model_name: lvwerra/distilbert-imdb 11 | label_ix: 1 12 | include_prompt_for_eval: True 13 | 14 | datapool: 15 | id: imdb 16 | args: 17 | seed: 42 18 | 19 | env: 20 | n_envs: 10 21 | args: 22 | max_prompt_length: 64 23 | max_episode_length: 48 24 | terminate_on_eos: True 25 | 26 | alg: 27 | id: ppo 28 | args: 29 | n_steps: 128 30 | batch_size: 64 31 | verbose: 1 32 | learning_rate: 0.000001 33 | n_epochs: 5 34 | 35 | kl_div: 36 | coeff: 0.1 37 | target_kl: 0.1 38 | 39 | policy: 40 | id: causal_lm_actor_critic_policy 41 | args: 42 | model_name: lvwerra/gpt2-imdb 43 | apply_model_parallel: True 44 | generation_kwargs: 45 | do_sample: True 46 | min_length: 48 47 | max_new_tokens: 48 48 | 49 | train_evaluation: 50 | eval_batch_size: 64 51 | n_iters: 50 52 | eval_every: 10 53 | save_every: 10 54 | metrics: 55 | - id: learned_reward 56 | args: 57 | model_name: lvwerra/distilbert-imdb 58 | label_ix: 1 59 | batch_size: 100 60 | - id: causal_perplexity 61 | args: 62 | tokenizer_id: gpt2 63 | stride: 512 64 | model_type: causal 65 | - id: diversity 66 | args: {} -------------------------------------------------------------------------------- /scripts/training/task_configs/imdb_text_continuation/gpt2_ppo_on_supervised.yml: -------------------------------------------------------------------------------- 1 | tokenizer: 2 | model_name: lvwerra/gpt2-imdb 3 | padding_side: left 4 | truncation_side: left 5 | pad_token_as_eos_token: True 6 | 7 | reward_fn: 8 | id: learned_reward 9 | args: 10 | model_name: lvwerra/distilbert-imdb 11 | label_ix: 1 12 | include_prompt_for_eval: True 13 | 14 | datapool: 15 | id: imdb 16 | args: 17 | seed: 42 18 | 19 | env: 20 | n_envs: 10 21 | args: 22 | max_prompt_length: 64 23 | max_episode_length: 48 24 | terminate_on_eos: True 25 | 26 | alg: 27 | id: ppo 28 | args: 29 | n_steps: 128 30 | batch_size: 64 31 | verbose: 1 32 | learning_rate: 0.000001 33 | n_epochs: 5 34 | 35 | kl_div: 36 | coeff: 0.1 37 | target_kl: 0.1 38 | 39 | policy: 40 | id: causal_lm_actor_critic_policy 41 | args: 42 | model_name: rajkumarrrk/gpt2-fine-tuned-on-imdb-positive-reviews 43 | apply_model_parallel: True 44 | generation_kwargs: 45 | do_sample: True 46 | min_length: 48 47 | max_new_tokens: 48 48 | 49 | train_evaluation: 50 | eval_batch_size: 256 51 | n_iters: 50 52 | eval_every: 10 53 | save_every: 10 54 | metrics: 55 | - id: learned_reward 56 | args: 57 | model_name: lvwerra/distilbert-imdb 58 | label_ix: 1 59 | batch_size: 100 60 | - id: causal_perplexity 61 | args: 62 | tokenizer_id: gpt2 63 | stride: 512 64 | model_type: causal 65 | - id: diversity 66 | args: {} -------------------------------------------------------------------------------- /scripts/training/task_configs/imdb_text_continuation/gpt2_supervised.yml: -------------------------------------------------------------------------------- 1 | tokenizer: 2 | model_name: gpt2 3 | padding_side: left 4 | truncation_side: left 5 | pad_token_as_eos_token: True 6 | max_length: 64 7 | 8 | datapool: 9 | id: "imdb_seq2seq" 10 | args: 11 | positive_ratio: 1.0 12 | 13 | alg: 14 | id: supervised 15 | training_args: 16 | per_device_train_batch_size: 16 17 | logging_steps: 200 18 | num_train_epochs: 10 19 | lr_scheduler_type: "constant" 20 | learning_rate: 0.00001 21 | save_total_limit: 1 22 | model_type: causal 23 | model_name: lvwerra/gpt2-imdb 24 | generation_kwargs: 25 | do_sample: True 26 | min_length: 48 27 | max_new_tokens: 48 28 | post_processing_fn: null 29 | 30 | train_evaluation: 31 | eval_batch_size: 256 32 | metrics: 33 | - id: learned_reward 34 | args: 35 | model_name: lvwerra/distilbert-imdb 36 | label_ix: 1 37 | batch_size: 100 38 | - id: causal_perplexity 39 | args: 40 | tokenizer_id: gpt2 41 | stride: 512 42 | model_type: causal 43 | use_text_from_meta_data: True 44 | - id: diversity 45 | args: {} 46 | 47 | -------------------------------------------------------------------------------- /scripts/training/task_configs/iwslt2017/t5_nlpo.yml: -------------------------------------------------------------------------------- 1 | tokenizer: 2 | model_name: t5-base 3 | padding_side: left 4 | truncation_side: left 5 | pad_token_as_eos_token: False 6 | 7 | reward_fn: 8 | id: ter 9 | # values: 10 | # - id: sacre_bleu 11 | # args: 12 | # tokenize: "intl" 13 | # - id: ter 14 | # - id: chrf 15 | # - id: bert_score 16 | # args: 17 | # language: "de" 18 | 19 | datapool: 20 | id: iwslt2017en_de 21 | args: 22 | prompt_prefix: "translate English to German: " 23 | 24 | env: 25 | n_envs: 10 26 | args: 27 | max_prompt_length: 128 28 | max_episode_length: 128 29 | terminate_on_eos: True 30 | prompt_truncation_side: "right" 31 | context_start_token: 0 32 | 33 | alg: 34 | args: 35 | batch_size: 64 36 | ent_coef: 0.0 37 | learning_rate: 0.000001 38 | n_epochs: 5 39 | n_steps: 512 40 | verbose: 1 41 | id: nlpo 42 | kl_div: 43 | coeff: 0.001 44 | target_kl: 0.2 45 | policy: 46 | args: 47 | apply_model_parallel: true 48 | generation_kwargs: 49 | do_sample: True 50 | top_k: 10 51 | max_new_tokens: 128 52 | mask_type: learned_top_p 53 | min_tokens_to_keep: 100 54 | model_name: t5-base 55 | prompt_truncation_side: right 56 | target_update_iterations: 20 57 | top_mask: 0.5 58 | id: maskable_seq2seq_lm_actor_critic_policy 59 | 60 | train_evaluation: 61 | eval_batch_size: 50 62 | n_iters: 50 63 | eval_every: 10 64 | save_every: 1 65 | metrics: 66 | - id: meteor 67 | args: {} 68 | - id: rouge 69 | - id: bleu 70 | args: {} 71 | - id: bert_score 72 | args: 73 | language: de 74 | - id: bleu 75 | args: {} 76 | - id: sacre_bleu 77 | args: 78 | tokenize: "intl" 79 | - id: ter 80 | args: {} 81 | - id: chrf 82 | args: {} 83 | - id: diversity 84 | args: {} 85 | generation_kwargs: 86 | num_beams: 4 87 | length_penalty: 0.6 88 | max_new_tokens: 128 89 | -------------------------------------------------------------------------------- /scripts/training/task_configs/iwslt2017/t5_nlpo_on_supervised.yml: -------------------------------------------------------------------------------- 1 | tokenizer: 2 | model_name: t5-base 3 | padding_side: left 4 | truncation_side: left 5 | pad_token_as_eos_token: False 6 | 7 | reward_fn: 8 | id: ter 9 | # values: 10 | # - id: sacre_bleu 11 | # args: 12 | # tokenize: "intl" 13 | # - id: ter 14 | # - id: chrf 15 | # - id: bert_score 16 | # args: 17 | # language: "de" 18 | 19 | datapool: 20 | id: iwslt2017en_de 21 | args: 22 | prompt_prefix: "translate English to German: " 23 | 24 | env: 25 | n_envs: 10 26 | args: 27 | max_prompt_length: 128 28 | max_episode_length: 128 29 | terminate_on_eos: True 30 | prompt_truncation_side: "right" 31 | context_start_token: 0 32 | 33 | alg: 34 | args: 35 | batch_size: 64 36 | ent_coef: 0.0 37 | learning_rate: 0.0000005 38 | n_epochs: 5 39 | n_steps: 512 40 | verbose: 1 41 | id: nlpo 42 | kl_div: 43 | coeff: 0.001 44 | target_kl: 0.2 45 | policy: 46 | args: 47 | apply_model_parallel: true 48 | generation_kwargs: 49 | do_sample: True 50 | top_k: 10 51 | max_new_tokens: 128 52 | mask_type: learned_top_p 53 | min_tokens_to_keep: 100 54 | model_name: rajkumarrrk/t5-fine-tuned-on-iwslt2017en_de 55 | prompt_truncation_side: right 56 | target_update_iterations: 20 57 | top_mask: 0.5 58 | id: maskable_seq2seq_lm_actor_critic_policy 59 | 60 | train_evaluation: 61 | eval_batch_size: 50 62 | n_iters: 50 63 | eval_every: 10 64 | save_every: 1 65 | metrics: 66 | - id: meteor 67 | args: {} 68 | - id: rouge 69 | - id: bleu 70 | args: {} 71 | - id: bert_score 72 | args: 73 | language: de 74 | - id: bleu 75 | args: {} 76 | - id: sacre_bleu 77 | args: 78 | tokenize: "intl" 79 | - id: ter 80 | args: {} 81 | - id: chrf 82 | args: {} 83 | - id: diversity 84 | args: {} 85 | generation_kwargs: 86 | num_beams: 4 87 | length_penalty: 0.6 88 | max_new_tokens: 128 89 | -------------------------------------------------------------------------------- /scripts/training/task_configs/iwslt2017/t5_ppo.yml: -------------------------------------------------------------------------------- 1 | tokenizer: 2 | model_name: t5-base 3 | padding_side: left 4 | truncation_side: left 5 | pad_token_as_eos_token: False 6 | 7 | reward_fn: 8 | id: ter 9 | # values: 10 | # - id: sacre_bleu 11 | # args: 12 | # tokenize: "intl" 13 | # - id: ter 14 | # - id: chrf 15 | # - id: bert_score 16 | # args: 17 | # language: "de" 18 | 19 | datapool: 20 | id: iwslt2017en_de 21 | args: 22 | prompt_prefix: "translate English to German: " 23 | 24 | 25 | env: 26 | n_envs: 10 27 | args: 28 | max_prompt_length: 128 29 | max_episode_length: 128 30 | terminate_on_eos: True 31 | prompt_truncation_side: "right" 32 | context_start_token: 0 33 | 34 | alg: 35 | id: ppo 36 | args: 37 | n_steps: 512 38 | batch_size: 64 39 | verbose: 1 40 | learning_rate: 0.000001 41 | n_epochs: 5 42 | ent_coef: 0.0 43 | kl_div: 44 | coeff: 0.001 45 | target_kl: 0.2 46 | policy: 47 | id: seq2seq_lm_actor_critic_policy 48 | args: 49 | model_name: t5-base 50 | apply_model_parallel: True 51 | prompt_truncation_side: "right" 52 | generation_kwargs: 53 | do_sample: True 54 | top_k: 10 55 | max_new_tokens: 128 56 | 57 | train_evaluation: 58 | eval_batch_size: 50 59 | n_iters: 50 60 | eval_every: 10 61 | save_every: 1 62 | metrics: 63 | - id: meteor 64 | args: {} 65 | - id: rouge 66 | - id: bleu 67 | args: {} 68 | - id: bert_score 69 | args: 70 | language: de 71 | - id: bleu 72 | args: {} 73 | - id: sacre_bleu 74 | args: 75 | tokenize: "intl" 76 | - id: ter 77 | args: {} 78 | - id: chrf 79 | args: {} 80 | - id: diversity 81 | args: {} 82 | generation_kwargs: 83 | num_beams: 4 84 | length_penalty: 0.6 85 | max_new_tokens: 128 86 | 87 | -------------------------------------------------------------------------------- /scripts/training/task_configs/iwslt2017/t5_ppo_on_supervised.yml: -------------------------------------------------------------------------------- 1 | tokenizer: 2 | model_name: t5-base 3 | padding_side: left 4 | truncation_side: left 5 | pad_token_as_eos_token: False 6 | 7 | reward_fn: 8 | id: ter 9 | # values: 10 | # - id: sacre_bleu 11 | # args: 12 | # tokenize: "intl" 13 | # # - id: ter 14 | # # - id: chrf 15 | # # - id: bert_score 16 | # # args: 17 | # # language: "de" 18 | 19 | datapool: 20 | id: iwslt2017en_de 21 | args: 22 | prompt_prefix: "translate English to German: " 23 | 24 | 25 | env: 26 | n_envs: 10 27 | args: 28 | max_prompt_length: 128 29 | max_episode_length: 128 30 | terminate_on_eos: True 31 | prompt_truncation_side: "right" 32 | context_start_token: 0 33 | 34 | alg: 35 | id: ppo 36 | args: 37 | n_steps: 512 38 | batch_size: 64 39 | verbose: 1 40 | learning_rate: 0.0000005 41 | n_epochs: 5 42 | ent_coef: 0.0 43 | kl_div: 44 | coeff: 0.001 45 | target_kl: 0.2 46 | policy: 47 | id: seq2seq_lm_actor_critic_policy 48 | args: 49 | model_name: rajkumarrrk/t5-fine-tuned-on-iwslt2017en_de 50 | apply_model_parallel: True 51 | prompt_truncation_side: "right" 52 | generation_kwargs: 53 | do_sample: True 54 | top_k: 5 55 | max_new_tokens: 128 56 | 57 | train_evaluation: 58 | eval_batch_size: 50 59 | n_iters: 50 60 | eval_every: 10 61 | save_every: 1 62 | metrics: 63 | - id: meteor 64 | args: {} 65 | - id: rouge 66 | - id: bleu 67 | args: {} 68 | - id: bert_score 69 | args: 70 | language: de 71 | - id: bleu 72 | args: {} 73 | - id: sacre_bleu 74 | args: 75 | tokenize: "intl" 76 | - id: ter 77 | args: {} 78 | - id: chrf 79 | args: {} 80 | - id: diversity 81 | args: {} 82 | generation_kwargs: 83 | num_beams: 4 84 | length_penalty: 0.6 85 | max_new_tokens: 128 86 | 87 | -------------------------------------------------------------------------------- /scripts/training/task_configs/iwslt2017/t5_supervised.yml: -------------------------------------------------------------------------------- 1 | tokenizer: 2 | model_name: t5-base 3 | padding_side: left 4 | truncation_side: left 5 | pad_token_as_eos_token: False 6 | 7 | datapool: 8 | id: iwslt2017en_de 9 | args: 10 | prompt_prefix: "translate English to German: " 11 | 12 | alg: 13 | id: supervised 14 | training_args: 15 | per_device_train_batch_size: 64 16 | logging_steps: 1000 17 | num_train_epochs: 5 18 | weight_decay: 0.1 19 | lr_scheduler_type: "constant" 20 | learning_rate: 0.00001 21 | save_total_limit: 1 22 | model_type: seq2seq 23 | model_name: "t5-base" 24 | generation_kwargs: 25 | post_processing_fn: null 26 | num_beams: 4 27 | length_penalty: 0.6 28 | max_new_tokens: 128 29 | 30 | 31 | train_evaluation: 32 | eval_batch_size: 50 33 | metrics: 34 | - id: meteor 35 | args: {} 36 | - id: rouge 37 | - id: bleu 38 | args: {} 39 | - id: bert_score 40 | args: 41 | language: en 42 | - id: bleu 43 | args: {} 44 | - id: sacre_bleu 45 | args: 46 | tokenize: "intl" 47 | - id: ter 48 | args: {} 49 | - id: chrf 50 | args: {} 51 | - id: diversity 52 | args: {} 53 | 54 | -------------------------------------------------------------------------------- /scripts/training/task_configs/narrative_qa/t5_nlpo.yml: -------------------------------------------------------------------------------- 1 | tokenizer: 2 | model_name: t5-base 3 | padding_side: left 4 | truncation_side: left 5 | pad_token_as_eos_token: False 6 | 7 | reward_fn: 8 | id: rouge_l_max 9 | args: 10 | max_n: 4 11 | limit_length: True 12 | length_limit: 100 13 | length_limit_type: "words" 14 | apply_avg: True 15 | apply_best: True 16 | alpha: 0.5 17 | weight_factor: 1.2 18 | stemming: True 19 | # expand: True 20 | # values: 21 | # - id: rouge_l_max 22 | # args: 23 | # max_n: 4 24 | # limit_length: True 25 | # length_limit: 100 26 | # length_limit_type: "words" 27 | # apply_avg: True 28 | # apply_best: True 29 | # alpha: 0.5 30 | # weight_factor: 1.2 31 | # stemming: True 32 | # - id: meteor 33 | # - id: bleu 34 | 35 | datapool: 36 | id: narrative_qa 37 | 38 | env: 39 | n_envs: 10 40 | args: 41 | max_prompt_length: 512 42 | max_episode_length: 50 43 | terminate_on_eos: True 44 | context_start_token: 0 45 | prompt_truncation_side: "right" 46 | 47 | alg: 48 | id: nlpo 49 | args: 50 | n_steps: 256 51 | batch_size: 64 52 | verbose: 1 53 | learning_rate: 0.000002 54 | n_epochs: 5 55 | kl_div: 56 | coeff: 0.001 57 | target_kl: 1.0 58 | policy: 59 | id: maskable_seq2seq_lm_actor_critic_policy 60 | args: 61 | model_name: t5-base 62 | apply_model_parallel: True 63 | mask_type: "learned_top_p" 64 | top_mask: 0.9 65 | target_update_iterations: 20 66 | generation_kwargs: 67 | do_sample: True 68 | top_k: 50 69 | 70 | train_evaluation: 71 | eval_batch_size: 50 72 | n_iters: 100 73 | eval_every: 10 74 | save_every: 1 75 | metrics: 76 | - id: meteor 77 | args: {} 78 | - id: rouge 79 | args: 80 | use_single_ref: False 81 | - id: bleu 82 | args: {} 83 | - id: bert_score 84 | args: 85 | language: en 86 | - id: rouge_l_max 87 | args: 88 | max_n: 4 89 | limit_length: True 90 | length_limit: 100 91 | length_limit_type: "words" 92 | apply_avg: True 93 | apply_best: True, 94 | alpha: 0.5 95 | weight_factor: 1.2 96 | stemming: True 97 | - id: diversity 98 | args: {} 99 | generation_kwargs: 100 | num_beams: 4 101 | -------------------------------------------------------------------------------- /scripts/training/task_configs/narrative_qa/t5_nlpo_on_supervised.yml: -------------------------------------------------------------------------------- 1 | tokenizer: 2 | model_name: allenai/unifiedqa-t5-base 3 | padding_side: left 4 | truncation_side: left 5 | pad_token_as_eos_token: False 6 | 7 | reward_fn: 8 | id: rouge_l_max 9 | args: 10 | max_n: 4 11 | limit_length: True 12 | length_limit: 100 13 | length_limit_type: "words" 14 | apply_avg: True 15 | apply_best: True 16 | alpha: 0.5 17 | weight_factor: 1.2 18 | stemming: True 19 | # expand: True 20 | # values: 21 | # - id: rouge_l_max 22 | # args: 23 | # max_n: 4 24 | # limit_length: True 25 | # length_limit: 100 26 | # length_limit_type: "words" 27 | # apply_avg: True 28 | # apply_best: True 29 | # alpha: 0.5 30 | # weight_factor: 1.2 31 | # stemming: True 32 | # - id: meteor 33 | # - id: bleu 34 | 35 | datapool: 36 | id: narrative_qa 37 | 38 | env: 39 | n_envs: 10 40 | args: 41 | max_prompt_length: 512 42 | max_episode_length: 50 43 | terminate_on_eos: True 44 | context_start_token: 0 45 | prompt_truncation_side: "right" 46 | 47 | alg: 48 | id: nlpo 49 | args: 50 | n_steps: 256 51 | batch_size: 64 52 | verbose: 1 53 | learning_rate: 0.0000005 54 | n_epochs: 5 55 | kl_div: 56 | coeff: 0.001 57 | target_kl: 0.2 58 | policy: 59 | id: maskable_seq2seq_lm_actor_critic_policy 60 | args: 61 | model_name: allenai/unifiedqa-t5-base 62 | apply_model_parallel: True 63 | mask_type: "learned_top_p" 64 | top_mask: 0.9 65 | target_update_iterations: 20 66 | generation_kwargs: 67 | do_sample: True 68 | top_k: 50 69 | 70 | train_evaluation: 71 | eval_batch_size: 50 72 | n_iters: 100 73 | eval_every: 10 74 | save_every: 1 75 | metrics: 76 | - id: meteor 77 | args: {} 78 | - id: rouge 79 | args: 80 | use_single_ref: False 81 | - id: bleu 82 | args: {} 83 | - id: bert_score 84 | args: 85 | language: en 86 | - id: rouge_l_max 87 | args: 88 | max_n: 4 89 | limit_length: True 90 | length_limit: 100 91 | length_limit_type: "words" 92 | apply_avg: True 93 | apply_best: True 94 | alpha: 0.5 95 | weight_factor: 1.2 96 | stemming: True 97 | - id: diversity 98 | args: {} 99 | generation_kwargs: 100 | num_beams: 4 101 | max_new_tokens: 50 102 | -------------------------------------------------------------------------------- /scripts/training/task_configs/narrative_qa/t5_ppo.yml: -------------------------------------------------------------------------------- 1 | tokenizer: 2 | model_name: t5-base 3 | padding_side: left 4 | truncation_side: left 5 | pad_token_as_eos_token: False 6 | 7 | reward_fn: 8 | id: rouge_l_max 9 | args: 10 | max_n: 4 11 | limit_length: True 12 | length_limit: 100 13 | length_limit_type: "words" 14 | apply_avg: True 15 | apply_best: True 16 | alpha: 0.5 17 | weight_factor: 1.2 18 | stemming: True 19 | # expand: True 20 | # values: 21 | # - id: rouge_l_max 22 | # args: 23 | # max_n: 4 24 | # limit_length: True 25 | # length_limit: 100 26 | # length_limit_type: "words" 27 | # apply_avg: True 28 | # apply_best: True 29 | # alpha: 0.5 30 | # weight_factor: 1.2 31 | # stemming: True 32 | # - id: meteor 33 | # - id: bleu 34 | 35 | 36 | datapool: 37 | id: narrative_qa 38 | 39 | env: 40 | n_envs: 10 41 | args: 42 | max_prompt_length: 512 43 | max_episode_length: 50 44 | terminate_on_eos: True 45 | context_start_token: 0 46 | prompt_truncation_side: "right" 47 | 48 | 49 | alg: 50 | id: ppo 51 | args: 52 | n_steps: 256 53 | batch_size: 64 54 | verbose: 1 55 | learning_rate: 0.000002 56 | n_epochs: 5 57 | kl_div: 58 | coeff: 0.001 59 | target_kl: 1.0 60 | policy: 61 | id: seq2seq_lm_actor_critic_policy 62 | args: 63 | model_name: t5-base 64 | apply_model_parallel: True 65 | prompt_truncation_side: "right" 66 | generation_kwargs: 67 | do_sample: True 68 | top_k: 50 69 | 70 | train_evaluation: 71 | eval_batch_size: 50 72 | n_iters: 100 73 | eval_every: 10 74 | save_every: 1 75 | metrics: 76 | - id: meteor 77 | args: {} 78 | - id: rouge 79 | args: 80 | use_single_ref: False 81 | - id: bleu 82 | args: {} 83 | - id: bert_score 84 | args: 85 | language: en 86 | - id: rouge_l_max 87 | args: 88 | max_n: 4 89 | limit_length: True 90 | length_limit: 100 91 | length_limit_type: "words" 92 | apply_avg: True 93 | apply_best: True, 94 | alpha: 0.5 95 | weight_factor: 1.2 96 | stemming: True 97 | - id: diversity 98 | args: {} 99 | generation_kwargs: 100 | num_beams: 4 101 | 102 | -------------------------------------------------------------------------------- /scripts/training/task_configs/narrative_qa/t5_ppo_on_supervised.yml: -------------------------------------------------------------------------------- 1 | tokenizer: 2 | model_name: allenai/unifiedqa-t5-base 3 | padding_side: left 4 | truncation_side: left 5 | pad_token_as_eos_token: False 6 | 7 | reward_fn: 8 | id: rouge_l_max 9 | args: 10 | max_n: 4 11 | limit_length: True 12 | length_limit: 100 13 | length_limit_type: "words" 14 | apply_avg: True 15 | apply_best: True 16 | alpha: 0.5 17 | weight_factor: 1.2 18 | stemming: True 19 | # expand: True 20 | # values: 21 | # - id: rouge_l_max 22 | # args: 23 | # max_n: 4 24 | # limit_length: True 25 | # length_limit: 100 26 | # length_limit_type: "words" 27 | # apply_avg: True 28 | # apply_best: True 29 | # alpha: 0.5 30 | # weight_factor: 1.2 31 | # stemming: True 32 | # - id: rouge_combined 33 | 34 | 35 | datapool: 36 | id: narrative_qa 37 | 38 | env: 39 | n_envs: 10 40 | args: 41 | max_prompt_length: 512 42 | max_episode_length: 50 43 | terminate_on_eos: True 44 | context_start_token: 0 45 | prompt_truncation_side: "right" 46 | 47 | 48 | alg: 49 | id: ppo 50 | args: 51 | n_steps: 512 52 | batch_size: 64 53 | verbose: 1 54 | learning_rate: 0.0000005 55 | n_epochs: 5 56 | kl_div: 57 | coeff: 0.001 58 | target_kl: 0.2 59 | policy: 60 | id: seq2seq_lm_actor_critic_policy 61 | args: 62 | model_name: allenai/unifiedqa-t5-base 63 | apply_model_parallel: True 64 | prompt_truncation_side: "right" 65 | generation_kwargs: 66 | do_sample: True 67 | top_k: 50 68 | max_new_tokens: 50 69 | 70 | train_evaluation: 71 | eval_batch_size: 50 72 | n_iters: 100 73 | eval_every: 10 74 | save_every: 1 75 | metrics: 76 | - id: meteor 77 | args: {} 78 | - id: rouge 79 | args: 80 | use_single_ref: False 81 | - id: bleu 82 | args: {} 83 | - id: bert_score 84 | args: 85 | language: en 86 | - id: rouge_l_max 87 | args: 88 | max_n: 4 89 | limit_length: True 90 | length_limit: 100 91 | length_limit_type: "words" 92 | apply_avg: True 93 | apply_best: True 94 | alpha: 0.5 95 | weight_factor: 1.2 96 | stemming: True 97 | - id: diversity 98 | args: {} 99 | generation_kwargs: 100 | num_beams: 4 101 | max_new_tokens: 50 102 | 103 | -------------------------------------------------------------------------------- /scripts/training/task_configs/summarization/t5_nlpo.yml: -------------------------------------------------------------------------------- 1 | tokenizer: 2 | model_name: t5-base 3 | padding_side: left 4 | truncation_side: left 5 | pad_token_as_eos_token: False 6 | 7 | reward_fn: 8 | id: meteor 9 | 10 | datapool: 11 | id: cnn_daily_mail 12 | args: 13 | prompt_prefix: "Summarize: " 14 | 15 | env: 16 | n_envs: 10 17 | args: 18 | max_prompt_length: 512 19 | max_episode_length: 100 20 | terminate_on_eos: True 21 | prompt_truncation_side: "right" 22 | context_start_token: 0 23 | 24 | alg: 25 | id: nlpo 26 | args: 27 | n_steps: 512 28 | batch_size: 64 29 | verbose: 1 30 | learning_rate: 0.000002 31 | n_epochs: 5 32 | ent_coef: 0.0 33 | kl_div: 34 | coeff: 0.001 35 | target_kl: 0.2 36 | policy: 37 | id: maskable_seq2seq_lm_actor_critic_policy 38 | args: 39 | model_name: t5-base 40 | apply_model_parallel: True 41 | prompt_truncation_side: "right" 42 | min_tokens_to_keep: 100 43 | top_mask: 0.9 44 | mask_type: learned_top_p 45 | target_update_iterations: 20 46 | generation_kwargs: 47 | do_sample: True 48 | top_k: 100 49 | min_length: 50 50 | max_new_tokens: 100 51 | 52 | train_evaluation: 53 | eval_batch_size: 100 54 | n_iters: 100 55 | eval_every: 10 56 | save_every: 1 57 | metrics: 58 | - id: meteor 59 | args: {} 60 | - id: rouge 61 | - id: bleu 62 | args: {} 63 | - id: bert_score 64 | args: 65 | language: en 66 | # - id: bleurt 67 | # args: 68 | # config_name: bleurt-large-512 69 | - id: diversity 70 | args: {} 71 | # - id: summaCZS 72 | # args: 73 | # granularity: sentence 74 | # use_ent: True 75 | # use_con: False 76 | # - id: summaCConv 77 | # args: 78 | # granularity: sentence 79 | generation_kwargs: 80 | do_sample: True 81 | top_k: 0 82 | temperature: 0.7 83 | min_length: 50 84 | max_new_tokens: 100 85 | -------------------------------------------------------------------------------- /scripts/training/task_configs/summarization/t5_nlpo_on_supervised.yml: -------------------------------------------------------------------------------- 1 | tokenizer: 2 | model_name: t5-base 3 | padding_side: left 4 | truncation_side: left 5 | pad_token_as_eos_token: False 6 | 7 | reward_fn: 8 | id: meteor 9 | 10 | datapool: 11 | id: cnn_daily_mail 12 | args: 13 | prompt_prefix: "Summarize: " 14 | 15 | env: 16 | n_envs: 10 17 | args: 18 | max_prompt_length: 512 19 | max_episode_length: 100 20 | terminate_on_eos: True 21 | prompt_truncation_side: "right" 22 | context_start_token: 0 23 | 24 | alg: 25 | id: nlpo 26 | args: 27 | n_steps: 512 28 | batch_size: 64 29 | verbose: 1 30 | learning_rate: 0.000002 31 | n_epochs: 5 32 | ent_coef: 0.0 33 | kl_div: 34 | coeff: 0.01 35 | target_kl: 0.2 36 | policy: 37 | id: maskable_seq2seq_lm_actor_critic_policy 38 | args: 39 | model_name: rajkumarrrk/t5-base-fine-tuned-on-cnn-dm 40 | apply_model_parallel: True 41 | prompt_truncation_side: "right" 42 | min_tokens_to_keep: 100 43 | top_mask: 0.9 44 | mask_type: "learned_top_p" 45 | target_update_iterations: 30 46 | generation_kwargs: 47 | do_sample: True 48 | top_k: 100 49 | min_length: 50 50 | max_new_tokens: 100 51 | 52 | train_evaluation: 53 | eval_batch_size: 100 54 | n_iters: 50 55 | eval_every: 10 56 | save_every: 1 57 | metrics: 58 | - id: meteor 59 | args: {} 60 | - id: rouge 61 | - id: bleu 62 | args: {} 63 | - id: bert_score 64 | args: 65 | language: en 66 | # - id: bleurt 67 | # args: 68 | # config_name: bleurt-large-512 69 | - id: diversity 70 | args: {} 71 | # - id: summaCZS 72 | # args: 73 | # granularity: sentence 74 | # use_ent: True 75 | # use_con: False 76 | # - id: summaCConv 77 | # args: 78 | # granularity: sentence 79 | generation_kwargs: 80 | do_sample: True 81 | top_k: 0 82 | temperature: 0.7 83 | min_length: 50 84 | max_new_tokens: 100 -------------------------------------------------------------------------------- /scripts/training/task_configs/summarization/t5_ppo.yml: -------------------------------------------------------------------------------- 1 | tokenizer: 2 | model_name: t5-base 3 | padding_side: left 4 | truncation_side: left 5 | pad_token_as_eos_token: False 6 | 7 | reward_fn: 8 | id: rouge 9 | args: 10 | rouge_type: "rouge1" 11 | 12 | datapool: 13 | id: cnn_daily_mail 14 | args: 15 | prompt_prefix: "Summarize: " 16 | 17 | 18 | env: 19 | n_envs: 10 20 | args: 21 | max_prompt_length: 512 22 | max_episode_length: 100 23 | terminate_on_eos: True 24 | prompt_truncation_side: "right" 25 | context_start_token: 0 26 | 27 | alg: 28 | id: ppo 29 | args: 30 | n_steps: 512 31 | batch_size: 64 32 | verbose: 1 33 | learning_rate: 0.000002 34 | n_epochs: 5 35 | ent_coef: 0.0 36 | kl_div: 37 | coeff: 0.001 38 | target_kl: 0.2 39 | policy: 40 | id: seq2seq_lm_actor_critic_policy 41 | args: 42 | model_name: t5-base 43 | apply_model_parallel: True 44 | prompt_truncation_side: "right" 45 | generation_kwargs: 46 | do_sample: True 47 | top_k: 50 48 | min_length: 50 49 | max_new_tokens: 100 50 | 51 | train_evaluation: 52 | eval_batch_size: 100 53 | n_iters: 100 54 | eval_every: 10 55 | save_every: 1 56 | metrics: 57 | - id: meteor 58 | args: {} 59 | - id: rouge 60 | - id: bleu 61 | args: {} 62 | - id: bert_score 63 | args: 64 | language: en 65 | # - id: bleurt 66 | # args: 67 | # config_name: bleurt-large-512 68 | - id: diversity 69 | args: {} 70 | # - id: summaCZS 71 | # args: 72 | # granularity: sentence 73 | # use_ent: True 74 | # use_con: False 75 | # - id: summaCConv 76 | # args: 77 | # granularity: sentence 78 | generation_kwargs: 79 | do_sample: True 80 | top_k: 0 81 | temperature: 0.7 82 | min_length: 50 83 | max_new_tokens: 100 84 | 85 | -------------------------------------------------------------------------------- /scripts/training/task_configs/summarization/t5_ppo_on_supervised.yml: -------------------------------------------------------------------------------- 1 | tokenizer: 2 | model_name: t5-base 3 | padding_side: left 4 | truncation_side: left 5 | pad_token_as_eos_token: False 6 | 7 | reward_fn: 8 | id: meteor 9 | 10 | datapool: 11 | id: cnn_daily_mail 12 | args: 13 | prompt_prefix: "Summarize: " 14 | 15 | 16 | env: 17 | n_envs: 10 18 | args: 19 | max_prompt_length: 512 20 | max_episode_length: 100 21 | terminate_on_eos: True 22 | prompt_truncation_side: "right" 23 | context_start_token: 0 24 | 25 | alg: 26 | id: ppo 27 | args: 28 | n_steps: 512 29 | batch_size: 64 30 | verbose: 1 31 | learning_rate: 0.000002 32 | n_epochs: 5 33 | ent_coef: 0.0 34 | kl_div: 35 | coeff: 0.01 36 | target_kl: 0.2 37 | policy: 38 | id: seq2seq_lm_actor_critic_policy 39 | args: 40 | model_name: rajkumarrrk/t5-base-fine-tuned-on-cnn-dm 41 | apply_model_parallel: True 42 | prompt_truncation_side: "right" 43 | generation_kwargs: 44 | do_sample: True 45 | top_k: 100 46 | min_length: 50 47 | max_new_tokens: 100 48 | 49 | train_evaluation: 50 | eval_batch_size: 100 51 | n_iters: 50 52 | eval_every: 10 53 | save_every: 1 54 | metrics: 55 | - id: meteor 56 | args: {} 57 | - id: rouge 58 | - id: bleu 59 | args: {} 60 | - id: bert_score 61 | args: 62 | language: en 63 | # - id: bleurt 64 | # args: 65 | # config_name: bleurt-large-512 66 | - id: diversity 67 | args: {} 68 | # - id: summaCZS 69 | # args: 70 | # granularity: sentence 71 | # use_ent: True 72 | # use_con: False 73 | # - id: summaCConv 74 | # args: 75 | # granularity: sentence 76 | generation_kwargs: 77 | do_sample: True 78 | top_k: 0 79 | temperature: 0.7 80 | min_length: 50 81 | max_new_tokens: 100 82 | 83 | -------------------------------------------------------------------------------- /scripts/training/task_configs/summarization/t5_supervised.yml: -------------------------------------------------------------------------------- 1 | tokenizer: 2 | model_name: t5-base 3 | padding_side: left 4 | truncation_side: right 5 | pad_token_as_eos_token: False 6 | 7 | datapool: 8 | id: cnn_daily_mail 9 | args: 10 | prompt_prefix: "Summarize: " 11 | 12 | alg: 13 | id: supervised 14 | training_args: 15 | per_device_train_batch_size: 16 16 | logging_steps: 5000 17 | num_train_epochs: 2 18 | weight_decay: 0.1 19 | lr_scheduler_type: cosine 20 | learning_rate: 0.0001 21 | save_total_limit: 1 22 | model_type: seq2seq 23 | model_name: t5-base 24 | generation_kwargs: 25 | do_sample: True 26 | top_k: 0 27 | temperature: 0.7 28 | min_length: 50 29 | max_new_tokens: 100 30 | post_processing_fn: null 31 | 32 | train_evaluation: 33 | eval_batch_size: 100 34 | metrics: 35 | - id: meteor 36 | args: {} 37 | - id: rouge 38 | - id: bleu 39 | args: {} 40 | - id: bert_score 41 | args: 42 | language: en 43 | # - id: bleurt 44 | # args: 45 | # config_name: bleurt-large-512 46 | - id: diversity 47 | args: {} 48 | # - id: summaCZS 49 | # args: 50 | # granularity: sentence 51 | # use_ent: True 52 | # use_con: False 53 | # - id: summaCConv 54 | # args: 55 | # granularity: sentence 56 | 57 | -------------------------------------------------------------------------------- /scripts/training/task_configs/synthetic_generate_dates/gpt2_ppo.yml: -------------------------------------------------------------------------------- 1 | tokenizer: 2 | model_name: gpt2 3 | padding_side: left 4 | truncation_side: left 5 | pad_token_as_eos_token: True 6 | 7 | reward_fn: 8 | id: sentences_with_dates 9 | args: {} 10 | 11 | datapool: 12 | id: dummy_pool 13 | args: 14 | n_samples: 50 15 | prompt: '<|endoftext|>' 16 | 17 | env: 18 | n_envs: 10 19 | args: 20 | max_prompt_length: 5 21 | max_episode_length: 30 22 | terminate_on_eos: True 23 | 24 | alg: 25 | id: ppo 26 | args: 27 | n_steps: 128 28 | batch_size: 64 29 | verbose: 1 30 | learning_rate: 0.00001 31 | n_epochs: 5 32 | ent_coef: 0.001 33 | clip_range: 0.2 34 | kl_div: 35 | coeff: 0.02 36 | target_kl: 2 37 | policy: 38 | id: causal_lm_actor_critic_policy 39 | args: 40 | model_name: gpt2 41 | apply_model_parallel: True 42 | generation_kwargs: 43 | do_sample: True 44 | top_k: 0 45 | max_new_tokens: 30 # this must align with env's max steps 46 | 47 | train_evaluation: 48 | eval_batch_size: 256 49 | n_iters: 100 50 | eval_every: 5 51 | metrics: 52 | - id: dates 53 | args: {} -------------------------------------------------------------------------------- /scripts/training/task_configs/synthetic_generate_increasing_numbers/bart_ppo.yml: -------------------------------------------------------------------------------- 1 | tokenizer: 2 | model_name: facebook/bart-base 3 | padding_side: left 4 | truncation_side: left 5 | pad_token_as_eos_token: False 6 | 7 | reward_fn: 8 | id: increasing_numbers 9 | args: 10 | min_tokens: 20 11 | 12 | datapool: 13 | id: dummy_pool 14 | args: 15 | n_samples: 50 16 | prompt: 'Generate some numbers 1 2' 17 | env: 18 | n_envs: 10 19 | args: 20 | max_prompt_length: 5 21 | max_episode_length: 20 22 | terminate_on_eos: True 23 | context_start_token: 2 # this is decoder start token 24 | 25 | alg: 26 | id: ppo 27 | args: 28 | n_steps: 128 29 | batch_size: 128 30 | verbose: 1 31 | learning_rate: 0.0000001 32 | ent_coef: 0.0 33 | n_epochs: 5 34 | kl_div: 35 | coeff: 0.0001 36 | target_kl: 3 37 | policy: 38 | id: seq2seq_lm_actor_critic_policy 39 | args: 40 | model_name: facebook/bart-base 41 | apply_model_parallel: False 42 | generation_kwargs: 43 | do_sample: True 44 | min_length: 20 45 | top_k: 50 46 | max_new_tokens: 20 # this must align with env's max steps 47 | num_beams: 1 48 | 49 | train_evaluation: 50 | eval_batch_size: 256 51 | n_iters: 100 52 | eval_every: 10 53 | save_every: 10 54 | metrics: 55 | - id: increasing_numbers 56 | args: 57 | min_tokens: 20 -------------------------------------------------------------------------------- /scripts/training/task_configs/synthetic_generate_increasing_numbers/blendorbot_ppo.yml: -------------------------------------------------------------------------------- 1 | tokenizer: 2 | model_name: facebook/blenderbot-400M-distill 3 | padding_side: left 4 | truncation_side: left 5 | pad_token_as_eos_token: False 6 | 7 | reward_fn: 8 | id: increasing_numbers 9 | args: 10 | min_tokens: 20 11 | 12 | datapool: 13 | id: dummy_pool 14 | args: 15 | n_samples: 50 16 | prompt: 'Let us generate some numbers' 17 | env: 18 | n_envs: 10 19 | args: 20 | max_prompt_length: 5 21 | max_episode_length: 20 22 | terminate_on_eos: True 23 | context_start_token: 1 # this is decoder start token 24 | 25 | alg: 26 | id: ppo 27 | args: 28 | n_steps: 128 29 | batch_size: 64 30 | verbose: 1 31 | learning_rate: 0.000001 32 | ent_coef: 0.001 33 | n_epochs: 5 34 | kl_div: 35 | coeff: 0.00001 36 | target_kl: 3 37 | policy: 38 | id: seq2seq_lm_actor_critic_policy 39 | args: 40 | model_name: facebook/blenderbot-400M-distill 41 | apply_model_parallel: False 42 | generation_kwargs: 43 | do_sample: True 44 | min_length: 20 45 | top_k: 200 46 | max_new_tokens: 20 # this must align with env's max steps 47 | num_beams: 1 48 | 49 | train_evaluation: 50 | eval_batch_size: 256 51 | n_iters: 100 52 | eval_every: 10 53 | save_every: 10 54 | metrics: 55 | - id: increasing_numbers 56 | args: 57 | min_tokens: 20 -------------------------------------------------------------------------------- /scripts/training/task_configs/synthetic_generate_increasing_numbers/gpt2_a2c.yml: -------------------------------------------------------------------------------- 1 | tokenizer: 2 | model_name: gpt2 3 | padding_side: left 4 | truncation_side: left 5 | pad_token_as_eos_token: True 6 | 7 | reward_fn: 8 | id: increasing_numbers 9 | args: 10 | min_tokens: 20 11 | 12 | datapool: 13 | id: dummy_pool 14 | args: 15 | n_samples: 50 16 | prompt: '<|endoftext|>' 17 | 18 | env: 19 | n_envs: 10 20 | args: 21 | max_prompt_length: 5 22 | max_episode_length: 20 23 | terminate_on_eos: True 24 | 25 | alg: 26 | id: a2c 27 | args: 28 | n_steps: 20 29 | verbose: 1 30 | learning_rate: 0.00001 31 | ent_coef: 0.001 32 | kl_div: 33 | coeff: 0.02 34 | target_kl: 2 35 | policy: 36 | id: causal_lm_actor_critic_policy 37 | args: 38 | model_name: gpt2 39 | apply_model_parallel: True 40 | generation_kwargs: 41 | do_sample: True 42 | max_new_tokens: 20 #this must align with env's max steps 43 | 44 | train_evaluation: 45 | eval_batch_size: 256 46 | n_iters: 500 47 | eval_every: 20 48 | save_every: 20 49 | metrics: 50 | - id: increasing_numbers 51 | args: 52 | min_tokens: 20 -------------------------------------------------------------------------------- /scripts/training/task_configs/synthetic_generate_increasing_numbers/gpt2_nlpo.yml: -------------------------------------------------------------------------------- 1 | tokenizer: 2 | model_name: gpt2 3 | padding_side: left 4 | truncation_side: left 5 | pad_token_as_eos_token: True 6 | 7 | reward_fn: 8 | id: increasing_numbers 9 | args: 10 | min_tokens: 20 11 | 12 | datapool: 13 | id: dummy_pool 14 | args: 15 | n_samples: 50 16 | prompt: '<|endoftext|>' 17 | env: 18 | n_envs: 10 19 | args: 20 | max_prompt_length: 5 21 | max_episode_length: 20 22 | terminate_on_eos: True 23 | context_start_token: 0 24 | 25 | alg: 26 | id: nlpo 27 | args: 28 | n_steps: 128 29 | batch_size: 64 30 | verbose: 1 31 | learning_rate: 0.00001 32 | n_epochs: 5 33 | ent_coef: 0.0 34 | gae_lambda: 0.9 35 | vf_coef: 0.1 36 | kl_div: 37 | coeff: 0.02 38 | target_kl: 2 39 | policy: 40 | id: maskable_causal_lm_actor_critic_policy 41 | args: 42 | model_name: gpt2 43 | apply_model_parallel: True 44 | top_mask: 0.9 45 | mask_type: 'learned_top_p' 46 | target_update_iterations: 100 47 | generation_kwargs: 48 | do_sample: True 49 | top_k: 0 50 | min_length: 20 51 | max_new_tokens: 20 # this must align with env's max steps 52 | 53 | train_evaluation: 54 | eval_batch_size: 256 55 | n_iters: 200 56 | eval_every: 10 57 | save_every: 10 58 | metrics: 59 | - id: increasing_numbers 60 | args: 61 | min_tokens: 20 -------------------------------------------------------------------------------- /scripts/training/task_configs/synthetic_generate_increasing_numbers/gpt2_ppo.yml: -------------------------------------------------------------------------------- 1 | tokenizer: 2 | model_name: gpt2 3 | padding_side: left 4 | truncation_side: left 5 | pad_token_as_eos_token: True 6 | 7 | reward_fn: 8 | id: increasing_numbers 9 | args: 10 | min_tokens: 20 11 | 12 | datapool: 13 | id: dummy_pool 14 | args: 15 | n_samples: 50 16 | prompt: '<|endoftext|>' 17 | 18 | env: 19 | n_envs: 10 20 | args: 21 | max_prompt_length: 5 22 | max_episode_length: 20 23 | terminate_on_eos: True 24 | 25 | alg: 26 | id: ppo 27 | args: 28 | n_steps: 128 29 | batch_size: 64 30 | verbose: 1 31 | learning_rate: 0.000001 32 | n_epochs: 5 33 | ent_coef: 0.001 34 | kl_div: 35 | coeff: 0.02 36 | target_kl: 2 37 | policy: 38 | id: causal_lm_actor_critic_policy 39 | args: 40 | model_name: gpt2 41 | apply_model_parallel: True 42 | generation_kwargs: 43 | do_sample: True 44 | max_new_tokens: 20 #this must align with env's max steps 45 | 46 | train_evaluation: 47 | eval_batch_size: 256 48 | n_iters: 100 49 | eval_every: 5 50 | save_every: 20 51 | metrics: 52 | - id: increasing_numbers 53 | args: 54 | min_tokens: 20 -------------------------------------------------------------------------------- /scripts/training/task_configs/synthetic_generate_increasing_numbers/gpt2_trpo.yml: -------------------------------------------------------------------------------- 1 | tokenizer: 2 | model_name: gpt2 3 | padding_side: left 4 | truncation_side: left 5 | pad_token_as_eos_token: True 6 | 7 | reward_fn: 8 | id: increasing_numbers 9 | args: 10 | min_tokens: 20 11 | 12 | datapool: 13 | id: dummy_pool 14 | args: 15 | n_samples: 50 16 | prompt: '<|endoftext|>' 17 | 18 | env: 19 | n_envs: 2 20 | args: 21 | max_prompt_length: 5 22 | max_episode_length: 20 23 | terminate_on_eos: True 24 | 25 | alg: 26 | id: trpo 27 | args: 28 | n_steps: 20 29 | verbose: 1 30 | learning_rate: 0.00001 31 | kl_div: 32 | coeff: 0.01 33 | target_kl: 2 34 | policy: 35 | id: causal_lm_actor_critic_policy 36 | args: 37 | model_name: gpt2 38 | apply_model_parallel: True 39 | generation_kwargs: 40 | do_sample: True 41 | max_new_tokens: 20 #this must align with env's max steps 42 | 43 | train_evaluation: 44 | eval_batch_size: 256 45 | n_iters: 1000 46 | eval_every: 5 47 | save_every: 20 48 | metrics: 49 | - id: increasing_numbers 50 | args: 51 | min_tokens: 20 -------------------------------------------------------------------------------- /scripts/training/task_configs/synthetic_generate_increasing_numbers/t5_nlpo.yml: -------------------------------------------------------------------------------- 1 | tokenizer: 2 | model_name: t5-base 3 | padding_side: left 4 | truncation_side: left 5 | pad_token_as_eos_token: False 6 | 7 | reward_fn: 8 | id: increasing_numbers 9 | args: 10 | min_tokens: 20 11 | 12 | 13 | datapool: 14 | id: dummy_pool 15 | args: 16 | n_samples: 50 17 | prompt: 'Let us generate some numbers' 18 | 19 | env: 20 | n_envs: 10 21 | args: 22 | max_prompt_length: 5 23 | max_episode_length: 20 24 | terminate_on_eos: True 25 | context_start_token: 0 26 | 27 | alg: 28 | id: nlpo 29 | args: 30 | n_steps: 128 31 | batch_size: 64 32 | verbose: 1 33 | learning_rate: 0.000002 34 | n_epochs: 5 35 | kl_div: 36 | coeff: 0.001 37 | target_kl: 2.0 38 | policy: 39 | id: maskable_seq2seq_lm_actor_critic_policy 40 | args: 41 | model_name: t5-base 42 | apply_model_parallel: True 43 | mask_type: "learned_top_p" 44 | top_mask: 0.9 45 | target_update_iterations: 20 46 | generation_kwargs: 47 | do_sample: True 48 | min_length: 20 49 | top_k: 200 50 | max_new_tokens: 20 # this must align with env's max steps 51 | 52 | train_evaluation: 53 | eval_batch_size: 256 54 | n_iters: 100 55 | eval_every: 10 56 | save_every: 10 57 | metrics: 58 | - id: increasing_numbers 59 | args: 60 | min_tokens: 20 61 | -------------------------------------------------------------------------------- /scripts/training/task_configs/synthetic_generate_increasing_numbers/t5_ppo.yml: -------------------------------------------------------------------------------- 1 | tokenizer: 2 | model_name: t5-base 3 | padding_side: left 4 | truncation_side: left 5 | pad_token_as_eos_token: False 6 | 7 | reward_fn: 8 | id: increasing_numbers 9 | args: 10 | min_tokens: 20 11 | 12 | datapool: 13 | id: dummy_pool 14 | args: 15 | n_samples: 50 16 | prompt: 'Let us generate some numbers' 17 | env: 18 | n_envs: 10 19 | args: 20 | max_prompt_length: 5 21 | max_episode_length: 20 22 | terminate_on_eos: True 23 | context_start_token: 0 24 | 25 | alg: 26 | id: ppo 27 | args: 28 | n_steps: 128 29 | batch_size: 64 30 | verbose: 1 31 | learning_rate: 0.000001 32 | ent_coef: 0.001 33 | n_epochs: 5 34 | kl_div: 35 | coeff: 0.00001 36 | target_kl: 3 37 | policy: 38 | id: seq2seq_lm_actor_critic_policy 39 | args: 40 | model_name: t5-base 41 | apply_model_parallel: True 42 | generation_kwargs: 43 | do_sample: True 44 | min_length: 20 45 | top_k: 200 46 | max_new_tokens: 20 # this must align with env's max steps 47 | 48 | train_evaluation: 49 | eval_batch_size: 256 50 | n_iters: 100 51 | eval_every: 10 52 | save_every: 10 53 | metrics: 54 | - id: increasing_numbers 55 | args: 56 | min_tokens: 20 -------------------------------------------------------------------------------- /scripts/training/task_configs/totto/t5_nlpo.yml: -------------------------------------------------------------------------------- 1 | tokenizer: 2 | model_name: t5-base 3 | padding_side: left 4 | truncation_side: right 5 | pad_token_as_eos_token: False 6 | 7 | reward_fn: 8 | id: meteor 9 | # - id: parent 10 | # - id: meteor 11 | # - id: meteor 12 | # args: 13 | # shaping_fn: "parent" 14 | # - id: bleu 15 | # - id: sacre_bleu 16 | 17 | datapool: 18 | id: totto 19 | args: 20 | representation: 'subtable' 21 | 22 | env: 23 | n_envs: 10 24 | args: 25 | max_prompt_length: 512 26 | max_episode_length: 50 27 | terminate_on_eos: True 28 | context_start_token: 0 29 | 30 | alg: 31 | id: nlpo 32 | args: 33 | n_steps: 256 34 | batch_size: 64 35 | verbose: 1 36 | learning_rate: 0.000002 37 | n_epochs: 5 38 | kl_div: 39 | coeff: 0.001 40 | target_kl: 2.0 41 | policy: 42 | id: maskable_seq2seq_lm_actor_critic_policy 43 | args: 44 | model_name: t5-base 45 | apply_model_parallel: True 46 | mask_type: "learned_top_p" 47 | top_mask: 0.9 48 | target_update_iterations: 20 49 | generation_kwargs: 50 | do_sample: True 51 | top_k: 0 52 | min_length: 10 53 | max_new_tokens: 50 54 | 55 | train_evaluation: 56 | eval_batch_size: 100 57 | n_iters: 100 58 | eval_every: 20 59 | save_every: 1 60 | metrics: 61 | - id: meteor 62 | args: {} 63 | - id: parent_totto 64 | args: {} 65 | - id: rouge 66 | args: 67 | use_single_ref: False 68 | - id: bleu_totto 69 | args: {} 70 | - id: bert_score 71 | args: 72 | language: en 73 | # - id: bleurt 74 | # args: 75 | # config_name: bleurt-large-512 76 | - id: diversity 77 | args: {} 78 | # - id: summaCZS 79 | # args: 80 | # granularity: sentence 81 | # use_ent: True 82 | # use_con: False 83 | # - id: summaCConv 84 | # args: 85 | # granularity: sentence 86 | generation_kwargs: 87 | num_beams: 5 88 | min_length: 10 89 | max_new_tokens: 50 90 | -------------------------------------------------------------------------------- /scripts/training/task_configs/totto/t5_nlpo_on_supervised.yml: -------------------------------------------------------------------------------- 1 | tokenizer: 2 | model_name: t5-base 3 | padding_side: left 4 | truncation_side: right 5 | pad_token_as_eos_token: False 6 | 7 | reward_fn: 8 | id: meteor 9 | # - id: parent 10 | # - id: meteor 11 | # - id: meteor 12 | # args: 13 | # shaping_fn: "parent" 14 | # - id: bleu 15 | # - id: sacre_bleu 16 | 17 | datapool: 18 | id: totto 19 | args: 20 | representation: 'subtable' 21 | 22 | env: 23 | n_envs: 10 24 | args: 25 | max_prompt_length: 512 26 | max_episode_length: 50 27 | terminate_on_eos: True 28 | context_start_token: 0 29 | 30 | alg: 31 | id: nlpo 32 | args: 33 | n_steps: 256 34 | batch_size: 64 35 | verbose: 1 36 | learning_rate: 0.0000005 37 | n_epochs: 5 38 | kl_div: 39 | coeff: 0.01 40 | target_kl: 0.2 41 | policy: 42 | id: maskable_seq2seq_lm_actor_critic_policy 43 | args: 44 | model_name: rajkumarrrk/t5-base-fine-tuned-on-totto 45 | apply_model_parallel: True 46 | mask_type: "learned_top_p" 47 | top_mask: 0.9 48 | target_update_iterations: 20 49 | generation_kwargs: 50 | do_sample: True 51 | top_k: 0 52 | min_length: 10 53 | max_new_tokens: 50 54 | 55 | train_evaluation: 56 | eval_batch_size: 100 57 | n_iters: 100 58 | eval_every: 20 59 | save_every: 1 60 | metrics: 61 | - id: meteor 62 | args: {} 63 | - id: parent_totto 64 | args: {} 65 | - id: rouge 66 | args: 67 | use_single_ref: False 68 | - id: bleu_totto 69 | args: {} 70 | - id: bert_score 71 | args: 72 | language: en 73 | # - id: bleurt 74 | # args: 75 | # config_name: bleurt-large-512 76 | - id: diversity 77 | args: {} 78 | # - id: summaCZS 79 | # args: 80 | # granularity: sentence 81 | # use_ent: True 82 | # use_con: False 83 | # - id: summaCConv 84 | # args: 85 | # granularity: sentence 86 | generation_kwargs: 87 | num_beams: 5 88 | min_length: 10 89 | max_new_tokens: 50 90 | -------------------------------------------------------------------------------- /scripts/training/task_configs/totto/t5_ppo.yml: -------------------------------------------------------------------------------- 1 | tokenizer: 2 | model_name: t5-base 3 | padding_side: left 4 | truncation_side: right 5 | pad_token_as_eos_token: False 6 | 7 | reward_fn: 8 | id: meteor 9 | # - id: parent 10 | # - id: meteor 11 | # - id: meteor 12 | # args: 13 | # shaping_fn: "parent" 14 | # - id: bleu 15 | # - id: sacre_bleu 16 | 17 | datapool: 18 | id: totto 19 | args: 20 | representation: 'subtable' 21 | 22 | 23 | env: 24 | n_envs: 10 25 | args: 26 | max_prompt_length: 512 27 | max_episode_length: 50 28 | terminate_on_eos: True 29 | context_start_token: 0 30 | 31 | 32 | alg: 33 | id: ppo 34 | args: 35 | n_steps: 256 36 | batch_size: 64 37 | verbose: 1 38 | learning_rate: 0.000002 39 | n_epochs: 5 40 | kl_div: 41 | coeff: 0.001 42 | target_kl: 2.0 43 | policy: 44 | id: seq2seq_lm_actor_critic_policy 45 | args: 46 | model_name: t5-base 47 | apply_model_parallel: True 48 | generation_kwargs: 49 | do_sample: True 50 | top_k: 0 51 | min_length: 10 52 | max_new_tokens: 50 53 | 54 | train_evaluation: 55 | eval_batch_size: 100 56 | n_iters: 100 57 | eval_every: 20 58 | save_every: 1 59 | metrics: 60 | - id: meteor 61 | args: {} 62 | - id: parent_totto 63 | args: {} 64 | - id: rouge 65 | args: 66 | use_single_ref: False 67 | - id: bleu_totto 68 | args: {} 69 | - id: bert_score 70 | args: 71 | language: en 72 | # - id: bleurt 73 | # args: 74 | # config_name: bleurt-large-512 75 | - id: diversity 76 | args: {} 77 | # - id: summaCZS 78 | # args: 79 | # granularity: sentence 80 | # use_ent: True 81 | # use_con: False 82 | # - id: summaCConv 83 | # args: 84 | # granularity: sentence 85 | generation_kwargs: 86 | num_beams: 5 87 | min_length: 10 88 | max_new_tokens: 50 89 | 90 | -------------------------------------------------------------------------------- /scripts/training/task_configs/totto/t5_ppo_on_supervised.yml: -------------------------------------------------------------------------------- 1 | tokenizer: 2 | model_name: t5-base 3 | padding_side: left 4 | truncation_side: right 5 | pad_token_as_eos_token: False 6 | 7 | reward_fn: 8 | id: meteor 9 | # values: 10 | # - id: parent 11 | # - id: meteor 12 | # - id: meteor 13 | # args: 14 | # shaping_fn: "parent" 15 | # - id: bleu 16 | # - id: sacre_bleu 17 | 18 | datapool: 19 | id: totto 20 | args: 21 | representation: 'subtable' 22 | 23 | 24 | env: 25 | n_envs: 10 26 | args: 27 | max_prompt_length: 512 28 | max_episode_length: 50 29 | terminate_on_eos: True 30 | context_start_token: 0 31 | 32 | 33 | alg: 34 | id: ppo 35 | args: 36 | n_steps: 256 37 | batch_size: 64 38 | verbose: 1 39 | learning_rate: 0.0000005 40 | n_epochs: 5 41 | kl_div: 42 | coeff: 0.01 43 | target_kl: 0.2 44 | policy: 45 | id: seq2seq_lm_actor_critic_policy 46 | args: 47 | model_name: rajkumarrrk/t5-base-fine-tuned-on-totto 48 | apply_model_parallel: True 49 | generation_kwargs: 50 | do_sample: True 51 | top_k: 50 52 | min_length: 10 53 | max_new_tokens: 50 54 | 55 | train_evaluation: 56 | eval_batch_size: 100 57 | n_iters: 100 58 | eval_every: 20 59 | save_every: 1 60 | metrics: 61 | - id: meteor 62 | args: {} 63 | - id: parent_totto 64 | args: {} 65 | - id: rouge 66 | args: 67 | use_single_ref: False 68 | - id: bleu_totto 69 | args: {} 70 | - id: bert_score 71 | args: 72 | language: en 73 | # - id: bleurt 74 | # args: 75 | # config_name: bleurt-large-512 76 | - id: diversity 77 | args: {} 78 | # - id: summaCZS 79 | # args: 80 | # granularity: sentence 81 | # use_ent: True 82 | # use_con: False 83 | # - id: summaCConv 84 | # args: 85 | # granularity: sentence 86 | generation_kwargs: 87 | num_beams: 5 88 | min_length: 10 89 | max_new_tokens: 50 90 | 91 | -------------------------------------------------------------------------------- /scripts/training/task_configs/totto/t5_supervised.yml: -------------------------------------------------------------------------------- 1 | tokenizer: 2 | model_name: t5-base 3 | padding_side: left 4 | truncation_side: right 5 | pad_token_as_eos_token: False 6 | 7 | datapool: 8 | id: totto 9 | args: 10 | representation: 'subtable' 11 | 12 | alg: 13 | id: supervised 14 | training_args: 15 | per_device_train_batch_size: 8 16 | logging_steps: 20000 17 | num_train_epochs: 5 18 | weight_decay: 0.1 19 | lr_scheduler_type: constant_with_warmup 20 | learning_rate: 0.0001 21 | save_total_limit: 1 22 | model_type: seq2seq 23 | model_name: "t5-base" 24 | generation_kwargs: 25 | do_sample: True 26 | num_beams: 10 27 | min_length: 10 28 | max_new_tokens: 50 29 | post_processing_fn: null 30 | 31 | train_evaluation: 32 | eval_batch_size: 100 33 | metrics: 34 | - id: meteor 35 | args: {} 36 | - id: parent_totto 37 | args: {} 38 | - id: rouge 39 | args: 40 | use_single_ref: False 41 | - id: bleu_totto 42 | args: {} 43 | - id: bert_score 44 | args: 45 | language: en 46 | # - id: bleurt 47 | # args: 48 | # config_name: bleurt-large-512 49 | - id: diversity 50 | args: {} 51 | # - id: summaCZS 52 | # args: 53 | # granularity: sentence 54 | # use_ent: True 55 | # use_con: False 56 | # - id: summaCConv 57 | # args: 58 | # granularity: sentence 59 | 60 | -------------------------------------------------------------------------------- /scripts/training/train_text_generation.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | 4 | import yaml 5 | 6 | from rl4lms.envs.text_generation.logging_utils import Tracker 7 | from rl4lms.envs.text_generation.training_utils import ( 8 | OnPolicyTrainer, 9 | SupervisedTrainer, 10 | ) 11 | 12 | 13 | def main( 14 | config_path: str, 15 | project_name: str, 16 | experiment_name: str, 17 | base_path_to_store_results: str, 18 | entity_name: str, 19 | log_to_wandb: bool, 20 | ): 21 | 22 | # load the config file 23 | with open(config_path, "r") as fp: 24 | config = yaml.safe_load(fp) 25 | 26 | # load tracker 27 | tracker = Tracker( 28 | base_path_to_store_results, 29 | config, 30 | project_name, 31 | experiment_name, 32 | entity_name, 33 | log_to_wandb, 34 | ) 35 | 36 | # instantiate the trainer here 37 | if "supervised" in config["alg"]["id"]: 38 | trainer = SupervisedTrainer( 39 | tokenizer_config=config["tokenizer"], 40 | datapool_config=config["datapool"], 41 | alg_config=config["alg"], 42 | train_eval_config=config["train_evaluation"], 43 | tracker=tracker, 44 | ) 45 | else: 46 | trainer = OnPolicyTrainer( 47 | tokenizer_config=config["tokenizer"], 48 | datapool_config=config["datapool"], 49 | reward_config=config["reward_fn"], 50 | env_config=config["env"], 51 | on_policy_alg_config=config["alg"], 52 | train_eval_config=config["train_evaluation"], 53 | tracker=tracker, 54 | ) 55 | trainer.train_and_eval() 56 | 57 | 58 | if __name__ == "__main__": 59 | parser = ArgumentParser(description="Fine-tune LM to generate controlled text") 60 | parser.add_argument("--config_path", type=str, help="path to the config file") 61 | parser.add_argument( 62 | "--project_name", type=str, help="WANDB project name", default="rl4lm_exps" 63 | ) 64 | parser.add_argument( 65 | "--experiment_name", 66 | type=str, 67 | help="WANDB experiment name", 68 | default="rl4lm_experiment", 69 | ) 70 | parser.add_argument( 71 | "--entity_name", type=str, help="WANDB entity name", default=None 72 | ) 73 | parser.add_argument( 74 | "--base_path_to_store_results", 75 | type=str, 76 | help="Base path to store experiment results", 77 | default=os.getcwd(), 78 | ) 79 | parser.add_argument( 80 | "--log_to_wandb", action="store_true", help="Whether to use wandb logging" 81 | ) 82 | args = parser.parse_args() 83 | 84 | main( 85 | args.config_path, 86 | args.project_name, 87 | args.experiment_name, 88 | args.base_path_to_store_results, 89 | args.entity_name, 90 | args.log_to_wandb, 91 | ) 92 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open("requirements.txt") as fp: 4 | requirements = fp.read().splitlines() 5 | 6 | setup( 7 | name="rl4lms", 8 | version="0.2.2", 9 | description="A library for training language models (LM) using RL", 10 | author="Rajkumar Ramamurthy, Prithviraj Ammanabrolu", 11 | packages=find_packages(), 12 | python_requires=">=3.7", 13 | install_requires=requirements, 14 | url="https://github.com/allenai/RL4LMs", 15 | ) 16 | --------------------------------------------------------------------------------