├── requirements.txt ├── tests ├── __init__.py ├── accumulation_sampler_test.py ├── pipeline_scorer_test.py ├── exponential_scorer_test.py ├── quasi_rejection_sampler_test.py ├── boolean_scorer_test.py ├── product_test.py ├── tuner_test.py ├── base_distribution_test.py └── lm_distribution_test.py ├── disco ├── tuners │ ├── loggers │ │ ├── __init__.py │ │ ├── base.py │ │ ├── wandb.py │ │ ├── console.py │ │ ├── tensorboard.py │ │ ├── neptune.py │ │ ├── json.py │ │ └── mlflow.py │ ├── losses │ │ ├── base.py │ │ ├── __init__.py │ │ ├── kl.py │ │ ├── tv.py │ │ ├── reverse_kl.py │ │ ├── chi_square.py │ │ ├── js.py │ │ ├── reverse_chi_square.py │ │ └── f_divergence.py │ ├── __init__.py │ ├── dpg_tuner.py │ ├── cdpg_tuner.py │ ├── fcdpg_tuner.py │ └── fdpg_tuner.py ├── __init__.py ├── metrics │ ├── __init__.py │ ├── base.py │ ├── tv.py │ ├── kl.py │ └── js.py ├── utils │ ├── __init__.py │ ├── observable.py │ ├── timer.py │ ├── device.py │ ├── moving_average.py │ └── helpers.py ├── samplers │ ├── __init__.py │ ├── sampler.py │ ├── accumulation_sampler.py │ └── quasi_rejection_sampler.py ├── scorers │ ├── __init__.py │ ├── boolean_scorer.py │ ├── scorer.py │ ├── pipeline_scorer.py │ ├── exponential_scorer.py │ └── positive_scorer.py └── distributions │ ├── __init__.py │ ├── distribution.py │ ├── single_context_distribution.py │ ├── context_distribution.py │ ├── dataset_context_distribution.py │ └── base_distribution.py ├── .gitignore ├── Makefile ├── make.bat ├── scripts ├── online_tuning_script.py └── offline_tuning_script.py ├── tutorials ├── data │ └── incipits.txt ├── 5.QRS_sampling.ipynb ├── 4.conditional_tuning.ipynb ├── 3.tuning_DPG.ipynb └── 1.quick_introduction.ipynb ├── setup.py ├── LICENSE └── README.MD /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | transformers 3 | datasets 4 | notebook 5 | numpy 6 | scipy 7 | spacy 8 | neptune-client 9 | wandb -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | -------------------------------------------------------------------------------- /disco/tuners/loggers/__init__.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | -------------------------------------------------------------------------------- /disco/__init__.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | __version__ = "1.1.1" 6 | __author__ = 'Naver Labs Europe' 7 | -------------------------------------------------------------------------------- /disco/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | from .kl import KL 6 | from .tv import TV 7 | from .js import JS 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | *$py.class 4 | 5 | *.pt 6 | 7 | *.egg-info/ 8 | dist/ 9 | 10 | .vscode 11 | .ipynb_checkpoints/ 12 | 13 | .env* 14 | env*/ 15 | 16 | explorations/ 17 | models/ 18 | 19 | .DS_Store 20 | .neptune -------------------------------------------------------------------------------- /disco/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | import imp 6 | from .device import get_device 7 | from .helpers import get_token_first_indices 8 | from .helpers import batchify -------------------------------------------------------------------------------- /disco/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | from .sampler import Sampler 6 | from .quasi_rejection_sampler import QuasiRejectionSampler 7 | from .accumulation_sampler import AccumulationSampler -------------------------------------------------------------------------------- /disco/tuners/losses/base.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | from disco.utils.observable import Observable 6 | 7 | class BaseLoss: 8 | def __init__(self): 9 | self.metric_updated = Observable() 10 | 11 | -------------------------------------------------------------------------------- /disco/tuners/__init__.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | from .tuner import Tuner 6 | from .dpg_tuner import DPGTuner 7 | from .cdpg_tuner import CDPGTuner 8 | from .fdpg_tuner import FDPGTuner 9 | from .fcdpg_tuner import FCDPGTuner 10 | -------------------------------------------------------------------------------- /disco/tuners/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | from .kl import KLLoss 6 | from .js import JSLoss 7 | from .tv import TVLoss 8 | from .reverse_kl import ReverseKLLoss 9 | from .chi_square import ChiSquaredLoss 10 | from .reverse_chi_square import ReverseChiSquaredLoss 11 | -------------------------------------------------------------------------------- /disco/scorers/__init__.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | from .scorer import Scorer 6 | from .positive_scorer import PositiveScorer 7 | from .positive_scorer import Product 8 | from .exponential_scorer import ExponentialScorer 9 | from .boolean_scorer import BooleanScorer 10 | from .pipeline_scorer import PipelineScorer -------------------------------------------------------------------------------- /disco/samplers/sampler.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | from abc import ABC, abstractmethod 6 | 7 | class Sampler(ABC): 8 | """ 9 | Top-level abstract class for all samplers 10 | """ 11 | 12 | def __init__(self, target, proposal): 13 | self.proposal = proposal 14 | self.target = target 15 | 16 | @abstractmethod 17 | def sample(self): 18 | pass -------------------------------------------------------------------------------- /disco/distributions/__init__.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | from .distribution import Distribution 6 | from .base_distribution import BaseDistribution 7 | from .lm_distribution import LMDistribution 8 | from .single_context_distribution import SingleContextDistribution 9 | from .context_distribution import ContextDistribution 10 | from .dataset_context_distribution import DatasetContextDistribution -------------------------------------------------------------------------------- /disco/distributions/distribution.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | from abc import abstractmethod 6 | import torch 7 | 8 | from disco.scorers.positive_scorer import PositiveScorer 9 | 10 | 11 | class Distribution(PositiveScorer): 12 | """ 13 | Abstract distribution class, a core entity which can 14 | be introduced as a PositiveScorer that can produce samples. 15 | """ 16 | 17 | @abstractmethod 18 | def sample(self, context): 19 | """Produces samples for the context from the distribution. 20 | """ 21 | pass -------------------------------------------------------------------------------- /disco/scorers/boolean_scorer.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | import torch 6 | import numpy as np 7 | 8 | from .positive_scorer import PositiveScorer 9 | 10 | 11 | class BooleanScorer(PositiveScorer): 12 | """ 13 | Predicate-based scoring class 14 | """ 15 | 16 | def __init__(self, predicate): 17 | """ 18 | Parameters 19 | ---------- 20 | predicate: scoring predicate 21 | predicate function to be used on each sample 22 | """ 23 | super().__init__(predicate) 24 | 25 | self.predicate = self._broadcast(lambda s, c: predicate(s, c).float()) -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = doc 9 | BUILDDIR = docbuild 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /disco/utils/observable.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | class Observable(object): 6 | """Textbook observer pattern""" 7 | 8 | def __init__(self): 9 | self.observers = set() 10 | 11 | def enroll(self, function): 12 | self.observers.add(function) 13 | 14 | def dispatch(self, *args, **kwargs): 15 | for f in self.observers: 16 | f(*args, **kwargs) 17 | 18 | def forward(observable1, observable2): 19 | """Forwards messages from an observable to another 20 | with the same signature""" 21 | def forwarder(*args, **kwargs): 22 | observable2.dispatch(*args, **kwargs) 23 | observable1.enroll(forwarder) 24 | 25 | -------------------------------------------------------------------------------- /tests/accumulation_sampler_test.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | import unittest 6 | 7 | from disco.distributions import LMDistribution 8 | from disco.samplers import AccumulationSampler 9 | 10 | class AccumulationSamplerTest(unittest.TestCase): 11 | 12 | def test_clm_sample(self): 13 | prefix = "It was a cold and stormy night" 14 | proposal = LMDistribution() 15 | 16 | total_size = 2**8 17 | sampler = AccumulationSampler(proposal, total_size=total_size) 18 | samples, _ = sampler.sample(context=prefix, sampling_size=2**6) 19 | 20 | self.assertEqual(total_size, len(samples), 21 | "there should as many sampled sequences as requested from the CLM.") 22 | 23 | 24 | if __name__ == '__main__': 25 | unittest.main() -------------------------------------------------------------------------------- /disco/tuners/losses/kl.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | import torch 6 | from .f_divergence import FDivergenceLoss 7 | 8 | class KLLoss(FDivergenceLoss): 9 | """ 10 | Kullback-Leibler divergence loss for DPG 11 | """ 12 | def __init__(self, use_baseline=True, baseline_window_size=1024): 13 | """ 14 | Parameters 15 | ---------- 16 | use_baseline: boolean 17 | use a baseline to reduce variance 18 | """ 19 | super(KLLoss, self).__init__(use_baseline, baseline_window_size) 20 | 21 | def f_prime(self, log_t): 22 | """ 23 | Parameters 24 | ---------- 25 | log_t: 0-dim Tensor 26 | The log ratio of the policy and the normalized target distribution 27 | """ 28 | return -torch.exp(-log_t) -------------------------------------------------------------------------------- /disco/tuners/losses/tv.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | from .f_divergence import FDivergenceLoss 6 | import torch 7 | 8 | class TVLoss(FDivergenceLoss): 9 | """ 10 | Total variation distance loss for DPG 11 | """ 12 | def __init__(self, use_baseline=True, baseline_window_size=1024): 13 | """ 14 | Parameters 15 | ---------- 16 | use_baseline: boolean 17 | use a baseline to reduce variance 18 | """ 19 | super(TVLoss, self).__init__(use_baseline, baseline_window_size) 20 | 21 | def f_prime(self, log_t): 22 | """ 23 | Parameters 24 | ---------- 25 | log_t: 0-dim Tensor 26 | The log ratio of the policy and the normalized target distribution 27 | """ 28 | return torch.sign(log_t) / 2. -------------------------------------------------------------------------------- /make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=doc 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.https://www.sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /disco/tuners/losses/reverse_kl.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | import torch 6 | from .f_divergence import FDivergenceLoss 7 | 8 | class ReverseKLLoss(FDivergenceLoss): 9 | """ 10 | Kullback-Leibler divergence loss for DPG 11 | """ 12 | def __init__(self, use_baseline=True, baseline_window_size=1024): 13 | """ 14 | Parameters 15 | ---------- 16 | use_baseline: boolean 17 | use a baseline to reduce variance 18 | """ 19 | super(ReverseKLLoss, self).__init__(use_baseline, baseline_window_size) 20 | 21 | def f_prime(self, log_t): 22 | """ 23 | Parameters 24 | ---------- 25 | log_t: 0-dim Tensor 26 | The log ratio of the policy and the normalized target distribution 27 | """ 28 | return log_t + 1 -------------------------------------------------------------------------------- /disco/tuners/losses/chi_square.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | import torch 6 | from .f_divergence import FDivergenceLoss 7 | 8 | class ChiSquaredLoss(FDivergenceLoss): 9 | """ 10 | Chi squared divergence χ²(p || π) loss for DPG 11 | """ 12 | def __init__(self, use_baseline=True, baseline_window_size=1024): 13 | """ 14 | Parameters 15 | ---------- 16 | use_baseline: boolean 17 | use a baseline to reduce variance 18 | """ 19 | super(ChiSquaredLoss, self).__init__(use_baseline, baseline_window_size) 20 | 21 | def f_prime(self, log_t): 22 | """ 23 | Parameters 24 | ---------- 25 | log_t: 0-dim Tensor 26 | The log ratio of the policy and the normalized target distribution 27 | """ 28 | return 1.0 - torch.exp(-2 * log_t) -------------------------------------------------------------------------------- /disco/utils/timer.py: -------------------------------------------------------------------------------- 1 | import time 2 | from contextlib import contextmanager 3 | 4 | class Timer: 5 | """A context manager for timing code execution.""" 6 | 7 | def __init__(self, name="Timer"): 8 | self.name = name 9 | self.elapsed = 0 10 | 11 | def __enter__(self): 12 | self.start_time = time.perf_counter() 13 | return self 14 | 15 | def __exit__(self, exc_type, exc_val, exc_tb): 16 | end_time = time.perf_counter() 17 | self.elapsed = end_time - self.start_time 18 | 19 | def __str__(self): 20 | return f"{self.name}: {self.elapsed:.4f} seconds" 21 | 22 | 23 | # Alternative using @contextmanager decorator 24 | @contextmanager 25 | def timer(name="Timer"): 26 | """A simple timer context manager using contextmanager decorator.""" 27 | start_time = time.perf_counter() 28 | try: 29 | yield 30 | finally: 31 | end_time = time.perf_counter() 32 | elapsed = end_time - start_time -------------------------------------------------------------------------------- /disco/tuners/losses/js.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | from .f_divergence import FDivergenceLoss 6 | import torch 7 | 8 | class JSLoss(FDivergenceLoss): 9 | """ 10 | Jensen-Shannon divergence loss for DPG 11 | """ 12 | def __init__(self, use_baseline=True, baseline_window_size=1024): 13 | """ 14 | Parameters 15 | ---------- 16 | use_baseline: boolean 17 | use a baseline to reduce variance 18 | """ 19 | super(JSLoss, self).__init__(use_baseline, baseline_window_size) 20 | 21 | def f_prime(self, log_t): 22 | """ 23 | Parameters 24 | ---------- 25 | log_t: 0-dim Tensor 26 | The log ratio of the policy and the normalized target distribution 27 | """ 28 | log_2 = torch.log(torch.tensor(2.0, dtype=log_t.dtype, device=log_t.device)) 29 | return log_2 - torch.nn.functional.softplus(-log_t) -------------------------------------------------------------------------------- /disco/utils/device.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | def get_device(tensor): 6 | """Get the device of a tensor 7 | 8 | Parameters 9 | ---------- 10 | tensor: Tensor 11 | tensor from which to extract the device 12 | 13 | Returns 14 | ------- 15 | computation device (0-n or cpu) 16 | """ 17 | 18 | device = tensor.get_device() 19 | return "cpu" if 0 > device else device 20 | 21 | 22 | def to_same_device(*tensors, device=None): 23 | """Move all tensors to the same device (by default the first tensor's one) 24 | 25 | Parameters 26 | ---------- 27 | tensors: Tensor 28 | tensors which we want to move to a common device 29 | 30 | Returns 31 | ------- 32 | a tuple of tensors in the same order as the arguments moved to a common device 33 | """ 34 | if device is None: 35 | device = get_device(tensors[0]) 36 | return tuple(tensor.to(device) for tensor in tensors) 37 | -------------------------------------------------------------------------------- /disco/tuners/losses/reverse_chi_square.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | import torch 6 | from .f_divergence import FDivergenceLoss 7 | 8 | class ReverseChiSquaredLoss(FDivergenceLoss): 9 | """ 10 | Reverse Chi Squared divergence χ²(π || p) loss for DPG 11 | """ 12 | def __init__(self, use_baseline=True, baseline_window_size=1024): 13 | """ 14 | Parameters 15 | ---------- 16 | use_baseline: boolean 17 | use a baseline to reduce variance 18 | """ 19 | super(ReverseChiSquaredLoss, self).__init__(use_baseline, baseline_window_size) 20 | 21 | def f_prime(self, log_t): 22 | """ 23 | Computes the f' term for the reverse (Neyman) chi-square divergence. 24 | 25 | Parameters 26 | ---------- 27 | log_t: 0-dim Tensor 28 | The log ratio of the policy and the normalized target distribution, log(p/q). 29 | """ 30 | return (torch.exp(log_t) - 1) * 2 -------------------------------------------------------------------------------- /disco/tuners/loggers/base.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | class BaseTunerObserver(object): 6 | def __init__(self, tuner): 7 | tuner.parameters_updated.enroll(self.on_parameters_updated) 8 | tuner.metric_updated.enroll(self.on_metric_updated) 9 | tuner.step_idx_updated.enroll(self.on_step_idx_updated) 10 | tuner.ministep_idx_updated.enroll(self.on_ministep_idx_updated) 11 | tuner.eval_samples_updated.enroll(self.on_eval_samples_updated) 12 | 13 | def __enter__(self): 14 | return self 15 | 16 | def __exit__(self, type, value, traceback): 17 | pass 18 | 19 | def on_parameters_updated(self, params): 20 | pass 21 | 22 | def on_metric_updated(self, name, value): 23 | pass 24 | 25 | def on_step_idx_updated(self, s): 26 | pass 27 | 28 | def on_ministep_idx_updated(self, s): 29 | pass 30 | 31 | def on_eval_samples_updated(self, context, samples, proposal_log_scores, model_log_scores, target_log_scores): 32 | pass 33 | -------------------------------------------------------------------------------- /disco/tuners/dpg_tuner.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | from .cdpg_tuner import CDPGTuner 6 | from disco.distributions.single_context_distribution import SingleContextDistribution 7 | 8 | class DPGTuner(CDPGTuner): 9 | """ 10 | DPG tuning class, 11 | a specific case of CDPG with a single, fixed, context. 12 | 13 | The algorithm has been introduced in 14 | "Distributional Reinforcement Learning for Energy-Based Sequential Models" 15 | Tetiana Parshakova, Jean-Marc Andreoli, Marc Dymetman 16 | https://arxiv.org/abs/1912.08517 17 | """ 18 | 19 | def __init__(self, *args, context="", **kwargs): 20 | """ 21 | Parameters 22 | ---------- 23 | context: text 24 | a single textual sequence to contextualize the sampling from the proposal 25 | """ 26 | 27 | super(DPGTuner, self).__init__( 28 | *args, 29 | context_distribution=SingleContextDistribution(context), 30 | n_contexts_per_step=1, 31 | **kwargs 32 | ) 33 | -------------------------------------------------------------------------------- /disco/metrics/base.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | import torch 6 | 7 | from disco.utils.device import get_device 8 | 9 | class BaseDivergence: 10 | """ 11 | Kullback-Leibler divergence class. 12 | """ 13 | 14 | @classmethod 15 | def divergence(cls, m1_log_scores, m2_log_scores, z, proposal_log_scores=None): 16 | """ 17 | Computes an IS of the KL divergence between 2 distributions 18 | 19 | Parameters 20 | ---------- 21 | m1_log_scores: floats 22 | log-scores for samples according to network 1 23 | m2_log_scores: floats 24 | log-scores for samples according to network 2 25 | z: float 26 | partition function of network 1 27 | proposal_log_scores: floats 28 | log-scores for samples according to proposal (by default m2_log_scores) 29 | 30 | Returns 31 | ------- 32 | divergence between m1 and m2 33 | """ 34 | 35 | return torch.mean(cls.pointwise_estimates( 36 | m1_log_scores, m2_log_scores, z, proposal_log_scores=proposal_log_scores)) 37 | -------------------------------------------------------------------------------- /tests/pipeline_scorer_test.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | import unittest 6 | 7 | import numpy as np 8 | 9 | from disco.scorers import PipelineScorer 10 | 11 | class PipelineScorerTest(unittest.TestCase): 12 | 13 | def test_detection(self): 14 | params = { 15 | "task": "sentiment-analysis", 16 | "model": "siebert/sentiment-roberta-large-english" 17 | } 18 | pf = PipelineScorer('POSITIVE', params) 19 | 20 | texts = [ 21 | "This is so interesting: I love it!", 22 | "How can you be so depressing? I don't want to go there anymore." 23 | ] 24 | 25 | from disco.distributions.lm_distribution import TextSample 26 | samples = [TextSample(list(), t) for t in texts] # fake samples without the tokenizations 27 | 28 | log_scores = pf.log_score(samples, None) 29 | expected_sentiment = [1, -1] 30 | self.assertTrue(all(e == np.sign(np.exp(s)-0.5) for e, s in zip(expected_sentiment, log_scores)), 31 | "a sentiment should be correctly classified.") 32 | 33 | 34 | if __name__ == '__main__': 35 | unittest.main() 36 | -------------------------------------------------------------------------------- /disco/tuners/cdpg_tuner.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | from .tuner import Tuner 6 | from disco.distributions.single_context_distribution import SingleContextDistribution 7 | from disco.tuners.losses import * 8 | 9 | 10 | class CDPGTuner(Tuner): 11 | """Contextual DPG tuning class, 12 | relying on a ContextDistribution and KLLoss(). 13 | 14 | The algorithm has been introduced in 15 | "Controlling Conditional Language Models without Catastrophic Forgetting" 16 | Tomasz Korbak, Hady Elsahar, Germán Kruszewski and Marc Dymetman. 17 | https://proceedings.mlr.press/v162/korbak22a/korbak22a.pdf 18 | """ 19 | 20 | def __init__(self, *args, context_distribution=SingleContextDistribution(), 21 | loss=KLLoss(), **kwargs): 22 | """ 23 | Parameters 24 | ---------- 25 | context_distribution: distribution 26 | a distribution to contextualize the sampling from the proposal 27 | """ 28 | 29 | super(CDPGTuner, self).__init__( 30 | *args, 31 | context_distribution=context_distribution, 32 | loss=loss, 33 | **kwargs 34 | ) 35 | -------------------------------------------------------------------------------- /disco/tuners/fcdpg_tuner.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | from .tuner import Tuner 6 | from disco.distributions.single_context_distribution import SingleContextDistribution 7 | from disco.tuners.losses import * 8 | 9 | 10 | class FCDPGTuner(Tuner): 11 | """Contextual f-DPG tuning class. The algorithm was introduced in 12 | 13 | "Aligning Language Models with Preferences through f-divergence Minimization." 14 | Dongyoung Go, Tomasz Korbak, Germán Kruszewski, Jos Rozen, Nahyeon Ryu, Marc Dymetman. 15 | https://arxiv.org/abs/2302.08215 16 | """ 17 | 18 | def __init__(self, *args, context_distribution=SingleContextDistribution(), 19 | loss=JSLoss(), **kwargs): 20 | """ 21 | Parameters 22 | ---------- 23 | context_distribution: distribution 24 | a distribution to contextualize the sampling from the proposal 25 | loss: functor object 26 | used to compute of the loss at each step 27 | """ 28 | 29 | super(FCDPGTuner, self).__init__( 30 | *args, 31 | context_distribution=context_distribution, 32 | loss=loss, 33 | **kwargs 34 | ) 35 | -------------------------------------------------------------------------------- /disco/tuners/fdpg_tuner.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | from .tuner import Tuner 6 | from disco.distributions.single_context_distribution import SingleContextDistribution 7 | from disco.tuners.losses import * 8 | 9 | 10 | class FDPGTuner(Tuner): 11 | """Unconditional f-DPG tuning class. The algorithm was introduced in 12 | 13 | "Aligning Language Models with Preferences through f-divergence Minimization." 14 | Dongyoung Go, Tomasz Korbak, Germán Kruszewski, Jos Rozen, Nahyeon Ryu, Marc Dymetman. 15 | https://arxiv.org/abs/2302.08215 16 | """ 17 | 18 | def __init__(self, *args, context='', 19 | loss=JSLoss(), **kwargs): 20 | """ 21 | Parameters 22 | ---------- 23 | context: text 24 | a single textual sequence to contextualize the sampling from the proposal 25 | loss: functor object 26 | used to compute of the loss at each step 27 | """ 28 | 29 | super(FDPGTuner, self).__init__( 30 | *args, 31 | context_distribution=SingleContextDistribution(context), 32 | n_contexts_per_step=1, 33 | loss=loss, 34 | **kwargs 35 | ) 36 | -------------------------------------------------------------------------------- /disco/scorers/scorer.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | import torch 6 | import numpy as np 7 | 8 | 9 | class Scorer(): 10 | """ 11 | Generic scorer 12 | """ 13 | 14 | def __init__(self, scoring_function): 15 | """ 16 | Parameters 17 | ---------- 18 | scoring_function: scoring function 19 | function to be used on each sample 20 | """ 21 | 22 | self.scoring_function = self._broadcast(scoring_function) 23 | 24 | def _broadcast(self, function): 25 | def broadcasted_function(xs, context): 26 | return torch.tensor( 27 | np.array([function(x, context) for x in xs]) 28 | ) 29 | return broadcasted_function 30 | 31 | def score(self, samples, context): 32 | """Relies on the instance's scoring function to compute 33 | scores for the samples given the context 34 | 35 | Parameters 36 | ---------- 37 | samples : list() 38 | the samples to score, as a list 39 | context: text 40 | context that the samples relate to 41 | 42 | Returns 43 | ------- 44 | tensor of scores for the samples""" 45 | 46 | return self.scoring_function(samples, context) -------------------------------------------------------------------------------- /disco/scorers/pipeline_scorer.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | import torch 6 | from transformers import pipeline 7 | 8 | from .positive_scorer import PositiveScorer 9 | 10 | 11 | class PipelineScorer(PositiveScorer): 12 | """ 13 | Feature class relying on the pipelines from Huggingface's transformers 14 | """ 15 | 16 | def __init__(self, label, params, temperature=1.0): 17 | """initializes a PipelineFeature's instance 18 | 19 | Parameters 20 | ---------- 21 | label: string 22 | expected positive label from the pipeline 23 | """ 24 | 25 | self.label = label 26 | self.pipeline = pipeline(**params) 27 | self.temperature = temperature 28 | 29 | def log_score(self, samples, _): 30 | """computes the log-scores of the samples 31 | from the label returned by the pipeline 32 | 33 | Parameters 34 | ---------- 35 | samples : list(Sample) 36 | list of samples to log-score 37 | 38 | Returns 39 | ------- 40 | tensor of log-probabilities""" 41 | 42 | return torch.log( 43 | torch.tensor( 44 | [[r_i["score"] for r_i in r if self.label == r_i["label"]][0] 45 | for r in self.pipeline([s.text for s in samples], return_all_scores=True)] 46 | ).float() 47 | ) / self.temperature 48 | -------------------------------------------------------------------------------- /disco/tuners/loggers/wandb.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | import wandb 6 | import os 7 | from .base import BaseTunerObserver 8 | 9 | class WandBLogger(BaseTunerObserver): 10 | """ 11 | Reports DPGTuner statistics to Weights & Biases 12 | """ 13 | 14 | def __init__(self, tuner, project, name=None): 15 | """Constructor of a WandBLogger object 16 | 17 | Parameters 18 | ---------- 19 | tuner: DPGTuner 20 | The tuner object whose statistics we want to report 21 | project: string 22 | The W&B project to which we want to report the statistics 23 | """ 24 | super(WandBLogger, self).__init__(tuner) 25 | self.run = wandb.init(project=project, 26 | name=name) 27 | 28 | def __setitem__(self, k, v): 29 | """ 30 | Report arbitrary parameter/value combinations 31 | """ 32 | wandb.log({k: v}) 33 | 34 | def __exit__(self, *exc): 35 | self.run.finish() 36 | 37 | def on_parameters_updated(self, params): 38 | wandb.config.update(params) 39 | 40 | def on_metric_updated(self, name, value): 41 | wandb.log({name: value}) 42 | 43 | def on_eval_samples_updated(self, context, samples, proposal_log_scores, model_log_scores, target_log_scores): 44 | wandb.log({"samples": [s.text for s in samples[:10]]}) 45 | 46 | def on_step_idx_updated(self, s): 47 | wandb.log({"steps": s}) -------------------------------------------------------------------------------- /disco/metrics/tv.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | import torch 6 | 7 | from disco.utils.device import get_device 8 | from .base import BaseDivergence 9 | 10 | 11 | class TV(BaseDivergence): 12 | 13 | @classmethod 14 | def pointwise_estimates(cls, m1_log_scores, m2_log_scores, z, proposal_log_scores=None): 15 | """ 16 | computes the TVD between 2 distributions 17 | 18 | Parameters 19 | ---------- 20 | m1_log_scores: floats 21 | log-scores for samples according to network 1 22 | m2_log_scores: floats 23 | log-scores for samples according to network 2 24 | z: float 25 | partition function of network 1 26 | proposal_log_scores: floats 27 | log-scores for samples according to proposal (by default m2_log_scores) 28 | 29 | Returns 30 | ------- 31 | divergence between m1 and m2 32 | """ 33 | 34 | device = get_device(m1_log_scores) 35 | 36 | if isinstance(z, float): 37 | z = torch.tensor(z, device=device, dtype=m1_log_scores.dtype) 38 | 39 | m2_log_scores = m2_log_scores.to(device) 40 | 41 | if proposal_log_scores is None: 42 | proposal_log_scores = m2_log_scores 43 | else: 44 | proposal_log_scores = proposal_log_scores.to(device) 45 | 46 | normalized_m1_log_scores = m1_log_scores - torch.log(z) 47 | 48 | return 1/2 * (torch.abs(torch.exp(m2_log_scores - proposal_log_scores) - 49 | torch.exp(normalized_m1_log_scores - proposal_log_scores))) 50 | -------------------------------------------------------------------------------- /disco/tuners/loggers/console.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | from datetime import datetime 6 | from .base import BaseTunerObserver 7 | from tqdm.autonotebook import tqdm 8 | 9 | 10 | class ConsoleLogger(BaseTunerObserver): 11 | def __init__(self, tuner): 12 | super(ConsoleLogger, self).__init__(tuner) 13 | tuner.proposal_updated.enroll(self.on_proposal_updated) 14 | self.n = tuner.params["n_gradient_steps"] 15 | self.step = None 16 | 17 | def __exit__(self, *exc): 18 | stamp = datetime.now().strftime("%H:%M:%S (%Y/%m/%d)") 19 | print (f"finished at {stamp}") 20 | 21 | def on_parameters_updated(self, params): 22 | for k, v in params.items(): 23 | print (f"{k}: ", v) 24 | 25 | def on_eval_samples_updated(self, context, samples, proposal_log_scores, model_log_scores, target_log_scores): 26 | tqdm.write(f"Context: {context}") 27 | tqdm.write("Samples:") 28 | n = 3 29 | tqdm.write("\n".join(s.text + f"\ntarget: {t.item()} - proposal: {p.item()} - model: {m.item()}" for (s, p, m, t) in zip(samples[:n], proposal_log_scores[:n], model_log_scores[:n], target_log_scores[:n]))) 30 | 31 | def on_step_idx_updated(self, s): 32 | self.step = s 33 | tqdm.write(f"Step {s}/{self.n}") 34 | 35 | def on_proposal_updated(self, proposal, divergence_metric, divergence_target_new, divergence_target_old): 36 | tqdm.write(f"updating proposal according to {divergence_metric} divergence at step {self.step}: " 37 | f"{divergence_target_new} < {divergence_target_old}") -------------------------------------------------------------------------------- /disco/distributions/single_context_distribution.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | import torch 6 | 7 | from .distribution import Distribution 8 | 9 | class SingleContextDistribution(Distribution): 10 | """ 11 | Single context distribution class, useful to sample the 12 | same context that is to fall back to a fixed-context case. 13 | """ 14 | 15 | def __init__(self, context=''): 16 | """ 17 | Parameters 18 | ---------- 19 | context: string 20 | unique context to return when sampling 21 | """ 22 | 23 | self.context = context 24 | 25 | def log_score(self, contexts): 26 | """Computes log-probabilities of the contexts 27 | to match the instance's context 28 | 29 | Parameters 30 | ---------- 31 | contexts: list(str) 32 | list of contexts to (log-)score 33 | 34 | Returns 35 | ------- 36 | tensor of log-probabilities 37 | """ 38 | 39 | return torch.tensor([0 if self.context == context else -float("inf") for context in contexts]) 40 | 41 | def sample(self, sampling_size=32): 42 | """Samples multiple copies of the instance's context 43 | 44 | Parameters 45 | ---------- 46 | sampling_size: int 47 | number of contexts to sample 48 | 49 | Returns 50 | ------- 51 | tuple of (list of texts, tensor of log-probabilities) 52 | """ 53 | 54 | return ( 55 | [self.context] * sampling_size, 56 | [0] * sampling_size 57 | ) -------------------------------------------------------------------------------- /scripts/online_tuning_script.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | import os, re 6 | from datetime import datetime 7 | import torch 8 | 9 | from disco.distributions import LMDistribution 10 | from disco.scorers import BooleanScorer 11 | from disco.tuners import DPGTuner 12 | from disco.tuners.loggers.console import ConsoleLogger 13 | 14 | 15 | word = "amazing" 16 | incipit = "It was a cold and stormy night" 17 | path = "models" 18 | gpt = "gpt2" # or a larger gpt2-medium or another CLM from Transformers 19 | dev0, dev1 = "cpu", "cpu" # "0", "1" to use GPUs 20 | n_gradient_steps = 10 # 1000 or more for actual tuning 21 | divergence_evaluation_interval = 2**2 # 2**4 for actual tuning? 22 | 23 | base = LMDistribution(model=gpt, device=dev0) 24 | has_word = lambda s, c: bool(re.search(f"\\b{word}\\b", s.text)) 25 | word_scorer = BooleanScorer(has_word) 26 | target = base * word_scorer 27 | 28 | model = LMDistribution(model=gpt, freeze=False, device=dev1) 29 | 30 | tuner = DPGTuner(model, target, 31 | context = incipit, 32 | features = [(word, word_scorer)], 33 | n_gradient_steps=n_gradient_steps, 34 | n_samples_per_context=2**8, 35 | sampling_size=2**5, 36 | scoring_size=2**5, 37 | divergence_evaluation_interval=divergence_evaluation_interval, 38 | n_kl_samples=2**10) 39 | console_logger = ConsoleLogger(tuner) 40 | tuner.tune() 41 | 42 | samples, _ = model.sample(context=incipit) 43 | print("rate after tuning is:") 44 | print(sum([has_word(s, _) for s in samples]) / len(samples)) 45 | 46 | stamp = datetime.now().strftime("%Y_%m_%d-%H:%M") 47 | torch.save(model, os.path.join(path, f"{word}.{stamp}.pt")) 48 | -------------------------------------------------------------------------------- /scripts/offline_tuning_script.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | import os, re 6 | from datetime import datetime 7 | import torch 8 | 9 | from disco.distributions import LMDistribution 10 | from disco.scorers import BooleanScorer 11 | from disco.tuners import DPGTuner 12 | from disco.tuners.loggers.console import ConsoleLogger 13 | 14 | 15 | word = "amazing" 16 | incipit = "It was a cold and stormy night" 17 | path = "models" 18 | gpt = "gpt2" # or a larger gpt2-medium or another CLM from Transformers 19 | dev0, dev1 = "cpu", "cpu" # "0", "1" to use GPUs 20 | n_gradient_steps = 10 # 1000 or more for actual tuning 21 | divergence_evaluation_interval = 2**2 # 2**4 for actual tuning? 22 | 23 | base = LMDistribution(model=gpt, device=dev0) 24 | has_word = lambda s, c: bool(re.search(f"\\b{word}\\b", s.text)) 25 | word_scorer = BooleanScorer(has_word) 26 | target = base * word_scorer 27 | 28 | proposal = LMDistribution(model=gpt, device=dev0) 29 | 30 | model = LMDistribution(model=gpt, freeze=False, device=dev1) 31 | 32 | tuner = DPGTuner(model, target, proposal, 33 | context = incipit, 34 | features = [(word, word_scorer)], 35 | n_gradient_steps=n_gradient_steps, 36 | n_samples_per_context=2**8, 37 | sampling_size=2**5, 38 | scoring_size=2**5, 39 | divergence_evaluation_interval=divergence_evaluation_interval, 40 | n_kl_samples=2**10) 41 | console_logger = ConsoleLogger(tuner) 42 | tuner.tune() 43 | 44 | samples, _ = model.sample(context=incipit) 45 | print("rate after tuning is:") 46 | print(sum([has_word(s, _) for s in samples]) / len(samples)) 47 | 48 | stamp = datetime.now().strftime("%Y_%m_%d-%H:%M") 49 | torch.save(model, os.path.join(path, f"{word}.{stamp}.pt")) 50 | -------------------------------------------------------------------------------- /disco/metrics/kl.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | import torch 6 | 7 | from disco.utils.device import get_device 8 | from .base import BaseDivergence 9 | 10 | 11 | class KL(BaseDivergence): 12 | """ 13 | Kullback-Leibler divergence class 14 | """ 15 | 16 | @classmethod 17 | def pointwise_estimates(cls, m1_log_scores, m2_log_scores, z, proposal_log_scores=None): 18 | """ 19 | computes the KL divergence between 2 distributions 20 | 21 | Parameters 22 | ---------- 23 | m1_log_scores: floats 24 | log-scores for samples according to network 1 25 | m2_log_scores: floats 26 | log-scores for samples according to network 2 27 | z: float 28 | partition function of network 1 29 | proposal_log_scores: floats 30 | log-scores for samples according to proposal (by default m2_log_scores) 31 | 32 | Returns 33 | ------- 34 | divergence between m1 and m2 35 | """ 36 | 37 | device = get_device(m1_log_scores) 38 | 39 | if isinstance(z, float): 40 | z = torch.tensor(z, device=device, dtype=m1_log_scores.dtype) 41 | 42 | m2_log_scores = m2_log_scores.to(device) 43 | 44 | if proposal_log_scores is None: 45 | proposal_log_scores = m2_log_scores 46 | else: 47 | proposal_log_scores = proposal_log_scores.to(device) 48 | 49 | importance_ratio = torch.exp(m1_log_scores - proposal_log_scores) 50 | 51 | unnormalized_pointwise_estimates = importance_ratio * (m1_log_scores - m2_log_scores) 52 | unnormalized_pointwise_estimates[ 53 | torch.isnan(unnormalized_pointwise_estimates)] = 0 54 | 55 | return -1 * torch.log(z) + (1 / z) * unnormalized_pointwise_estimates 56 | -------------------------------------------------------------------------------- /disco/tuners/loggers/tensorboard.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | import os 6 | from torch.utils.tensorboard import SummaryWriter 7 | from .base import BaseTunerObserver 8 | from collections import defaultdict 9 | 10 | class TensorBoardLogger(BaseTunerObserver): 11 | """ 12 | Reports DPGTuner statistics to Neptune 13 | """ 14 | 15 | def __init__(self, tuner, **kwargs): 16 | """Constructor of a NeptuneLogger object 17 | 18 | Parameters 19 | ---------- 20 | tuner: DPGTuner 21 | The tuner object whose statistics we want to report 22 | """ 23 | super(TensorBoardLogger, self).__init__(tuner) 24 | self.writer = SummaryWriter(**kwargs) 25 | self.x_counter = defaultdict(int) 26 | 27 | def __setitem__(self, k, v): 28 | """ 29 | Report arbitrary parameter/value combinations 30 | """ 31 | if isinstance(v, str): 32 | self.writer.add_text(k, v) 33 | else: 34 | self.writer.add_scalar(k, v) 35 | 36 | def __exit__(self, *exc): 37 | self.writer.close() 38 | 39 | def on_parameters_updated(self, params): 40 | self.writer.add_hparams(params, {}) 41 | 42 | def on_metric_updated(self, name, value): 43 | x = self.x_counter[name] 44 | self.x_counter[name] += 1 45 | self.writer.add_scalar(name, value, x) 46 | 47 | def on_eval_samples_updated(self, context, samples, proposal_log_scores, model_log_scores, target_log_scores): 48 | x = self.x_counter['samples'] 49 | self.x_counter['samples'] += 1 50 | for s in samples[:10]: 51 | self.writer.add_text('samples', context + s.text, x) 52 | 53 | def on_step_idx_updated(self, s): 54 | pass 55 | 56 | def on_ministep_idx_updated(self, s): 57 | pass 58 | # self.run["ministeps"].log(s) 59 | -------------------------------------------------------------------------------- /disco/metrics/js.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | import torch 6 | import logging 7 | from disco.utils.device import get_device 8 | from .base import BaseDivergence 9 | from .kl import KL 10 | 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | class JS(BaseDivergence): 15 | """ 16 | Jensen-Shannon divergence class. 17 | """ 18 | 19 | @classmethod 20 | def pointwise_estimates(cls, m1_log_scores, m2_log_scores, z, proposal_log_scores=None): 21 | """ 22 | Computes the KL divergence between 2 distributions 23 | 24 | Parameters 25 | ---------- 26 | m1_log_scores: floats 27 | log-scores for samples according to network 1 28 | m2_log_scores: floats 29 | log-scores for samples according to network 2 30 | z: float 31 | partition function of network 1 32 | proposal_log_scores: floats 33 | log-scores for samples according to proposal (by default m2_log_scores) 34 | 35 | Returns 36 | ------- 37 | divergence between m1 and m2 38 | """ 39 | 40 | device = get_device(m1_log_scores) 41 | 42 | if isinstance(z, float): 43 | z = torch.tensor(z, device=device, dtype=m1_log_scores.dtype) 44 | 45 | m2_log_scores = m2_log_scores.to(device) 46 | normalized_m1_log_scores = m1_log_scores - torch.log(z) 47 | 48 | max_log_scores = torch.max(normalized_m1_log_scores, m2_log_scores) 49 | 50 | m_log_scores = max_log_scores + torch.log(((normalized_m1_log_scores - max_log_scores).double().exp() + (m2_log_scores - max_log_scores).double().exp()) / 2).float() 51 | 52 | divergence = KL.pointwise_estimates(normalized_m1_log_scores, m_log_scores, torch.as_tensor(1), proposal_log_scores) / 2 + \ 53 | KL.pointwise_estimates(m2_log_scores, m_log_scores, torch.as_tensor(1), proposal_log_scores) / 2 54 | return divergence 55 | -------------------------------------------------------------------------------- /disco/tuners/loggers/neptune.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | import neptune 6 | import os 7 | from .base import BaseTunerObserver 8 | 9 | def get_proxies(): 10 | proxies = {} 11 | http_proxy = os.getenv("http_proxy") or os.getenv("HTTP_PROXY") 12 | if http_proxy: 13 | proxies["http"] = http_proxy 14 | https_proxy = os.getenv("https_proxy") or os.getenv("HTTPS_PROXY") 15 | if https_proxy: 16 | proxies["https"] = https_proxy 17 | return proxies 18 | 19 | class NeptuneLogger(BaseTunerObserver): 20 | """ 21 | Reports DPGTuner statistics to Neptune 22 | """ 23 | 24 | def __init__(self, tuner, project, name=None, api_token=None, **kwargs): 25 | """Constructor of a NeptuneLogger object 26 | 27 | Parameters 28 | ---------- 29 | tuner: DPGTuner 30 | The tuner object whose statistics we want to report 31 | project: string 32 | The Neptune project to which we want to report the statistics 33 | api_token: string 34 | The Neptune API token 35 | """ 36 | super(NeptuneLogger, self).__init__(tuner) 37 | if not 'proxies' in kwargs: 38 | kwargs['proxies'] = get_proxies() 39 | self.run = neptune.init_run(project=project, 40 | name=name, 41 | api_token=api_token, 42 | **kwargs) 43 | 44 | def __setitem__(self, k, v): 45 | """ 46 | Report arbitrary parameter/value combinations 47 | """ 48 | self.run[k] = v 49 | 50 | def __exit__(self, *exc): 51 | self.run.stop() 52 | 53 | def on_parameters_updated(self, params): 54 | self.run["parameters"] = params 55 | 56 | def on_metric_updated(self, name, value): 57 | self.run[name].log(value) 58 | 59 | def on_eval_samples_updated(self, context, samples, proposal_log_scores, model_log_scores, target_log_scores): 60 | self.run["samples"].log([s.text for s in samples[:10]]) 61 | 62 | def on_step_idx_updated(self, s): 63 | self.run["steps"].log(s) -------------------------------------------------------------------------------- /disco/scorers/exponential_scorer.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | import torch 6 | import numpy as np 7 | 8 | from .positive_scorer import PositiveScorer 9 | from disco.utils.device import get_device 10 | 11 | 12 | class ExponentialScorer(PositiveScorer): 13 | """Exponential scorer to add distributional constraints 14 | when building an EBM. 15 | """ 16 | 17 | def __init__(self, features, coefficients): 18 | """ 19 | Parameters 20 | ---------- 21 | features: list(Scorer) 22 | scoring features 23 | coefficients: list(float) 24 | features' coefficients 25 | """ 26 | 27 | if not len(features) == len(coefficients): 28 | raise ValueError("there should be as many as many coefficients as there are features.") 29 | 30 | self.features = features 31 | if type(coefficients) in [list, np.ndarray]: 32 | self.coefficients = torch.tensor(coefficients) 33 | else: 34 | self.coefficients = coefficients 35 | 36 | if torch.Tensor != type(self.coefficients): 37 | raise TypeError("coefficients should come in a tensor, or a tensorable structure.") 38 | 39 | def log_score(self, samples, context): 40 | """Log-scores the samples given the context 41 | using the instance's features and their coefficients 42 | 43 | Parameters 44 | ---------- 45 | samples : list(str) 46 | list of samples to log-score 47 | context: text 48 | context used for the samples 49 | 50 | Returns 51 | ------- 52 | tensor of log-scores""" 53 | 54 | device = get_device(self.coefficients) 55 | 56 | feature_log_scores = torch.stack( 57 | ([feature.score(samples, context).to(device) for feature in self.features]) 58 | ) # [n_features, n_samples] 59 | weighted_log_scores = self.coefficients.repeat(len(samples), 1) * feature_log_scores.t() 60 | 61 | return weighted_log_scores.sum(dim=1) 62 | 63 | def __str__(self): 64 | return f"ExponentialScorer({self.features}, {self.coefficients})" 65 | -------------------------------------------------------------------------------- /disco/distributions/context_distribution.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | import torch 6 | import numpy as np 7 | from random import sample 8 | 9 | from .distribution import Distribution 10 | 11 | class ContextDistribution(Distribution): 12 | """ 13 | Context distribution class, fetching the contexts from a text file. 14 | It can be used as a template for other context distributions. 15 | """ 16 | 17 | def __init__(self, path="contexts.txt"): 18 | """ 19 | Parameters 20 | ---------- 21 | path: string 22 | path to context file 23 | """ 24 | 25 | try: 26 | with open(path) as f: 27 | self.contexts = f.readlines() 28 | except IOError: 29 | self.contexts = list() 30 | 31 | assert self.contexts, "there's an issue with the context file provided." 32 | 33 | def log_score(self, contexts): 34 | """Computes log-probabilities of the contexts 35 | 36 | Parameters 37 | ---------- 38 | contexts: list(str) 39 | list of contexts to (log-)score 40 | 41 | Returns 42 | ------- 43 | tensor of logprobabilities 44 | """ 45 | 46 | assert contexts, "there needs to be contexts to (log-)score." 47 | 48 | n_contexts = len(contexts) 49 | return torch.tensor( 50 | [np.log(self.contexts.count(context) / n_contexts) if context in self.contexts\ 51 | else -float("inf")\ 52 | for context in contexts 53 | ] 54 | ) 55 | 56 | def sample(self, sampling_size=32): 57 | """Samples random elements from the list of contexts 58 | 59 | Parameters 60 | ---------- 61 | sampling_size: int 62 | number of contexts to sample 63 | 64 | Returns 65 | ------- 66 | tuple of (list of texts, tensor of logprobs) 67 | """ 68 | 69 | assert len(self.contexts) >= sampling_size, "the contexts does not have enough elements to sample." 70 | 71 | contexts = sample(self.contexts, sampling_size) 72 | return (contexts, self.log_score(contexts)) -------------------------------------------------------------------------------- /disco/distributions/dataset_context_distribution.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | import torch 6 | import numpy as np 7 | from random import sample 8 | from datasets import load_dataset 9 | 10 | from .distribution import Distribution 11 | 12 | class DatasetContextDistribution(Distribution): 13 | """ 14 | Context distribution class, fetching the contexts from a text file. 15 | It can be used as a template for other context distributions. 16 | """ 17 | def __init__(self, dataset="", subset="", split="train", key="text", prefix=""): 18 | """ 19 | Parameters 20 | ---------- 21 | dataset: string 22 | name of dataset in Hugging Face's Datasets 23 | subset: string 24 | reference of subset in dataset 25 | split: string 26 | reference of split in dataset/subset 27 | key: string 28 | key to use on row to pick the relevant part 29 | prefix: text 30 | text prepended to each context 31 | """ 32 | 33 | try: 34 | self.dataset = load_dataset(dataset, subset, split=split) 35 | except IOError: 36 | self.dataset = list() 37 | 38 | assert self.dataset, "there's an issue with the parameters of the dataset." 39 | 40 | self.key = key 41 | self.prefix = prefix 42 | 43 | def log_score(self, contexts): 44 | """Computes plausible log-probabilities of the contexts. 45 | Note that there's no check that the context are part of the dataset, 46 | hence the plausible qualifier. 47 | 48 | Parameters 49 | ---------- 50 | contexts: list(str) 51 | list of contexts to (log-)score 52 | 53 | Returns 54 | ------- 55 | tensor of logprobabilities 56 | """ 57 | 58 | assert contexts, "there needs to be contexts to (log-)score." 59 | 60 | return torch.log(torch.full((len(contexts), ), 1 / self.dataset.num_rows)) 61 | 62 | def sample(self, sampling_size=32): 63 | """Samples random elements from the list of contexts 64 | 65 | Parameters 66 | ---------- 67 | sampling_size: int 68 | number of contexts to sample 69 | 70 | Returns 71 | ------- 72 | tuple of (list of texts, tensor of logprobs) 73 | """ 74 | 75 | assert self.dataset.num_rows >= sampling_size, "the dataset does not have enough elements to sample." 76 | 77 | contexts = [self.prefix + c[self.key]\ 78 | for c in self.dataset.select(sample(range(self.dataset.num_rows), sampling_size))] 79 | return (contexts, self.log_score(contexts)) -------------------------------------------------------------------------------- /tests/exponential_scorer_test.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | import unittest 6 | 7 | import numpy as np 8 | 9 | from disco.scorers import ExponentialScorer 10 | from disco.scorers import BooleanScorer 11 | 12 | 13 | rain = lambda s, c: "rain" in s.text 14 | city = lambda s, c: "city" in s.text 15 | 16 | class ExponentialScorerTest(unittest.TestCase): 17 | 18 | def test_features_and_coefficients_match(self): 19 | scorer = ExponentialScorer([BooleanScorer(rain), BooleanScorer(city)], [0.5, 0.25]) 20 | self.assertTrue(hasattr(scorer, "features"), 21 | "the exponential scorer should have a features attribute.") 22 | self.assertTrue(hasattr(scorer, "coefficients"), 23 | "the exponential scorer should have a coefficients attribute.") 24 | self.assertEqual(len(scorer.features), len(scorer.coefficients), 25 | "the length of both features and coefficients list should be equal.") 26 | 27 | def test_features_and_coefficients_mismatch(self): 28 | with self.assertRaises(ValueError) as cm: 29 | ExponentialScorer( 30 | [BooleanScorer(rain)], 31 | [0.5, 0.25] 32 | ) 33 | 34 | def test_coefficients_as_tensor_like(self): 35 | with self.assertRaises(TypeError) as cm: 36 | ExponentialScorer( 37 | [BooleanScorer(rain)], 38 | 0.5 39 | ) 40 | with self.assertRaises(TypeError) as cm: 41 | ExponentialScorer( 42 | [BooleanScorer(rain), BooleanScorer(city)], 43 | {"rain": 0.5, "city": 0.25} 44 | ) 45 | 46 | def test_score(self): 47 | scorer = ExponentialScorer([BooleanScorer(rain), BooleanScorer(city)], [0.5, 0.25]) 48 | texts = [ 49 | "I'm singing in the rain.", 50 | "What is the city but the people?", 51 | "The rain that fell on the city runs down the dark gutters and empties into the sea without even soaking the ground" 52 | "Every drop in the ocean counts." 53 | ] 54 | 55 | from disco.distributions.lm_distribution import TextSample 56 | samples = [TextSample(list(), t) for t in texts] # fake samples without the tokenizations 57 | 58 | scores = scorer.score(samples, None) 59 | self.assertEqual(len(samples), len(scores), 60 | "there should be a score for each sample.") 61 | log_scores = scorer.log_score(samples, None) 62 | self.assertEqual(len(samples), len(log_scores), 63 | "there should be a (log-)score for each sample.") 64 | for e, s in zip([0.5, 0.25, 0.75, 0], log_scores): 65 | self.assertEqual(e, s, 66 | "the exponential scorer should (log-)score correctly." 67 | ) 68 | 69 | 70 | if __name__ == '__main__': 71 | unittest.main() -------------------------------------------------------------------------------- /tests/quasi_rejection_sampler_test.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | import unittest 6 | 7 | from disco.distributions import LMDistribution 8 | from disco.scorers import BooleanScorer 9 | from disco.samplers import QuasiRejectionSampler 10 | from disco.samplers.quasi_rejection_sampler import QuasiRejectionSamplerEstimator 11 | from disco.metrics import KL, TV 12 | from scipy import stats 13 | import torch 14 | 15 | class QuasiRejectionSamplerTest(unittest.TestCase): 16 | 17 | def test_sample(self): 18 | prefix = "It was a cold and stormy" 19 | word = "night" 20 | target = LMDistribution() * BooleanScorer(lambda s, c: word in s.text) 21 | proposal = LMDistribution() 22 | 23 | sampler = QuasiRejectionSampler(target, proposal) 24 | samples, _ = sampler.sample(sampling_size=32, context=prefix) 25 | 26 | self.assertTrue(all(word in s.text for s in samples), 27 | "sampled sequences should respect the constraints.") 28 | 29 | self.assertGreater(sampler.get_acceptance_rate(), 0., 30 | "that sampling should probably return at least a sample.") 31 | 32 | def test_estimator(self): 33 | proposal = PoissonDistribution(10) 34 | target = PoissonDistribution(11) 35 | estimator = QuasiRejectionSamplerEstimator(target, proposal, n_estimation_samples=int(10e4)) 36 | ar = estimator.acceptance_rate_at_beta(0.5) 37 | self.assertAlmostEqual(ar, 0.999, places=2) 38 | 39 | ar = estimator.acceptance_rate_at_beta(1) 40 | self.assertAlmostEqual(ar, 0.877, places=2) 41 | 42 | ar = estimator.acceptance_rate_at_beta(2) 43 | self.assertAlmostEqual(ar, 0.5, places=2) 44 | 45 | tvd = estimator.divergence_at_beta(0.5, divergence=TV) 46 | self.assertAlmostEqual(tvd, 0.12, places=2) 47 | 48 | tvd = estimator.divergence_at_beta(1, divergence=TV) 49 | self.assertAlmostEqual(tvd, 0.075, places=2) 50 | 51 | tvd = estimator.divergence_at_beta(2, divergence=TV) 52 | self.assertAlmostEqual(tvd, 0.0041, places=2) 53 | 54 | kl = estimator.divergence_at_beta(1, divergence=KL) 55 | self.assertAlmostEqual(kl, 0.021, places=2) 56 | 57 | kl = estimator.divergence_at_beta(0.5, divergence=KL) 58 | self.assertAlmostEqual(kl, 0.046, places=2) 59 | 60 | kl = estimator.divergence_at_beta(2, divergence=KL) 61 | self.assertAlmostEqual(kl, 0.0, places=2) 62 | 63 | class PoissonDistribution(object): 64 | 65 | def __init__(self, lam): 66 | self.lam = lam 67 | 68 | def sample(self, sampling_size=1, context=None): 69 | samples = list(stats.poisson.rvs(self.lam, size=sampling_size)) 70 | return samples, self.log_score(samples) 71 | 72 | def log_score(self, x, context=None): 73 | return torch.tensor(stats.poisson.logpmf(x, self.lam)) 74 | 75 | if __name__ == '__main__': 76 | unittest.main() 77 | -------------------------------------------------------------------------------- /tests/boolean_scorer_test.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | import unittest 6 | 7 | import numpy as np 8 | 9 | from disco.scorers import BooleanScorer 10 | 11 | class BooleanScorerTest(unittest.TestCase): 12 | 13 | def test_truism(self): 14 | bf = BooleanScorer(lambda x, _: True) 15 | items = [1, 1.2345, "blah"] 16 | self.assertEqual(len(items), len(bf.score(items, None)), 17 | "a boolean feature should score all items.") 18 | self.assertTrue(all(1. == s for s in bf.score(items, None)), 19 | "a truism should always score 1.") 20 | self.assertEqual(len(items), len(bf.log_score(items, None)), 21 | "a boolean feature should (log-)score all items.") 22 | self.assertTrue(all(0 == s for s in bf.log_score(items, None)), 23 | "a truism should always (log-)score 0.") 24 | 25 | def test_true_and_false(self): 26 | bf = BooleanScorer(lambda x, _: 2 < len(x)) 27 | scores = bf.score(["a", "abc"], None) 28 | self.assertEqual(0, scores[0], 29 | "False should result in 0.") 30 | self.assertEqual(1, scores[1], 31 | "True should result in 1.") 32 | log_scores = bf.log_score(["a", "abc"], None) 33 | self.assertEqual(-np.Inf, log_scores[0], 34 | "False should result in an infinite negative in log space.") 35 | self.assertEqual(0, log_scores[1], 36 | "True should result in a zero in log space.") 37 | 38 | def test_lambda_predicate(self): 39 | bf = BooleanScorer(lambda x, _: 2 < len(x)) 40 | samples = ["", "a", "ab", "abc", "abcd"] 41 | scores = bf.score(samples, None) 42 | expected_scores = [0, 0, 0, 1, 1] 43 | self.assertTrue(all(e == s for e, s in zip(expected_scores, scores)), 44 | "a predicate expressed via a lambda should score correctly the items.") 45 | log_scores = bf.log_score(samples, None) 46 | expected_log_scores = [-np.Inf, -np.Inf, -np.Inf, 0, 0] 47 | self.assertTrue(all(e == s for e, s in zip(expected_log_scores, log_scores)), 48 | "a predicate expressed via a lambda should (log-)score correctly the items.") 49 | 50 | def test_fn_predicate(self): 51 | def longer_than_2(x, _): 52 | return 2 < len(x) 53 | bf = BooleanScorer(longer_than_2) 54 | samples = ["", "a", "ab", "abc", "abcd"] 55 | scores = bf.score(samples, None) 56 | expected_scores = [0, 0, 0, 1, 1] 57 | self.assertTrue(all(e == s for e, s in zip(expected_scores, scores)), 58 | "a predicate expressed via a regular function should score correctly the items.") 59 | log_scores = bf.log_score(samples, None) 60 | expected_log_scores = [-np.Inf, -np.Inf, -np.Inf, 0, 0] 61 | self.assertTrue(all(e == s for e, s in zip(expected_log_scores, log_scores)), 62 | "a predicate expressed via a regular function should (log-)score correctly the items.") 63 | 64 | if __name__ == '__main__': 65 | unittest.main() -------------------------------------------------------------------------------- /disco/tuners/losses/f_divergence.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | import torch 6 | from .base import BaseLoss 7 | from disco.utils.moving_average import WindowedMovingAverage 8 | 9 | class FDivergenceLoss(BaseLoss): 10 | """ 11 | Kullback-Leibler divergence loss for DPG 12 | """ 13 | def __init__(self, use_baseline=True, baseline_window_size=1024): 14 | """ 15 | Parameters 16 | ---------- 17 | use_baseline: boolean 18 | use a baseline to reduce variance 19 | """ 20 | super(FDivergenceLoss, self).__init__() 21 | if use_baseline: 22 | self.baseline = WindowedMovingAverage(baseline_window_size) 23 | else: 24 | self.baseline = None 25 | 26 | def __call__(self, proposal_log_probs, target_log_scores, policy_log_probs, z): 27 | """ 28 | Computes the KL loss on a given minibatch of samples 29 | ∇ loss = π(x) / q(x) * f'(π(x) / p(x))) * ∇ log π(x) 30 | 31 | Parameters 32 | ---------- 33 | samples: list of items 34 | samples from the proposal network 35 | context: text 36 | context for the samples 37 | proposal_log_scores: 1-dim Tensor 38 | log-probabilities for the samples according to the proposal 39 | target_log_scores: 1-dim Tensor 40 | log-probabilities for the samples according to the target unnormalized distribution (EBM) 41 | policy_log_scores: 1-dim Tensor 42 | log-probabilities for the samples according to the policy network 43 | z: 0-dim Tensor 44 | estimation of the partition function of the target distribution 45 | 46 | Returns 47 | ------- 48 | mean loss across the minibatch 49 | """ 50 | target_log_probs = target_log_scores - torch.log(z) 51 | 52 | log_t = policy_log_probs.detach() - target_log_probs 53 | 54 | # implemented in derived class depending on the desired f 55 | f_prime = self.f_prime(log_t) 56 | 57 | pseudo_reward = -f_prime 58 | 59 | for r in pseudo_reward: 60 | self.metric_updated.dispatch('pseudo_reward', r.item()) 61 | 62 | if self.baseline is not None: 63 | if self.baseline.value is not None: 64 | advantage = pseudo_reward - self.baseline.value 65 | else: 66 | advantage = pseudo_reward 67 | 68 | self.baseline.update(pseudo_reward) 69 | 70 | self.metric_updated.dispatch("baseline", self.baseline.value) 71 | else: 72 | advantage = pseudo_reward 73 | 74 | 75 | for a in advantage: 76 | self.metric_updated.dispatch('advantage', a.item()) 77 | 78 | importance_ratios = (policy_log_probs.detach() - proposal_log_probs).exp() 79 | 80 | for ir in importance_ratios: 81 | self.metric_updated.dispatch('importance_ratios', ir.item()) 82 | 83 | loss = (importance_ratios * (-advantage) * policy_log_probs).mean() 84 | 85 | return loss -------------------------------------------------------------------------------- /tests/product_test.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | import unittest 6 | 7 | import numpy as np 8 | 9 | from disco.scorers import Product 10 | from disco.scorers import BooleanScorer 11 | 12 | class ProductFeatureTest(unittest.TestCase): 13 | 14 | def test_single_feature(self): 15 | integers = np.random.random_integers(0, 5, 100) 16 | bf = BooleanScorer(lambda x, _: 5 < x) 17 | bf_log_scores = bf.log_score(integers, None) 18 | product = Product(bf) 19 | pr_log_scores = product.log_score(integers, None) 20 | self.assertTrue(all(b == p for b, p in zip(bf_log_scores, pr_log_scores)), 21 | "a product of a single feature should behave like that feature.") 22 | 23 | def test_two_features(self): 24 | integers = np.random.random_integers(0, 5, 100) 25 | bf1 = BooleanScorer(lambda x, _: 5 < x) 26 | bf1_log_scores = bf1.log_score(integers, None) 27 | bf2 = BooleanScorer(lambda x, _: 8 > x) 28 | bf2_log_scores = bf2.log_score(integers, None) 29 | product = Product(bf1, bf2) 30 | pr_log_scores = product.log_score(integers, None) 31 | self.assertTrue(all((b1 + b2) == p for b1, b2, p in zip(bf1_log_scores, bf2_log_scores, pr_log_scores)), 32 | "a product of two features should sum their respective log-scores.") 33 | 34 | def test_three_features(self): 35 | integers = np.random.random_integers(0, 5, 100) 36 | bf1 = BooleanScorer(lambda x, _: 2 < x) 37 | bf1_log_scores = bf1.log_score(integers, None) 38 | bf2 = BooleanScorer(lambda x, _: 8 > x) 39 | bf2_log_scores = bf2.log_score(integers, None) 40 | bf3 = BooleanScorer(lambda x, _: 0 == x % 2) 41 | bf3_log_scores = bf3.log_score(integers, None) 42 | product = Product(bf1, bf2, bf3) 43 | pr_log_scores = product.log_score(integers, None) 44 | self.assertTrue(all((b1 + b2 + b3) == p for b1, b2, b3, p in zip(bf1_log_scores, bf2_log_scores, bf3_log_scores, pr_log_scores)), 45 | "a product of three features should sum their respective log-scores.") 46 | 47 | def test_product_is_commutative(self): 48 | integers = np.random.random_integers(0, 5, 100) 49 | bf1 = BooleanScorer(lambda x, _: 5 < x) 50 | bf2 = BooleanScorer(lambda x, _: 8 > x) 51 | pr12 = Product(bf1, bf2) 52 | pr21 = Product(bf1, bf2) 53 | self.assertTrue(all(p12 == p21 for p12, p21 in zip(pr12.log_score(integers, None), pr21.log_score(integers, None))), 54 | "a product of two features should be commutative.") 55 | 56 | def test_product_has_a_sugar(self): 57 | integers = np.random.random_integers(0, 5, 100) 58 | bf1 = BooleanScorer(lambda x, _: 5 < x) 59 | bf1_log_scores = bf1.log_score(integers, None) 60 | bf2 = BooleanScorer(lambda x, _: 8 > x) 61 | bf2_log_scores = bf2.log_score(integers, None) 62 | product = bf1 * bf2 63 | pr_log_scores = product.log_score(integers, None) 64 | self.assertTrue(all((b1 + b2) == p for b1, b2, p in zip(bf1_log_scores, bf2_log_scores, pr_log_scores)), 65 | "a product also works when defined through its syntactic sugar.") 66 | 67 | 68 | if __name__ == '__main__': 69 | unittest.main() -------------------------------------------------------------------------------- /disco/tuners/loggers/json.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | import os 6 | from .base import BaseTunerObserver 7 | import json 8 | from pathlib import Path 9 | import torch 10 | import logging 11 | from collections import defaultdict 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | class JSONLogger(BaseTunerObserver): 16 | """ 17 | Reports DPGTuner statistics to a JSON file 18 | """ 19 | 20 | def __init__(self, tuner, project, name, path=os.environ['DISCO_SAVE_PATH'], 21 | save_steps=1, store_eval_samples=False, **kwargs): 22 | """Constructor of a JSONLogger object 23 | 24 | Parameters 25 | ---------- 26 | tuner: DPGTuner 27 | The tuner object whose statistics we want to report 28 | path: string/Path 29 | The path where we want to store the logs 30 | project: string 31 | The subfolder where to store the logs 32 | name: string 33 | The filename to which we want to report the statistics 34 | save_steps: integer 35 | Number of gradient steps every which to write the json data to disk 36 | store_eval_samples: boolean 37 | Whether or not to store the samples in the json file 38 | """ 39 | super(JSONLogger, self).__init__(tuner) 40 | self.filename = Path(path) / project / f"{name}.json" 41 | self.filename.parent.mkdir(parents=True, exist_ok=True) 42 | self.data = defaultdict(list) 43 | self.save_steps = save_steps 44 | self.store_eval_samples = store_eval_samples 45 | 46 | def __exit__(self, *exc): 47 | self.save() 48 | 49 | def save(self): 50 | with open(self.filename, 'w') as fout: 51 | json.dump(self.data, fout) 52 | 53 | def __setitem__(self, k, v): 54 | """ 55 | Report arbitrary parameter/value combinations 56 | """ 57 | if isinstance(v, Path): 58 | v = str(v) # avoid serialization error 59 | elif isinstance(v, torch.Tensor): 60 | v = v.item() 61 | self.data[k] = v 62 | 63 | def on_parameters_updated(self, params): 64 | self.data["parameters"] = dict(params) 65 | 66 | def on_metric_updated(self, name, value): 67 | if isinstance(value, torch.Tensor): 68 | value = value.item() 69 | if name not in self.data: 70 | self.data[name] = [] 71 | self.data[name].append(value) 72 | 73 | def on_eval_samples_updated(self, context, samples, proposal_log_scores, model_log_scores, target_log_scores): 74 | if not self.store_eval_samples: 75 | return 76 | self.data["samples"].append([s.text for s in samples]) 77 | self.data["samples_ids"].append([s.token_ids.tolist() for s in samples]) 78 | self.data["proposal_scores"].append(proposal_log_scores.tolist()) 79 | self.data["target_scores"].append(target_log_scores.tolist()) 80 | self.data["model_scores"].append(model_log_scores.tolist()) 81 | 82 | def on_step_idx_updated(self, s): 83 | self.data["steps"] = s 84 | if self.save_steps > 0 and (s % self.save_steps) == 0: 85 | self.save() 86 | 87 | def on_ministep_idx_updated(self, s): 88 | self.data["ministeps"] = s 89 | -------------------------------------------------------------------------------- /tutorials/data/incipits.txt: -------------------------------------------------------------------------------- 1 | It was the best of times, it was the worst of times. 2 | All happy families are alike; each unhappy family is unhappy in its own way. 3 | Alas, poor Yorick! I knew him, Horatio. 4 | Call me Ishmael. 5 | It was a dark and stormy night. 6 | I'm going to kill myself, said Ambrose Bierce. 7 | Mrs. Dalloway said she would buy the flowers herself. 8 | A sound came out of the darkness — a frightful, indescribable sound. 9 | So we beat on, boats against the current, borne back ceaselessly into the past. 10 | You don’t know about me without you have read a book by the name of The Adventures of Tom Sawyer; but that ain’t no matter. 11 | “Haleakala,” said theoldest Hawaiian on the island, “is the house of the sun.” 12 | All this happened, more or less. 13 | It was a pleasure to burn. 14 | There was a man named Guy Montag, and fire was his job. 15 | We didn't start the fire. 16 | Everything was bluethen, my kid, blue shadows and blue smoke. 17 | It was a bright cold day in April, and the clocks were striking thirteen. 18 | If you really want to hear about it, the first thing you’ll probably want to know is where I was born, and what my lousy childhood was like, and how my parents were occupied and all before they had me, and all that David Copperfield kind of crap, but I don’t feel like going into it, if you want to know the truth.” 19 | I am an invisible man. 20 | I am a sick man... I am a spiteful man. I am an unattractive man. 21 | In this respect, I did not differ from millions of other fictitious characters created by other fictitious writers. 22 | You better not never tell nobody but God. 23 | I was born twice. 24 | There was me, that is Alex, and my three droogs, that is Pete, Georgie, and Dim, and we sat in the Korova Milkbar. 25 | It was a queer, sultry summer, the summer they electrocuted the Rosenbergs, and I didn’t know anybody named Rosenberg and wasn’t elected but I was sore as hell. 26 | The last time I saw my father, he was washing his hands in the bathroom sink of a Greyhound bus. 27 | There's a bluebird in my heart that wants to get out but I pour whiskey on him. 28 | You don’t have to live forever, you just have to live. 29 | We the animals crouched in our cages, waiting. 30 | It was probably the color of the sea that decided me. 31 | All children, except one, grow up. 32 | You’ll have to come a long way with me, dear reader, before you hear that again. 33 | Whether I shall turn out to be the hero of my own life, or whether that station will be held by anybody else, these pages must show. 34 | It was the best of times, it was the worst of times, it was the age of wisdom, it was the age of foolishness. 35 | In an old house in Paris that was covered with vines lived twelve little girls in two straight lines. 36 | There was a boy called Eustace Clarence Scrubb, and he almost deserved it. 37 | My father’s family name being Pirrip, and my Christian name Philip, my infant tongue could make of both names nothing longer or more explicit than Pip. 38 | It was a dark and stormy night; the rain fell in torrents, except at occasional intervals, when it was checked by a violent gust of wind which swept up the streets. 39 | “You fancy yourself obedient, don’t you, slave?” She looked up into the cold Hardwick eyes. 40 | As Gregor Samsa awoke one morning from uneasy dreams he found himself transformed in his bed into a gigantic insect. 41 | Marley was dead: to begin with. 42 | There was no possibility of taking a walk that day. 43 | It was a queer, sultry summer, the summer they electrocuted the Rosenbergs, and I didn’t know anybody named Rosenberg -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | from setuptools import setup, find_packages 6 | 7 | setup( 8 | name='disco-generation', 9 | version='1.1.2', 10 | description='A toolkit for distributional control of generative models', 11 | url='https://github.com/naver/disco', 12 | author='Naver Labs Europe', author_email='jos.rozen@naverlabs.com', 13 | license='Creative Commons Attribution-NonCommercial-ShareAlike 4.0', 14 | long_description="""The 🕺🏽 **disco** toolkit allows to control the properties of the generations by language models and other generative systems to match human preferences while avoiding catastrophic forgetting. 15 | 16 | To achieve this in **disco**, we first represent in what ways we want to update original model as a target distribution and then, generate samples from this new distribution through a combination of learning or monte-carlo methods, as follows. 17 | 18 | **Step 1: We express how the target distribution *should* be** 19 | 20 | To have a handle on the generative model, we define some feature over the generated samples. It can be anything we can compute. For example, on a language model it can be as simple as whether the generated text contains a certain word or as complex as the compilability of some generated piece of code. Importantly, there is no need for the feature to be differentiable. 21 | Then, we can express our preferences on the target distribution by defining the target *moments* of this feature. For example, we might want to ask that a certain word appears 50% of the time when sampling from the model; or that 100% of the generated code is compilable. The resulting target distribution is expressed as an energy-based model or EBM, which is an unnormalized probability distribution that respects the desired moments while avoiding catastrophic forgetting, as a result of having minimal KL divergence to the original model. 22 | This representation of the target distribution can *score* samples, but cannot directly be used to *generate* them. 23 | 24 | **Step 2: We generate samples from the target distribution** 25 | 26 | To generate samples from the target distribution, if not perfectly, we can tune a model to approximate it. The resulting model can generate samples directly from a close approximation of the target distribution. Furthermore, it can be used jointly with Quasi-Rejection Sampling (QRS), a Monte Carlo sampling technique that allows the generation of samples that are even more representative of the target distribution. 27 | Alternatively, it is then possible to use decoding methods such as nucleus sampling, top-k sampling, or beam search, which would return samples from a further updated target distribution.""", 28 | long_description_content_type='text/markdown', 29 | packages=find_packages(include=['disco', 'disco.*']), 30 | python_requires='>=3.8', 31 | install_requires=['torch', 'transformers>=4.49', 32 | 'numpy', 'scipy', 33 | 'datasets', 'spacy', 34 | 'notebook', 35 | 'neptune-client', 'wandb'], 36 | classifiers=[ 37 | 'Development Status :: 5 - Production/Stable', 38 | 'Intended Audience :: Science/Research', 39 | 'License :: Free for non-commercial use', 40 | 'Operating System :: POSIX :: Linux', 41 | 'Operating System :: MacOS', 42 | 'Operating System :: Microsoft :: Windows', 43 | 'Programming Language :: Python :: 3', 44 | 'Programming Language :: Python :: 3.8', 45 | 'Programming Language :: Python :: 3.9', 46 | 'Programming Language :: Python :: 3.10', 47 | ], 48 | ) 49 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | disco, Copyright (C) 2025 NAVER Corporation. All Rights Reserved. 2 | 3 | You must agree to the terms of this license in order to install and use the “Materials” associated with this project, which may include source code, executable code, models, model checkpoints, and data, together with any documentation and any updates provided at Naver’s discretion. By exercising any rights to the Materials, you accept and agree to be bound by the terms of this license. If you are entering into this license on behalf of a company or other entity, you represent that you are the employee or agent of such company (or other entity) and you have the authority to enter into this license on behalf of such company (or other entity). The Materials are protected by copyright and other intellectual property laws and is licensed, not sold. 4 | 5 | Non-Commercial License 6 | 7 | Subject to any LICENSE EXCEPTIONS, NAVER Corporation (“NAVER”) hereby grants you a non-exclusive, non-sublicensable, non-transferable license to use the Materials, subject to the following conditions: 8 | 9 | (1) SCOPE OF USE: The Materials are used solely for non-commercial purposes (“Purpose”). You may not use the Materials or derivatives thereof for any commercial purpose (i.e., primarily intended for or directed towards commercial advantage or monetary compensation). You may not distribute the Materials or derivatives thereof under different terms and conditions as this License. 10 | 11 | (2) COPYRIGHT: The above copyright notice and this License along with the disclaimer below shall be retained in all copies and derivatives. 12 | 13 | (3) TERM: The License automatically terminates without notice if you fail to comply with its terms or the Purpose no longer exists. You may terminate this License at any time by ceasing to use the Materials. Upon termination you agree to delete any and all copies of the Materials and derivatives. The license to any of your Contributions under (4) will survive termination. 14 | 15 | (4) CONTRIBUTIONS: If you contribute to the project by providing feedback (“Contributions”) by, for example, making comments or a pull request, you agree to grant, and hereby grant, NAVER, without any restrictions or limitations, a non-exclusive, perpetual, irrevocable, royalty-free, paid-up, assignable and sub-licensable license, to reproduce, publicly perform or display, install, use, modify, adapt, prepare derivative works of, post, distribute, make and have made, sell and transfer your Contributions, and derivative works thereof, for any purpose. This grant by you does not change your rights to use your Contributions. Your Contributions may be used to update the Materials at Naver’s discretion. 16 | 17 | (5) NO IMPLIED LICENSE: Except as otherwise expressly stated in this License, nothing herein shall be construed to grant you any license, by implication, estoppel, or otherwise, to any intellectual property of NAVER, including trademarks, copyrights, patents, or trade secrets. 18 | 19 | (6) LIMITATION OF LIABILITY: THE MATERIALS ARE PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL NAVER BE LIABLE FOR ANY CLAIM, DAMAGES (INCLUDING, BUT NOT LIMITED TO LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT (INCLUDING NEGLIGENCE), STRICT LIABILITY OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 20 | 21 | LICENSE EXCEPTIONS: If the Materials include subcomponents or dependencies with separate copyright notices or license terms, they will be set forth in the file NOTICE.txt. 22 | 23 | -------------------------------------------------------------------------------- /disco/utils/moving_average.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | import numpy as np 6 | import torch 7 | 8 | def average(moving_averages): 9 | if moving_averages.items(): 10 | weighted_values, weights = list(zip( 11 | *[(ma.value * ma.weight, ma.weight) for _, ma in moving_averages.items()] 12 | )) 13 | return sum(weighted_values) / sum(weights) 14 | else: 15 | return 0 16 | 17 | class WindowedMovingAverage: 18 | """ 19 | Keeps a moving average of a quantity over a fixed-size window. 20 | """ 21 | def __init__(self, window_size=1000): 22 | """ 23 | Parameters 24 | ---------- 25 | window_size: int 26 | The number of samples to average over. 27 | """ 28 | if window_size <= 0: 29 | raise ValueError("window_size must be a positive integer.") 30 | self.window_size = window_size 31 | self.buffer = [] 32 | self.value = None # Stores the average of the last completed window. 33 | 34 | def update(self, new): 35 | """ 36 | Adds new values to the buffer and updates the average 37 | 38 | Parameters 39 | ---------- 40 | new: torch.Tensor, np.ndarray, list, or float 41 | A collection of new pointwise estimates to add to the buffer. 42 | """ 43 | # Ensure 'new' is a flat list of numbers 44 | if isinstance(new, torch.Tensor): 45 | new_values = new.flatten().tolist() 46 | elif isinstance(new, np.ndarray): 47 | new_values = new.flatten().tolist() 48 | elif hasattr(new, '__iter__'): 49 | new_values = list(new) 50 | else: 51 | new_values = [new] # Treat a single number as a list 52 | 53 | self.buffer.extend(new_values) 54 | 55 | # crop to window size 56 | self.buffer = self.buffer[-self.window_size:] 57 | 58 | # Calculate the mean of the full window and update the public value 59 | self.value = sum(self.buffer) / len(self.buffer) 60 | 61 | class MovingAverage: 62 | """ 63 | Keeps a moving average of a quantity. 64 | The average is initialized with the first value reported. 65 | """ 66 | def __init__(self): 67 | """ 68 | Initializes the moving average. 69 | """ 70 | self.value = None 71 | self.weight = 0 72 | 73 | def update(self, new): 74 | """ 75 | Adds new values to the average. 76 | 77 | Parameters 78 | ---------- 79 | new: torch.Tensor, np.ndarray, list, or float 80 | A collection of new pointwise estimates to add. 81 | """ 82 | # Ensure 'new' is a flat list of numbers 83 | if isinstance(new, torch.Tensor): 84 | new_values = new.detach().flatten().tolist() 85 | elif isinstance(new, np.ndarray): 86 | new_values = new.flatten().tolist() 87 | elif hasattr(new, '__iter__'): 88 | new_values = list(new) 89 | else: 90 | new_values = [new] # Treat a single number as a list 91 | 92 | if not new_values: 93 | return 94 | 95 | new_weight = len(new_values) 96 | new_mean = sum(new_values) / new_weight 97 | 98 | # If this is the first update, initialize the moving average 99 | if self.value is None: 100 | self.value = new_mean 101 | self.weight = new_weight 102 | else: 103 | # Update the existing moving average incrementally 104 | total_weight = self.weight + new_weight 105 | self.value = (self.value * self.weight + new_mean * new_weight) / total_weight 106 | self.weight = total_weight 107 | 108 | def reset(self): 109 | self.weight = 0 110 | self.value = None -------------------------------------------------------------------------------- /disco/tuners/loggers/mlflow.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | import mlflow 6 | import tempfile 7 | import json 8 | from pathlib import Path 9 | from .base import BaseTunerObserver 10 | from collections import defaultdict 11 | import torch 12 | import numpy as np 13 | 14 | class MLFlowLogger(BaseTunerObserver): 15 | """ 16 | Reports DPGTuner statistics to MLFlow 17 | """ 18 | 19 | def __init__(self, tuner, project, name=None): 20 | """Constructor of a WandBLogger object 21 | 22 | Parameters 23 | ---------- 24 | tuner: DPGTuner 25 | The tuner object whose statistics we want to report 26 | project: string 27 | The MLFlow experiment_id to which we want to report the statistics 28 | name: string 29 | The MLFlow run_id to which we want to report the statistics 30 | """ 31 | super(MLFlowLogger, self).__init__(tuner) 32 | mlflow.set_experiment(project) 33 | self.run = mlflow.start_run(run_name=name, log_system_metrics=True) 34 | self.step = None 35 | self.last_step_eval_samples_reported = None 36 | self._reset_stats() 37 | 38 | def _reset_stats(self): 39 | self.stats = defaultdict(list) 40 | 41 | def _report_step_stats(self, step): 42 | metrics = {} 43 | for k, vals in self.stats.items(): 44 | if len(vals) == 1: 45 | metrics[k] = vals[0] 46 | else: 47 | try: 48 | metrics[f"{k}/min"] = np.min(vals) 49 | metrics[f"{k}/max"] = np.max(vals) 50 | metrics[f"{k}/mean"] = np.mean(vals) 51 | except TypeError: 52 | raise RuntimeError(f"TypeError while processing {k} = {vals}") 53 | mlflow.log_metrics(metrics, step=step) 54 | 55 | def __setitem__(self, k, v): 56 | """ 57 | Report arbitrary parameter/value combinations 58 | """ 59 | mlflow.log_param(k, v) 60 | 61 | def __exit__(self, *exc): 62 | self._report_step_stats(self.step) 63 | self.run.__exit__(*exc) 64 | 65 | def on_metric_updated(self, name, value): 66 | if isinstance(value, torch.Tensor): 67 | value = value.item() 68 | self.stats[name].append(value) 69 | 70 | def on_parameters_updated(self, params): 71 | mlflow.log_params(params) 72 | 73 | def on_eval_samples_updated(self, context, samples, proposal_log_scores, model_log_scores, target_log_scores): 74 | """ 75 | Logs evaluation samples and their scores as an MLflow artifact (JSON format). 76 | """ 77 | if self.last_step_eval_samples_reported is not None and self.last_step_eval_samples_reported == self.step: 78 | # we already reported samples for this step 79 | return 80 | self.last_step_eval_samples_reported = self.step 81 | # Build dict 82 | data = { 83 | "context": context, 84 | "samples": [ 85 | { 86 | "text": getattr(s, "text", str(s)), 87 | "proposal_log_score": float(proposal_log_scores[i].item()), 88 | "model_log_score": float(model_log_scores[i].item()), 89 | "target_log_score": float(target_log_scores[i].item()) 90 | } 91 | for i, s in enumerate(samples) 92 | ] 93 | } 94 | 95 | # Write to a temporary JSON file and log 96 | with tempfile.TemporaryDirectory() as tmpdir: 97 | tmpfile = Path(tmpdir, f"{self.step:03}.json") 98 | json.dump(data, open(tmpfile, 'w'), indent=2) 99 | mlflow.log_artifact(tmpfile, artifact_path=f"samples") 100 | 101 | def on_step_idx_updated(self, s): 102 | if self.step is not None: 103 | self._report_step_stats(self.step) 104 | self._reset_stats() 105 | self.step = s 106 | mlflow.log_metric("steps", s, step=self.step) -------------------------------------------------------------------------------- /disco/utils/helpers.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | import torch 6 | import tqdm.autonotebook as tqdm 7 | 8 | def batchify(func, batch, samples=list(), **args): 9 | all = [] 10 | with tqdm.tqdm(total=len(samples), desc=func.__name__, position=1, leave=False) as pbar: 11 | for i in range(len(samples)//batch + 1): 12 | subsamples = samples[i * batch:(i+1) * batch] 13 | if subsamples: 14 | all.append(func(subsamples, **args)) 15 | pbar.update(batch) 16 | return torch.cat(all) 17 | 18 | def score_in_chunks_batched(model, samples_nested, contexts, chunk_size): 19 | """ 20 | Scores samples from multiple contexts in chunks to manage memory, generalized 21 | to handle any chunk size. 22 | 23 | If the `chunk_size` is smaller than the number of contexts, contexts are 24 | processed sequentially to avoid memory overflow. 25 | 26 | Parameters: 27 | model: The distribution model with a `log_score_batch` method. 28 | samples_nested (list of lists): Samples, structured as [context][sample_index]. 29 | contexts (list of str): The list of contexts. 30 | chunk_size (int): The maximum number of (sample, context) pairs to score in a single pass. 31 | 32 | Returns: 33 | torch.Tensor: A tensor of scores with shape (num_contexts, total_samples). 34 | """ 35 | num_contexts = len(contexts) 36 | # Handle empty inputs to prevent errors 37 | if num_contexts == 0: 38 | return torch.empty(0, 0) 39 | 40 | num_samples_per_context = len(samples_nested[0]) 41 | if num_samples_per_context == 0: 42 | return torch.empty(num_contexts, 0) 43 | 44 | desc = f"Scoring ({model.__class__.__name__})" 45 | 46 | # --- Case 1: `chunk_size` is large enough for all contexts at once --- 47 | if chunk_size >= num_contexts: 48 | all_scores_chunks = [] 49 | # Calculate how many samples PER CONTEXT can fit in one batch. 50 | samples_per_context_chunk = max(1, chunk_size // num_contexts) 51 | 52 | for i in tqdm.trange(0, num_samples_per_context, samples_per_context_chunk, desc=desc, leave=False): 53 | chunk_slice = slice(i, i + samples_per_context_chunk) 54 | chunk_samples_nested = [sublist[chunk_slice] for sublist in samples_nested] 55 | 56 | if not chunk_samples_nested[0]: 57 | continue 58 | 59 | # Score this chunk across all contexts in one parallel call 60 | scores_chunk = model.log_score_batch( 61 | samples=chunk_samples_nested, 62 | contexts=contexts, 63 | ) # Shape: (num_contexts, samples_per_context_chunk) 64 | all_scores_chunks.append(scores_chunk) 65 | 66 | return torch.cat(all_scores_chunks, dim=1) 67 | 68 | # --- Case 2: `chunk_size` is smaller than num_contexts --- 69 | else: 70 | all_context_scores = [] 71 | for i in tqdm.trange(num_contexts, desc=desc, leave=False): 72 | # Isolate samples and context for the current iteration 73 | current_context = [contexts[i]] 74 | current_samples_list = samples_nested[i] 75 | 76 | scores_for_one_context = [] 77 | # For this single context, we can use the full `chunk_size` for its samples. 78 | for j in range(0, num_samples_per_context, chunk_size): 79 | sample_chunk = [current_samples_list[j : j + chunk_size]] 80 | 81 | if not sample_chunk[0]: 82 | continue 83 | 84 | scores_chunk = model.log_score_batch( 85 | samples=sample_chunk, 86 | contexts=current_context, 87 | ) # Shape: (1, len_of_chunk) 88 | scores_for_one_context.append(scores_chunk) 89 | 90 | # After processing all sample chunks, concatenate them for the current context 91 | if scores_for_one_context: 92 | full_scores = torch.cat(scores_for_one_context, dim=1) 93 | all_context_scores.append(full_scores) 94 | 95 | # Finally, concatenate the results from all contexts 96 | return torch.cat(all_context_scores, dim=0) 97 | 98 | def get_token_first_indices(x, token): 99 | """Find the first occurrence of a token in a 2D token array. 100 | 101 | Parameters 102 | ---------- 103 | x: 2D int array 104 | list of token sequences 105 | token: int 106 | token to search 107 | 108 | Returns 109 | ------ 110 | 1D array containing the position of the first occurrence of the token or -1 if not found 111 | """ 112 | if 0 == x.shape[-1]: 113 | return torch.tensor(-1).repeat(x.shape[0]) 114 | else: 115 | mask = token == x 116 | mask_max_values, mask_max_indices = torch.max(mask, dim=1) 117 | mask_max_indices[mask_max_values == 0] = -1 118 | return mask_max_indices 119 | -------------------------------------------------------------------------------- /tutorials/5.QRS_sampling.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "disco \n", 9 | "Copyright (C) 2022-present NAVER Corp. \n", 10 | "Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license " 11 | ] 12 | }, 13 | { 14 | "attachments": {}, 15 | "cell_type": "markdown", 16 | "metadata": {}, 17 | "source": [ 18 | "# Quasi-Rejection Sampling" 19 | ] 20 | }, 21 | { 22 | "attachments": {}, 23 | "cell_type": "markdown", 24 | "metadata": {}, 25 | "source": [ 26 | "Another way to approximate our EBM is to use Quasi-Rejection Sampling." 27 | ] 28 | }, 29 | { 30 | "attachments": {}, 31 | "cell_type": "markdown", 32 | "metadata": {}, 33 | "source": [ 34 | "## Expressing Preferences" 35 | ] 36 | }, 37 | { 38 | "attachments": {}, 39 | "cell_type": "markdown", 40 | "metadata": {}, 41 | "source": [ 42 | "That part is similar to the previous notebooks, about the Expression of Preferences or Tuning with DPG. Here we're going to define a distributional constraint: we want half of our examples to ~~be~~include amazing." 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "import re\n", 52 | "from disco.scorers import BooleanScorer\n", 53 | "from disco.distributions import LMDistribution\n", 54 | "from disco.distributions.single_context_distribution import SingleContextDistribution\n", 55 | "\n", 56 | "is_amazing = lambda s, c: bool(re.search(r\"\\bamazing\\b\", s.text))\n", 57 | "amazing_scorer = BooleanScorer(is_amazing)\n", 58 | "base = LMDistribution()\n", 59 | "\n", 60 | "incipit = \"It was a cold and stormy night\"\n", 61 | "\n", 62 | "target = base.constrain([amazing_scorer], [1/2],\n", 63 | " n_samples=2**10,\n", 64 | " context_distribution=SingleContextDistribution(incipit))" 65 | ] 66 | }, 67 | { 68 | "attachments": {}, 69 | "cell_type": "markdown", 70 | "metadata": {}, 71 | "source": [ 72 | "## Sampling" 73 | ] 74 | }, 75 | { 76 | "attachments": {}, 77 | "cell_type": "markdown", 78 | "metadata": {}, 79 | "source": [ 80 | "We then instantiate a proposal model to sample from." 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "proposal = LMDistribution()" 90 | ] 91 | }, 92 | { 93 | "attachments": {}, 94 | "cell_type": "markdown", 95 | "metadata": {}, 96 | "source": [ 97 | "Note that we can peek inside the EBM to look at the computed coeeficient" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "print(target.scorers[1].coefficients)" 107 | ] 108 | }, 109 | { 110 | "attachments": {}, 111 | "cell_type": "markdown", 112 | "metadata": {}, 113 | "source": [ 114 | "It's now just a matter of instantiating a sampler." 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "from disco.samplers import QuasiRejectionSampler" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [ 132 | "sampler = QuasiRejectionSampler(target, proposal, beta=400)" 133 | ] 134 | }, 135 | { 136 | "attachments": {}, 137 | "cell_type": "markdown", 138 | "metadata": {}, 139 | "source": [ 140 | "How are we doing?" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": null, 146 | "metadata": {}, 147 | "outputs": [], 148 | "source": [ 149 | "from disco.samplers import AccumulationSampler" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": null, 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | "samples, log_scores = sampler.sample(sampling_size=2**10, context=incipit)" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": null, 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "sum([is_amazing(s, _) for s in samples]) / len(samples)" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [] 176 | } 177 | ], 178 | "metadata": { 179 | "kernelspec": { 180 | "display_name": "Python 3.10.6 (conda)", 181 | "language": "python", 182 | "name": "python3" 183 | }, 184 | "language_info": { 185 | "codemirror_mode": { 186 | "name": "ipython", 187 | "version": 3 188 | }, 189 | "file_extension": ".py", 190 | "mimetype": "text/x-python", 191 | "name": "python", 192 | "nbconvert_exporter": "python", 193 | "pygments_lexer": "ipython3", 194 | "version": "3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:38:29) [Clang 13.0.1 ]" 195 | }, 196 | "vscode": { 197 | "interpreter": { 198 | "hash": "babb4baf4e80bd80b9852210fc5469c0783907e52a560ed7247caef52808358d" 199 | } 200 | } 201 | }, 202 | "nbformat": 4, 203 | "nbformat_minor": 2 204 | } 205 | -------------------------------------------------------------------------------- /disco/distributions/base_distribution.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | import torch 6 | from tqdm.autonotebook import trange 7 | import copy 8 | 9 | from disco.scorers.positive_scorer import Product 10 | from disco.scorers.exponential_scorer import ExponentialScorer 11 | from disco.scorers.boolean_scorer import BooleanScorer 12 | from .distribution import Distribution 13 | from .single_context_distribution import SingleContextDistribution 14 | from disco.samplers.accumulation_sampler import AccumulationSampler 15 | from disco.utils.device import get_device 16 | from disco.utils.helpers import batchify 17 | from disco.utils.moving_average import MovingAverage 18 | 19 | 20 | class BaseDistribution(Distribution): 21 | """ 22 | Base distribution class, which can be used 23 | to build an EBM. 24 | """ 25 | 26 | def clone(self): 27 | return copy.deepcopy(self) 28 | 29 | def constrain(self, 30 | features, moments=None, 31 | proposal=None, context_distribution=SingleContextDistribution(''), n_contexts_per_step=1, 32 | n_samples=2**9, iterations=1000, learning_rate=0.05, tolerance=1e-5, sampling_size=2**5 33 | ): 34 | """ 35 | Constrains features to the base according to their moments, 36 | so producing an EBM 37 | 38 | Parameters 39 | ---------- 40 | features: list(feature) 41 | multiple features to constrain 42 | moments: list(float) 43 | moments for the features. There should be as many moments as there are features 44 | proposal: distribution 45 | distribution to sample from, if different from self 46 | context_distribution: distribution 47 | to contextualize the sampling and scoring 48 | n_contexts_per_step: 49 | number of contexts sampled for each gradient step 50 | n_samples: int 51 | number of samples to use to fit the coefficients 52 | learning_rate: float 53 | multipliers of the delta used when fitting the coefficients 54 | tolerance: float 55 | accepted difference between the targets and moments 56 | sampling_size: 57 | size of the batch when sampling samples 58 | 59 | Returns 60 | ------- 61 | exponential scorer with fitted coefficients 62 | """ 63 | 64 | if list != type(features): 65 | raise TypeError("features should be passed as a list.") 66 | 67 | if not moments: 68 | return Product(self, *features) 69 | 70 | if list != type(moments): 71 | raise TypeError("moments should be passed as a list.") 72 | if not len(features) == len(moments): 73 | raise TypeError("there should be as many as many moments as there are features.") 74 | 75 | if all([BooleanScorer == type(f) for f in features])\ 76 | and all([1.0 == float(m) for m in moments]): 77 | return Product(self, *features) 78 | 79 | if not proposal: 80 | proposal = self 81 | 82 | context_samples, context_log_scores = context_distribution.sample(n_contexts_per_step) 83 | 84 | proposal_samples = dict() 85 | proposal_log_scores = dict() 86 | joint_log_scores = dict() 87 | feature_scores = dict() 88 | for (context, log_score) in zip(context_samples, context_log_scores): 89 | accumulator = AccumulationSampler(proposal, total_size=n_samples) 90 | proposal_samples[context], proposal_log_scores[context] = accumulator.sample( 91 | sampling_size=sampling_size, context=context 92 | ) 93 | device = get_device(proposal_log_scores[context]) 94 | reference_log_scores = batchify( 95 | self.log_score, sampling_size, samples=proposal_samples[context], context=context 96 | ).to(device) 97 | joint_log_scores[context] = torch.tensor(log_score).repeat(n_samples).to(device) + reference_log_scores 98 | feature_scores[context] = torch.stack( 99 | ([f.score(proposal_samples[context], context).to(device) for f in features]) 100 | ) 101 | 102 | coefficients = torch.tensor(0.0).repeat(len(features)).to(device) 103 | targets = torch.tensor(moments).to(device) 104 | with trange(iterations, desc='fitting exponential scorer') as t: 105 | for i in t: 106 | scorer = ExponentialScorer(features, coefficients) 107 | numerator = torch.tensor(0.0).repeat(len(features)).to(device) 108 | denominator = torch.tensor(0.0).repeat(len(features)).to(device) 109 | for context in context_samples: 110 | target_log_scores = joint_log_scores[context] + scorer.log_score( 111 | proposal_samples[context], context 112 | ).to(device) 113 | importance_ratios = torch.exp(target_log_scores - proposal_log_scores[context]) 114 | numerator += (importance_ratios * feature_scores[context]).sum(dim=1) 115 | denominator += importance_ratios.sum() 116 | moments = numerator / denominator 117 | grad_coefficients = moments - targets 118 | err = grad_coefficients.abs().max().item() 119 | t.set_postfix(err=err) 120 | if tolerance > err: 121 | t.total_size = i 122 | t.refresh() 123 | break 124 | coefficients -= learning_rate * grad_coefficients 125 | 126 | return self * ExponentialScorer(features, coefficients) 127 | -------------------------------------------------------------------------------- /tests/tuner_test.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | import unittest 6 | 7 | from disco.distributions import LMDistribution 8 | from disco.distributions.single_context_distribution import SingleContextDistribution 9 | from disco.scorers import ExponentialScorer, Scorer, BooleanScorer 10 | from disco.tuners import Tuner, DPGTuner, FDPGTuner, CDPGTuner, FCDPGTuner 11 | from disco.tuners.losses import * 12 | from disco.tuners.loggers.base import BaseTunerObserver 13 | 14 | class TunerTest(unittest.TestCase): 15 | 16 | def test_rkl(self): 17 | self._test_loss_rlhf(ReverseKLLoss) 18 | 19 | def test_tv(self): 20 | self._test_loss_rlhf(TVLoss) 21 | 22 | def test_js(self): 23 | self._test_loss_rlhf(JSLoss) 24 | 25 | def test_kl(self): 26 | self._test_loss_rlhf(KLLoss) 27 | 28 | def _test_loss_rlhf(self, loss_cls): 29 | base = LMDistribution() 30 | model = LMDistribution(freeze=False) 31 | reward = Scorer(lambda s, c: 1 if "amazing" in s.text else 0) 32 | rlhf_target = base * ExponentialScorer([reward], [0.1]) 33 | loss = loss_cls() 34 | tuner = Tuner(model, rlhf_target, loss=loss, n_gradient_steps=1, 35 | n_contexts_per_step=1, 36 | n_samples_per_context=32, 37 | scoring_size=32) 38 | with MockObserver(tuner) as obs: 39 | tuner.tune() 40 | self.assertTrue('loss' in obs.observations) 41 | 42 | def test_dpg(self): 43 | base = LMDistribution() 44 | scorer = BooleanScorer(lambda s, c: "amazing" in s.text) 45 | target = base.constrain([scorer], [0.5]) 46 | model = LMDistribution(freeze=False) 47 | tuner = DPGTuner(model, target, n_gradient_steps=1, 48 | n_samples_per_context=32, 49 | scoring_size=32) 50 | with MockObserver(tuner) as obs: 51 | tuner.tune() 52 | self.assertTrue('loss' in obs.observations) 53 | 54 | def test_cdpg(self): 55 | base = LMDistribution() 56 | scorer = BooleanScorer(lambda s, c: "amazing" in s.text) 57 | target = base.constrain([scorer], [0.5]) 58 | model = LMDistribution(freeze=False) 59 | context_distribution = SingleContextDistribution("Speaking about today's dinner, it was") 60 | tuner = CDPGTuner(model, target, 61 | context_distribution=context_distribution, 62 | n_gradient_steps=1, 63 | n_contexts_per_step=1, 64 | n_samples_per_context=1, 65 | scoring_size=1) 66 | with MockObserver(tuner) as obs: 67 | tuner.tune() 68 | self.assertTrue('loss' in obs.observations) 69 | 70 | def test_fdpg(self): 71 | base = LMDistribution() 72 | scorer = BooleanScorer(lambda s, c: "amazing" in s.text) 73 | target = base.constrain([scorer], [0.5]) 74 | model = LMDistribution(freeze=False) 75 | tuner = FDPGTuner(model, target, n_gradient_steps=1, 76 | n_samples_per_context=1, 77 | scoring_size=1) 78 | with MockObserver(tuner) as obs: 79 | tuner.tune() 80 | self.assertTrue('loss' in obs.observations) 81 | 82 | def test_fcdpg(self): 83 | base = LMDistribution() 84 | scorer = BooleanScorer(lambda s, c: "amazing" in s.text) 85 | target = base.constrain([scorer], [0.5]) 86 | model = LMDistribution(freeze=False) 87 | context_distribution = SingleContextDistribution("Speaking about today's dinner, it was") 88 | tuner = FCDPGTuner(model, target, 89 | context_distribution=context_distribution, 90 | n_gradient_steps=1, 91 | n_contexts_per_step=1, 92 | n_samples_per_context=1, 93 | scoring_size=1) 94 | with MockObserver(tuner) as obs: 95 | tuner.tune() 96 | self.assertTrue('loss' in obs.observations) 97 | 98 | def test_moments(self): 99 | base = LMDistribution() 100 | scorer = BooleanScorer(lambda s, c: "amazing" in s.text) 101 | target = base.constrain([scorer], [1]) 102 | model = LMDistribution(freeze=False) 103 | context_distribution = SingleContextDistribution("The movie was absolutely") 104 | tuner = FCDPGTuner(model, target, 105 | context_distribution=context_distribution, 106 | n_gradient_steps=1, 107 | n_contexts_per_step=1, 108 | n_samples_per_context=128, 109 | divergence_evaluation_interval=1, 110 | scoring_size=32, 111 | features=[('amazing', scorer)]) 112 | with MockObserver(tuner) as obs: 113 | tuner.tune() 114 | self.assertTrue('amazing_proposal' in obs.observations) 115 | self.assertTrue('amazing_target' in obs.observations) 116 | self.assertAlmostEqual(obs.observations['amazing_proposal'].item(), 0.1, 1) 117 | self.assertAlmostEqual(obs.observations['amazing_target'].item(), 0.1, 1) 118 | self.assertAlmostEqual(obs.observations['amazing_proposal'].item(), 119 | obs.observations['amazing_target'].item(), 4) 120 | tuner = FCDPGTuner(model, target, 121 | context_distribution=context_distribution, 122 | n_gradient_steps=1, 123 | n_contexts_per_step=1, 124 | n_samples_per_context=1, 125 | scoring_size=1) 126 | with MockObserver(tuner) as obs: 127 | tuner.tune() 128 | self.assertTrue('amazing_proposal' not in obs.observations) 129 | self.assertTrue('amazing_target' not in obs.observations) 130 | 131 | class MockObserver(BaseTunerObserver): 132 | def __init__(self, tuner): 133 | super(MockObserver, self).__init__(tuner) 134 | self.observations = {} 135 | def on_metric_updated(self, k, v): 136 | self.observations[k] = v 137 | -------------------------------------------------------------------------------- /disco/scorers/positive_scorer.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | import torch 6 | import numpy as np 7 | from functools import reduce 8 | 9 | from .scorer import Scorer 10 | 11 | 12 | class PositiveScorer(Scorer): 13 | """ 14 | Scorer, but limited to positive values 15 | """ 16 | def __init__(self, scoring_function): 17 | super().__init__(scoring_function) 18 | 19 | def __mul__(self, ot): 20 | """enables the use of the multiplication sign (*) 21 | to compose positive scorers""" 22 | 23 | return Product(self, ot) 24 | 25 | def log_score(self, samples, context): 26 | """returns the log-scores for the samples 27 | given the context by taking the log of their scores 28 | 29 | Parameters 30 | ---------- 31 | samples : list(Sample) 32 | list of samples to log-score 33 | context: text 34 | context that the samples relate to 35 | 36 | Returns 37 | ------- 38 | tensor of log-scores for the samples""" 39 | 40 | return torch.log(self.score(samples, context=context)) 41 | 42 | def log_score_batch(self, samples, contexts): 43 | """ 44 | Returns the batched log-scores for samples against multiple contexts. 45 | 46 | Parameters 47 | ---------- 48 | samples : list 49 | A flat list of sample objects. 50 | contexts : list of str 51 | A list of contextual text strings. 52 | n_repeats : int 53 | The number of samples associated with each context. 54 | 55 | Returns 56 | ------- 57 | torch.Tensor 58 | A tensor of log-scores with shape `(num_contexts, n_repeats)`. 59 | """ 60 | scores = self.score_batch(samples, contexts=contexts) 61 | return torch.log(scores) 62 | 63 | def score(self, samples, context): 64 | """relies on the instance's scoring function 65 | to compute the scores of the samples given the context 66 | 67 | Parameters 68 | ---------- 69 | samples : list(Sample) 70 | list of samples to score 71 | context: text 72 | context that the samples relate to 73 | 74 | Returns 75 | ------- 76 | tensor of scores for the samples""" 77 | 78 | return self.scoring_function(samples, context) 79 | 80 | def score_batch(self, samples, contexts): 81 | """ 82 | Computes scores for a batch of samples against multiple contexts. 83 | 84 | This default implementation works by iterating through the contexts, 85 | slicing the corresponding samples, and calling the single-context 86 | `scoring_function` for each. 87 | 88 | Parameters 89 | ---------- 90 | samples : list 91 | A flat list of sample objects. 92 | contexts : list of str 93 | A list of contextual text strings. 94 | 95 | Returns 96 | ------- 97 | torch.Tensor 98 | A tensor of scores with shape `(num_contexts, n_samples_per_context)`. 99 | """ 100 | n_samples_per_context = len(samples) // len(contexts) 101 | all_scores = [] 102 | for i, context in enumerate(contexts): 103 | # Slice the flat list of samples to get the ones for the current context 104 | context_samples = samples[i] 105 | 106 | if not context_samples: 107 | continue 108 | 109 | # Call the existing single-context scoring function 110 | context_scores = self.scoring_function(context_samples, context) 111 | all_scores.append(context_scores) 112 | 113 | if not all_scores: 114 | return torch.empty(len(contexts), 0) 115 | 116 | # Stack the list of 1D tensors into a single 2D tensor 117 | return torch.stack(all_scores, dim=0) 118 | 119 | 120 | class Product(PositiveScorer): 121 | """ 122 | Utility class to compose scorers on the product of their scores 123 | """ 124 | 125 | def __init__(self, *scorers): 126 | self.scorers = scorers 127 | 128 | def log_score(self, samples, context): 129 | """computes the product of the log-scores, 130 | hence adds the log-scores from the individual scorers 131 | 132 | Parameters 133 | ---------- 134 | samples : list(Sample) 135 | list of samples to log-score 136 | context: text 137 | context used for the samples 138 | 139 | Returns: 140 | -------- 141 | list of log-scores for the samples 142 | """ 143 | try: 144 | device = self.scorers[0].device 145 | except AttributeError: 146 | device = "cpu" 147 | 148 | log_scores = [s.log_score(samples, context=context).to(device) for s in self.scorers] 149 | return reduce(lambda x,y: x+y, log_scores) 150 | 151 | def log_score_batch(self, samples, contexts): 152 | """ 153 | Computes the product of scores in a batched manner by adding the 154 | batched log-scores from individual scorers. 155 | 156 | Parameters 157 | ---------- 158 | samples : list 159 | A flat list of sample objects. 160 | contexts : list of str 161 | A list of contextual text strings. 162 | 163 | Returns 164 | ------- 165 | torch.Tensor 166 | A tensor of the summed log-scores with shape `(num_contexts, n_samples_per_context)`. 167 | """ 168 | try: 169 | device = self.scorers[0].device 170 | except AttributeError: 171 | device = "cpu" 172 | 173 | # Call the batched method on each scorer 174 | log_scores = [ 175 | s.log_score_batch(samples, contexts=contexts).to(device) 176 | for s in self.scorers 177 | ] 178 | 179 | # Sum the resulting tensors element-wise 180 | return reduce(lambda x, y: x + y, log_scores) 181 | 182 | def __str__(self): 183 | scorers_str = ", ".join((str(scorer) for scorer in self.scorers)) 184 | return f"Product({scorers_str})" -------------------------------------------------------------------------------- /disco/samplers/accumulation_sampler.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | from . import Sampler 6 | import torch 7 | from tqdm.autonotebook import trange 8 | 9 | 10 | class AccumulationSampler(Sampler): 11 | """ 12 | Utility class to accumulate samples, up to a total size 13 | """ 14 | 15 | def __init__(self, distribution, total_size=512): 16 | """ 17 | Parameters 18 | ---------- 19 | distribution: distribution 20 | distribution to sample from 21 | total_size: int 22 | total number of samples 23 | """ 24 | 25 | self.distribution = distribution 26 | self.total_size = total_size 27 | 28 | def sample(self, sampling_size=32, context=""): 29 | """accumulates batches of samples from the distribution 30 | 31 | Parameters 32 | ---------- 33 | sampling_size: int 34 | number of requested samples per individual sampling 35 | context: text 36 | contextual text for which to sample 37 | 38 | Returns 39 | ------- 40 | a tuple of accumulated samples and scores 41 | """ 42 | with trange( 43 | self.total_size, 44 | desc=f"sampling from {type(self.distribution).__name__}", 45 | position=1, 46 | leave=False 47 | ) as t: 48 | remaining = self.total_size 49 | samples, log_scores = list(), torch.empty([0]) 50 | while remaining > 0: 51 | more_samples, more_log_scores = self.distribution.sample(context=context, sampling_size=sampling_size) 52 | length = min(remaining, len(more_samples)) 53 | more_samples, more_log_scores = more_samples[:length], more_log_scores[:length] 54 | samples, log_scores = ( 55 | samples + more_samples, 56 | torch.cat((log_scores, more_log_scores)) 57 | ) if samples else (more_samples, more_log_scores) 58 | remaining -= len(more_samples) 59 | t.update(len(more_samples)) 60 | 61 | return (samples, log_scores) 62 | 63 | def sample_batch(self, contexts, sampling_size=32, output_scores=True): 64 | """ 65 | Accumulates batches of samples for a list of contexts simultaneously. 66 | 67 | This method repeatedly calls the distribution's `sample_batch` method 68 | until `total_size` samples are collected for each context. 69 | 70 | Parameters 71 | ---------- 72 | contexts: list of str 73 | A list of contextual texts for which to sample. 74 | sampling_size: int 75 | The number of samples to request for each context per batch call. 76 | 77 | Returns 78 | ------- 79 | tuple of (list of lists, torch.Tensor) 80 | - A list of lists, where the outer list corresponds to the input 81 | contexts and each inner list contains `total_size` sample objects. 82 | - A tensor of log scores with shape `(num_contexts, total_size)`. 83 | """ 84 | if not isinstance(contexts, list) or not contexts: 85 | raise ValueError("contexts must be a non-empty list of strings.") 86 | 87 | num_contexts = len(contexts) 88 | 89 | # Initialize containers for accumulating results for each context 90 | accumulated_samples = [[] for _ in range(num_contexts)] 91 | if output_scores: 92 | # Store score tensors from each batch call in a list for each context 93 | accumulated_scores_parts = [[] for _ in range(num_contexts)] 94 | 95 | with trange( 96 | self.total_size * num_contexts, 97 | desc=f"Batch sampling for {num_contexts} contexts from {type(self.distribution).__name__}", 98 | position=1, 99 | leave=False 100 | ) as t: 101 | collected_count = 0 102 | while collected_count < self.total_size: 103 | # Determine how many more samples are needed to reach the goal. 104 | # This prevents over-sampling in the final iteration. 105 | remaining = self.total_size - collected_count 106 | current_batch_size = min(remaining, sampling_size) 107 | 108 | # Call the batched sampling method on the distribution 109 | ret = self.distribution.sample_batch( 110 | contexts=contexts, 111 | sampling_size=current_batch_size, 112 | output_scores=output_scores 113 | ) 114 | if output_scores: 115 | more_samples_nested, more_log_scores = ret 116 | else: 117 | more_samples_nested = ret 118 | # `more_log_scores` has shape: (num_contexts, current_batch_size) 119 | 120 | # Distribute the flat list of samples and batched scores to their 121 | # respective context accumulators. 122 | for i in range(num_contexts): 123 | start_idx = i * current_batch_size 124 | end_idx = start_idx + current_batch_size 125 | 126 | # Add samples for the i-th context 127 | accumulated_samples[i].extend(more_samples_nested[i]) 128 | 129 | if output_scores: 130 | # Append the score tensor for the i-th context 131 | accumulated_scores_parts[i].append(more_log_scores[i]) 132 | 133 | collected_count += current_batch_size 134 | t.update(current_batch_size * num_contexts) 135 | 136 | if output_scores: 137 | # Finalize the scores by concatenating the collected tensor parts for each context 138 | # and then stacking them into a single (num_contexts, total_size) tensor. 139 | final_scores = torch.stack( 140 | [torch.cat(parts) for parts in accumulated_scores_parts], 141 | dim=0 142 | ) 143 | 144 | if output_scores: 145 | return (accumulated_samples, final_scores) 146 | else: 147 | return accumulated_samples 148 | -------------------------------------------------------------------------------- /tests/base_distribution_test.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | import unittest 6 | import torch 7 | import numpy as np 8 | import random 9 | 10 | from disco.distributions import BaseDistribution 11 | from disco.distributions.single_context_distribution import SingleContextDistribution 12 | from disco.scorers import BooleanScorer 13 | from disco.scorers.positive_scorer import PositiveScorer 14 | from disco.scorers.positive_scorer import Product 15 | from disco.scorers.exponential_scorer import ExponentialScorer 16 | 17 | 18 | class DummyDistribution(BaseDistribution): 19 | 20 | def __init__(self, low=0): 21 | self.low = low 22 | 23 | def _max(self, high): 24 | return max(self.low + 8, high) 25 | 26 | def log_score(self, numbers, context=None): 27 | context = 64 if context is None else context 28 | high = self._max(context) 29 | 30 | numbers = torch.tensor(numbers) 31 | probability = 1 / (high - self.low) 32 | probabilities = torch.where( 33 | torch.logical_and(self.low <= numbers, numbers < high), probability, 0.0 34 | ) 35 | return torch.log(probabilities) 36 | 37 | def sample(self, sampling_size=1, context=None): 38 | context = 64 if context is None else context 39 | high = self._max(context) 40 | 41 | numbers = torch.randint(self.low, high, (sampling_size,)) 42 | 43 | return ( 44 | numbers.tolist(), 45 | self.log_score(numbers), 46 | ) 47 | 48 | even = BooleanScorer(lambda x, _: 0 == x % 2) 49 | divisible_by_3 = BooleanScorer(lambda x, _: 0 == x % 3) 50 | odds = [i * 2 + 1 for i in range(10)] 51 | evens = [i * 2 for i in range(10)] 52 | first_20_integers = odds + evens 53 | 54 | class BaseDistributionTest(unittest.TestCase): 55 | 56 | def test_constrain_features_should_passed_as_a_list(self): 57 | reference = DummyDistribution() 58 | with self.assertRaises(TypeError) as cm: 59 | _ = reference.constrain(even) 60 | err = cm.exception 61 | self.assertEqual(str(err), "features should be passed as a list.") 62 | 63 | def test_constrain_moments_should_passed_as_a_list_when_any(self): 64 | reference = DummyDistribution() 65 | with self.assertRaises(TypeError) as cm: 66 | _ = reference.constrain([even], 0.5) 67 | err = cm.exception 68 | self.assertEqual(str(err), "moments should be passed as a list.") 69 | 70 | def test_constrain_features_should_match_moments(self): 71 | reference = DummyDistribution() 72 | with self.assertRaises(TypeError) as cm: 73 | _ = reference.constrain([even, divisible_by_3], [0.5]) 74 | err = cm.exception 75 | self.assertEqual(str(err), "there should be as many as many moments as there are features.") 76 | 77 | def test_constrain_pointwisely_with_a_feature(self): 78 | reference = DummyDistribution() 79 | target = reference.constrain([even]) 80 | self.assertEqual(Product, type(target), 81 | "constrain(...) should return a Product.") 82 | self.assertTrue(all([-np.inf == s for s in target.log_score(odds, None)]), 83 | "a base distribution should be constrainable with a single pointwise feature to build an EBM.") 84 | self.assertTrue(all([-np.inf < s < 0 for s in target.log_score(evens, None)]), 85 | "a base distribution should be constrainable with a single pointwise feature to build an EBM.") 86 | 87 | def test_constrain_pointwisely_with_multiple_features(self): 88 | reference = DummyDistribution() 89 | target = reference.constrain([even, divisible_by_3]) 90 | self.assertEqual(4, len([s for s in target.log_score(first_20_integers, None) if -np.inf < s < 0]), 91 | "a base distribution should be constrainable with multiple pointwise features to build an EBM.") 92 | 93 | def test_constrain_is_like_sugar_when_no_moment(self): 94 | reference = DummyDistribution() 95 | target_constrain = reference.constrain([even, divisible_by_3]) 96 | target_sugar = reference * even * divisible_by_3 97 | some_integers, _ = reference.sample(20, 16) 98 | self.assertTrue( 99 | all([c == s for c, s in zip(target_constrain.log_score(some_integers, None), target_sugar.log_score(some_integers, None))]), 100 | "an EBM built from constrain(…) should be equivalent to the one built with the product sign sugar.") 101 | 102 | def test_constrain_distributionally_with_a_feature(self): 103 | reference = DummyDistribution() 104 | target = reference.constrain([even], [0.8], context_distribution=SingleContextDistribution(None)) 105 | self.assertEqual(Product, type(target), 106 | "constrain(...) should return a Product.") 107 | 108 | def test_constrain_distributionally_with_multiple_features(self): 109 | reference = DummyDistribution() 110 | target = reference.constrain([even, divisible_by_3], [0.8, 0.5], context_distribution=SingleContextDistribution(None)) 111 | self.assertEqual(Product, type(target), 112 | "constrain(...) should return a Product.") 113 | 114 | def test_constrain_without_moments(self): 115 | bs = BooleanScorer(lambda x, c: True) 116 | distribution = BaseDistribution(lambda x, c: random.random()) 117 | self.assertEqual(BooleanScorer, type(distribution.constrain([bs]).scorers[-1]), 118 | "when no moment is specified a feature should be composed directly in a product.") 119 | for s in distribution.constrain([bs, bs, bs]).scorers[1:]: 120 | self.assertEqual(BooleanScorer, type(s), 121 | "when no moments are specified multiple features should appear as is in a product.") 122 | 123 | def test_constrain_with_boolean_scorers_and_moments_set_to_one(self): 124 | bs = BooleanScorer(lambda x, c: True) 125 | distribution = BaseDistribution(lambda x, c: 1) 126 | self.assertEqual(BooleanScorer, type(distribution.constrain([bs], [1]).scorers[-1]), 127 | "when only a boolean scorer is specified, and its moment is one, this feature should appear as is in a product.") 128 | bs = BooleanScorer(lambda x, c: True) 129 | for s in distribution.constrain([bs, bs, bs], [1, 1.0, "1"]).scorers[1:]: 130 | self.assertEqual(BooleanScorer, type(s), 131 | "when only boolean scorer are specified, and their moments are one, these feature should appear as is in a product.") 132 | 133 | def test_constrain_with_boolean_scorers_but_with_moments_not_set_to_one_returns_an_exponential_scorer(self): 134 | bs = BooleanScorer(lambda x, c: True) 135 | distribution = DummyDistribution() 136 | self.assertEqual(ExponentialScorer, type(distribution.constrain([bs, bs], [1, 0.5], context_distribution=SingleContextDistribution(None)).scorers[-1]), 137 | "when only boolean scorers are specified, but their moments are not one, features should be composed in an exponential scorer.") 138 | 139 | def test_constrain_returns_an_exponential_scorer(self): 140 | bs = BooleanScorer(lambda x, c: True) 141 | ps = PositiveScorer(lambda x, c: random.random()) 142 | distribution = DummyDistribution() 143 | self.assertEqual(ExponentialScorer, type(distribution.constrain([bs, ps], [0.2, 0.5], context_distribution=SingleContextDistribution(None)).scorers[-1]), 144 | "features should be composed by default in an exponential scorer.") 145 | 146 | if __name__ == '__main__': 147 | unittest.main() -------------------------------------------------------------------------------- /disco/samplers/quasi_rejection_sampler.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | import torch 6 | 7 | from . import Sampler 8 | from .accumulation_sampler import AccumulationSampler 9 | from disco.metrics import KL 10 | from disco.utils.device import get_device 11 | from disco.utils.helpers import batchify 12 | 13 | 14 | class QuasiRejectionSampler(Sampler): 15 | """ 16 | Quasi Rejection-Sampling class 17 | """ 18 | 19 | def __init__(self, target, proposal, beta=1): 20 | """ 21 | Parameters 22 | ---------- 23 | target: distribution 24 | Energy-based model to (log-)score the samples 25 | proposal: distribution 26 | distribution to generate the samples 27 | beta: float 28 | coefficient to control the sampling 29 | """ 30 | 31 | super(QuasiRejectionSampler, self).__init__(target, proposal) 32 | self.beta = beta 33 | self.n_samples = 0 34 | self.n_accepted_samples = 0 35 | 36 | def sample(self, sampling_size=32, context=''): 37 | """Generates samples according to the QRS algorithm 38 | 39 | Parameters 40 | ---------- 41 | sampling_size: int 42 | number of requested samples when sampling 43 | context: text 44 | contextual text for which to sample 45 | 46 | Returns 47 | ------- 48 | tuple of accepted samples and their log-scores 49 | """ 50 | 51 | samples, proposal_log_scores = self.proposal.sample(sampling_size=sampling_size, context=context) 52 | self.n_samples += len(samples) 53 | 54 | device = get_device(proposal_log_scores) 55 | 56 | target_log_scores = self.target.log_score(samples=samples, context=context).to(device) 57 | 58 | rs = torch.clamp( 59 | torch.exp(target_log_scores - proposal_log_scores) / self.beta, 60 | min=0.0, max=1.0 61 | ) 62 | 63 | us = torch.rand(len(rs)).to(device) 64 | accepted_samples = [x for k, x in zip(us < rs, samples) if k] 65 | self.n_accepted_samples += len(accepted_samples) 66 | accepted_log_scores = torch.tensor([s for k, s in zip(us < rs, proposal_log_scores) if k]).to(device) 67 | 68 | return accepted_samples, accepted_log_scores 69 | 70 | def get_acceptance_rate(self): 71 | """Computes the currently observed acceptance rate, that is the number 72 | of accepted samples over the total sampled ones 73 | 74 | Returns 75 | ------- 76 | acceptance rate as float between 0 and 1""" 77 | 78 | return self.n_accepted_samples / self.n_samples 79 | 80 | 81 | class QuasiRejectionSamplerEstimator: 82 | """ 83 | Provides routines to compute estimates of useful metrics related to 84 | QuasiRejectionSampler (QRS) 85 | """ 86 | 87 | def __init__(self, target, proposal, n_estimation_samples=10000, 88 | sampling_size=32, context=''): 89 | """ 90 | Parameters 91 | ---------- 92 | target: distribution 93 | Energy-based model to (log-)score the samples 94 | proposal: distribution 95 | distribution to generate the samples 96 | n_estimation_samples: integer 97 | number of samples to use for computing estimates 98 | sampling_size: integer 99 | number of samples that are concurrently obtained from the proposal 100 | context: text 101 | context to condition the proposal and target 102 | """ 103 | sampler = AccumulationSampler(proposal, total_size=n_estimation_samples) 104 | 105 | self.samples, self.proposal_log_scores = \ 106 | sampler.sample(sampling_size=sampling_size, context=context) 107 | 108 | self.target_log_scores = batchify(target.log_score, 109 | sampling_size, samples=self.samples, context=context) 110 | 111 | def acceptance_rate_at_beta(self, beta): 112 | """ 113 | Estimate the acceptance rate that QRS has for a given beta parameter. 114 | 115 | Parameters 116 | ---------- 117 | beta: float 118 | the value of beta for which we want the a.r. estimated 119 | 120 | Returns 121 | ------- 122 | the estimated acceptance rate for the given beta 123 | """ 124 | vals = self._compute_intermediary_values(beta) 125 | 126 | return vals['qrs_Z'].item() / beta 127 | 128 | def divergence_at_beta(self, beta, divergence=KL): 129 | """ 130 | Estimate the divergence to the target distribution that QRS has 131 | for a given beta parameter. 132 | 133 | Parameters 134 | ---------- 135 | beta: float 136 | the value of beta for which we want the a.r. estimated 137 | divergence: divergence 138 | the divergence that we want to compute (see :py:mod:`metrics`) 139 | 140 | Returns 141 | ------- 142 | the estimated value of the chosen divergence for the given beta 143 | """ 144 | vals = self._compute_intermediary_values(beta) 145 | 146 | return divergence.divergence( 147 | vals['target_log_scores'], 148 | vals['normalized_qrs_log_scores'], 149 | vals['Z'], 150 | vals['proposal_log_scores']).item() 151 | 152 | def feature_moment_at_beta(self, beta, feature): 153 | """ 154 | Estimate the first moment (expected value) of a feature when using 155 | QRS with a given beta parameter. 156 | 157 | Parameters 158 | ---------- 159 | beta: float 160 | the value of beta for which we want the a.r. estimated 161 | feature: scorer 162 | the feature whose moment we want to compute 163 | 164 | Returns 165 | ------- 166 | the estimated feature moment at the given beta 167 | """ 168 | vals = self._compute_intermediary_values(beta) 169 | 170 | feature_scores = batchify(feature.score, 171 | sampling_size, samples=self.estimation_samples, context=context) 172 | 173 | return torch.mean(vals['qrs_importance_ratios'] * feature_scores).item() 174 | 175 | def _compute_intermediary_values(self, beta): 176 | """ 177 | Computes intermediary values used in QRS estimates 178 | """ 179 | with torch.no_grad(): 180 | target_importance_ratios = torch.exp(self.target_log_scores - 181 | self.proposal_log_scores) 182 | Z = torch.mean(target_importance_ratios) 183 | log_beta = torch.log(torch.tensor(beta)) if beta > 0 else float("-inf") 184 | qrs_log_scores = torch.minimum(self.target_log_scores, 185 | self.proposal_log_scores + log_beta) 186 | qrs_Z = torch.mean(torch.exp(qrs_log_scores - self.proposal_log_scores)) 187 | normalized_qrs_log_scores = qrs_log_scores - torch.log(qrs_Z) 188 | qrs_importance_ratios = torch.exp(normalized_qrs_log_scores - 189 | self.proposal_log_scores) 190 | 191 | return { 192 | 'target_log_scores': self.target_log_scores, 193 | 'proposal_log_scores': self.proposal_log_scores, 194 | 'qrs_log_scores': qrs_log_scores, 195 | 'normalized_qrs_log_scores': normalized_qrs_log_scores, 196 | 'target_importance_ratios': target_importance_ratios, 197 | 'Z': Z, 198 | 'qrs_Z': qrs_Z, 199 | 'qrs_importance_ratios': qrs_importance_ratios} 200 | -------------------------------------------------------------------------------- /tutorials/4.conditional_tuning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "disco \n", 9 | "Copyright (C) 2022-present NAVER Corp. \n", 10 | "Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license " 11 | ] 12 | }, 13 | { 14 | "attachments": {}, 15 | "cell_type": "markdown", 16 | "metadata": {}, 17 | "source": [ 18 | "# Conditional Tuning with CDPG" 19 | ] 20 | }, 21 | { 22 | "attachments": {}, 23 | "cell_type": "markdown", 24 | "metadata": {}, 25 | "source": [ 26 | "In the conditional case the context is no longer fixed and we rely on a more generic tuner, a `CDPGTuner`. In most cases we favor a seq2seq model, and our features make use of the context and the sample." 27 | ] 28 | }, 29 | { 30 | "attachments": {}, 31 | "cell_type": "markdown", 32 | "metadata": {}, 33 | "source": [ 34 | "For our experiment, we're going to summarize news article with T5, making sure the model does not hallucinate organizations." 35 | ] 36 | }, 37 | { 38 | "attachments": {}, 39 | "cell_type": "markdown", 40 | "metadata": {}, 41 | "source": [ 42 | "## Expressing Preferences" 43 | ] 44 | }, 45 | { 46 | "attachments": {}, 47 | "cell_type": "markdown", 48 | "metadata": {}, 49 | "source": [ 50 | "Using spaCy, we can extract organization names from a text." 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "import spacy" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "nlp = spacy.load(\"en_core_web_sm\")" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "def organizations(text):\n", 78 | " \"\"\"returns a set of organizations from a text\"\"\"\n", 79 | " doc = nlp(text)\n", 80 | " return set(ent.text for ent in doc.ents if \"ORG\" == ent.label_)" 81 | ] 82 | }, 83 | { 84 | "attachments": {}, 85 | "cell_type": "markdown", 86 | "metadata": {}, 87 | "source": [ 88 | "Now that we can obtain a set of organizations from a text, we can build a scorer: we want to make sure that a sample only includes the organizations mentioned in the context, that we're going to summarize —in other words we don't want to have hallucinated organizations." 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "from disco.scorers.boolean_scorer import BooleanScorer" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "organization_scorer = BooleanScorer(lambda s, c: all({o in organizations(c) for o in organizations(s.text)}))" 107 | ] 108 | }, 109 | { 110 | "attachments": {}, 111 | "cell_type": "markdown", 112 | "metadata": {}, 113 | "source": [ 114 | "For this task, we're going to use a powerful seq2seq model from Transformers, in its \"base\" version: T5." 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "from disco.distributions import LMDistribution\n", 124 | "from transformers import AutoModelForSeq2SeqLM" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "base = LMDistribution(model=\"t5-base\", tokenizer=\"t5-base\", auto=AutoModelForSeq2SeqLM)" 134 | ] 135 | }, 136 | { 137 | "attachments": {}, 138 | "cell_type": "markdown", 139 | "metadata": {}, 140 | "source": [ 141 | "And we simply state that we want all samples to respect our preferences." 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": null, 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "target = base * organization_scorer" 151 | ] 152 | }, 153 | { 154 | "attachments": {}, 155 | "cell_type": "markdown", 156 | "metadata": {}, 157 | "source": [ 158 | "## Tuning, Conditionally" 159 | ] 160 | }, 161 | { 162 | "attachments": {}, 163 | "cell_type": "markdown", 164 | "metadata": {}, 165 | "source": [ 166 | "We now want to tune a model in order to approximate this target distribution. For this we will need many contexts: we can use a DatasetContextDistribution to rely on a dataset from Hugging Face's Datasets repository, the CNN / Dailymail dataset. Let's see how this works." 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [ 175 | "from disco.distributions.dataset_context_distribution import DatasetContextDistribution" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": null, 181 | "metadata": {}, 182 | "outputs": [], 183 | "source": [ 184 | "dataset = DatasetContextDistribution(dataset=\"cnn_dailymail\", subset=\"1.0.0\", split=\"train\", key=\"article\")" 185 | ] 186 | }, 187 | { 188 | "attachments": {}, 189 | "cell_type": "markdown", 190 | "metadata": {}, 191 | "source": [ 192 | "Out of curiosity, we can sample a few articles and extract a set of organizations from the first one by doing:" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": null, 198 | "metadata": {}, 199 | "outputs": [], 200 | "source": [ 201 | "articles, log_scores = dataset.sample(sampling_size=2**3)" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": null, 207 | "metadata": {}, 208 | "outputs": [], 209 | "source": [ 210 | "articles[0]" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": null, 216 | "metadata": {}, 217 | "outputs": [], 218 | "source": [ 219 | "organizations(articles[0])" 220 | ] 221 | }, 222 | { 223 | "attachments": {}, 224 | "cell_type": "markdown", 225 | "metadata": {}, 226 | "source": [ 227 | "We're using the online scheme, sampling directly from the model we'll be tuning —it's also very possible to rely on the offline scheme, see the [Tuning notebook](./3.tuning_DPG.ipynb)." 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": null, 233 | "metadata": {}, 234 | "outputs": [], 235 | "source": [ 236 | "model = LMDistribution(model=\"t5-base\", tokenizer=\"t5-base\", auto=AutoModelForSeq2SeqLM, length=256, freeze=False, )" 237 | ] 238 | }, 239 | { 240 | "attachments": {}, 241 | "cell_type": "markdown", 242 | "metadata": {}, 243 | "source": [ 244 | "\n", 245 | "\n", 246 | "We can now instantiate a tuner. We're going:\n", 247 | " * to tune model to approximate target getting our samples from the model itself;\n", 248 | " * to use a context distribution to fetch articles from the CNN / Dailymail —all prepended with the task \"summarize :\" to control T5." 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": null, 254 | "metadata": {}, 255 | "outputs": [], 256 | "source": [ 257 | "from disco.tuners import CDPGTuner" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": null, 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [ 266 | "tuner = CDPGTuner(model, target,\n", 267 | " context_distribution=DatasetContextDistribution(\n", 268 | " dataset=\"cnn_dailymail\", subset=\"1.0.0\", split=\"train\", key=\"article\", prefix=\"summarize: \"),\n", 269 | " n_gradient_steps=1000,\n", 270 | " n_samples_per_context=2**8,\n", 271 | " sampling_size=2**5,\n", 272 | " scoring_size=2**5)" 273 | ] 274 | }, 275 | { 276 | "attachments": {}, 277 | "cell_type": "markdown", 278 | "metadata": {}, 279 | "source": [ 280 | "Of course we want to monitor the progress so we use a logger." 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "execution_count": null, 286 | "metadata": {}, 287 | "outputs": [], 288 | "source": [ 289 | "from disco.tuners.loggers.console import ConsoleLogger" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": null, 295 | "metadata": {}, 296 | "outputs": [], 297 | "source": [ 298 | "ConsoleLogger(tuner)" 299 | ] 300 | }, 301 | { 302 | "attachments": {}, 303 | "cell_type": "markdown", 304 | "metadata": {}, 305 | "source": [ 306 | "Note that to instead / also use a `NeptuneLogger` we can simply uncomment the following cell, assuming we've actually setup to use the service." 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": null, 312 | "metadata": {}, 313 | "outputs": [], 314 | "source": [ 315 | "# from disco.tuners.loggers.neptune import NeptuneLogger\n", 316 | "# import os\n", 317 | "# NEPTUNE_API_TOKEN = os.environ[\"NEPTUNE_API_TOKEN\"]\n", 318 | "# NeptuneLogger(tuner,\n", 319 | "# project=\"disco\", api_token=NEPTUNE_API_TOKEN\n", 320 | "# )" 321 | ] 322 | }, 323 | { 324 | "attachments": {}, 325 | "cell_type": "markdown", 326 | "metadata": {}, 327 | "source": [ 328 | "Let's dance!" 329 | ] 330 | }, 331 | { 332 | "cell_type": "code", 333 | "execution_count": null, 334 | "metadata": {}, 335 | "outputs": [], 336 | "source": [ 337 | "tuner.tune()" 338 | ] 339 | }, 340 | { 341 | "cell_type": "code", 342 | "execution_count": null, 343 | "metadata": {}, 344 | "outputs": [], 345 | "source": [] 346 | } 347 | ], 348 | "metadata": { 349 | "kernelspec": { 350 | "display_name": "Python 3.10.6 (conda)", 351 | "language": "python", 352 | "name": "python3" 353 | }, 354 | "language_info": { 355 | "codemirror_mode": { 356 | "name": "ipython", 357 | "version": 3 358 | }, 359 | "file_extension": ".py", 360 | "mimetype": "text/x-python", 361 | "name": "python", 362 | "nbconvert_exporter": "python", 363 | "pygments_lexer": "ipython3", 364 | "version": "3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:38:29) [Clang 13.0.1 ]" 365 | }, 366 | "vscode": { 367 | "interpreter": { 368 | "hash": "babb4baf4e80bd80b9852210fc5469c0783907e52a560ed7247caef52808358d" 369 | } 370 | } 371 | }, 372 | "nbformat": 4, 373 | "nbformat_minor": 2 374 | } 375 | -------------------------------------------------------------------------------- /tutorials/3.tuning_DPG.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "disco \n", 9 | "Copyright (C) 2022-present NAVER Corp. \n", 10 | "Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license " 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "metadata": {}, 16 | "source": [ 17 | "# Tuning with DPG" 18 | ] 19 | }, 20 | { 21 | "cell_type": "markdown", 22 | "metadata": {}, 23 | "source": [ 24 | "Once we have expressed our preferences on the generated sequences, through an Energy-Based Model (EBM), we cannot directly sample from it. What we can do is approximate it by fine-tuning a model. \n", 25 | "Let's first see the case of classic, unconditional, ie with a fixed context, DPG." 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "## Expressing Preferences" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "Let's stick with our _amazing_ use case: we want the word to appear in our samples —see the [Expressing Preference](./2.expressing_preferences.ipynb) notebook for the detailed explanations." 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "from disco.scorers import BooleanScorer" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "import re\n", 58 | "\n", 59 | "is_amazing = lambda s, c: bool(re.search(r\"\\bamazing\\b\", s.text))\n", 60 | "amazing_scorer = BooleanScorer(is_amazing)" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "from disco.distributions import LMDistribution" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "metadata": {}, 75 | "source": [ 76 | "for a pointwise constraint:" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "base = LMDistribution()\n", 86 | "pw_target = base * amazing_scorer" 87 | ] 88 | }, 89 | { 90 | "cell_type": "markdown", 91 | "metadata": {}, 92 | "source": [ 93 | "for a distributional one:" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "from disco.distributions.single_context_distribution import SingleContextDistribution" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "incipit = \"It was a cold and stormy night\"" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "dc_target = base.constrain([amazing_scorer], [1/2],\n", 121 | " n_samples=2**10,\n", 122 | " context_distribution=SingleContextDistribution(incipit))" 123 | ] 124 | }, 125 | { 126 | "cell_type": "markdown", 127 | "metadata": {}, 128 | "source": [ 129 | "## Tuning" 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "metadata": {}, 135 | "source": [ 136 | "We then instantiate the model we want to tune —we'll tune the \"network\" inside the distribution." 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "metadata": {}, 143 | "outputs": [], 144 | "source": [ 145 | "model = LMDistribution(freeze=False)" 146 | ] 147 | }, 148 | { 149 | "cell_type": "markdown", 150 | "metadata": {}, 151 | "source": [ 152 | "Let's check the initial rate for our constraint." 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": null, 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "from disco.samplers import AccumulationSampler" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": null, 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [ 170 | "sampler = AccumulationSampler(model, total_size=2**9)\n", 171 | "samples, log_scores = sampler.sample(context=incipit)" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": null, 177 | "metadata": {}, 178 | "outputs": [], 179 | "source": [ 180 | "sum([is_amazing(s, _) for s in samples]) / len(samples)" 181 | ] 182 | }, 183 | { 184 | "cell_type": "markdown", 185 | "metadata": {}, 186 | "source": [ 187 | "### Offline" 188 | ] 189 | }, 190 | { 191 | "cell_type": "markdown", 192 | "metadata": {}, 193 | "source": [ 194 | "In the offline scheme, we use a companion proposal distribution to sample from, and update that proposal, eventually, during the tuning." 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": null, 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [ 203 | "proposal = LMDistribution()" 204 | ] 205 | }, 206 | { 207 | "cell_type": "markdown", 208 | "metadata": {}, 209 | "source": [ 210 | "We can now instantiate a tuner. We're going:\n", 211 | " * to tune model to approximate dc_target getting our samples from proposal;\n", 212 | " * to use a fixed incipit for the context;\n", 213 | " * to check the divergence every `divergence_evaluation_interval` gradient steps, when we'll also eventually update the proposal." 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": null, 219 | "metadata": {}, 220 | "outputs": [], 221 | "source": [ 222 | "from disco.tuners import DPGTuner" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": null, 228 | "metadata": {}, 229 | "outputs": [], 230 | "source": [ 231 | "tuner = DPGTuner(model, dc_target, proposal,\n", 232 | " context=incipit,\n", 233 | " n_gradient_steps=1000,\n", 234 | " n_samples_per_context=2**8,\n", 235 | " sampling_size=2**5,\n", 236 | " scoring_size=2**5,\n", 237 | " divergence_evaluation_interval=2**2,\n", 238 | " n_kl_samples=2**10)" 239 | ] 240 | }, 241 | { 242 | "cell_type": "markdown", 243 | "metadata": {}, 244 | "source": [ 245 | "There are loggers we can use to monitor the tuning. They are built on the observer patterns so it's easy to add more specific ones —although beyond the simple `ConsoleLoger` disco provides loggers for Neptune, Weight & Biases, ..." 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": null, 251 | "metadata": {}, 252 | "outputs": [], 253 | "source": [ 254 | "from disco.tuners.loggers.console import ConsoleLogger" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": null, 260 | "metadata": {}, 261 | "outputs": [], 262 | "source": [ 263 | "ConsoleLogger(tuner)" 264 | ] 265 | }, 266 | { 267 | "cell_type": "markdown", 268 | "metadata": {}, 269 | "source": [ 270 | "Let's dance!" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": null, 276 | "metadata": { 277 | "scrolled": false 278 | }, 279 | "outputs": [], 280 | "source": [ 281 | "tuner.tune()" 282 | ] 283 | }, 284 | { 285 | "cell_type": "markdown", 286 | "metadata": {}, 287 | "source": [ 288 | "Are we doing better?" 289 | ] 290 | }, 291 | { 292 | "cell_type": "code", 293 | "execution_count": null, 294 | "metadata": {}, 295 | "outputs": [], 296 | "source": [ 297 | "sampler = AccumulationSampler(model, total_size=512)\n", 298 | "samples, log_scores = sampler.sample(context=incipit)" 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": null, 304 | "metadata": {}, 305 | "outputs": [], 306 | "source": [ 307 | "sum([is_amazing(s, _) for s in samples]) / len(samples)" 308 | ] 309 | }, 310 | { 311 | "cell_type": "markdown", 312 | "metadata": {}, 313 | "source": [ 314 | "### Online tuning" 315 | ] 316 | }, 317 | { 318 | "cell_type": "markdown", 319 | "metadata": {}, 320 | "source": [ 321 | "In the online scheme, the model being tuned is also the one providing the samples, so we don't need a proposal." 322 | ] 323 | }, 324 | { 325 | "cell_type": "code", 326 | "execution_count": null, 327 | "metadata": {}, 328 | "outputs": [], 329 | "source": [ 330 | "model = LMDistribution(freeze=False)" 331 | ] 332 | }, 333 | { 334 | "attachments": {}, 335 | "cell_type": "markdown", 336 | "metadata": {}, 337 | "source": [ 338 | "_Note that, for an actual tuning, you might want to move the networks to GPUs first, for example with:_\n", 339 | "```\n", 340 | "model.to(\"cuda\")\n", 341 | "dc_target.scorers[0].to(\"cuda\")\n", 342 | "```" 343 | ] 344 | }, 345 | { 346 | "cell_type": "code", 347 | "execution_count": null, 348 | "metadata": {}, 349 | "outputs": [], 350 | "source": [ 351 | "tuner = DPGTuner(model, dc_target,\n", 352 | " context=incipit,\n", 353 | " n_gradient_steps=100,\n", 354 | " n_samples_per_context=2**8,\n", 355 | " sampling_size=2**5,\n", 356 | " scoring_size=2**5,\n", 357 | " divergence_evaluation_interval=1)" 358 | ] 359 | }, 360 | { 361 | "attachments": {}, 362 | "cell_type": "markdown", 363 | "metadata": {}, 364 | "source": [ 365 | "_Again, for an actual tuning, you might want to initiate logging, for example with:_\n", 366 | "```\n", 367 | "from disco.tuners.loggers.wandb import WandBLogger\n", 368 | "logger = WandBLogger(tuner, \"my_project\", \"my_run\")\n", 369 | "```" 370 | ] 371 | }, 372 | { 373 | "cell_type": "code", 374 | "execution_count": null, 375 | "metadata": {}, 376 | "outputs": [], 377 | "source": [ 378 | "tuner.tune()" 379 | ] 380 | }, 381 | { 382 | "cell_type": "code", 383 | "execution_count": null, 384 | "metadata": {}, 385 | "outputs": [], 386 | "source": [ 387 | "sampler = AccumulationSampler(model, total_size=2**9)\n", 388 | "samples, log_scores = sampler.sample(context=incipit)" 389 | ] 390 | }, 391 | { 392 | "cell_type": "code", 393 | "execution_count": null, 394 | "metadata": {}, 395 | "outputs": [], 396 | "source": [ 397 | "sum([is_amazing(s, _) for s in samples]) / len(samples)" 398 | ] 399 | }, 400 | { 401 | "cell_type": "code", 402 | "execution_count": null, 403 | "metadata": {}, 404 | "outputs": [], 405 | "source": [] 406 | } 407 | ], 408 | "metadata": { 409 | "kernelspec": { 410 | "display_name": "Python 3", 411 | "language": "python", 412 | "name": "python3" 413 | }, 414 | "language_info": { 415 | "codemirror_mode": { 416 | "name": "ipython", 417 | "version": 3 418 | }, 419 | "file_extension": ".py", 420 | "mimetype": "text/x-python", 421 | "name": "python", 422 | "nbconvert_exporter": "python", 423 | "pygments_lexer": "ipython3", 424 | "version": "3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:38:29) [Clang 13.0.1 ]" 425 | }, 426 | "vscode": { 427 | "interpreter": { 428 | "hash": "babb4baf4e80bd80b9852210fc5469c0783907e52a560ed7247caef52808358d" 429 | } 430 | } 431 | }, 432 | "nbformat": 4, 433 | "nbformat_minor": 2 434 | } 435 | -------------------------------------------------------------------------------- /tests/lm_distribution_test.py: -------------------------------------------------------------------------------- 1 | # disco 2 | # Copyright (C) 2022-present NAVER Corp. 3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license 4 | 5 | import unittest 6 | import random, torch 7 | import numpy as np 8 | from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM 9 | 10 | from disco.distributions import LMDistribution 11 | from disco.distributions.lm_distribution import TextSample 12 | 13 | prefix = "It was a cold and stormy night" 14 | 15 | class LMDistributionTest(unittest.TestCase): 16 | 17 | def setUp(self): 18 | test_models = [("gpt2", AutoModelForCausalLM), 19 | ("facebook/bart-base", AutoModelForSeq2SeqLM)] 20 | self.test_distributions = {model: LMDistribution(model, auto=auto) 21 | for (model, auto) in test_models} 22 | 23 | def test_instantiate_a_default_distribution(self): 24 | distribution = LMDistribution() 25 | self.assertTrue(hasattr(distribution, "length"), 26 | "the distribution should have a length attribute.") 27 | self.assertEqual(40, distribution.length, 28 | "the default length should be 40.") 29 | self.assertTrue(hasattr(distribution, "params"), 30 | "the distribution should have a params attribute.") 31 | self.assertTrue(hasattr(distribution, "scorable"), 32 | "the distribution should have a scorable attribute.") 33 | self.assertEqual(True, distribution.scorable, 34 | "the distribution should be scorable by default.") 35 | 36 | def test_sample_a_continuation_from_a_default_distribution(self): 37 | distribution = LMDistribution() 38 | samples, log_scores = distribution.sample(context=prefix) 39 | self.assertTrue(isinstance(samples, list), "the samples should be returned as a list.") 40 | from disco.distributions.lm_distribution import TextSample 41 | for sample in samples: 42 | self.assertTrue(isinstance(sample, TextSample), "each text should be a textSample.") 43 | self.assertTrue(isinstance(log_scores, torch.Tensor), "the log-scores should be returned as a tensor.") 44 | for log_score in log_scores: 45 | self.assertTrue(isinstance(log_score.item(), float), "each log-score should be a float.") 46 | 47 | def test_sample_multiple_continuations_for_a_prefix(self): 48 | distribution = LMDistribution() 49 | sampling_size = 8 50 | samples, log_scores = distribution.sample(sampling_size=sampling_size) 51 | self.assertEqual(sampling_size, len(samples), 52 | "the number of returned samples should be {}.".format(sampling_size)) 53 | self.assertEqual(sampling_size, len(log_scores), 54 | "the number of returned log_scores should be {}.".format(sampling_size)) 55 | 56 | def test_sample_a_continuation_with_temperature_close_to_zero(self): 57 | distribution = LMDistribution(temperature=0.001) 58 | samples, log_scores = distribution.sample(sampling_size=2, context=prefix) 59 | self.assertEqual(samples[0].text, samples[1].text, 60 | "samples should not vary when temperature is (almost) zero.") 61 | self.assertEqual(log_scores[0], log_scores[1], 62 | "log_scores should not vary when temperature is (almost) zero.") 63 | 64 | def test_score_sequences(self): 65 | prompt = "It was a cold and stormy night" 66 | distribution = LMDistribution() 67 | texts = [ 68 | " and the streets were empty and quiet.", 69 | "; the rain fell in torrents" 70 | ] 71 | tokenized_texts = distribution.tokenizer(texts, return_tensors="pt", add_special_tokens=True, padding=True) 72 | from disco.distributions.lm_distribution import TextSample 73 | samples = [TextSample(s, t) for (s, t) in zip(tokenized_texts["input_ids"], texts)] 74 | log_scores = distribution.log_score(samples, context=prefix) 75 | self.assertTrue(isinstance(log_scores, torch.Tensor), "the log-scores should be returned as a tensor.") 76 | self.assertEqual(len(samples), len(log_scores), "there should be as many log-scores as there are sequences.") 77 | for log_score in log_scores: 78 | self.assertTrue(isinstance(log_score.item(), float), "each log-score should be a float.") 79 | self.assertLess(log_score, 0.0, "each log-score should be negative.") 80 | 81 | def test_score_sampled_sequences(self): 82 | for model, distribution in self.test_distributions.items(): 83 | prompt = "It was a cold and stormy night" 84 | distribution = LMDistribution() 85 | samples, log_scores = distribution.sample(context=prefix) 86 | samples_log_scores_again = distribution.log_score(samples, context=prefix) 87 | self.assertLess((log_scores - samples_log_scores_again).abs().max(), 88 | 1e-3, "log-scores given at sampling time and the scoring of " 89 | "the same samples should match") 90 | 91 | def test_score_with_a_non_default_top_k_specified(self): 92 | distribution = LMDistribution(top_k=20) 93 | samples, _ = distribution.sample(context=prefix) 94 | self.assertRaises(AssertionError, distribution.log_score, samples, context=prefix) 95 | 96 | def test_score_with_a_non_default_top_p_specified(self): 97 | distribution = LMDistribution(top_p=0.92) 98 | samples, _ = distribution.sample(context=prefix) 99 | self.assertRaises(AssertionError, distribution.log_score, samples, context=prefix) 100 | 101 | def test_score_with_a_non_default_typical_p_specified(self): 102 | distribution = LMDistribution(typical_p=0.9) 103 | samples, _ = distribution.sample(context=prefix) 104 | self.assertRaises(AssertionError, distribution.log_score, samples, context=prefix) 105 | 106 | def test_score_with_a_non_default_temperature_specified(self): 107 | distribution = LMDistribution(temperature=0.7) 108 | samples, _ = distribution.sample(context=prefix) 109 | self.assertRaises(AssertionError, distribution.log_score, samples, context=prefix) 110 | 111 | def test_score_sequences_different_length(self): 112 | distribution = LMDistribution() 113 | texts = [ 114 | " and the streets were empty and quiet.", 115 | " in Brest." 116 | ] 117 | 118 | from disco.distributions.lm_distribution import TextSample 119 | samples = [ 120 | TextSample(distribution.tokenizer(t, return_tensors="pt", add_special_tokens=True)["input_ids"], t)\ 121 | for t in texts 122 | ] 123 | self.assertRaises(AssertionError, distribution.log_score, samples, context=prefix) 124 | 125 | def test_ignore_padding_at_score(self): 126 | for model, distribution in self.test_distributions.items(): 127 | text = "The streets were empty and quiet." 128 | context = "This is some context." if distribution.network.config.is_encoder_decoder else "" 129 | 130 | from disco.distributions.lm_distribution import TextSample 131 | sample = TextSample(distribution.tokenizer(text, return_tensors="pt", add_special_tokens=False)["input_ids"][0], text) 132 | eos_token_id = distribution.tokenizer.eos_token_id 133 | pad_token_id = distribution.tokenizer.pad_token_id 134 | 135 | eos_token = distribution.tokenizer.eos_token 136 | pad_token = distribution.tokenizer.pad_token 137 | completed_sample = TextSample( 138 | token_ids=torch.cat((sample.token_ids, torch.tensor([eos_token_id])), dim=0), 139 | text=text + eos_token) 140 | 141 | padded_sample = TextSample( 142 | token_ids=torch.cat((sample.token_ids, torch.tensor([eos_token_id, pad_token_id])), dim=0), 143 | text=text + pad_token + pad_token) 144 | 145 | log_scores = distribution.log_score([sample], context=context).item() 146 | completed_log_scores = distribution.log_score([completed_sample], context=context).item() 147 | padded_log_scores = distribution.log_score([padded_sample], context=context).item() 148 | 149 | self.assertNotEqual(0, log_scores) 150 | if completed_log_scores != float("-inf") and completed_log_scores != float("-inf"): 151 | self.assertNotAlmostEqual(log_scores, completed_log_scores, 4, 152 | f"The {model} scores must be different with and without and eos token.") 153 | self.assertAlmostEqual(completed_log_scores, padded_log_scores, 4, 154 | f"The {model} scores must not be different with and without a pad token after an eos token.") 155 | 156 | def test_ignore_padding_at_sample(self): 157 | distribution = LMDistribution() 158 | seed = 0 159 | torch.manual_seed(seed) 160 | np.random.seed(seed) 161 | random.seed(seed) 162 | samples, log_scores = distribution.sample() 163 | padded_sequence_id = 16 # found manually 164 | padded_sequence = samples[padded_sequence_id] 165 | padded_sequence_log_score = log_scores[padded_sequence_id] 166 | pad_token_id = distribution.tokenizer.pad_token_id 167 | self.assertTrue((padded_sequence.token_ids[1:] == pad_token_id).all()) 168 | padded_sequence_relog_score = distribution.log_score([padded_sequence]) 169 | self.assertLess(torch.abs(padded_sequence_log_score - padded_sequence_relog_score), 1e-2) 170 | 171 | def test_empty_sequence_score(self): 172 | for model, distribution in self.test_distributions.items(): 173 | pad_token_id = distribution.tokenizer.pad_token_id 174 | empty_sequence1_tok = [pad_token_id]*1 175 | empty_sequence1_str = distribution.tokenizer.decode(empty_sequence1_tok) 176 | empty_sequence1 = TextSample(token_ids=torch.tensor(empty_sequence1_tok), 177 | text=empty_sequence1_str) 178 | empty_sequence2_tok = [pad_token_id]*20 179 | empty_sequence2_str = distribution.tokenizer.decode(empty_sequence2_tok) 180 | empty_sequence2 = TextSample(token_ids=torch.tensor(empty_sequence2_tok), 181 | text=empty_sequence2_str) 182 | if distribution.network.config.is_encoder_decoder: 183 | log_score1 = distribution.log_score([empty_sequence1], context='yada yada') 184 | log_score2 = distribution.log_score([empty_sequence2], context='yada yada') 185 | else: 186 | log_score1 = distribution.log_score([empty_sequence1]) 187 | log_score2 = distribution.log_score([empty_sequence2]) 188 | self.assertAlmostEqual(log_score1.item(), log_score2.item(), 4, 189 | f"The {model} score of all pad tokens should correspond to a single one.") 190 | 191 | def test_scoring_consistent_with_loss(self): 192 | for model, distribution in self.test_distributions.items(): 193 | text = "The streets were empty and quiet." 194 | context = "This is an example." if distribution.network.config.is_encoder_decoder else "" 195 | 196 | from disco.distributions.lm_distribution import TextSample 197 | sample = TextSample(distribution.tokenizer(text, return_tensors="pt", add_special_tokens=True)["input_ids"][0], text) 198 | log_score = distribution.log_score([sample], context=context) 199 | if distribution.network.config.is_encoder_decoder: 200 | ctxt = distribution.tokenizer(context, return_tensors="pt", add_special_tokens=True) 201 | loss = distribution.network(input_ids=ctxt.input_ids, 202 | labels=sample.token_ids.unsqueeze(0)).loss 203 | else: 204 | loss = distribution.network(sample.token_ids, labels=sample.token_ids).loss 205 | self.assertLess(torch.abs(-log_score / len(sample.token_ids) - loss), 1e-2, 206 | f'The {model} score is not consistent with the loss reported by the forward function.') 207 | 208 | def test_sample_empty_sequence(self): 209 | distribution = LMDistribution("gpt2") 210 | pad_token_id = distribution.tokenizer.pad_token_id 211 | epsilon = 0.05 212 | # increase the likelihood of sampling pad_token_id 213 | distribution.network.lm_head.weight.data[pad_token_id, :] += epsilon 214 | seed = 1 215 | torch.manual_seed(seed) 216 | np.random.seed(seed) 217 | random.seed(seed) 218 | samples, log_scores = distribution.sample() 219 | self.assertTrue(any((s.token_ids == pad_token_id).all() for s in samples)) 220 | new_log_scores = distribution.log_score(samples) 221 | self.assertLess(torch.abs(log_scores - new_log_scores).max(), 1e-4) 222 | 223 | def test_freeze_parameters_by_default(self): 224 | distribution = LMDistribution() 225 | self.assertTrue(all([not p.requires_grad for p in distribution.network.parameters()])) 226 | 227 | def test_unfreeze_parameters_on_demand(self): 228 | distribution = LMDistribution() 229 | distribution.freeze(False) 230 | self.assertTrue(all([p.requires_grad for p in distribution.network.parameters()])) 231 | 232 | def test_unfreeze_all_parameters_with_parameter(self): 233 | distribution = LMDistribution(freeze=False) 234 | self.assertTrue(all([p.requires_grad for p in distribution.network.parameters()])) 235 | 236 | def test_freeze_all_parameters_on_demand(self): 237 | distribution = LMDistribution(freeze=False) 238 | distribution.freeze() 239 | self.assertTrue(all([not p.requires_grad for p in distribution.network.parameters()])) 240 | 241 | def test_clone(self): 242 | distribution = LMDistribution() 243 | distribution2 = distribution.clone() 244 | self.assertEqual(distribution.network.config, distribution2.network.config) 245 | self.assertNotEqual(distribution.network, distribution2.network) 246 | 247 | def test_bart_consistency(self): 248 | model = LMDistribution( 249 | model="facebook/bart-large-cnn", 250 | auto=AutoModelForSeq2SeqLM, 251 | length=128) 252 | 253 | context="""Self-trained autonomous agents developed using machine learning are showing great promise in a variety of control settings, 254 | perhaps most remarkably in applications involving autonomous vehicles. The main challenge associated with self-learned agents in the form 255 | of deep neural networks, is their black-box nature: it is impossible for humans to interpret deep neural networks. Therefore, humans cannot 256 | directly interpret the actions of deep neural network based agents, or foresee their robustness in different scenarios. In this work, 257 | we demonstrate a method for probing which concepts self-learning agents internalise in the course of their training. For demonstration, 258 | we use a chess playing agent in a fast and light environment developed specifically to be suitable for research groups without access to 259 | enormous computational resources or machine learning models.""" 260 | torch.manual_seed(1) 261 | samples, log_scores = model.sample(context=context, sampling_size=1, sum=False) 262 | ls_log_scores = model.log_score(samples, context=context, sum=False) 263 | sample_idx = 0 264 | self.assertAlmostEqual(log_scores.sum().item(), ls_log_scores.sum().item(), 4) 265 | 266 | if __name__ == '__main__': 267 | unittest.main() 268 | -------------------------------------------------------------------------------- /README.MD: -------------------------------------------------------------------------------- 1 | # 🕺🏽 disco: A Toolkit for Distributional Control of Generative Models 2 | 3 | The 🕺🏽 **disco** toolkit allows to control language models and other generative systems to match human preferences while avoiding catastrophic forgetting. 4 | 5 | To achieve this, **disco** decouples the problem of expressing _what_ properties the model should have from _how_ to actually get the desired model as separate steps. 6 | 7 | **Step 1: ⚓ We express how the target distribution *should* be** 8 | 9 | First, we define some feature over the generated samples that matters to us. It can be anything we can compute. For example, on a language model it can be as simple as whether the generated text contains a certain word or as complex as the compilability of some generated piece of code. Importantly, there is no need for the feature to be differentiable. 10 | 11 | Then, we can express our preferences on the target distribution by deciding how prevalent the feature should be. For example, we might want to ask that a certain word appears 50% of the time when sampling from the model; or that 100% of the generated code is compilable. The resulting target distribution is expressed as an energy-based model or EBM, which is an unnormalized probability distribution that respects the desired moments while avoiding catastrophic forgetting, as a result of having minimal KL divergence to the original model. 12 | 13 | The resulting representation of the target distribution can *score* samples, but cannot directly be used to *generate* them. 14 | 15 | **Step 2: 🎯 Approximate the target distribution** 16 | 17 | To generate samples from the target distribution we can tune a model to approximate it. We do this by minimizing the divergence to the target distribution. While techniques such as reinforcement learning from human feedback (RLHF) are restricted to using one kind of divergence only (specifically, reverse KL divergence), **disco** is more general, allowing the use of the full class of f-divergences, including both forward and reverse KL divergence, Jensen-Shannon, and total variation distance. 18 | 19 | **Step 3: 💬 Generate content that matches the preferences** 20 | 21 | The resulting model can generate samples directly from a close approximation of the target distribution. Furthermore, it can be used jointly with Quasi-Rejection Sampling (QRS), a Monte Carlo sampling technique that allows the generation of samples that are even more representative of the target distribution. 22 | Alternatively, it is then possible to use decoding methods such as nucleus sampling, top-k sampling, or beam search, which would return samples from a further updated target distribution. 23 | 24 | See the references below for more theoretical and technical details. 25 | 26 | ## Installation 27 | 28 | ### Standard installation 29 | 30 | The easiest way to install **disco** is to rely on pip, asking for the ```disco-generation``` package: 31 | 32 | ``` 33 | pip install disco-generation 34 | ``` 35 | 36 | Note that the toolkit: 37 | - depends on PyTorch, 38 | - uses HuggingFace's Transformers library for the generic handling of language models, as well as the generation of samples from them. 39 | 40 | ### Toolkit Developers 41 | 42 | If we plan to extend the toolkit we will need to clone and to install it as a local package. 43 | From the toolkit top folder, once we've git-cloned the repository and activated our development environment, we simply do: 44 | ``` 45 | pip install -e . 46 | ``` 47 | 48 | ## Quick introduction 49 | 50 | 51 | ### Distributions 52 | 53 | The generative model that we want to tune must be wrapped by a `Distribution` object. For example, for a (causal or seq2seq) language model compatible with the 🤗 Hugging Face interface use an `LMDistribution`. 54 | 55 | A valid `Distribution` must have the following two methods: 56 | - `.sample(context)` that given an optional `context` on which the distribution can be conditioned, returns a list of samples from the underlying distribution and a tensor with their corresponding log-probabilities; 57 | - `.log_score(samples, context)` that given a list of samples and the `context` on which to condition the distribution, returns their corresponding log-probabilities. 58 | 59 | ```python 60 | from disco.distributions import LMDistribution 61 | distribution = LMDistribution() 62 | 63 | incipit = "It was a cold and stormy night" 64 | samples, log_scores = distribution.sample(context=incipit) 65 | 66 | distribution.log_score(samples, context=incipit) 67 | ``` 68 | 69 | `LMDistribution` generate samples, with the `TextSample` type, which are named tuples with both a `text` and `token_ids` fields. 70 | 71 | From now on, after this initial example, imports will be skipped for clarity. 72 | 73 | ### Features 74 | 75 | Features are represented by an object with the method 76 | 77 | - `.score(samples, context)` which given a list of samples and an eventual context returns a tensor of real-valued scores. 78 | 79 | A convenient way to define one is using the `Scorer` class, which accepts a function or a lambda abstraction that takes sample and a context, and vectorizes it. For example, we can compute the effective length of a GPT-2 text sample by finding the eos token: 80 | 81 | ```python 82 | sequence_length = Scorer(lambda s, c: s.text.index("<|endoftext|>")) 83 | ``` 84 | 85 | Where `s` is the sample (assumed to be a `TextSample`) and `c` is an eventual context. 86 | 87 | #### Boolean Features 88 | 89 | An important class of features are *boolean* features. While general features can only be used to define *distributional* constraints, boolean features can also be used to define *pointwise* constraints, see below. To define one, we can use the `BooleanScorer` helper class, which takes a function as an argument. For example, we can score the presence of the string "amazing", as follows: 90 | 91 | ```python 92 | amazing = BooleanScorer(lambda s, c: "amazing" in s.text) 93 | ``` 94 | 95 | The ```False```/```True``` results from the lambda are casted to `0.0`/`1.0` float values so that they can be used in the EBM definition. 96 | 97 | `BooleanScorer` belongs to the more general `PositiveScorer` class, which can be used to construct EBMs. The main properties of a `PostiveScorer` are that first, it returns positive scorers, and second that it provides the method: 98 | 99 | - `.log_score(samples, context)` that given a list of samples and the `context` on which to condition the distribution, returns their corresponding log-probabilities. 100 | 101 | As a consequence, we can see that a ```Distribution``` is also a ```PositiveScorer``` that is able to sample as well. 102 | 103 | 104 | ### Controlling Generation 105 | 106 | #### Expressing preferences through an EBM 107 | 108 | We express preferences over the distribution by defining target moments for specific features. This results in a target distribution that matches the desired moments while minimizing the KL divergence to the original distribution. In other words, it incorporates the preferences while avoiding catastrophic forgetting. This distribution is represented as an EBM, which can be used to score samples, in other words it is a `PositiveScorer`, but cannot be used to sample, we'll see how to sample below. 109 | 110 | We can express either *pointwise* or *distributional* constraints on a distribution and compose them at will. The former expresses a (boolean) property that must apply to *all* sequences, whereas the latter represents properties at the distributional level. 111 | 112 | To obtain the target distribution that incorporates our constraints, we use the `constraint` method of the corresponding `Distribution`. This method takes a list of features and their corresponding target moments. 113 | 114 | For example, we can define an EBM with a *pointwise* constraint requiring that all our samples must include "amazing" by setting the target moment to `1` on a `BooleanFeature`: 115 | 116 | ```python 117 | target_ebm = base.constrain([amazing], [1]) 118 | ``` 119 | 120 | Or we can ask for a _distributional_ constraint requiring that _half_ of the samples include "amazing": 121 | 122 | ```python 123 | target_ebm = base.constrain([amazing], [1/2]) 124 | ``` 125 | 126 | 127 | #### Approximating the target EBM 128 | 129 | 130 | Given an EBM target distribution, we now want to train a model to approximate it so that we can use it to generate samples. In the _unconditional_ case, namely when there is a single fixed context used in generation, then we can use a `Tuner`, more specifically a ```DPGTuner```, as follows. 131 | 132 | 133 | ```python 134 | target_ebm = base.constrain([amazing], [1]) 135 | 136 | model = LMDistribution(freeze=False) 137 | incipit = "It was a cold and stormy night" 138 | 139 | tuner = DPGTuner(model, target_ebm, context=incipit) 140 | tuner.tune() 141 | ``` 142 | 143 | And we can sample _amazing_ sequences from the tuned model. 144 | ```python 145 | samples, log_scores = model.sample(context=incipit) 146 | for s in samples: 147 | print(incipit + s.text) 148 | ``` 149 | 150 | ##### Tuning parameters 151 | 152 | Important parameters of the `Tuner` include: 153 | 154 | - `n_gradient_steps`: number of total gradient steps in the full tuning process; 155 | - `n_samples_per_context`: total number of samples used in performing a gradient step (aka batch size); 156 | - `scoring_size`: number of samples sent in a batch to the `.score` function. This parameter affects training speed or helps solve GPU memory errors, but does not affect final results; 157 | - `sampling_size`: number of samples obtained from a single call to the `.sample` function. This parameter affects training speed or helps solve GPU memory errors, but does not affect final results; 158 | - `features`: list of pairs (`name`, `feature`) so that the `feature` moments will be computed by importance sampling (and reported using the key given by `name`); 159 | - `track_divergence_from_base`: set to True to track the reverse KL divergence from the original model —this requires an additional round of samples' scoring). 160 | 161 | #### Logging 162 | 163 | The Tuner reports a number of metrics that are useful to monitor the training progress. A number of `Logger` classes are provided to keep track of these metrics. Basic logging is provided though the console, as follows: 164 | 165 | ```python 166 | console_logger = ConsoleLogger(tuner) 167 | ``` 168 | 169 | However, more detailed statistics can be kept trhough a JSON/WandB/Neptune loggers: 170 | 171 | ```python 172 | project = "example_project" 173 | name = "run_01" 174 | json_logger = JSONLogger(tuner, project, name) 175 | neptune_logger = NeptuneLogger(tuner, project, name) 176 | wandb_logger = WandBLogger(tuner, project, name) 177 | ``` 178 | 179 | where `project` and `name` refer to the project and run name, respectively. 180 | 181 | ##### Logged Metrics 182 | 183 | Loggers store a number of metrics about the training process. Here we list a few of the most relevant ones: 184 | 185 | - `kl_target_model` and `kl_target_proposal`: estimates of the forward KL divergence to the target EBM from the tuned model and the proposal distribution, respectively. In the case of using online training, the two are equivalent with the only caveat that `kl_target_model` is computed —this is the metric being optimized, and not the value reported as `loss`; 186 | - `kl_model_base` and `kl_proposal_base`: estimates of the reverse KL divergence to the original model of the tuned model and the proposal distribution, respectively —only reported if `track_divergence_from_base` is set to True; 187 | - Feature moments: Estimate of the features' moments for those features specified with the `features` parameter at the Tuner's construction time. 188 | 189 | ### Controlled Conditional Generation 190 | 191 | The _conditional_ case is superficially very similar, with an extra step needed to instantiate a `ContextDistribution`, which allows to sample contexts that can then be used to condition the model. Furthermore, we use the more general ```CDPGTuner``` class. 192 | 193 | Assuming we have a file of incipits, one per line, in a `data/incipits.txt` file, we could do: 194 | 195 | ```python 196 | target_ebm = base.constrain([amazing], [1]) 197 | 198 | model = LMDistribution(freeze=False) 199 | 200 | tuner = CDPGTuner(model, target_ebm, 201 | context_distribution=ContextDistribution("data/incipits.txt"), 202 | n_contexts_per_step=2**3) 203 | tuner.tune() 204 | ``` 205 | 206 | Note that while we have used a decoder-only model here for illustrative purposes, the real power of the CDPGTuner is that it allows to control _seq2seq models_ such as those used in NMT, summarization, etc... Please refer to the dedicated [tutorial notebook](tutorials/4.conditional_tuning.ipynb) for an example of how to control an actual conditional model. 207 | 208 | 209 | ### Improving the approximation through minimizing other f-divergences 210 | 211 | The Tuner classes train the model by minimizing its divergence from the distribution induced by the target ebm. While in the original DPG and CDPG algorithms this divergence was always the KL divergence, [recent work](https://arxiv.org/abs/2302.08215) has generalized them to the wider class of [f-divergences](https://en.wikipedia.org/wiki/F-divergence). To pick the loss to minimize, use the corresponding `FDPGTuner` or `FCDPGTuner` class, depending on whether you are in the unconditional or conditional case, and choose the divergence to minimize through the `loss` parameter. Some of the possible losses are: 212 | 213 | - `KLLoss()`: KL divergence 214 | - `ReverseKLLoss()`: KL divergence reversing the order of the arguments 215 | - `TVLoss()`: Total Variation Distance 216 | - `JSLoss()`: Jensen-Shannon divergence. 217 | 218 | Each of these divergences strikes a different balance between level of alignment and diversity. KL exhibits lower alignment than other losses, but higher diversity. On the other hand, ReverseKL tends to produce hight alignment at the cost of lower diversity. Jensen-Shannon strikes a good balance between the two, making it a good default choice. 219 | 220 | As an example, 221 | 222 | ```python 223 | target_ebm = base.constrain([amazing], [1]) 224 | 225 | model = LMDistribution(freeze=False) 226 | 227 | tuner = FCDPGTuner(model, target_ebm, loss=JSLoss(), 228 | context_distribution=ContextDistribution("data/incipits.txt"), 229 | n_contexts_per_step=2**3) 230 | tuner.tune() 231 | ``` 232 | 233 | ### Reinforcement Learning from Human Feedback (RLHF) 234 | 235 | RLHF is a popular paradigm for aligning language models to preferences. While RLHF is commonly known in the form of a reward maximization algorithm, [recent work](https://arxiv.org/abs/2206.00761) has shown that it is equivalent to a distribution approximation problem which can be easily handled by **disco**. Specifically, given a reward function `r(x)` and the regularization parameter `beta`, the following code optimizes the same objective as RLHF: 236 | 237 | ```python 238 | target_ebm = base * ExponentialScorer([r], [1./beta]) 239 | 240 | model = base.clone() 241 | 242 | 243 | tuner = FCDPGTuner(model, target_ebm, loss=ReverseKLLoss()) 244 | ``` 245 | 246 | In other words, RLHF optimizes the *reverse* KL to the above-defined target EBM. Interestingly, this opens new opportunities as the divergence to be minimized could be now chosen to be any other, as explored [here](https://arxiv.org/abs/2302.08215). 247 | 248 | 249 | #### Monte-Carlo sampling to improve the approximation 250 | 251 | After the tuning is done, `model` is now a better approximation to the target EBM, but it is not guaranteed to perfectly match this distribution. While further training can improve the situation, another alternative is using [quasi-rejection sampling (QRS)](https://disco.europe.naverlabs.com/QRS/), a Monte-Carlo sampling technique that allows to trade-off sampling efficiency for a higher fidelity to the target distribution —a higher value of `beta` yields a better fidelity although at a higher computational cost. 252 | 253 | ```python 254 | beta=0.5 255 | sampler = QuasiRejectionSampler(target_ebm, model, beta=beta) 256 | samples, log_scores = sampler.sample(sampling_size=2**7) 257 | ``` 258 | 259 | #### In summary 260 | 261 | To put some of this (distributional constraint, tuning in the unconditional case and using QRS) together: 262 | 263 | ```python 264 | base = LMDistribution() 265 | target_ebm = base.constrain([amazing], [1/2]) 266 | 267 | model = LMDistribution(freeze=False) 268 | 269 | tuner = DPGTuner(model, target_ebm) 270 | tuner.tune() 271 | 272 | beta=0.5 273 | sampler = QuasiRejectionSampler(target_ebm, model, beta=beta) 274 | samples, log_scores = sampler.sample(context=incipit, sampling_size=2**7) 275 | ``` 276 | 277 | ### Going further 278 | 279 | A few things to keep in mind while reading the following paragraphs showing the principles of **disco**: 280 | 1. this is only an introduction, skipping some details and relying on toyish use cases; 281 | 1. the notebooks in the tutorials folder go in more depth, on more use cases; 282 | 1. the focus here and in most notebooks is on natural language, but again the toolkit can be used to control distributions over sequences such as code or chess moves, or even other data types, as long as they respect the basic assumptions of a disco `Distribution` object. 283 | 284 | ## References 285 | 286 | The **disco** toolkit implements the theoretical framework presented in the following works: 287 | - A Distributional Approach to Controlled Text Generation, Khalifa et al., 2021, , ICLR; 288 | - An approximate sampler for energy-based models with divergence diagnostics, Eikema et al., 2022, , TMLR; 289 | - Energy-Based Models for Code Generation under Compilability Constraints, Korbak et al., 2021, , ACL (Workshop on Natural Language Processing for Programming); 290 | - Controlling Conditional Language Models without Catastrophic Forgetting, Korbak et al., 2022, , ICML; 291 | - On Reinforcement Learning and Distribution Matching for Fine-Tuning Language Models with no Catastrophic Forgetting, Korbak et al., 2022, , NeurIPS; 292 | - Aligning Language Models with Preferences through f-divergence Minimization, Go et al., 2023, https://arxiv.org/abs/2302.08215, ICML. 293 | 294 | To cite **disco**, please use: 295 | ``` 296 | @inproceedings{kruszewski-etal-2023-disco, 297 | title = "disco: a toolkit for Distributional Control of Generative Models", 298 | author = "Kruszewski, Germ{\'a}n and 299 | Rozen, Jos and 300 | Dymetman, Marc", 301 | booktitle = "Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 3: System Demonstrations)", 302 | month = jul, 303 | year = "2023", 304 | address = "Toronto, Canada", 305 | publisher = "Association for Computational Linguistics", 306 | url = "https://aclanthology.org/2023.acl-demo.14", 307 | pages = "144--160", 308 | abstract = "Pre-trained language models and other generative models have revolutionized NLP and beyond. However, these models tend to reproduce undesirable biases present in their training data. Also, they may overlook patterns that are important but challenging to capture. To address these limitations, researchers have introduced distributional control techniques. These techniques, not limited to language, allow controlling the prevalence (i.e. expectations) of any features of interest in the model{'}s outputs. Despite their potential, the widespread adoption of these techniques has been hindered by the difficulty in adapting the complex, disconnected code. Here, we present disco, an open-source Python library that brings these techniques to the broader public", 309 | } 310 | ``` 311 | 312 | ## License 313 | 314 | See [LICENSE](LICENSE) file. 315 | -------------------------------------------------------------------------------- /tutorials/1.quick_introduction.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "disco \n", 9 | "Copyright (C) 2022-present NAVER Corp. \n", 10 | "Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license " 11 | ] 12 | }, 13 | { 14 | "attachments": {}, 15 | "cell_type": "markdown", 16 | "metadata": {}, 17 | "source": [ 18 | "# Quick introduction" 19 | ] 20 | }, 21 | { 22 | "attachments": {}, 23 | "cell_type": "markdown", 24 | "metadata": {}, 25 | "source": [ 26 | "Note that this is only an introduction, from the [README](../README.MD), skipping some details and relying on toyish use cases: \n", 27 | "see the other notebooks in the tutorial folder for more depth, and more use cases." 28 | ] 29 | }, 30 | { 31 | "attachments": {}, 32 | "cell_type": "markdown", 33 | "metadata": {}, 34 | "source": [ 35 | "## Distributions" 36 | ] 37 | }, 38 | { 39 | "attachments": {}, 40 | "cell_type": "markdown", 41 | "metadata": {}, 42 | "source": [ 43 | "The generative model that you want to tune must be wrapped by a `Distribution` object. For example, for a (causal or seq2seq) language model compatible with the 🤗 Hugging Face interface use an `LMDistribution`." 44 | ] 45 | }, 46 | { 47 | "attachments": {}, 48 | "cell_type": "markdown", 49 | "metadata": {}, 50 | "source": [ 51 | "A valid `Distribution` must have the following two methods:\n", 52 | "- `.sample(context)` that given an optional `context` on which the distribution can be conditioned, returns a list of samples from the underlying distribution and a tensor with their corresponding log-probabilities. \n", 53 | "- `.log_score(samples, context)` that given a list of samples and the `context` on which to condition the distribution, returns their corresponding log-probabilities." 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "from disco.distributions import LMDistribution\n", 63 | "\n", 64 | "distribution = LMDistribution()" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "incipit = \"It was a cold and stormy night\"\n", 74 | "samples, log_scores = distribution.sample(context=incipit)" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "distribution.log_score(samples, context=incipit)" 84 | ] 85 | }, 86 | { 87 | "attachments": {}, 88 | "cell_type": "markdown", 89 | "metadata": {}, 90 | "source": [ 91 | "`LMDistribution` generate samples, with the `TextSample` type, which are named tuples with both a `text` and `token_ids` fields." 92 | ] 93 | }, 94 | { 95 | "attachments": {}, 96 | "cell_type": "markdown", 97 | "metadata": {}, 98 | "source": [ 99 | "### Features" 100 | ] 101 | }, 102 | { 103 | "attachments": {}, 104 | "cell_type": "markdown", 105 | "metadata": {}, 106 | "source": [ 107 | "Features are represented by an object with the method\n", 108 | "- `.score(samples, context)` which given a list of samples and an eventual context returns a tensor of real-valued scores." 109 | ] 110 | }, 111 | { 112 | "attachments": {}, 113 | "cell_type": "markdown", 114 | "metadata": {}, 115 | "source": [ 116 | "A convenient way to define one is using the `Scorer` class, which accepts a function, or a lambda definition, that takes sample and a context, and vectorizes it. For example, we can compute the effective length of a GPT-2 text sample by finding the eos token:" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": null, 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "from disco.scorers.scorer import Scorer\n", 126 | "\n", 127 | "sequence_length = Scorer(lambda s, c: s.text.index(\"<|endoftext|>\"))" 128 | ] 129 | }, 130 | { 131 | "attachments": {}, 132 | "cell_type": "markdown", 133 | "metadata": {}, 134 | "source": [ 135 | "where `s` is the sample (assumed to be a `TextSample`) and `c` is an eventual context." 136 | ] 137 | }, 138 | { 139 | "attachments": {}, 140 | "cell_type": "markdown", 141 | "metadata": {}, 142 | "source": [ 143 | "#### Boolean Features" 144 | ] 145 | }, 146 | { 147 | "attachments": {}, 148 | "cell_type": "markdown", 149 | "metadata": {}, 150 | "source": [ 151 | "An important class of features are *boolean* features. While general features can only be used to define *distributional* constraints, boolean features can also be used to define *pointwise* constraints, see below. To define one, we can use the `BooleanScorer` helper class, which takes a function as an argument. \n", 152 | "For example, we can score the presence of the string \"amazing\", as follows:" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": null, 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "from disco.scorers.boolean_scorer import BooleanScorer" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": null, 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [ 170 | "amazing = BooleanScorer(lambda s, c: \"amazing\" in s.text)" 171 | ] 172 | }, 173 | { 174 | "attachments": {}, 175 | "cell_type": "markdown", 176 | "metadata": {}, 177 | "source": [ 178 | "The ```False```/```True``` results from the lambda are casted to `0.0`/`1.0` float values so that they can be used in the EBM definition. " 179 | ] 180 | }, 181 | { 182 | "attachments": {}, 183 | "cell_type": "markdown", 184 | "metadata": {}, 185 | "source": [ 186 | "`BooleanScorer` belongs to the more general class of `PositiveScorer`s, which can be used to construct EBMs. The main properties of a `PostiveScorer` is that first, it returns positive scorers, and second that it provides the method\n", 187 | " \n", 188 | " - `.log_score(samples, context)` that given a list of samples and the `context` on which to condition the distribution, returns their corresponding log-probabilities." 189 | ] 190 | }, 191 | { 192 | "attachments": {}, 193 | "cell_type": "markdown", 194 | "metadata": {}, 195 | "source": [ 196 | "As a consequence, we can see that a ```Distribution``` is also a ```PositiveScorer``` that is able to sample as well." 197 | ] 198 | }, 199 | { 200 | "attachments": {}, 201 | "cell_type": "markdown", 202 | "metadata": {}, 203 | "source": [ 204 | "## Controlling Generation" 205 | ] 206 | }, 207 | { 208 | "attachments": {}, 209 | "cell_type": "markdown", 210 | "metadata": {}, 211 | "source": [ 212 | "### Expressing preferences in an EBM" 213 | ] 214 | }, 215 | { 216 | "attachments": {}, 217 | "cell_type": "markdown", 218 | "metadata": {}, 219 | "source": [ 220 | "We express preferences over the distribution by defining target moments for specific features. This results in a target distribution that matches the desired moments while minimizing the KL divergence to the original distribution. In other words, it incorporates the preferences while avoiding catastrophic forgetting. This distribution is represented as an EBM, which can be used to score samples, in other words it is a `PositiveScorer`, but cannot be used to sample, we'll see how to sample below." 221 | ] 222 | }, 223 | { 224 | "attachments": {}, 225 | "cell_type": "markdown", 226 | "metadata": {}, 227 | "source": [ 228 | "We can express either *pointwise* or *distributional* constraints on a distribution and compose them at will. The former expresses a (boolean) property that must apply to *all* sequences, whereas the latter represents properties at the distributional level. " 229 | ] 230 | }, 231 | { 232 | "attachments": {}, 233 | "cell_type": "markdown", 234 | "metadata": {}, 235 | "source": [ 236 | "To obtain the target distribution that incorporates our constraints, we use the `constraint` method of the corresponding `Distribution`. This method takes a list of features and their corresponding target moments." 237 | ] 238 | }, 239 | { 240 | "attachments": {}, 241 | "cell_type": "markdown", 242 | "metadata": {}, 243 | "source": [ 244 | "For example, we can define an EBM with a *pointwise* constraint requiring that all our samples must include \"amazing\" by setting the target moment to `1` on a `BooleanFeature`:" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": null, 250 | "metadata": {}, 251 | "outputs": [], 252 | "source": [ 253 | "from disco.distributions.lm_distribution import LMDistribution\n", 254 | "\n", 255 | "base = LMDistribution()\n", 256 | "ebm = base.constrain([amazing], [1])\n", 257 | "# ebm = base.constrain([amazing]) # would also work\n", 258 | "# ebm = base * amazing # as well, using a product notation " 259 | ] 260 | }, 261 | { 262 | "attachments": {}, 263 | "cell_type": "markdown", 264 | "metadata": {}, 265 | "source": [ 266 | "Or we can ask for a _distributional_ constraint requiring that _half_ of the samples include \"amazing\":" 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "execution_count": null, 272 | "metadata": {}, 273 | "outputs": [], 274 | "source": [ 275 | "import os\n", 276 | "# disabling parallelism to avoid deadlocks (see warning from HF's tokenizer)\n", 277 | "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"" 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": null, 283 | "metadata": {}, 284 | "outputs": [], 285 | "source": [ 286 | "ebm = base.constrain([amazing], [1/2])" 287 | ] 288 | }, 289 | { 290 | "attachments": {}, 291 | "cell_type": "markdown", 292 | "metadata": {}, 293 | "source": [ 294 | "### Approximating the target EBM" 295 | ] 296 | }, 297 | { 298 | "attachments": {}, 299 | "cell_type": "markdown", 300 | "metadata": {}, 301 | "source": [ 302 | "Given an EBM target distribution, we now want to train a model to approximate it so that we can use it to generate samples. In the _unconditional_ case, namely when there is a single fixed context used in generation, then we can use a `Tuner`, more specifically a ```DPGTuner```, as follows." 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": null, 308 | "metadata": {}, 309 | "outputs": [], 310 | "source": [ 311 | "from disco.tuners.dpg_tuner import DPGTuner" 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": null, 317 | "metadata": {}, 318 | "outputs": [], 319 | "source": [ 320 | "target_ebm = base.constrain([amazing], [1])\n", 321 | "\n", 322 | "model = LMDistribution(freeze=False)\n", 323 | "incipit = \"It was a cold and stormy night\"\n", 324 | "\n", 325 | "tuner = DPGTuner(model, target_ebm, context=incipit, n_gradient_steps=4)\n", 326 | "tuner.tune()" 327 | ] 328 | }, 329 | { 330 | "attachments": {}, 331 | "cell_type": "markdown", 332 | "metadata": {}, 333 | "source": [ 334 | "And we can sample _amazing_ sequences from the tuned model." 335 | ] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "execution_count": null, 340 | "metadata": {}, 341 | "outputs": [], 342 | "source": [ 343 | "samples, log_scores = model.sample(context=incipit)\n", 344 | "for s in samples:\n", 345 | " print(incipit + s.text)" 346 | ] 347 | }, 348 | { 349 | "attachments": {}, 350 | "cell_type": "markdown", 351 | "metadata": {}, 352 | "source": [ 353 | "#### Tuning parameters" 354 | ] 355 | }, 356 | { 357 | "attachments": {}, 358 | "cell_type": "markdown", 359 | "metadata": {}, 360 | "source": [ 361 | "Important parameters of the `Tuner` include:\n", 362 | "- `n_gradient_steps`: number of total gradient steps in the full tuning process;\n", 363 | "- `n_samples_per_context`: total number of samples used in performing a gradient step (aka batch size);\n", 364 | "- `scoring_size`: number of samples sent in a batch to the `.score` function. This parameter affects training speed or helps solve GPU memory errors, but does not affect final results;\n", 365 | "- `sampling_size`: number of samples obtained from a single call to the `.sample` function. This parameter affects training speed or helps solve GPU memory errors, but does not affect final results;\n", 366 | "- `features`: list of pairs (`name`, `feature`) so that the `feature` moments will be computed by importance sampling (and reported using the key given by `name`);\n", 367 | "- `track_divergence_from_base`: set to True to track the reverse KL divergence from the original model —this requires an additional round of samples' scoring." 368 | ] 369 | }, 370 | { 371 | "attachments": {}, 372 | "cell_type": "markdown", 373 | "metadata": {}, 374 | "source": [ 375 | "### Logging" 376 | ] 377 | }, 378 | { 379 | "attachments": {}, 380 | "cell_type": "markdown", 381 | "metadata": {}, 382 | "source": [ 383 | "The Tuner reports a number of metrics that are useful to monitor the training progress. A number of `Logger` classes are provided to keep track of these metrics. Basic logging is provided though the console, as follows:" 384 | ] 385 | }, 386 | { 387 | "cell_type": "code", 388 | "execution_count": null, 389 | "metadata": {}, 390 | "outputs": [], 391 | "source": [ 392 | "from disco.tuners.loggers.console import ConsoleLogger" 393 | ] 394 | }, 395 | { 396 | "attachments": {}, 397 | "cell_type": "markdown", 398 | "metadata": {}, 399 | "source": [ 400 | "console_logger = ConsoleLogger(tuner)" 401 | ] 402 | }, 403 | { 404 | "attachments": {}, 405 | "cell_type": "markdown", 406 | "metadata": {}, 407 | "source": [ 408 | "However, more detailed statistics can be kept through a JSON/WandB/Neptune loggers:" 409 | ] 410 | }, 411 | { 412 | "cell_type": "code", 413 | "execution_count": null, 414 | "metadata": {}, 415 | "outputs": [], 416 | "source": [ 417 | "from disco.tuners.loggers.json import JSONLogger\n", 418 | "from disco.tuners.loggers.neptune import NeptuneLogger\n", 419 | "from disco.tuners.loggers.wandb import WandBLogger\n" 420 | ] 421 | }, 422 | { 423 | "cell_type": "code", 424 | "execution_count": null, 425 | "metadata": {}, 426 | "outputs": [], 427 | "source": [ 428 | "project = \"example_project\"\n", 429 | "name = \"run_01\"\n", 430 | "json_logger = JSONLogger(tuner, project, name)\n", 431 | "neptune_logger = NeptuneLogger(tuner, project, name)\n", 432 | "wandb_logger = WandBLogger(tuner, project, name)" 433 | ] 434 | }, 435 | { 436 | "attachments": {}, 437 | "cell_type": "markdown", 438 | "metadata": {}, 439 | "source": [ 440 | "where `project` and `name` refer to the project and run name, respectively." 441 | ] 442 | }, 443 | { 444 | "attachments": {}, 445 | "cell_type": "markdown", 446 | "metadata": {}, 447 | "source": [ 448 | "#### Logged Metrics" 449 | ] 450 | }, 451 | { 452 | "attachments": {}, 453 | "cell_type": "markdown", 454 | "metadata": {}, 455 | "source": [ 456 | "Loggers store a number of metrics about the training process. Here we list a few of the most relevant ones:\n", 457 | "\n", 458 | "- `kl_target_model` and `kl_target_proposal`: estimates of the forward KL divergence to the target EBM from the tuned model and the proposal distribution, respectively. In the case of using online training, the two are equivalent with the only caveat that `kl_target_model` is computed —this is the metric being optimized, and not the value reported as `loss`;\n", 459 | "- `kl_model_base` and `kl_proposal_base`: estimates of the reverse KL divergence to the original model of the tuned model and the proposal distribution, respectively —only reported if `track_divergence_from_base` is set to True;\n", 460 | "- Feature moments: estimate of the features' moments for those features specified with the `features` parameter at the Tuner's construction time." 461 | ] 462 | }, 463 | { 464 | "attachments": {}, 465 | "cell_type": "markdown", 466 | "metadata": {}, 467 | "source": [ 468 | "## Controlled Conditional Generation" 469 | ] 470 | }, 471 | { 472 | "attachments": {}, 473 | "cell_type": "markdown", 474 | "metadata": {}, 475 | "source": [ 476 | "The _conditional_ case is superficially very similar, with an extra step needed to instantiate a `ContextDistribution`, which allows to sample contexts that can then be used to condition the model. Furthermore, we use the more general ```CDPGTuner``` class." 477 | ] 478 | }, 479 | { 480 | "attachments": {}, 481 | "cell_type": "markdown", 482 | "metadata": {}, 483 | "source": [ 484 | "Assuming we have a file of incipits, one per line, in a `data/incipits.txt` file, we could do:" 485 | ] 486 | }, 487 | { 488 | "cell_type": "code", 489 | "execution_count": null, 490 | "metadata": {}, 491 | "outputs": [], 492 | "source": [ 493 | "from disco.tuners.cdpg_tuner import CDPGTuner\n", 494 | "from disco.distributions.context_distribution import ContextDistribution" 495 | ] 496 | }, 497 | { 498 | "cell_type": "code", 499 | "execution_count": null, 500 | "metadata": {}, 501 | "outputs": [], 502 | "source": [ 503 | "target_ebm = base.constrain([amazing], [1])\n", 504 | "\n", 505 | "model = LMDistribution(freeze=False)\n", 506 | "\n", 507 | "tuner = CDPGTuner(model, target_ebm, n_gradient_steps=4, # use a much higher value for actual tuning\n", 508 | " context_distribution=ContextDistribution(\"data/incipits.txt\"), n_contexts_per_step=2**3)\n", 509 | "tuner.tune()" 510 | ] 511 | }, 512 | { 513 | "attachments": {}, 514 | "cell_type": "markdown", 515 | "metadata": {}, 516 | "source": [ 517 | "Note that while we have used a decoder-only model here for illustrative purposes, the real power of the CDPGTuner is that it allows to control _seq2seq models_ such as those used in NMT, summarization, etc... Please refer to the dedicated [tutorial notebook](tutorials/4.conditional_tuning.ipynb) for an example of how to control an actual conditional model." 518 | ] 519 | }, 520 | { 521 | "attachments": {}, 522 | "cell_type": "markdown", 523 | "metadata": {}, 524 | "source": [ 525 | "### Monte-Carlo sampling to improve the approximation" 526 | ] 527 | }, 528 | { 529 | "attachments": {}, 530 | "cell_type": "markdown", 531 | "metadata": {}, 532 | "source": [ 533 | "After the tuning is done, `model` is now a better approximation to the target EBM, but it is not guaranteed to perfectly match this distribution. While further training can improve the situation, another alternative is using [quasi-rejection sampling (QRS)](https://disco.europe.naverlabs.com/QRS/), a Monte-Carlo sampling technique that allows to trade-off sampling efficiency for a higher fidelity to the target distribution —a higher value of `beta` yields a better fidelity although at a higher computational cost." 534 | ] 535 | }, 536 | { 537 | "cell_type": "code", 538 | "execution_count": null, 539 | "metadata": {}, 540 | "outputs": [], 541 | "source": [ 542 | "from disco.samplers.quasi_rejection_sampler import QuasiRejectionSampler" 543 | ] 544 | }, 545 | { 546 | "cell_type": "code", 547 | "execution_count": null, 548 | "metadata": {}, 549 | "outputs": [], 550 | "source": [ 551 | "beta=0.5\n", 552 | "sampler = QuasiRejectionSampler(target_ebm, model, beta=beta)\n", 553 | "samples, log_scores = sampler.sample(sampling_size=2**7)" 554 | ] 555 | }, 556 | { 557 | "attachments": {}, 558 | "cell_type": "markdown", 559 | "metadata": {}, 560 | "source": [ 561 | "### In summary" 562 | ] 563 | }, 564 | { 565 | "attachments": {}, 566 | "cell_type": "markdown", 567 | "metadata": {}, 568 | "source": [ 569 | "To put some of this (distributional constraint, tuning in the unconditional case and using QRS) together:" 570 | ] 571 | }, 572 | { 573 | "cell_type": "code", 574 | "execution_count": null, 575 | "metadata": {}, 576 | "outputs": [], 577 | "source": [ 578 | "base = LMDistribution()\n", 579 | "target_ebm = base.constrain([amazing], [1/2])\n", 580 | "\n", 581 | "model = LMDistribution()\n", 582 | "\n", 583 | "tuner = DPGTuner(model, target_ebm)\n", 584 | "tuner.tune()\n", 585 | "\n", 586 | "beta=0.5\n", 587 | "sampler = QuasiRejectionSampler(target_ebm, model, beta=beta)\n", 588 | "samples, log_scores = sampler.sample(context=incipit, sampling_size=2**7)" 589 | ] 590 | }, 591 | { 592 | "cell_type": "code", 593 | "execution_count": null, 594 | "metadata": {}, 595 | "outputs": [], 596 | "source": [] 597 | } 598 | ], 599 | "metadata": { 600 | "kernelspec": { 601 | "display_name": "Python 3.10.6 (conda)", 602 | "language": "python", 603 | "name": "python3" 604 | }, 605 | "language_info": { 606 | "codemirror_mode": { 607 | "name": "ipython", 608 | "version": 3 609 | }, 610 | "file_extension": ".py", 611 | "mimetype": "text/x-python", 612 | "name": "python", 613 | "nbconvert_exporter": "python", 614 | "pygments_lexer": "ipython3", 615 | "version": "3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:38:29) [Clang 13.0.1 ]" 616 | }, 617 | "orig_nbformat": 4, 618 | "vscode": { 619 | "interpreter": { 620 | "hash": "babb4baf4e80bd80b9852210fc5469c0783907e52a560ed7247caef52808358d" 621 | } 622 | } 623 | }, 624 | "nbformat": 4, 625 | "nbformat_minor": 2 626 | } 627 | --------------------------------------------------------------------------------