├── requirements.txt
├── tests
├── __init__.py
├── accumulation_sampler_test.py
├── pipeline_scorer_test.py
├── exponential_scorer_test.py
├── quasi_rejection_sampler_test.py
├── boolean_scorer_test.py
├── product_test.py
├── tuner_test.py
├── base_distribution_test.py
└── lm_distribution_test.py
├── disco
├── tuners
│ ├── loggers
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── wandb.py
│ │ ├── console.py
│ │ ├── tensorboard.py
│ │ ├── neptune.py
│ │ ├── json.py
│ │ └── mlflow.py
│ ├── losses
│ │ ├── base.py
│ │ ├── __init__.py
│ │ ├── kl.py
│ │ ├── tv.py
│ │ ├── reverse_kl.py
│ │ ├── chi_square.py
│ │ ├── js.py
│ │ ├── reverse_chi_square.py
│ │ └── f_divergence.py
│ ├── __init__.py
│ ├── dpg_tuner.py
│ ├── cdpg_tuner.py
│ ├── fcdpg_tuner.py
│ └── fdpg_tuner.py
├── __init__.py
├── metrics
│ ├── __init__.py
│ ├── base.py
│ ├── tv.py
│ ├── kl.py
│ └── js.py
├── utils
│ ├── __init__.py
│ ├── observable.py
│ ├── timer.py
│ ├── device.py
│ ├── moving_average.py
│ └── helpers.py
├── samplers
│ ├── __init__.py
│ ├── sampler.py
│ ├── accumulation_sampler.py
│ └── quasi_rejection_sampler.py
├── scorers
│ ├── __init__.py
│ ├── boolean_scorer.py
│ ├── scorer.py
│ ├── pipeline_scorer.py
│ ├── exponential_scorer.py
│ └── positive_scorer.py
└── distributions
│ ├── __init__.py
│ ├── distribution.py
│ ├── single_context_distribution.py
│ ├── context_distribution.py
│ ├── dataset_context_distribution.py
│ └── base_distribution.py
├── .gitignore
├── Makefile
├── make.bat
├── scripts
├── online_tuning_script.py
└── offline_tuning_script.py
├── tutorials
├── data
│ └── incipits.txt
├── 5.QRS_sampling.ipynb
├── 4.conditional_tuning.ipynb
├── 3.tuning_DPG.ipynb
└── 1.quick_introduction.ipynb
├── setup.py
├── LICENSE
└── README.MD
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | transformers
3 | datasets
4 | notebook
5 | numpy
6 | scipy
7 | spacy
8 | neptune-client
9 | wandb
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 |
--------------------------------------------------------------------------------
/disco/tuners/loggers/__init__.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 |
--------------------------------------------------------------------------------
/disco/__init__.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | __version__ = "1.1.1"
6 | __author__ = 'Naver Labs Europe'
7 |
--------------------------------------------------------------------------------
/disco/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | from .kl import KL
6 | from .tv import TV
7 | from .js import JS
8 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | *.py[cod]
3 | *$py.class
4 |
5 | *.pt
6 |
7 | *.egg-info/
8 | dist/
9 |
10 | .vscode
11 | .ipynb_checkpoints/
12 |
13 | .env*
14 | env*/
15 |
16 | explorations/
17 | models/
18 |
19 | .DS_Store
20 | .neptune
--------------------------------------------------------------------------------
/disco/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | import imp
6 | from .device import get_device
7 | from .helpers import get_token_first_indices
8 | from .helpers import batchify
--------------------------------------------------------------------------------
/disco/samplers/__init__.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | from .sampler import Sampler
6 | from .quasi_rejection_sampler import QuasiRejectionSampler
7 | from .accumulation_sampler import AccumulationSampler
--------------------------------------------------------------------------------
/disco/tuners/losses/base.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | from disco.utils.observable import Observable
6 |
7 | class BaseLoss:
8 | def __init__(self):
9 | self.metric_updated = Observable()
10 |
11 |
--------------------------------------------------------------------------------
/disco/tuners/__init__.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | from .tuner import Tuner
6 | from .dpg_tuner import DPGTuner
7 | from .cdpg_tuner import CDPGTuner
8 | from .fdpg_tuner import FDPGTuner
9 | from .fcdpg_tuner import FCDPGTuner
10 |
--------------------------------------------------------------------------------
/disco/tuners/losses/__init__.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | from .kl import KLLoss
6 | from .js import JSLoss
7 | from .tv import TVLoss
8 | from .reverse_kl import ReverseKLLoss
9 | from .chi_square import ChiSquaredLoss
10 | from .reverse_chi_square import ReverseChiSquaredLoss
11 |
--------------------------------------------------------------------------------
/disco/scorers/__init__.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | from .scorer import Scorer
6 | from .positive_scorer import PositiveScorer
7 | from .positive_scorer import Product
8 | from .exponential_scorer import ExponentialScorer
9 | from .boolean_scorer import BooleanScorer
10 | from .pipeline_scorer import PipelineScorer
--------------------------------------------------------------------------------
/disco/samplers/sampler.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | from abc import ABC, abstractmethod
6 |
7 | class Sampler(ABC):
8 | """
9 | Top-level abstract class for all samplers
10 | """
11 |
12 | def __init__(self, target, proposal):
13 | self.proposal = proposal
14 | self.target = target
15 |
16 | @abstractmethod
17 | def sample(self):
18 | pass
--------------------------------------------------------------------------------
/disco/distributions/__init__.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | from .distribution import Distribution
6 | from .base_distribution import BaseDistribution
7 | from .lm_distribution import LMDistribution
8 | from .single_context_distribution import SingleContextDistribution
9 | from .context_distribution import ContextDistribution
10 | from .dataset_context_distribution import DatasetContextDistribution
--------------------------------------------------------------------------------
/disco/distributions/distribution.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | from abc import abstractmethod
6 | import torch
7 |
8 | from disco.scorers.positive_scorer import PositiveScorer
9 |
10 |
11 | class Distribution(PositiveScorer):
12 | """
13 | Abstract distribution class, a core entity which can
14 | be introduced as a PositiveScorer that can produce samples.
15 | """
16 |
17 | @abstractmethod
18 | def sample(self, context):
19 | """Produces samples for the context from the distribution.
20 | """
21 | pass
--------------------------------------------------------------------------------
/disco/scorers/boolean_scorer.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | import torch
6 | import numpy as np
7 |
8 | from .positive_scorer import PositiveScorer
9 |
10 |
11 | class BooleanScorer(PositiveScorer):
12 | """
13 | Predicate-based scoring class
14 | """
15 |
16 | def __init__(self, predicate):
17 | """
18 | Parameters
19 | ----------
20 | predicate: scoring predicate
21 | predicate function to be used on each sample
22 | """
23 | super().__init__(predicate)
24 |
25 | self.predicate = self._broadcast(lambda s, c: predicate(s, c).float())
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = doc
9 | BUILDDIR = docbuild
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/disco/utils/observable.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | class Observable(object):
6 | """Textbook observer pattern"""
7 |
8 | def __init__(self):
9 | self.observers = set()
10 |
11 | def enroll(self, function):
12 | self.observers.add(function)
13 |
14 | def dispatch(self, *args, **kwargs):
15 | for f in self.observers:
16 | f(*args, **kwargs)
17 |
18 | def forward(observable1, observable2):
19 | """Forwards messages from an observable to another
20 | with the same signature"""
21 | def forwarder(*args, **kwargs):
22 | observable2.dispatch(*args, **kwargs)
23 | observable1.enroll(forwarder)
24 |
25 |
--------------------------------------------------------------------------------
/tests/accumulation_sampler_test.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | import unittest
6 |
7 | from disco.distributions import LMDistribution
8 | from disco.samplers import AccumulationSampler
9 |
10 | class AccumulationSamplerTest(unittest.TestCase):
11 |
12 | def test_clm_sample(self):
13 | prefix = "It was a cold and stormy night"
14 | proposal = LMDistribution()
15 |
16 | total_size = 2**8
17 | sampler = AccumulationSampler(proposal, total_size=total_size)
18 | samples, _ = sampler.sample(context=prefix, sampling_size=2**6)
19 |
20 | self.assertEqual(total_size, len(samples),
21 | "there should as many sampled sequences as requested from the CLM.")
22 |
23 |
24 | if __name__ == '__main__':
25 | unittest.main()
--------------------------------------------------------------------------------
/disco/tuners/losses/kl.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | import torch
6 | from .f_divergence import FDivergenceLoss
7 |
8 | class KLLoss(FDivergenceLoss):
9 | """
10 | Kullback-Leibler divergence loss for DPG
11 | """
12 | def __init__(self, use_baseline=True, baseline_window_size=1024):
13 | """
14 | Parameters
15 | ----------
16 | use_baseline: boolean
17 | use a baseline to reduce variance
18 | """
19 | super(KLLoss, self).__init__(use_baseline, baseline_window_size)
20 |
21 | def f_prime(self, log_t):
22 | """
23 | Parameters
24 | ----------
25 | log_t: 0-dim Tensor
26 | The log ratio of the policy and the normalized target distribution
27 | """
28 | return -torch.exp(-log_t)
--------------------------------------------------------------------------------
/disco/tuners/losses/tv.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | from .f_divergence import FDivergenceLoss
6 | import torch
7 |
8 | class TVLoss(FDivergenceLoss):
9 | """
10 | Total variation distance loss for DPG
11 | """
12 | def __init__(self, use_baseline=True, baseline_window_size=1024):
13 | """
14 | Parameters
15 | ----------
16 | use_baseline: boolean
17 | use a baseline to reduce variance
18 | """
19 | super(TVLoss, self).__init__(use_baseline, baseline_window_size)
20 |
21 | def f_prime(self, log_t):
22 | """
23 | Parameters
24 | ----------
25 | log_t: 0-dim Tensor
26 | The log ratio of the policy and the normalized target distribution
27 | """
28 | return torch.sign(log_t) / 2.
--------------------------------------------------------------------------------
/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=doc
11 | set BUILDDIR=build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.https://www.sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/disco/tuners/losses/reverse_kl.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | import torch
6 | from .f_divergence import FDivergenceLoss
7 |
8 | class ReverseKLLoss(FDivergenceLoss):
9 | """
10 | Kullback-Leibler divergence loss for DPG
11 | """
12 | def __init__(self, use_baseline=True, baseline_window_size=1024):
13 | """
14 | Parameters
15 | ----------
16 | use_baseline: boolean
17 | use a baseline to reduce variance
18 | """
19 | super(ReverseKLLoss, self).__init__(use_baseline, baseline_window_size)
20 |
21 | def f_prime(self, log_t):
22 | """
23 | Parameters
24 | ----------
25 | log_t: 0-dim Tensor
26 | The log ratio of the policy and the normalized target distribution
27 | """
28 | return log_t + 1
--------------------------------------------------------------------------------
/disco/tuners/losses/chi_square.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | import torch
6 | from .f_divergence import FDivergenceLoss
7 |
8 | class ChiSquaredLoss(FDivergenceLoss):
9 | """
10 | Chi squared divergence χ²(p || π) loss for DPG
11 | """
12 | def __init__(self, use_baseline=True, baseline_window_size=1024):
13 | """
14 | Parameters
15 | ----------
16 | use_baseline: boolean
17 | use a baseline to reduce variance
18 | """
19 | super(ChiSquaredLoss, self).__init__(use_baseline, baseline_window_size)
20 |
21 | def f_prime(self, log_t):
22 | """
23 | Parameters
24 | ----------
25 | log_t: 0-dim Tensor
26 | The log ratio of the policy and the normalized target distribution
27 | """
28 | return 1.0 - torch.exp(-2 * log_t)
--------------------------------------------------------------------------------
/disco/utils/timer.py:
--------------------------------------------------------------------------------
1 | import time
2 | from contextlib import contextmanager
3 |
4 | class Timer:
5 | """A context manager for timing code execution."""
6 |
7 | def __init__(self, name="Timer"):
8 | self.name = name
9 | self.elapsed = 0
10 |
11 | def __enter__(self):
12 | self.start_time = time.perf_counter()
13 | return self
14 |
15 | def __exit__(self, exc_type, exc_val, exc_tb):
16 | end_time = time.perf_counter()
17 | self.elapsed = end_time - self.start_time
18 |
19 | def __str__(self):
20 | return f"{self.name}: {self.elapsed:.4f} seconds"
21 |
22 |
23 | # Alternative using @contextmanager decorator
24 | @contextmanager
25 | def timer(name="Timer"):
26 | """A simple timer context manager using contextmanager decorator."""
27 | start_time = time.perf_counter()
28 | try:
29 | yield
30 | finally:
31 | end_time = time.perf_counter()
32 | elapsed = end_time - start_time
--------------------------------------------------------------------------------
/disco/tuners/losses/js.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | from .f_divergence import FDivergenceLoss
6 | import torch
7 |
8 | class JSLoss(FDivergenceLoss):
9 | """
10 | Jensen-Shannon divergence loss for DPG
11 | """
12 | def __init__(self, use_baseline=True, baseline_window_size=1024):
13 | """
14 | Parameters
15 | ----------
16 | use_baseline: boolean
17 | use a baseline to reduce variance
18 | """
19 | super(JSLoss, self).__init__(use_baseline, baseline_window_size)
20 |
21 | def f_prime(self, log_t):
22 | """
23 | Parameters
24 | ----------
25 | log_t: 0-dim Tensor
26 | The log ratio of the policy and the normalized target distribution
27 | """
28 | log_2 = torch.log(torch.tensor(2.0, dtype=log_t.dtype, device=log_t.device))
29 | return log_2 - torch.nn.functional.softplus(-log_t)
--------------------------------------------------------------------------------
/disco/utils/device.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | def get_device(tensor):
6 | """Get the device of a tensor
7 |
8 | Parameters
9 | ----------
10 | tensor: Tensor
11 | tensor from which to extract the device
12 |
13 | Returns
14 | -------
15 | computation device (0-n or cpu)
16 | """
17 |
18 | device = tensor.get_device()
19 | return "cpu" if 0 > device else device
20 |
21 |
22 | def to_same_device(*tensors, device=None):
23 | """Move all tensors to the same device (by default the first tensor's one)
24 |
25 | Parameters
26 | ----------
27 | tensors: Tensor
28 | tensors which we want to move to a common device
29 |
30 | Returns
31 | -------
32 | a tuple of tensors in the same order as the arguments moved to a common device
33 | """
34 | if device is None:
35 | device = get_device(tensors[0])
36 | return tuple(tensor.to(device) for tensor in tensors)
37 |
--------------------------------------------------------------------------------
/disco/tuners/losses/reverse_chi_square.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | import torch
6 | from .f_divergence import FDivergenceLoss
7 |
8 | class ReverseChiSquaredLoss(FDivergenceLoss):
9 | """
10 | Reverse Chi Squared divergence χ²(π || p) loss for DPG
11 | """
12 | def __init__(self, use_baseline=True, baseline_window_size=1024):
13 | """
14 | Parameters
15 | ----------
16 | use_baseline: boolean
17 | use a baseline to reduce variance
18 | """
19 | super(ReverseChiSquaredLoss, self).__init__(use_baseline, baseline_window_size)
20 |
21 | def f_prime(self, log_t):
22 | """
23 | Computes the f' term for the reverse (Neyman) chi-square divergence.
24 |
25 | Parameters
26 | ----------
27 | log_t: 0-dim Tensor
28 | The log ratio of the policy and the normalized target distribution, log(p/q).
29 | """
30 | return (torch.exp(log_t) - 1) * 2
--------------------------------------------------------------------------------
/disco/tuners/loggers/base.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | class BaseTunerObserver(object):
6 | def __init__(self, tuner):
7 | tuner.parameters_updated.enroll(self.on_parameters_updated)
8 | tuner.metric_updated.enroll(self.on_metric_updated)
9 | tuner.step_idx_updated.enroll(self.on_step_idx_updated)
10 | tuner.ministep_idx_updated.enroll(self.on_ministep_idx_updated)
11 | tuner.eval_samples_updated.enroll(self.on_eval_samples_updated)
12 |
13 | def __enter__(self):
14 | return self
15 |
16 | def __exit__(self, type, value, traceback):
17 | pass
18 |
19 | def on_parameters_updated(self, params):
20 | pass
21 |
22 | def on_metric_updated(self, name, value):
23 | pass
24 |
25 | def on_step_idx_updated(self, s):
26 | pass
27 |
28 | def on_ministep_idx_updated(self, s):
29 | pass
30 |
31 | def on_eval_samples_updated(self, context, samples, proposal_log_scores, model_log_scores, target_log_scores):
32 | pass
33 |
--------------------------------------------------------------------------------
/disco/tuners/dpg_tuner.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | from .cdpg_tuner import CDPGTuner
6 | from disco.distributions.single_context_distribution import SingleContextDistribution
7 |
8 | class DPGTuner(CDPGTuner):
9 | """
10 | DPG tuning class,
11 | a specific case of CDPG with a single, fixed, context.
12 |
13 | The algorithm has been introduced in
14 | "Distributional Reinforcement Learning for Energy-Based Sequential Models"
15 | Tetiana Parshakova, Jean-Marc Andreoli, Marc Dymetman
16 | https://arxiv.org/abs/1912.08517
17 | """
18 |
19 | def __init__(self, *args, context="", **kwargs):
20 | """
21 | Parameters
22 | ----------
23 | context: text
24 | a single textual sequence to contextualize the sampling from the proposal
25 | """
26 |
27 | super(DPGTuner, self).__init__(
28 | *args,
29 | context_distribution=SingleContextDistribution(context),
30 | n_contexts_per_step=1,
31 | **kwargs
32 | )
33 |
--------------------------------------------------------------------------------
/disco/metrics/base.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | import torch
6 |
7 | from disco.utils.device import get_device
8 |
9 | class BaseDivergence:
10 | """
11 | Kullback-Leibler divergence class.
12 | """
13 |
14 | @classmethod
15 | def divergence(cls, m1_log_scores, m2_log_scores, z, proposal_log_scores=None):
16 | """
17 | Computes an IS of the KL divergence between 2 distributions
18 |
19 | Parameters
20 | ----------
21 | m1_log_scores: floats
22 | log-scores for samples according to network 1
23 | m2_log_scores: floats
24 | log-scores for samples according to network 2
25 | z: float
26 | partition function of network 1
27 | proposal_log_scores: floats
28 | log-scores for samples according to proposal (by default m2_log_scores)
29 |
30 | Returns
31 | -------
32 | divergence between m1 and m2
33 | """
34 |
35 | return torch.mean(cls.pointwise_estimates(
36 | m1_log_scores, m2_log_scores, z, proposal_log_scores=proposal_log_scores))
37 |
--------------------------------------------------------------------------------
/tests/pipeline_scorer_test.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | import unittest
6 |
7 | import numpy as np
8 |
9 | from disco.scorers import PipelineScorer
10 |
11 | class PipelineScorerTest(unittest.TestCase):
12 |
13 | def test_detection(self):
14 | params = {
15 | "task": "sentiment-analysis",
16 | "model": "siebert/sentiment-roberta-large-english"
17 | }
18 | pf = PipelineScorer('POSITIVE', params)
19 |
20 | texts = [
21 | "This is so interesting: I love it!",
22 | "How can you be so depressing? I don't want to go there anymore."
23 | ]
24 |
25 | from disco.distributions.lm_distribution import TextSample
26 | samples = [TextSample(list(), t) for t in texts] # fake samples without the tokenizations
27 |
28 | log_scores = pf.log_score(samples, None)
29 | expected_sentiment = [1, -1]
30 | self.assertTrue(all(e == np.sign(np.exp(s)-0.5) for e, s in zip(expected_sentiment, log_scores)),
31 | "a sentiment should be correctly classified.")
32 |
33 |
34 | if __name__ == '__main__':
35 | unittest.main()
36 |
--------------------------------------------------------------------------------
/disco/tuners/cdpg_tuner.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | from .tuner import Tuner
6 | from disco.distributions.single_context_distribution import SingleContextDistribution
7 | from disco.tuners.losses import *
8 |
9 |
10 | class CDPGTuner(Tuner):
11 | """Contextual DPG tuning class,
12 | relying on a ContextDistribution and KLLoss().
13 |
14 | The algorithm has been introduced in
15 | "Controlling Conditional Language Models without Catastrophic Forgetting"
16 | Tomasz Korbak, Hady Elsahar, Germán Kruszewski and Marc Dymetman.
17 | https://proceedings.mlr.press/v162/korbak22a/korbak22a.pdf
18 | """
19 |
20 | def __init__(self, *args, context_distribution=SingleContextDistribution(),
21 | loss=KLLoss(), **kwargs):
22 | """
23 | Parameters
24 | ----------
25 | context_distribution: distribution
26 | a distribution to contextualize the sampling from the proposal
27 | """
28 |
29 | super(CDPGTuner, self).__init__(
30 | *args,
31 | context_distribution=context_distribution,
32 | loss=loss,
33 | **kwargs
34 | )
35 |
--------------------------------------------------------------------------------
/disco/tuners/fcdpg_tuner.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | from .tuner import Tuner
6 | from disco.distributions.single_context_distribution import SingleContextDistribution
7 | from disco.tuners.losses import *
8 |
9 |
10 | class FCDPGTuner(Tuner):
11 | """Contextual f-DPG tuning class. The algorithm was introduced in
12 |
13 | "Aligning Language Models with Preferences through f-divergence Minimization."
14 | Dongyoung Go, Tomasz Korbak, Germán Kruszewski, Jos Rozen, Nahyeon Ryu, Marc Dymetman.
15 | https://arxiv.org/abs/2302.08215
16 | """
17 |
18 | def __init__(self, *args, context_distribution=SingleContextDistribution(),
19 | loss=JSLoss(), **kwargs):
20 | """
21 | Parameters
22 | ----------
23 | context_distribution: distribution
24 | a distribution to contextualize the sampling from the proposal
25 | loss: functor object
26 | used to compute of the loss at each step
27 | """
28 |
29 | super(FCDPGTuner, self).__init__(
30 | *args,
31 | context_distribution=context_distribution,
32 | loss=loss,
33 | **kwargs
34 | )
35 |
--------------------------------------------------------------------------------
/disco/tuners/fdpg_tuner.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | from .tuner import Tuner
6 | from disco.distributions.single_context_distribution import SingleContextDistribution
7 | from disco.tuners.losses import *
8 |
9 |
10 | class FDPGTuner(Tuner):
11 | """Unconditional f-DPG tuning class. The algorithm was introduced in
12 |
13 | "Aligning Language Models with Preferences through f-divergence Minimization."
14 | Dongyoung Go, Tomasz Korbak, Germán Kruszewski, Jos Rozen, Nahyeon Ryu, Marc Dymetman.
15 | https://arxiv.org/abs/2302.08215
16 | """
17 |
18 | def __init__(self, *args, context='',
19 | loss=JSLoss(), **kwargs):
20 | """
21 | Parameters
22 | ----------
23 | context: text
24 | a single textual sequence to contextualize the sampling from the proposal
25 | loss: functor object
26 | used to compute of the loss at each step
27 | """
28 |
29 | super(FDPGTuner, self).__init__(
30 | *args,
31 | context_distribution=SingleContextDistribution(context),
32 | n_contexts_per_step=1,
33 | loss=loss,
34 | **kwargs
35 | )
36 |
--------------------------------------------------------------------------------
/disco/scorers/scorer.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | import torch
6 | import numpy as np
7 |
8 |
9 | class Scorer():
10 | """
11 | Generic scorer
12 | """
13 |
14 | def __init__(self, scoring_function):
15 | """
16 | Parameters
17 | ----------
18 | scoring_function: scoring function
19 | function to be used on each sample
20 | """
21 |
22 | self.scoring_function = self._broadcast(scoring_function)
23 |
24 | def _broadcast(self, function):
25 | def broadcasted_function(xs, context):
26 | return torch.tensor(
27 | np.array([function(x, context) for x in xs])
28 | )
29 | return broadcasted_function
30 |
31 | def score(self, samples, context):
32 | """Relies on the instance's scoring function to compute
33 | scores for the samples given the context
34 |
35 | Parameters
36 | ----------
37 | samples : list()
38 | the samples to score, as a list
39 | context: text
40 | context that the samples relate to
41 |
42 | Returns
43 | -------
44 | tensor of scores for the samples"""
45 |
46 | return self.scoring_function(samples, context)
--------------------------------------------------------------------------------
/disco/scorers/pipeline_scorer.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | import torch
6 | from transformers import pipeline
7 |
8 | from .positive_scorer import PositiveScorer
9 |
10 |
11 | class PipelineScorer(PositiveScorer):
12 | """
13 | Feature class relying on the pipelines from Huggingface's transformers
14 | """
15 |
16 | def __init__(self, label, params, temperature=1.0):
17 | """initializes a PipelineFeature's instance
18 |
19 | Parameters
20 | ----------
21 | label: string
22 | expected positive label from the pipeline
23 | """
24 |
25 | self.label = label
26 | self.pipeline = pipeline(**params)
27 | self.temperature = temperature
28 |
29 | def log_score(self, samples, _):
30 | """computes the log-scores of the samples
31 | from the label returned by the pipeline
32 |
33 | Parameters
34 | ----------
35 | samples : list(Sample)
36 | list of samples to log-score
37 |
38 | Returns
39 | -------
40 | tensor of log-probabilities"""
41 |
42 | return torch.log(
43 | torch.tensor(
44 | [[r_i["score"] for r_i in r if self.label == r_i["label"]][0]
45 | for r in self.pipeline([s.text for s in samples], return_all_scores=True)]
46 | ).float()
47 | ) / self.temperature
48 |
--------------------------------------------------------------------------------
/disco/tuners/loggers/wandb.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | import wandb
6 | import os
7 | from .base import BaseTunerObserver
8 |
9 | class WandBLogger(BaseTunerObserver):
10 | """
11 | Reports DPGTuner statistics to Weights & Biases
12 | """
13 |
14 | def __init__(self, tuner, project, name=None):
15 | """Constructor of a WandBLogger object
16 |
17 | Parameters
18 | ----------
19 | tuner: DPGTuner
20 | The tuner object whose statistics we want to report
21 | project: string
22 | The W&B project to which we want to report the statistics
23 | """
24 | super(WandBLogger, self).__init__(tuner)
25 | self.run = wandb.init(project=project,
26 | name=name)
27 |
28 | def __setitem__(self, k, v):
29 | """
30 | Report arbitrary parameter/value combinations
31 | """
32 | wandb.log({k: v})
33 |
34 | def __exit__(self, *exc):
35 | self.run.finish()
36 |
37 | def on_parameters_updated(self, params):
38 | wandb.config.update(params)
39 |
40 | def on_metric_updated(self, name, value):
41 | wandb.log({name: value})
42 |
43 | def on_eval_samples_updated(self, context, samples, proposal_log_scores, model_log_scores, target_log_scores):
44 | wandb.log({"samples": [s.text for s in samples[:10]]})
45 |
46 | def on_step_idx_updated(self, s):
47 | wandb.log({"steps": s})
--------------------------------------------------------------------------------
/disco/metrics/tv.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | import torch
6 |
7 | from disco.utils.device import get_device
8 | from .base import BaseDivergence
9 |
10 |
11 | class TV(BaseDivergence):
12 |
13 | @classmethod
14 | def pointwise_estimates(cls, m1_log_scores, m2_log_scores, z, proposal_log_scores=None):
15 | """
16 | computes the TVD between 2 distributions
17 |
18 | Parameters
19 | ----------
20 | m1_log_scores: floats
21 | log-scores for samples according to network 1
22 | m2_log_scores: floats
23 | log-scores for samples according to network 2
24 | z: float
25 | partition function of network 1
26 | proposal_log_scores: floats
27 | log-scores for samples according to proposal (by default m2_log_scores)
28 |
29 | Returns
30 | -------
31 | divergence between m1 and m2
32 | """
33 |
34 | device = get_device(m1_log_scores)
35 |
36 | if isinstance(z, float):
37 | z = torch.tensor(z, device=device, dtype=m1_log_scores.dtype)
38 |
39 | m2_log_scores = m2_log_scores.to(device)
40 |
41 | if proposal_log_scores is None:
42 | proposal_log_scores = m2_log_scores
43 | else:
44 | proposal_log_scores = proposal_log_scores.to(device)
45 |
46 | normalized_m1_log_scores = m1_log_scores - torch.log(z)
47 |
48 | return 1/2 * (torch.abs(torch.exp(m2_log_scores - proposal_log_scores) -
49 | torch.exp(normalized_m1_log_scores - proposal_log_scores)))
50 |
--------------------------------------------------------------------------------
/disco/tuners/loggers/console.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | from datetime import datetime
6 | from .base import BaseTunerObserver
7 | from tqdm.autonotebook import tqdm
8 |
9 |
10 | class ConsoleLogger(BaseTunerObserver):
11 | def __init__(self, tuner):
12 | super(ConsoleLogger, self).__init__(tuner)
13 | tuner.proposal_updated.enroll(self.on_proposal_updated)
14 | self.n = tuner.params["n_gradient_steps"]
15 | self.step = None
16 |
17 | def __exit__(self, *exc):
18 | stamp = datetime.now().strftime("%H:%M:%S (%Y/%m/%d)")
19 | print (f"finished at {stamp}")
20 |
21 | def on_parameters_updated(self, params):
22 | for k, v in params.items():
23 | print (f"{k}: ", v)
24 |
25 | def on_eval_samples_updated(self, context, samples, proposal_log_scores, model_log_scores, target_log_scores):
26 | tqdm.write(f"Context: {context}")
27 | tqdm.write("Samples:")
28 | n = 3
29 | tqdm.write("\n".join(s.text + f"\ntarget: {t.item()} - proposal: {p.item()} - model: {m.item()}" for (s, p, m, t) in zip(samples[:n], proposal_log_scores[:n], model_log_scores[:n], target_log_scores[:n])))
30 |
31 | def on_step_idx_updated(self, s):
32 | self.step = s
33 | tqdm.write(f"Step {s}/{self.n}")
34 |
35 | def on_proposal_updated(self, proposal, divergence_metric, divergence_target_new, divergence_target_old):
36 | tqdm.write(f"updating proposal according to {divergence_metric} divergence at step {self.step}: "
37 | f"{divergence_target_new} < {divergence_target_old}")
--------------------------------------------------------------------------------
/disco/distributions/single_context_distribution.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | import torch
6 |
7 | from .distribution import Distribution
8 |
9 | class SingleContextDistribution(Distribution):
10 | """
11 | Single context distribution class, useful to sample the
12 | same context that is to fall back to a fixed-context case.
13 | """
14 |
15 | def __init__(self, context=''):
16 | """
17 | Parameters
18 | ----------
19 | context: string
20 | unique context to return when sampling
21 | """
22 |
23 | self.context = context
24 |
25 | def log_score(self, contexts):
26 | """Computes log-probabilities of the contexts
27 | to match the instance's context
28 |
29 | Parameters
30 | ----------
31 | contexts: list(str)
32 | list of contexts to (log-)score
33 |
34 | Returns
35 | -------
36 | tensor of log-probabilities
37 | """
38 |
39 | return torch.tensor([0 if self.context == context else -float("inf") for context in contexts])
40 |
41 | def sample(self, sampling_size=32):
42 | """Samples multiple copies of the instance's context
43 |
44 | Parameters
45 | ----------
46 | sampling_size: int
47 | number of contexts to sample
48 |
49 | Returns
50 | -------
51 | tuple of (list of texts, tensor of log-probabilities)
52 | """
53 |
54 | return (
55 | [self.context] * sampling_size,
56 | [0] * sampling_size
57 | )
--------------------------------------------------------------------------------
/scripts/online_tuning_script.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | import os, re
6 | from datetime import datetime
7 | import torch
8 |
9 | from disco.distributions import LMDistribution
10 | from disco.scorers import BooleanScorer
11 | from disco.tuners import DPGTuner
12 | from disco.tuners.loggers.console import ConsoleLogger
13 |
14 |
15 | word = "amazing"
16 | incipit = "It was a cold and stormy night"
17 | path = "models"
18 | gpt = "gpt2" # or a larger gpt2-medium or another CLM from Transformers
19 | dev0, dev1 = "cpu", "cpu" # "0", "1" to use GPUs
20 | n_gradient_steps = 10 # 1000 or more for actual tuning
21 | divergence_evaluation_interval = 2**2 # 2**4 for actual tuning?
22 |
23 | base = LMDistribution(model=gpt, device=dev0)
24 | has_word = lambda s, c: bool(re.search(f"\\b{word}\\b", s.text))
25 | word_scorer = BooleanScorer(has_word)
26 | target = base * word_scorer
27 |
28 | model = LMDistribution(model=gpt, freeze=False, device=dev1)
29 |
30 | tuner = DPGTuner(model, target,
31 | context = incipit,
32 | features = [(word, word_scorer)],
33 | n_gradient_steps=n_gradient_steps,
34 | n_samples_per_context=2**8,
35 | sampling_size=2**5,
36 | scoring_size=2**5,
37 | divergence_evaluation_interval=divergence_evaluation_interval,
38 | n_kl_samples=2**10)
39 | console_logger = ConsoleLogger(tuner)
40 | tuner.tune()
41 |
42 | samples, _ = model.sample(context=incipit)
43 | print("rate after tuning is:")
44 | print(sum([has_word(s, _) for s in samples]) / len(samples))
45 |
46 | stamp = datetime.now().strftime("%Y_%m_%d-%H:%M")
47 | torch.save(model, os.path.join(path, f"{word}.{stamp}.pt"))
48 |
--------------------------------------------------------------------------------
/scripts/offline_tuning_script.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | import os, re
6 | from datetime import datetime
7 | import torch
8 |
9 | from disco.distributions import LMDistribution
10 | from disco.scorers import BooleanScorer
11 | from disco.tuners import DPGTuner
12 | from disco.tuners.loggers.console import ConsoleLogger
13 |
14 |
15 | word = "amazing"
16 | incipit = "It was a cold and stormy night"
17 | path = "models"
18 | gpt = "gpt2" # or a larger gpt2-medium or another CLM from Transformers
19 | dev0, dev1 = "cpu", "cpu" # "0", "1" to use GPUs
20 | n_gradient_steps = 10 # 1000 or more for actual tuning
21 | divergence_evaluation_interval = 2**2 # 2**4 for actual tuning?
22 |
23 | base = LMDistribution(model=gpt, device=dev0)
24 | has_word = lambda s, c: bool(re.search(f"\\b{word}\\b", s.text))
25 | word_scorer = BooleanScorer(has_word)
26 | target = base * word_scorer
27 |
28 | proposal = LMDistribution(model=gpt, device=dev0)
29 |
30 | model = LMDistribution(model=gpt, freeze=False, device=dev1)
31 |
32 | tuner = DPGTuner(model, target, proposal,
33 | context = incipit,
34 | features = [(word, word_scorer)],
35 | n_gradient_steps=n_gradient_steps,
36 | n_samples_per_context=2**8,
37 | sampling_size=2**5,
38 | scoring_size=2**5,
39 | divergence_evaluation_interval=divergence_evaluation_interval,
40 | n_kl_samples=2**10)
41 | console_logger = ConsoleLogger(tuner)
42 | tuner.tune()
43 |
44 | samples, _ = model.sample(context=incipit)
45 | print("rate after tuning is:")
46 | print(sum([has_word(s, _) for s in samples]) / len(samples))
47 |
48 | stamp = datetime.now().strftime("%Y_%m_%d-%H:%M")
49 | torch.save(model, os.path.join(path, f"{word}.{stamp}.pt"))
50 |
--------------------------------------------------------------------------------
/disco/metrics/kl.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | import torch
6 |
7 | from disco.utils.device import get_device
8 | from .base import BaseDivergence
9 |
10 |
11 | class KL(BaseDivergence):
12 | """
13 | Kullback-Leibler divergence class
14 | """
15 |
16 | @classmethod
17 | def pointwise_estimates(cls, m1_log_scores, m2_log_scores, z, proposal_log_scores=None):
18 | """
19 | computes the KL divergence between 2 distributions
20 |
21 | Parameters
22 | ----------
23 | m1_log_scores: floats
24 | log-scores for samples according to network 1
25 | m2_log_scores: floats
26 | log-scores for samples according to network 2
27 | z: float
28 | partition function of network 1
29 | proposal_log_scores: floats
30 | log-scores for samples according to proposal (by default m2_log_scores)
31 |
32 | Returns
33 | -------
34 | divergence between m1 and m2
35 | """
36 |
37 | device = get_device(m1_log_scores)
38 |
39 | if isinstance(z, float):
40 | z = torch.tensor(z, device=device, dtype=m1_log_scores.dtype)
41 |
42 | m2_log_scores = m2_log_scores.to(device)
43 |
44 | if proposal_log_scores is None:
45 | proposal_log_scores = m2_log_scores
46 | else:
47 | proposal_log_scores = proposal_log_scores.to(device)
48 |
49 | importance_ratio = torch.exp(m1_log_scores - proposal_log_scores)
50 |
51 | unnormalized_pointwise_estimates = importance_ratio * (m1_log_scores - m2_log_scores)
52 | unnormalized_pointwise_estimates[
53 | torch.isnan(unnormalized_pointwise_estimates)] = 0
54 |
55 | return -1 * torch.log(z) + (1 / z) * unnormalized_pointwise_estimates
56 |
--------------------------------------------------------------------------------
/disco/tuners/loggers/tensorboard.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | import os
6 | from torch.utils.tensorboard import SummaryWriter
7 | from .base import BaseTunerObserver
8 | from collections import defaultdict
9 |
10 | class TensorBoardLogger(BaseTunerObserver):
11 | """
12 | Reports DPGTuner statistics to Neptune
13 | """
14 |
15 | def __init__(self, tuner, **kwargs):
16 | """Constructor of a NeptuneLogger object
17 |
18 | Parameters
19 | ----------
20 | tuner: DPGTuner
21 | The tuner object whose statistics we want to report
22 | """
23 | super(TensorBoardLogger, self).__init__(tuner)
24 | self.writer = SummaryWriter(**kwargs)
25 | self.x_counter = defaultdict(int)
26 |
27 | def __setitem__(self, k, v):
28 | """
29 | Report arbitrary parameter/value combinations
30 | """
31 | if isinstance(v, str):
32 | self.writer.add_text(k, v)
33 | else:
34 | self.writer.add_scalar(k, v)
35 |
36 | def __exit__(self, *exc):
37 | self.writer.close()
38 |
39 | def on_parameters_updated(self, params):
40 | self.writer.add_hparams(params, {})
41 |
42 | def on_metric_updated(self, name, value):
43 | x = self.x_counter[name]
44 | self.x_counter[name] += 1
45 | self.writer.add_scalar(name, value, x)
46 |
47 | def on_eval_samples_updated(self, context, samples, proposal_log_scores, model_log_scores, target_log_scores):
48 | x = self.x_counter['samples']
49 | self.x_counter['samples'] += 1
50 | for s in samples[:10]:
51 | self.writer.add_text('samples', context + s.text, x)
52 |
53 | def on_step_idx_updated(self, s):
54 | pass
55 |
56 | def on_ministep_idx_updated(self, s):
57 | pass
58 | # self.run["ministeps"].log(s)
59 |
--------------------------------------------------------------------------------
/disco/metrics/js.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | import torch
6 | import logging
7 | from disco.utils.device import get_device
8 | from .base import BaseDivergence
9 | from .kl import KL
10 |
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 | class JS(BaseDivergence):
15 | """
16 | Jensen-Shannon divergence class.
17 | """
18 |
19 | @classmethod
20 | def pointwise_estimates(cls, m1_log_scores, m2_log_scores, z, proposal_log_scores=None):
21 | """
22 | Computes the KL divergence between 2 distributions
23 |
24 | Parameters
25 | ----------
26 | m1_log_scores: floats
27 | log-scores for samples according to network 1
28 | m2_log_scores: floats
29 | log-scores for samples according to network 2
30 | z: float
31 | partition function of network 1
32 | proposal_log_scores: floats
33 | log-scores for samples according to proposal (by default m2_log_scores)
34 |
35 | Returns
36 | -------
37 | divergence between m1 and m2
38 | """
39 |
40 | device = get_device(m1_log_scores)
41 |
42 | if isinstance(z, float):
43 | z = torch.tensor(z, device=device, dtype=m1_log_scores.dtype)
44 |
45 | m2_log_scores = m2_log_scores.to(device)
46 | normalized_m1_log_scores = m1_log_scores - torch.log(z)
47 |
48 | max_log_scores = torch.max(normalized_m1_log_scores, m2_log_scores)
49 |
50 | m_log_scores = max_log_scores + torch.log(((normalized_m1_log_scores - max_log_scores).double().exp() + (m2_log_scores - max_log_scores).double().exp()) / 2).float()
51 |
52 | divergence = KL.pointwise_estimates(normalized_m1_log_scores, m_log_scores, torch.as_tensor(1), proposal_log_scores) / 2 + \
53 | KL.pointwise_estimates(m2_log_scores, m_log_scores, torch.as_tensor(1), proposal_log_scores) / 2
54 | return divergence
55 |
--------------------------------------------------------------------------------
/disco/tuners/loggers/neptune.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | import neptune
6 | import os
7 | from .base import BaseTunerObserver
8 |
9 | def get_proxies():
10 | proxies = {}
11 | http_proxy = os.getenv("http_proxy") or os.getenv("HTTP_PROXY")
12 | if http_proxy:
13 | proxies["http"] = http_proxy
14 | https_proxy = os.getenv("https_proxy") or os.getenv("HTTPS_PROXY")
15 | if https_proxy:
16 | proxies["https"] = https_proxy
17 | return proxies
18 |
19 | class NeptuneLogger(BaseTunerObserver):
20 | """
21 | Reports DPGTuner statistics to Neptune
22 | """
23 |
24 | def __init__(self, tuner, project, name=None, api_token=None, **kwargs):
25 | """Constructor of a NeptuneLogger object
26 |
27 | Parameters
28 | ----------
29 | tuner: DPGTuner
30 | The tuner object whose statistics we want to report
31 | project: string
32 | The Neptune project to which we want to report the statistics
33 | api_token: string
34 | The Neptune API token
35 | """
36 | super(NeptuneLogger, self).__init__(tuner)
37 | if not 'proxies' in kwargs:
38 | kwargs['proxies'] = get_proxies()
39 | self.run = neptune.init_run(project=project,
40 | name=name,
41 | api_token=api_token,
42 | **kwargs)
43 |
44 | def __setitem__(self, k, v):
45 | """
46 | Report arbitrary parameter/value combinations
47 | """
48 | self.run[k] = v
49 |
50 | def __exit__(self, *exc):
51 | self.run.stop()
52 |
53 | def on_parameters_updated(self, params):
54 | self.run["parameters"] = params
55 |
56 | def on_metric_updated(self, name, value):
57 | self.run[name].log(value)
58 |
59 | def on_eval_samples_updated(self, context, samples, proposal_log_scores, model_log_scores, target_log_scores):
60 | self.run["samples"].log([s.text for s in samples[:10]])
61 |
62 | def on_step_idx_updated(self, s):
63 | self.run["steps"].log(s)
--------------------------------------------------------------------------------
/disco/scorers/exponential_scorer.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | import torch
6 | import numpy as np
7 |
8 | from .positive_scorer import PositiveScorer
9 | from disco.utils.device import get_device
10 |
11 |
12 | class ExponentialScorer(PositiveScorer):
13 | """Exponential scorer to add distributional constraints
14 | when building an EBM.
15 | """
16 |
17 | def __init__(self, features, coefficients):
18 | """
19 | Parameters
20 | ----------
21 | features: list(Scorer)
22 | scoring features
23 | coefficients: list(float)
24 | features' coefficients
25 | """
26 |
27 | if not len(features) == len(coefficients):
28 | raise ValueError("there should be as many as many coefficients as there are features.")
29 |
30 | self.features = features
31 | if type(coefficients) in [list, np.ndarray]:
32 | self.coefficients = torch.tensor(coefficients)
33 | else:
34 | self.coefficients = coefficients
35 |
36 | if torch.Tensor != type(self.coefficients):
37 | raise TypeError("coefficients should come in a tensor, or a tensorable structure.")
38 |
39 | def log_score(self, samples, context):
40 | """Log-scores the samples given the context
41 | using the instance's features and their coefficients
42 |
43 | Parameters
44 | ----------
45 | samples : list(str)
46 | list of samples to log-score
47 | context: text
48 | context used for the samples
49 |
50 | Returns
51 | -------
52 | tensor of log-scores"""
53 |
54 | device = get_device(self.coefficients)
55 |
56 | feature_log_scores = torch.stack(
57 | ([feature.score(samples, context).to(device) for feature in self.features])
58 | ) # [n_features, n_samples]
59 | weighted_log_scores = self.coefficients.repeat(len(samples), 1) * feature_log_scores.t()
60 |
61 | return weighted_log_scores.sum(dim=1)
62 |
63 | def __str__(self):
64 | return f"ExponentialScorer({self.features}, {self.coefficients})"
65 |
--------------------------------------------------------------------------------
/disco/distributions/context_distribution.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | import torch
6 | import numpy as np
7 | from random import sample
8 |
9 | from .distribution import Distribution
10 |
11 | class ContextDistribution(Distribution):
12 | """
13 | Context distribution class, fetching the contexts from a text file.
14 | It can be used as a template for other context distributions.
15 | """
16 |
17 | def __init__(self, path="contexts.txt"):
18 | """
19 | Parameters
20 | ----------
21 | path: string
22 | path to context file
23 | """
24 |
25 | try:
26 | with open(path) as f:
27 | self.contexts = f.readlines()
28 | except IOError:
29 | self.contexts = list()
30 |
31 | assert self.contexts, "there's an issue with the context file provided."
32 |
33 | def log_score(self, contexts):
34 | """Computes log-probabilities of the contexts
35 |
36 | Parameters
37 | ----------
38 | contexts: list(str)
39 | list of contexts to (log-)score
40 |
41 | Returns
42 | -------
43 | tensor of logprobabilities
44 | """
45 |
46 | assert contexts, "there needs to be contexts to (log-)score."
47 |
48 | n_contexts = len(contexts)
49 | return torch.tensor(
50 | [np.log(self.contexts.count(context) / n_contexts) if context in self.contexts\
51 | else -float("inf")\
52 | for context in contexts
53 | ]
54 | )
55 |
56 | def sample(self, sampling_size=32):
57 | """Samples random elements from the list of contexts
58 |
59 | Parameters
60 | ----------
61 | sampling_size: int
62 | number of contexts to sample
63 |
64 | Returns
65 | -------
66 | tuple of (list of texts, tensor of logprobs)
67 | """
68 |
69 | assert len(self.contexts) >= sampling_size, "the contexts does not have enough elements to sample."
70 |
71 | contexts = sample(self.contexts, sampling_size)
72 | return (contexts, self.log_score(contexts))
--------------------------------------------------------------------------------
/disco/distributions/dataset_context_distribution.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | import torch
6 | import numpy as np
7 | from random import sample
8 | from datasets import load_dataset
9 |
10 | from .distribution import Distribution
11 |
12 | class DatasetContextDistribution(Distribution):
13 | """
14 | Context distribution class, fetching the contexts from a text file.
15 | It can be used as a template for other context distributions.
16 | """
17 | def __init__(self, dataset="", subset="", split="train", key="text", prefix=""):
18 | """
19 | Parameters
20 | ----------
21 | dataset: string
22 | name of dataset in Hugging Face's Datasets
23 | subset: string
24 | reference of subset in dataset
25 | split: string
26 | reference of split in dataset/subset
27 | key: string
28 | key to use on row to pick the relevant part
29 | prefix: text
30 | text prepended to each context
31 | """
32 |
33 | try:
34 | self.dataset = load_dataset(dataset, subset, split=split)
35 | except IOError:
36 | self.dataset = list()
37 |
38 | assert self.dataset, "there's an issue with the parameters of the dataset."
39 |
40 | self.key = key
41 | self.prefix = prefix
42 |
43 | def log_score(self, contexts):
44 | """Computes plausible log-probabilities of the contexts.
45 | Note that there's no check that the context are part of the dataset,
46 | hence the plausible qualifier.
47 |
48 | Parameters
49 | ----------
50 | contexts: list(str)
51 | list of contexts to (log-)score
52 |
53 | Returns
54 | -------
55 | tensor of logprobabilities
56 | """
57 |
58 | assert contexts, "there needs to be contexts to (log-)score."
59 |
60 | return torch.log(torch.full((len(contexts), ), 1 / self.dataset.num_rows))
61 |
62 | def sample(self, sampling_size=32):
63 | """Samples random elements from the list of contexts
64 |
65 | Parameters
66 | ----------
67 | sampling_size: int
68 | number of contexts to sample
69 |
70 | Returns
71 | -------
72 | tuple of (list of texts, tensor of logprobs)
73 | """
74 |
75 | assert self.dataset.num_rows >= sampling_size, "the dataset does not have enough elements to sample."
76 |
77 | contexts = [self.prefix + c[self.key]\
78 | for c in self.dataset.select(sample(range(self.dataset.num_rows), sampling_size))]
79 | return (contexts, self.log_score(contexts))
--------------------------------------------------------------------------------
/tests/exponential_scorer_test.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | import unittest
6 |
7 | import numpy as np
8 |
9 | from disco.scorers import ExponentialScorer
10 | from disco.scorers import BooleanScorer
11 |
12 |
13 | rain = lambda s, c: "rain" in s.text
14 | city = lambda s, c: "city" in s.text
15 |
16 | class ExponentialScorerTest(unittest.TestCase):
17 |
18 | def test_features_and_coefficients_match(self):
19 | scorer = ExponentialScorer([BooleanScorer(rain), BooleanScorer(city)], [0.5, 0.25])
20 | self.assertTrue(hasattr(scorer, "features"),
21 | "the exponential scorer should have a features attribute.")
22 | self.assertTrue(hasattr(scorer, "coefficients"),
23 | "the exponential scorer should have a coefficients attribute.")
24 | self.assertEqual(len(scorer.features), len(scorer.coefficients),
25 | "the length of both features and coefficients list should be equal.")
26 |
27 | def test_features_and_coefficients_mismatch(self):
28 | with self.assertRaises(ValueError) as cm:
29 | ExponentialScorer(
30 | [BooleanScorer(rain)],
31 | [0.5, 0.25]
32 | )
33 |
34 | def test_coefficients_as_tensor_like(self):
35 | with self.assertRaises(TypeError) as cm:
36 | ExponentialScorer(
37 | [BooleanScorer(rain)],
38 | 0.5
39 | )
40 | with self.assertRaises(TypeError) as cm:
41 | ExponentialScorer(
42 | [BooleanScorer(rain), BooleanScorer(city)],
43 | {"rain": 0.5, "city": 0.25}
44 | )
45 |
46 | def test_score(self):
47 | scorer = ExponentialScorer([BooleanScorer(rain), BooleanScorer(city)], [0.5, 0.25])
48 | texts = [
49 | "I'm singing in the rain.",
50 | "What is the city but the people?",
51 | "The rain that fell on the city runs down the dark gutters and empties into the sea without even soaking the ground"
52 | "Every drop in the ocean counts."
53 | ]
54 |
55 | from disco.distributions.lm_distribution import TextSample
56 | samples = [TextSample(list(), t) for t in texts] # fake samples without the tokenizations
57 |
58 | scores = scorer.score(samples, None)
59 | self.assertEqual(len(samples), len(scores),
60 | "there should be a score for each sample.")
61 | log_scores = scorer.log_score(samples, None)
62 | self.assertEqual(len(samples), len(log_scores),
63 | "there should be a (log-)score for each sample.")
64 | for e, s in zip([0.5, 0.25, 0.75, 0], log_scores):
65 | self.assertEqual(e, s,
66 | "the exponential scorer should (log-)score correctly."
67 | )
68 |
69 |
70 | if __name__ == '__main__':
71 | unittest.main()
--------------------------------------------------------------------------------
/tests/quasi_rejection_sampler_test.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | import unittest
6 |
7 | from disco.distributions import LMDistribution
8 | from disco.scorers import BooleanScorer
9 | from disco.samplers import QuasiRejectionSampler
10 | from disco.samplers.quasi_rejection_sampler import QuasiRejectionSamplerEstimator
11 | from disco.metrics import KL, TV
12 | from scipy import stats
13 | import torch
14 |
15 | class QuasiRejectionSamplerTest(unittest.TestCase):
16 |
17 | def test_sample(self):
18 | prefix = "It was a cold and stormy"
19 | word = "night"
20 | target = LMDistribution() * BooleanScorer(lambda s, c: word in s.text)
21 | proposal = LMDistribution()
22 |
23 | sampler = QuasiRejectionSampler(target, proposal)
24 | samples, _ = sampler.sample(sampling_size=32, context=prefix)
25 |
26 | self.assertTrue(all(word in s.text for s in samples),
27 | "sampled sequences should respect the constraints.")
28 |
29 | self.assertGreater(sampler.get_acceptance_rate(), 0.,
30 | "that sampling should probably return at least a sample.")
31 |
32 | def test_estimator(self):
33 | proposal = PoissonDistribution(10)
34 | target = PoissonDistribution(11)
35 | estimator = QuasiRejectionSamplerEstimator(target, proposal, n_estimation_samples=int(10e4))
36 | ar = estimator.acceptance_rate_at_beta(0.5)
37 | self.assertAlmostEqual(ar, 0.999, places=2)
38 |
39 | ar = estimator.acceptance_rate_at_beta(1)
40 | self.assertAlmostEqual(ar, 0.877, places=2)
41 |
42 | ar = estimator.acceptance_rate_at_beta(2)
43 | self.assertAlmostEqual(ar, 0.5, places=2)
44 |
45 | tvd = estimator.divergence_at_beta(0.5, divergence=TV)
46 | self.assertAlmostEqual(tvd, 0.12, places=2)
47 |
48 | tvd = estimator.divergence_at_beta(1, divergence=TV)
49 | self.assertAlmostEqual(tvd, 0.075, places=2)
50 |
51 | tvd = estimator.divergence_at_beta(2, divergence=TV)
52 | self.assertAlmostEqual(tvd, 0.0041, places=2)
53 |
54 | kl = estimator.divergence_at_beta(1, divergence=KL)
55 | self.assertAlmostEqual(kl, 0.021, places=2)
56 |
57 | kl = estimator.divergence_at_beta(0.5, divergence=KL)
58 | self.assertAlmostEqual(kl, 0.046, places=2)
59 |
60 | kl = estimator.divergence_at_beta(2, divergence=KL)
61 | self.assertAlmostEqual(kl, 0.0, places=2)
62 |
63 | class PoissonDistribution(object):
64 |
65 | def __init__(self, lam):
66 | self.lam = lam
67 |
68 | def sample(self, sampling_size=1, context=None):
69 | samples = list(stats.poisson.rvs(self.lam, size=sampling_size))
70 | return samples, self.log_score(samples)
71 |
72 | def log_score(self, x, context=None):
73 | return torch.tensor(stats.poisson.logpmf(x, self.lam))
74 |
75 | if __name__ == '__main__':
76 | unittest.main()
77 |
--------------------------------------------------------------------------------
/tests/boolean_scorer_test.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | import unittest
6 |
7 | import numpy as np
8 |
9 | from disco.scorers import BooleanScorer
10 |
11 | class BooleanScorerTest(unittest.TestCase):
12 |
13 | def test_truism(self):
14 | bf = BooleanScorer(lambda x, _: True)
15 | items = [1, 1.2345, "blah"]
16 | self.assertEqual(len(items), len(bf.score(items, None)),
17 | "a boolean feature should score all items.")
18 | self.assertTrue(all(1. == s for s in bf.score(items, None)),
19 | "a truism should always score 1.")
20 | self.assertEqual(len(items), len(bf.log_score(items, None)),
21 | "a boolean feature should (log-)score all items.")
22 | self.assertTrue(all(0 == s for s in bf.log_score(items, None)),
23 | "a truism should always (log-)score 0.")
24 |
25 | def test_true_and_false(self):
26 | bf = BooleanScorer(lambda x, _: 2 < len(x))
27 | scores = bf.score(["a", "abc"], None)
28 | self.assertEqual(0, scores[0],
29 | "False should result in 0.")
30 | self.assertEqual(1, scores[1],
31 | "True should result in 1.")
32 | log_scores = bf.log_score(["a", "abc"], None)
33 | self.assertEqual(-np.Inf, log_scores[0],
34 | "False should result in an infinite negative in log space.")
35 | self.assertEqual(0, log_scores[1],
36 | "True should result in a zero in log space.")
37 |
38 | def test_lambda_predicate(self):
39 | bf = BooleanScorer(lambda x, _: 2 < len(x))
40 | samples = ["", "a", "ab", "abc", "abcd"]
41 | scores = bf.score(samples, None)
42 | expected_scores = [0, 0, 0, 1, 1]
43 | self.assertTrue(all(e == s for e, s in zip(expected_scores, scores)),
44 | "a predicate expressed via a lambda should score correctly the items.")
45 | log_scores = bf.log_score(samples, None)
46 | expected_log_scores = [-np.Inf, -np.Inf, -np.Inf, 0, 0]
47 | self.assertTrue(all(e == s for e, s in zip(expected_log_scores, log_scores)),
48 | "a predicate expressed via a lambda should (log-)score correctly the items.")
49 |
50 | def test_fn_predicate(self):
51 | def longer_than_2(x, _):
52 | return 2 < len(x)
53 | bf = BooleanScorer(longer_than_2)
54 | samples = ["", "a", "ab", "abc", "abcd"]
55 | scores = bf.score(samples, None)
56 | expected_scores = [0, 0, 0, 1, 1]
57 | self.assertTrue(all(e == s for e, s in zip(expected_scores, scores)),
58 | "a predicate expressed via a regular function should score correctly the items.")
59 | log_scores = bf.log_score(samples, None)
60 | expected_log_scores = [-np.Inf, -np.Inf, -np.Inf, 0, 0]
61 | self.assertTrue(all(e == s for e, s in zip(expected_log_scores, log_scores)),
62 | "a predicate expressed via a regular function should (log-)score correctly the items.")
63 |
64 | if __name__ == '__main__':
65 | unittest.main()
--------------------------------------------------------------------------------
/disco/tuners/losses/f_divergence.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | import torch
6 | from .base import BaseLoss
7 | from disco.utils.moving_average import WindowedMovingAverage
8 |
9 | class FDivergenceLoss(BaseLoss):
10 | """
11 | Kullback-Leibler divergence loss for DPG
12 | """
13 | def __init__(self, use_baseline=True, baseline_window_size=1024):
14 | """
15 | Parameters
16 | ----------
17 | use_baseline: boolean
18 | use a baseline to reduce variance
19 | """
20 | super(FDivergenceLoss, self).__init__()
21 | if use_baseline:
22 | self.baseline = WindowedMovingAverage(baseline_window_size)
23 | else:
24 | self.baseline = None
25 |
26 | def __call__(self, proposal_log_probs, target_log_scores, policy_log_probs, z):
27 | """
28 | Computes the KL loss on a given minibatch of samples
29 | ∇ loss = π(x) / q(x) * f'(π(x) / p(x))) * ∇ log π(x)
30 |
31 | Parameters
32 | ----------
33 | samples: list of items
34 | samples from the proposal network
35 | context: text
36 | context for the samples
37 | proposal_log_scores: 1-dim Tensor
38 | log-probabilities for the samples according to the proposal
39 | target_log_scores: 1-dim Tensor
40 | log-probabilities for the samples according to the target unnormalized distribution (EBM)
41 | policy_log_scores: 1-dim Tensor
42 | log-probabilities for the samples according to the policy network
43 | z: 0-dim Tensor
44 | estimation of the partition function of the target distribution
45 |
46 | Returns
47 | -------
48 | mean loss across the minibatch
49 | """
50 | target_log_probs = target_log_scores - torch.log(z)
51 |
52 | log_t = policy_log_probs.detach() - target_log_probs
53 |
54 | # implemented in derived class depending on the desired f
55 | f_prime = self.f_prime(log_t)
56 |
57 | pseudo_reward = -f_prime
58 |
59 | for r in pseudo_reward:
60 | self.metric_updated.dispatch('pseudo_reward', r.item())
61 |
62 | if self.baseline is not None:
63 | if self.baseline.value is not None:
64 | advantage = pseudo_reward - self.baseline.value
65 | else:
66 | advantage = pseudo_reward
67 |
68 | self.baseline.update(pseudo_reward)
69 |
70 | self.metric_updated.dispatch("baseline", self.baseline.value)
71 | else:
72 | advantage = pseudo_reward
73 |
74 |
75 | for a in advantage:
76 | self.metric_updated.dispatch('advantage', a.item())
77 |
78 | importance_ratios = (policy_log_probs.detach() - proposal_log_probs).exp()
79 |
80 | for ir in importance_ratios:
81 | self.metric_updated.dispatch('importance_ratios', ir.item())
82 |
83 | loss = (importance_ratios * (-advantage) * policy_log_probs).mean()
84 |
85 | return loss
--------------------------------------------------------------------------------
/tests/product_test.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | import unittest
6 |
7 | import numpy as np
8 |
9 | from disco.scorers import Product
10 | from disco.scorers import BooleanScorer
11 |
12 | class ProductFeatureTest(unittest.TestCase):
13 |
14 | def test_single_feature(self):
15 | integers = np.random.random_integers(0, 5, 100)
16 | bf = BooleanScorer(lambda x, _: 5 < x)
17 | bf_log_scores = bf.log_score(integers, None)
18 | product = Product(bf)
19 | pr_log_scores = product.log_score(integers, None)
20 | self.assertTrue(all(b == p for b, p in zip(bf_log_scores, pr_log_scores)),
21 | "a product of a single feature should behave like that feature.")
22 |
23 | def test_two_features(self):
24 | integers = np.random.random_integers(0, 5, 100)
25 | bf1 = BooleanScorer(lambda x, _: 5 < x)
26 | bf1_log_scores = bf1.log_score(integers, None)
27 | bf2 = BooleanScorer(lambda x, _: 8 > x)
28 | bf2_log_scores = bf2.log_score(integers, None)
29 | product = Product(bf1, bf2)
30 | pr_log_scores = product.log_score(integers, None)
31 | self.assertTrue(all((b1 + b2) == p for b1, b2, p in zip(bf1_log_scores, bf2_log_scores, pr_log_scores)),
32 | "a product of two features should sum their respective log-scores.")
33 |
34 | def test_three_features(self):
35 | integers = np.random.random_integers(0, 5, 100)
36 | bf1 = BooleanScorer(lambda x, _: 2 < x)
37 | bf1_log_scores = bf1.log_score(integers, None)
38 | bf2 = BooleanScorer(lambda x, _: 8 > x)
39 | bf2_log_scores = bf2.log_score(integers, None)
40 | bf3 = BooleanScorer(lambda x, _: 0 == x % 2)
41 | bf3_log_scores = bf3.log_score(integers, None)
42 | product = Product(bf1, bf2, bf3)
43 | pr_log_scores = product.log_score(integers, None)
44 | self.assertTrue(all((b1 + b2 + b3) == p for b1, b2, b3, p in zip(bf1_log_scores, bf2_log_scores, bf3_log_scores, pr_log_scores)),
45 | "a product of three features should sum their respective log-scores.")
46 |
47 | def test_product_is_commutative(self):
48 | integers = np.random.random_integers(0, 5, 100)
49 | bf1 = BooleanScorer(lambda x, _: 5 < x)
50 | bf2 = BooleanScorer(lambda x, _: 8 > x)
51 | pr12 = Product(bf1, bf2)
52 | pr21 = Product(bf1, bf2)
53 | self.assertTrue(all(p12 == p21 for p12, p21 in zip(pr12.log_score(integers, None), pr21.log_score(integers, None))),
54 | "a product of two features should be commutative.")
55 |
56 | def test_product_has_a_sugar(self):
57 | integers = np.random.random_integers(0, 5, 100)
58 | bf1 = BooleanScorer(lambda x, _: 5 < x)
59 | bf1_log_scores = bf1.log_score(integers, None)
60 | bf2 = BooleanScorer(lambda x, _: 8 > x)
61 | bf2_log_scores = bf2.log_score(integers, None)
62 | product = bf1 * bf2
63 | pr_log_scores = product.log_score(integers, None)
64 | self.assertTrue(all((b1 + b2) == p for b1, b2, p in zip(bf1_log_scores, bf2_log_scores, pr_log_scores)),
65 | "a product also works when defined through its syntactic sugar.")
66 |
67 |
68 | if __name__ == '__main__':
69 | unittest.main()
--------------------------------------------------------------------------------
/disco/tuners/loggers/json.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | import os
6 | from .base import BaseTunerObserver
7 | import json
8 | from pathlib import Path
9 | import torch
10 | import logging
11 | from collections import defaultdict
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 | class JSONLogger(BaseTunerObserver):
16 | """
17 | Reports DPGTuner statistics to a JSON file
18 | """
19 |
20 | def __init__(self, tuner, project, name, path=os.environ['DISCO_SAVE_PATH'],
21 | save_steps=1, store_eval_samples=False, **kwargs):
22 | """Constructor of a JSONLogger object
23 |
24 | Parameters
25 | ----------
26 | tuner: DPGTuner
27 | The tuner object whose statistics we want to report
28 | path: string/Path
29 | The path where we want to store the logs
30 | project: string
31 | The subfolder where to store the logs
32 | name: string
33 | The filename to which we want to report the statistics
34 | save_steps: integer
35 | Number of gradient steps every which to write the json data to disk
36 | store_eval_samples: boolean
37 | Whether or not to store the samples in the json file
38 | """
39 | super(JSONLogger, self).__init__(tuner)
40 | self.filename = Path(path) / project / f"{name}.json"
41 | self.filename.parent.mkdir(parents=True, exist_ok=True)
42 | self.data = defaultdict(list)
43 | self.save_steps = save_steps
44 | self.store_eval_samples = store_eval_samples
45 |
46 | def __exit__(self, *exc):
47 | self.save()
48 |
49 | def save(self):
50 | with open(self.filename, 'w') as fout:
51 | json.dump(self.data, fout)
52 |
53 | def __setitem__(self, k, v):
54 | """
55 | Report arbitrary parameter/value combinations
56 | """
57 | if isinstance(v, Path):
58 | v = str(v) # avoid serialization error
59 | elif isinstance(v, torch.Tensor):
60 | v = v.item()
61 | self.data[k] = v
62 |
63 | def on_parameters_updated(self, params):
64 | self.data["parameters"] = dict(params)
65 |
66 | def on_metric_updated(self, name, value):
67 | if isinstance(value, torch.Tensor):
68 | value = value.item()
69 | if name not in self.data:
70 | self.data[name] = []
71 | self.data[name].append(value)
72 |
73 | def on_eval_samples_updated(self, context, samples, proposal_log_scores, model_log_scores, target_log_scores):
74 | if not self.store_eval_samples:
75 | return
76 | self.data["samples"].append([s.text for s in samples])
77 | self.data["samples_ids"].append([s.token_ids.tolist() for s in samples])
78 | self.data["proposal_scores"].append(proposal_log_scores.tolist())
79 | self.data["target_scores"].append(target_log_scores.tolist())
80 | self.data["model_scores"].append(model_log_scores.tolist())
81 |
82 | def on_step_idx_updated(self, s):
83 | self.data["steps"] = s
84 | if self.save_steps > 0 and (s % self.save_steps) == 0:
85 | self.save()
86 |
87 | def on_ministep_idx_updated(self, s):
88 | self.data["ministeps"] = s
89 |
--------------------------------------------------------------------------------
/tutorials/data/incipits.txt:
--------------------------------------------------------------------------------
1 | It was the best of times, it was the worst of times.
2 | All happy families are alike; each unhappy family is unhappy in its own way.
3 | Alas, poor Yorick! I knew him, Horatio.
4 | Call me Ishmael.
5 | It was a dark and stormy night.
6 | I'm going to kill myself, said Ambrose Bierce.
7 | Mrs. Dalloway said she would buy the flowers herself.
8 | A sound came out of the darkness — a frightful, indescribable sound.
9 | So we beat on, boats against the current, borne back ceaselessly into the past.
10 | You don’t know about me without you have read a book by the name of The Adventures of Tom Sawyer; but that ain’t no matter.
11 | “Haleakala,” said theoldest Hawaiian on the island, “is the house of the sun.”
12 | All this happened, more or less.
13 | It was a pleasure to burn.
14 | There was a man named Guy Montag, and fire was his job.
15 | We didn't start the fire.
16 | Everything was bluethen, my kid, blue shadows and blue smoke.
17 | It was a bright cold day in April, and the clocks were striking thirteen.
18 | If you really want to hear about it, the first thing you’ll probably want to know is where I was born, and what my lousy childhood was like, and how my parents were occupied and all before they had me, and all that David Copperfield kind of crap, but I don’t feel like going into it, if you want to know the truth.”
19 | I am an invisible man.
20 | I am a sick man... I am a spiteful man. I am an unattractive man.
21 | In this respect, I did not differ from millions of other fictitious characters created by other fictitious writers.
22 | You better not never tell nobody but God.
23 | I was born twice.
24 | There was me, that is Alex, and my three droogs, that is Pete, Georgie, and Dim, and we sat in the Korova Milkbar.
25 | It was a queer, sultry summer, the summer they electrocuted the Rosenbergs, and I didn’t know anybody named Rosenberg and wasn’t elected but I was sore as hell.
26 | The last time I saw my father, he was washing his hands in the bathroom sink of a Greyhound bus.
27 | There's a bluebird in my heart that wants to get out but I pour whiskey on him.
28 | You don’t have to live forever, you just have to live.
29 | We the animals crouched in our cages, waiting.
30 | It was probably the color of the sea that decided me.
31 | All children, except one, grow up.
32 | You’ll have to come a long way with me, dear reader, before you hear that again.
33 | Whether I shall turn out to be the hero of my own life, or whether that station will be held by anybody else, these pages must show.
34 | It was the best of times, it was the worst of times, it was the age of wisdom, it was the age of foolishness.
35 | In an old house in Paris that was covered with vines lived twelve little girls in two straight lines.
36 | There was a boy called Eustace Clarence Scrubb, and he almost deserved it.
37 | My father’s family name being Pirrip, and my Christian name Philip, my infant tongue could make of both names nothing longer or more explicit than Pip.
38 | It was a dark and stormy night; the rain fell in torrents, except at occasional intervals, when it was checked by a violent gust of wind which swept up the streets.
39 | “You fancy yourself obedient, don’t you, slave?” She looked up into the cold Hardwick eyes.
40 | As Gregor Samsa awoke one morning from uneasy dreams he found himself transformed in his bed into a gigantic insect.
41 | Marley was dead: to begin with.
42 | There was no possibility of taking a walk that day.
43 | It was a queer, sultry summer, the summer they electrocuted the Rosenbergs, and I didn’t know anybody named Rosenberg
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | from setuptools import setup, find_packages
6 |
7 | setup(
8 | name='disco-generation',
9 | version='1.1.2',
10 | description='A toolkit for distributional control of generative models',
11 | url='https://github.com/naver/disco',
12 | author='Naver Labs Europe', author_email='jos.rozen@naverlabs.com',
13 | license='Creative Commons Attribution-NonCommercial-ShareAlike 4.0',
14 | long_description="""The 🕺🏽 **disco** toolkit allows to control the properties of the generations by language models and other generative systems to match human preferences while avoiding catastrophic forgetting.
15 |
16 | To achieve this in **disco**, we first represent in what ways we want to update original model as a target distribution and then, generate samples from this new distribution through a combination of learning or monte-carlo methods, as follows.
17 |
18 | **Step 1: We express how the target distribution *should* be**
19 |
20 | To have a handle on the generative model, we define some feature over the generated samples. It can be anything we can compute. For example, on a language model it can be as simple as whether the generated text contains a certain word or as complex as the compilability of some generated piece of code. Importantly, there is no need for the feature to be differentiable.
21 | Then, we can express our preferences on the target distribution by defining the target *moments* of this feature. For example, we might want to ask that a certain word appears 50% of the time when sampling from the model; or that 100% of the generated code is compilable. The resulting target distribution is expressed as an energy-based model or EBM, which is an unnormalized probability distribution that respects the desired moments while avoiding catastrophic forgetting, as a result of having minimal KL divergence to the original model.
22 | This representation of the target distribution can *score* samples, but cannot directly be used to *generate* them.
23 |
24 | **Step 2: We generate samples from the target distribution**
25 |
26 | To generate samples from the target distribution, if not perfectly, we can tune a model to approximate it. The resulting model can generate samples directly from a close approximation of the target distribution. Furthermore, it can be used jointly with Quasi-Rejection Sampling (QRS), a Monte Carlo sampling technique that allows the generation of samples that are even more representative of the target distribution.
27 | Alternatively, it is then possible to use decoding methods such as nucleus sampling, top-k sampling, or beam search, which would return samples from a further updated target distribution.""",
28 | long_description_content_type='text/markdown',
29 | packages=find_packages(include=['disco', 'disco.*']),
30 | python_requires='>=3.8',
31 | install_requires=['torch', 'transformers>=4.49',
32 | 'numpy', 'scipy',
33 | 'datasets', 'spacy',
34 | 'notebook',
35 | 'neptune-client', 'wandb'],
36 | classifiers=[
37 | 'Development Status :: 5 - Production/Stable',
38 | 'Intended Audience :: Science/Research',
39 | 'License :: Free for non-commercial use',
40 | 'Operating System :: POSIX :: Linux',
41 | 'Operating System :: MacOS',
42 | 'Operating System :: Microsoft :: Windows',
43 | 'Programming Language :: Python :: 3',
44 | 'Programming Language :: Python :: 3.8',
45 | 'Programming Language :: Python :: 3.9',
46 | 'Programming Language :: Python :: 3.10',
47 | ],
48 | )
49 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | disco, Copyright (C) 2025 NAVER Corporation. All Rights Reserved.
2 |
3 | You must agree to the terms of this license in order to install and use the “Materials” associated with this project, which may include source code, executable code, models, model checkpoints, and data, together with any documentation and any updates provided at Naver’s discretion. By exercising any rights to the Materials, you accept and agree to be bound by the terms of this license. If you are entering into this license on behalf of a company or other entity, you represent that you are the employee or agent of such company (or other entity) and you have the authority to enter into this license on behalf of such company (or other entity). The Materials are protected by copyright and other intellectual property laws and is licensed, not sold.
4 |
5 | Non-Commercial License
6 |
7 | Subject to any LICENSE EXCEPTIONS, NAVER Corporation (“NAVER”) hereby grants you a non-exclusive, non-sublicensable, non-transferable license to use the Materials, subject to the following conditions:
8 |
9 | (1) SCOPE OF USE: The Materials are used solely for non-commercial purposes (“Purpose”). You may not use the Materials or derivatives thereof for any commercial purpose (i.e., primarily intended for or directed towards commercial advantage or monetary compensation). You may not distribute the Materials or derivatives thereof under different terms and conditions as this License.
10 |
11 | (2) COPYRIGHT: The above copyright notice and this License along with the disclaimer below shall be retained in all copies and derivatives.
12 |
13 | (3) TERM: The License automatically terminates without notice if you fail to comply with its terms or the Purpose no longer exists. You may terminate this License at any time by ceasing to use the Materials. Upon termination you agree to delete any and all copies of the Materials and derivatives. The license to any of your Contributions under (4) will survive termination.
14 |
15 | (4) CONTRIBUTIONS: If you contribute to the project by providing feedback (“Contributions”) by, for example, making comments or a pull request, you agree to grant, and hereby grant, NAVER, without any restrictions or limitations, a non-exclusive, perpetual, irrevocable, royalty-free, paid-up, assignable and sub-licensable license, to reproduce, publicly perform or display, install, use, modify, adapt, prepare derivative works of, post, distribute, make and have made, sell and transfer your Contributions, and derivative works thereof, for any purpose. This grant by you does not change your rights to use your Contributions. Your Contributions may be used to update the Materials at Naver’s discretion.
16 |
17 | (5) NO IMPLIED LICENSE: Except as otherwise expressly stated in this License, nothing herein shall be construed to grant you any license, by implication, estoppel, or otherwise, to any intellectual property of NAVER, including trademarks, copyrights, patents, or trade secrets.
18 |
19 | (6) LIMITATION OF LIABILITY: THE MATERIALS ARE PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL NAVER BE LIABLE FOR ANY CLAIM, DAMAGES (INCLUDING, BUT NOT LIMITED TO LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT (INCLUDING NEGLIGENCE), STRICT LIABILITY OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
20 |
21 | LICENSE EXCEPTIONS: If the Materials include subcomponents or dependencies with separate copyright notices or license terms, they will be set forth in the file NOTICE.txt.
22 |
23 |
--------------------------------------------------------------------------------
/disco/utils/moving_average.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | import numpy as np
6 | import torch
7 |
8 | def average(moving_averages):
9 | if moving_averages.items():
10 | weighted_values, weights = list(zip(
11 | *[(ma.value * ma.weight, ma.weight) for _, ma in moving_averages.items()]
12 | ))
13 | return sum(weighted_values) / sum(weights)
14 | else:
15 | return 0
16 |
17 | class WindowedMovingAverage:
18 | """
19 | Keeps a moving average of a quantity over a fixed-size window.
20 | """
21 | def __init__(self, window_size=1000):
22 | """
23 | Parameters
24 | ----------
25 | window_size: int
26 | The number of samples to average over.
27 | """
28 | if window_size <= 0:
29 | raise ValueError("window_size must be a positive integer.")
30 | self.window_size = window_size
31 | self.buffer = []
32 | self.value = None # Stores the average of the last completed window.
33 |
34 | def update(self, new):
35 | """
36 | Adds new values to the buffer and updates the average
37 |
38 | Parameters
39 | ----------
40 | new: torch.Tensor, np.ndarray, list, or float
41 | A collection of new pointwise estimates to add to the buffer.
42 | """
43 | # Ensure 'new' is a flat list of numbers
44 | if isinstance(new, torch.Tensor):
45 | new_values = new.flatten().tolist()
46 | elif isinstance(new, np.ndarray):
47 | new_values = new.flatten().tolist()
48 | elif hasattr(new, '__iter__'):
49 | new_values = list(new)
50 | else:
51 | new_values = [new] # Treat a single number as a list
52 |
53 | self.buffer.extend(new_values)
54 |
55 | # crop to window size
56 | self.buffer = self.buffer[-self.window_size:]
57 |
58 | # Calculate the mean of the full window and update the public value
59 | self.value = sum(self.buffer) / len(self.buffer)
60 |
61 | class MovingAverage:
62 | """
63 | Keeps a moving average of a quantity.
64 | The average is initialized with the first value reported.
65 | """
66 | def __init__(self):
67 | """
68 | Initializes the moving average.
69 | """
70 | self.value = None
71 | self.weight = 0
72 |
73 | def update(self, new):
74 | """
75 | Adds new values to the average.
76 |
77 | Parameters
78 | ----------
79 | new: torch.Tensor, np.ndarray, list, or float
80 | A collection of new pointwise estimates to add.
81 | """
82 | # Ensure 'new' is a flat list of numbers
83 | if isinstance(new, torch.Tensor):
84 | new_values = new.detach().flatten().tolist()
85 | elif isinstance(new, np.ndarray):
86 | new_values = new.flatten().tolist()
87 | elif hasattr(new, '__iter__'):
88 | new_values = list(new)
89 | else:
90 | new_values = [new] # Treat a single number as a list
91 |
92 | if not new_values:
93 | return
94 |
95 | new_weight = len(new_values)
96 | new_mean = sum(new_values) / new_weight
97 |
98 | # If this is the first update, initialize the moving average
99 | if self.value is None:
100 | self.value = new_mean
101 | self.weight = new_weight
102 | else:
103 | # Update the existing moving average incrementally
104 | total_weight = self.weight + new_weight
105 | self.value = (self.value * self.weight + new_mean * new_weight) / total_weight
106 | self.weight = total_weight
107 |
108 | def reset(self):
109 | self.weight = 0
110 | self.value = None
--------------------------------------------------------------------------------
/disco/tuners/loggers/mlflow.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | import mlflow
6 | import tempfile
7 | import json
8 | from pathlib import Path
9 | from .base import BaseTunerObserver
10 | from collections import defaultdict
11 | import torch
12 | import numpy as np
13 |
14 | class MLFlowLogger(BaseTunerObserver):
15 | """
16 | Reports DPGTuner statistics to MLFlow
17 | """
18 |
19 | def __init__(self, tuner, project, name=None):
20 | """Constructor of a WandBLogger object
21 |
22 | Parameters
23 | ----------
24 | tuner: DPGTuner
25 | The tuner object whose statistics we want to report
26 | project: string
27 | The MLFlow experiment_id to which we want to report the statistics
28 | name: string
29 | The MLFlow run_id to which we want to report the statistics
30 | """
31 | super(MLFlowLogger, self).__init__(tuner)
32 | mlflow.set_experiment(project)
33 | self.run = mlflow.start_run(run_name=name, log_system_metrics=True)
34 | self.step = None
35 | self.last_step_eval_samples_reported = None
36 | self._reset_stats()
37 |
38 | def _reset_stats(self):
39 | self.stats = defaultdict(list)
40 |
41 | def _report_step_stats(self, step):
42 | metrics = {}
43 | for k, vals in self.stats.items():
44 | if len(vals) == 1:
45 | metrics[k] = vals[0]
46 | else:
47 | try:
48 | metrics[f"{k}/min"] = np.min(vals)
49 | metrics[f"{k}/max"] = np.max(vals)
50 | metrics[f"{k}/mean"] = np.mean(vals)
51 | except TypeError:
52 | raise RuntimeError(f"TypeError while processing {k} = {vals}")
53 | mlflow.log_metrics(metrics, step=step)
54 |
55 | def __setitem__(self, k, v):
56 | """
57 | Report arbitrary parameter/value combinations
58 | """
59 | mlflow.log_param(k, v)
60 |
61 | def __exit__(self, *exc):
62 | self._report_step_stats(self.step)
63 | self.run.__exit__(*exc)
64 |
65 | def on_metric_updated(self, name, value):
66 | if isinstance(value, torch.Tensor):
67 | value = value.item()
68 | self.stats[name].append(value)
69 |
70 | def on_parameters_updated(self, params):
71 | mlflow.log_params(params)
72 |
73 | def on_eval_samples_updated(self, context, samples, proposal_log_scores, model_log_scores, target_log_scores):
74 | """
75 | Logs evaluation samples and their scores as an MLflow artifact (JSON format).
76 | """
77 | if self.last_step_eval_samples_reported is not None and self.last_step_eval_samples_reported == self.step:
78 | # we already reported samples for this step
79 | return
80 | self.last_step_eval_samples_reported = self.step
81 | # Build dict
82 | data = {
83 | "context": context,
84 | "samples": [
85 | {
86 | "text": getattr(s, "text", str(s)),
87 | "proposal_log_score": float(proposal_log_scores[i].item()),
88 | "model_log_score": float(model_log_scores[i].item()),
89 | "target_log_score": float(target_log_scores[i].item())
90 | }
91 | for i, s in enumerate(samples)
92 | ]
93 | }
94 |
95 | # Write to a temporary JSON file and log
96 | with tempfile.TemporaryDirectory() as tmpdir:
97 | tmpfile = Path(tmpdir, f"{self.step:03}.json")
98 | json.dump(data, open(tmpfile, 'w'), indent=2)
99 | mlflow.log_artifact(tmpfile, artifact_path=f"samples")
100 |
101 | def on_step_idx_updated(self, s):
102 | if self.step is not None:
103 | self._report_step_stats(self.step)
104 | self._reset_stats()
105 | self.step = s
106 | mlflow.log_metric("steps", s, step=self.step)
--------------------------------------------------------------------------------
/disco/utils/helpers.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | import torch
6 | import tqdm.autonotebook as tqdm
7 |
8 | def batchify(func, batch, samples=list(), **args):
9 | all = []
10 | with tqdm.tqdm(total=len(samples), desc=func.__name__, position=1, leave=False) as pbar:
11 | for i in range(len(samples)//batch + 1):
12 | subsamples = samples[i * batch:(i+1) * batch]
13 | if subsamples:
14 | all.append(func(subsamples, **args))
15 | pbar.update(batch)
16 | return torch.cat(all)
17 |
18 | def score_in_chunks_batched(model, samples_nested, contexts, chunk_size):
19 | """
20 | Scores samples from multiple contexts in chunks to manage memory, generalized
21 | to handle any chunk size.
22 |
23 | If the `chunk_size` is smaller than the number of contexts, contexts are
24 | processed sequentially to avoid memory overflow.
25 |
26 | Parameters:
27 | model: The distribution model with a `log_score_batch` method.
28 | samples_nested (list of lists): Samples, structured as [context][sample_index].
29 | contexts (list of str): The list of contexts.
30 | chunk_size (int): The maximum number of (sample, context) pairs to score in a single pass.
31 |
32 | Returns:
33 | torch.Tensor: A tensor of scores with shape (num_contexts, total_samples).
34 | """
35 | num_contexts = len(contexts)
36 | # Handle empty inputs to prevent errors
37 | if num_contexts == 0:
38 | return torch.empty(0, 0)
39 |
40 | num_samples_per_context = len(samples_nested[0])
41 | if num_samples_per_context == 0:
42 | return torch.empty(num_contexts, 0)
43 |
44 | desc = f"Scoring ({model.__class__.__name__})"
45 |
46 | # --- Case 1: `chunk_size` is large enough for all contexts at once ---
47 | if chunk_size >= num_contexts:
48 | all_scores_chunks = []
49 | # Calculate how many samples PER CONTEXT can fit in one batch.
50 | samples_per_context_chunk = max(1, chunk_size // num_contexts)
51 |
52 | for i in tqdm.trange(0, num_samples_per_context, samples_per_context_chunk, desc=desc, leave=False):
53 | chunk_slice = slice(i, i + samples_per_context_chunk)
54 | chunk_samples_nested = [sublist[chunk_slice] for sublist in samples_nested]
55 |
56 | if not chunk_samples_nested[0]:
57 | continue
58 |
59 | # Score this chunk across all contexts in one parallel call
60 | scores_chunk = model.log_score_batch(
61 | samples=chunk_samples_nested,
62 | contexts=contexts,
63 | ) # Shape: (num_contexts, samples_per_context_chunk)
64 | all_scores_chunks.append(scores_chunk)
65 |
66 | return torch.cat(all_scores_chunks, dim=1)
67 |
68 | # --- Case 2: `chunk_size` is smaller than num_contexts ---
69 | else:
70 | all_context_scores = []
71 | for i in tqdm.trange(num_contexts, desc=desc, leave=False):
72 | # Isolate samples and context for the current iteration
73 | current_context = [contexts[i]]
74 | current_samples_list = samples_nested[i]
75 |
76 | scores_for_one_context = []
77 | # For this single context, we can use the full `chunk_size` for its samples.
78 | for j in range(0, num_samples_per_context, chunk_size):
79 | sample_chunk = [current_samples_list[j : j + chunk_size]]
80 |
81 | if not sample_chunk[0]:
82 | continue
83 |
84 | scores_chunk = model.log_score_batch(
85 | samples=sample_chunk,
86 | contexts=current_context,
87 | ) # Shape: (1, len_of_chunk)
88 | scores_for_one_context.append(scores_chunk)
89 |
90 | # After processing all sample chunks, concatenate them for the current context
91 | if scores_for_one_context:
92 | full_scores = torch.cat(scores_for_one_context, dim=1)
93 | all_context_scores.append(full_scores)
94 |
95 | # Finally, concatenate the results from all contexts
96 | return torch.cat(all_context_scores, dim=0)
97 |
98 | def get_token_first_indices(x, token):
99 | """Find the first occurrence of a token in a 2D token array.
100 |
101 | Parameters
102 | ----------
103 | x: 2D int array
104 | list of token sequences
105 | token: int
106 | token to search
107 |
108 | Returns
109 | ------
110 | 1D array containing the position of the first occurrence of the token or -1 if not found
111 | """
112 | if 0 == x.shape[-1]:
113 | return torch.tensor(-1).repeat(x.shape[0])
114 | else:
115 | mask = token == x
116 | mask_max_values, mask_max_indices = torch.max(mask, dim=1)
117 | mask_max_indices[mask_max_values == 0] = -1
118 | return mask_max_indices
119 |
--------------------------------------------------------------------------------
/tutorials/5.QRS_sampling.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "metadata": {},
7 | "source": [
8 | "disco \n",
9 | "Copyright (C) 2022-present NAVER Corp. \n",
10 | "Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license "
11 | ]
12 | },
13 | {
14 | "attachments": {},
15 | "cell_type": "markdown",
16 | "metadata": {},
17 | "source": [
18 | "# Quasi-Rejection Sampling"
19 | ]
20 | },
21 | {
22 | "attachments": {},
23 | "cell_type": "markdown",
24 | "metadata": {},
25 | "source": [
26 | "Another way to approximate our EBM is to use Quasi-Rejection Sampling."
27 | ]
28 | },
29 | {
30 | "attachments": {},
31 | "cell_type": "markdown",
32 | "metadata": {},
33 | "source": [
34 | "## Expressing Preferences"
35 | ]
36 | },
37 | {
38 | "attachments": {},
39 | "cell_type": "markdown",
40 | "metadata": {},
41 | "source": [
42 | "That part is similar to the previous notebooks, about the Expression of Preferences or Tuning with DPG. Here we're going to define a distributional constraint: we want half of our examples to ~~be~~include amazing."
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "execution_count": null,
48 | "metadata": {},
49 | "outputs": [],
50 | "source": [
51 | "import re\n",
52 | "from disco.scorers import BooleanScorer\n",
53 | "from disco.distributions import LMDistribution\n",
54 | "from disco.distributions.single_context_distribution import SingleContextDistribution\n",
55 | "\n",
56 | "is_amazing = lambda s, c: bool(re.search(r\"\\bamazing\\b\", s.text))\n",
57 | "amazing_scorer = BooleanScorer(is_amazing)\n",
58 | "base = LMDistribution()\n",
59 | "\n",
60 | "incipit = \"It was a cold and stormy night\"\n",
61 | "\n",
62 | "target = base.constrain([amazing_scorer], [1/2],\n",
63 | " n_samples=2**10,\n",
64 | " context_distribution=SingleContextDistribution(incipit))"
65 | ]
66 | },
67 | {
68 | "attachments": {},
69 | "cell_type": "markdown",
70 | "metadata": {},
71 | "source": [
72 | "## Sampling"
73 | ]
74 | },
75 | {
76 | "attachments": {},
77 | "cell_type": "markdown",
78 | "metadata": {},
79 | "source": [
80 | "We then instantiate a proposal model to sample from."
81 | ]
82 | },
83 | {
84 | "cell_type": "code",
85 | "execution_count": null,
86 | "metadata": {},
87 | "outputs": [],
88 | "source": [
89 | "proposal = LMDistribution()"
90 | ]
91 | },
92 | {
93 | "attachments": {},
94 | "cell_type": "markdown",
95 | "metadata": {},
96 | "source": [
97 | "Note that we can peek inside the EBM to look at the computed coeeficient"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": null,
103 | "metadata": {},
104 | "outputs": [],
105 | "source": [
106 | "print(target.scorers[1].coefficients)"
107 | ]
108 | },
109 | {
110 | "attachments": {},
111 | "cell_type": "markdown",
112 | "metadata": {},
113 | "source": [
114 | "It's now just a matter of instantiating a sampler."
115 | ]
116 | },
117 | {
118 | "cell_type": "code",
119 | "execution_count": null,
120 | "metadata": {},
121 | "outputs": [],
122 | "source": [
123 | "from disco.samplers import QuasiRejectionSampler"
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "execution_count": null,
129 | "metadata": {},
130 | "outputs": [],
131 | "source": [
132 | "sampler = QuasiRejectionSampler(target, proposal, beta=400)"
133 | ]
134 | },
135 | {
136 | "attachments": {},
137 | "cell_type": "markdown",
138 | "metadata": {},
139 | "source": [
140 | "How are we doing?"
141 | ]
142 | },
143 | {
144 | "cell_type": "code",
145 | "execution_count": null,
146 | "metadata": {},
147 | "outputs": [],
148 | "source": [
149 | "from disco.samplers import AccumulationSampler"
150 | ]
151 | },
152 | {
153 | "cell_type": "code",
154 | "execution_count": null,
155 | "metadata": {},
156 | "outputs": [],
157 | "source": [
158 | "samples, log_scores = sampler.sample(sampling_size=2**10, context=incipit)"
159 | ]
160 | },
161 | {
162 | "cell_type": "code",
163 | "execution_count": null,
164 | "metadata": {},
165 | "outputs": [],
166 | "source": [
167 | "sum([is_amazing(s, _) for s in samples]) / len(samples)"
168 | ]
169 | },
170 | {
171 | "cell_type": "code",
172 | "execution_count": null,
173 | "metadata": {},
174 | "outputs": [],
175 | "source": []
176 | }
177 | ],
178 | "metadata": {
179 | "kernelspec": {
180 | "display_name": "Python 3.10.6 (conda)",
181 | "language": "python",
182 | "name": "python3"
183 | },
184 | "language_info": {
185 | "codemirror_mode": {
186 | "name": "ipython",
187 | "version": 3
188 | },
189 | "file_extension": ".py",
190 | "mimetype": "text/x-python",
191 | "name": "python",
192 | "nbconvert_exporter": "python",
193 | "pygments_lexer": "ipython3",
194 | "version": "3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:38:29) [Clang 13.0.1 ]"
195 | },
196 | "vscode": {
197 | "interpreter": {
198 | "hash": "babb4baf4e80bd80b9852210fc5469c0783907e52a560ed7247caef52808358d"
199 | }
200 | }
201 | },
202 | "nbformat": 4,
203 | "nbformat_minor": 2
204 | }
205 |
--------------------------------------------------------------------------------
/disco/distributions/base_distribution.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | import torch
6 | from tqdm.autonotebook import trange
7 | import copy
8 |
9 | from disco.scorers.positive_scorer import Product
10 | from disco.scorers.exponential_scorer import ExponentialScorer
11 | from disco.scorers.boolean_scorer import BooleanScorer
12 | from .distribution import Distribution
13 | from .single_context_distribution import SingleContextDistribution
14 | from disco.samplers.accumulation_sampler import AccumulationSampler
15 | from disco.utils.device import get_device
16 | from disco.utils.helpers import batchify
17 | from disco.utils.moving_average import MovingAverage
18 |
19 |
20 | class BaseDistribution(Distribution):
21 | """
22 | Base distribution class, which can be used
23 | to build an EBM.
24 | """
25 |
26 | def clone(self):
27 | return copy.deepcopy(self)
28 |
29 | def constrain(self,
30 | features, moments=None,
31 | proposal=None, context_distribution=SingleContextDistribution(''), n_contexts_per_step=1,
32 | n_samples=2**9, iterations=1000, learning_rate=0.05, tolerance=1e-5, sampling_size=2**5
33 | ):
34 | """
35 | Constrains features to the base according to their moments,
36 | so producing an EBM
37 |
38 | Parameters
39 | ----------
40 | features: list(feature)
41 | multiple features to constrain
42 | moments: list(float)
43 | moments for the features. There should be as many moments as there are features
44 | proposal: distribution
45 | distribution to sample from, if different from self
46 | context_distribution: distribution
47 | to contextualize the sampling and scoring
48 | n_contexts_per_step:
49 | number of contexts sampled for each gradient step
50 | n_samples: int
51 | number of samples to use to fit the coefficients
52 | learning_rate: float
53 | multipliers of the delta used when fitting the coefficients
54 | tolerance: float
55 | accepted difference between the targets and moments
56 | sampling_size:
57 | size of the batch when sampling samples
58 |
59 | Returns
60 | -------
61 | exponential scorer with fitted coefficients
62 | """
63 |
64 | if list != type(features):
65 | raise TypeError("features should be passed as a list.")
66 |
67 | if not moments:
68 | return Product(self, *features)
69 |
70 | if list != type(moments):
71 | raise TypeError("moments should be passed as a list.")
72 | if not len(features) == len(moments):
73 | raise TypeError("there should be as many as many moments as there are features.")
74 |
75 | if all([BooleanScorer == type(f) for f in features])\
76 | and all([1.0 == float(m) for m in moments]):
77 | return Product(self, *features)
78 |
79 | if not proposal:
80 | proposal = self
81 |
82 | context_samples, context_log_scores = context_distribution.sample(n_contexts_per_step)
83 |
84 | proposal_samples = dict()
85 | proposal_log_scores = dict()
86 | joint_log_scores = dict()
87 | feature_scores = dict()
88 | for (context, log_score) in zip(context_samples, context_log_scores):
89 | accumulator = AccumulationSampler(proposal, total_size=n_samples)
90 | proposal_samples[context], proposal_log_scores[context] = accumulator.sample(
91 | sampling_size=sampling_size, context=context
92 | )
93 | device = get_device(proposal_log_scores[context])
94 | reference_log_scores = batchify(
95 | self.log_score, sampling_size, samples=proposal_samples[context], context=context
96 | ).to(device)
97 | joint_log_scores[context] = torch.tensor(log_score).repeat(n_samples).to(device) + reference_log_scores
98 | feature_scores[context] = torch.stack(
99 | ([f.score(proposal_samples[context], context).to(device) for f in features])
100 | )
101 |
102 | coefficients = torch.tensor(0.0).repeat(len(features)).to(device)
103 | targets = torch.tensor(moments).to(device)
104 | with trange(iterations, desc='fitting exponential scorer') as t:
105 | for i in t:
106 | scorer = ExponentialScorer(features, coefficients)
107 | numerator = torch.tensor(0.0).repeat(len(features)).to(device)
108 | denominator = torch.tensor(0.0).repeat(len(features)).to(device)
109 | for context in context_samples:
110 | target_log_scores = joint_log_scores[context] + scorer.log_score(
111 | proposal_samples[context], context
112 | ).to(device)
113 | importance_ratios = torch.exp(target_log_scores - proposal_log_scores[context])
114 | numerator += (importance_ratios * feature_scores[context]).sum(dim=1)
115 | denominator += importance_ratios.sum()
116 | moments = numerator / denominator
117 | grad_coefficients = moments - targets
118 | err = grad_coefficients.abs().max().item()
119 | t.set_postfix(err=err)
120 | if tolerance > err:
121 | t.total_size = i
122 | t.refresh()
123 | break
124 | coefficients -= learning_rate * grad_coefficients
125 |
126 | return self * ExponentialScorer(features, coefficients)
127 |
--------------------------------------------------------------------------------
/tests/tuner_test.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | import unittest
6 |
7 | from disco.distributions import LMDistribution
8 | from disco.distributions.single_context_distribution import SingleContextDistribution
9 | from disco.scorers import ExponentialScorer, Scorer, BooleanScorer
10 | from disco.tuners import Tuner, DPGTuner, FDPGTuner, CDPGTuner, FCDPGTuner
11 | from disco.tuners.losses import *
12 | from disco.tuners.loggers.base import BaseTunerObserver
13 |
14 | class TunerTest(unittest.TestCase):
15 |
16 | def test_rkl(self):
17 | self._test_loss_rlhf(ReverseKLLoss)
18 |
19 | def test_tv(self):
20 | self._test_loss_rlhf(TVLoss)
21 |
22 | def test_js(self):
23 | self._test_loss_rlhf(JSLoss)
24 |
25 | def test_kl(self):
26 | self._test_loss_rlhf(KLLoss)
27 |
28 | def _test_loss_rlhf(self, loss_cls):
29 | base = LMDistribution()
30 | model = LMDistribution(freeze=False)
31 | reward = Scorer(lambda s, c: 1 if "amazing" in s.text else 0)
32 | rlhf_target = base * ExponentialScorer([reward], [0.1])
33 | loss = loss_cls()
34 | tuner = Tuner(model, rlhf_target, loss=loss, n_gradient_steps=1,
35 | n_contexts_per_step=1,
36 | n_samples_per_context=32,
37 | scoring_size=32)
38 | with MockObserver(tuner) as obs:
39 | tuner.tune()
40 | self.assertTrue('loss' in obs.observations)
41 |
42 | def test_dpg(self):
43 | base = LMDistribution()
44 | scorer = BooleanScorer(lambda s, c: "amazing" in s.text)
45 | target = base.constrain([scorer], [0.5])
46 | model = LMDistribution(freeze=False)
47 | tuner = DPGTuner(model, target, n_gradient_steps=1,
48 | n_samples_per_context=32,
49 | scoring_size=32)
50 | with MockObserver(tuner) as obs:
51 | tuner.tune()
52 | self.assertTrue('loss' in obs.observations)
53 |
54 | def test_cdpg(self):
55 | base = LMDistribution()
56 | scorer = BooleanScorer(lambda s, c: "amazing" in s.text)
57 | target = base.constrain([scorer], [0.5])
58 | model = LMDistribution(freeze=False)
59 | context_distribution = SingleContextDistribution("Speaking about today's dinner, it was")
60 | tuner = CDPGTuner(model, target,
61 | context_distribution=context_distribution,
62 | n_gradient_steps=1,
63 | n_contexts_per_step=1,
64 | n_samples_per_context=1,
65 | scoring_size=1)
66 | with MockObserver(tuner) as obs:
67 | tuner.tune()
68 | self.assertTrue('loss' in obs.observations)
69 |
70 | def test_fdpg(self):
71 | base = LMDistribution()
72 | scorer = BooleanScorer(lambda s, c: "amazing" in s.text)
73 | target = base.constrain([scorer], [0.5])
74 | model = LMDistribution(freeze=False)
75 | tuner = FDPGTuner(model, target, n_gradient_steps=1,
76 | n_samples_per_context=1,
77 | scoring_size=1)
78 | with MockObserver(tuner) as obs:
79 | tuner.tune()
80 | self.assertTrue('loss' in obs.observations)
81 |
82 | def test_fcdpg(self):
83 | base = LMDistribution()
84 | scorer = BooleanScorer(lambda s, c: "amazing" in s.text)
85 | target = base.constrain([scorer], [0.5])
86 | model = LMDistribution(freeze=False)
87 | context_distribution = SingleContextDistribution("Speaking about today's dinner, it was")
88 | tuner = FCDPGTuner(model, target,
89 | context_distribution=context_distribution,
90 | n_gradient_steps=1,
91 | n_contexts_per_step=1,
92 | n_samples_per_context=1,
93 | scoring_size=1)
94 | with MockObserver(tuner) as obs:
95 | tuner.tune()
96 | self.assertTrue('loss' in obs.observations)
97 |
98 | def test_moments(self):
99 | base = LMDistribution()
100 | scorer = BooleanScorer(lambda s, c: "amazing" in s.text)
101 | target = base.constrain([scorer], [1])
102 | model = LMDistribution(freeze=False)
103 | context_distribution = SingleContextDistribution("The movie was absolutely")
104 | tuner = FCDPGTuner(model, target,
105 | context_distribution=context_distribution,
106 | n_gradient_steps=1,
107 | n_contexts_per_step=1,
108 | n_samples_per_context=128,
109 | divergence_evaluation_interval=1,
110 | scoring_size=32,
111 | features=[('amazing', scorer)])
112 | with MockObserver(tuner) as obs:
113 | tuner.tune()
114 | self.assertTrue('amazing_proposal' in obs.observations)
115 | self.assertTrue('amazing_target' in obs.observations)
116 | self.assertAlmostEqual(obs.observations['amazing_proposal'].item(), 0.1, 1)
117 | self.assertAlmostEqual(obs.observations['amazing_target'].item(), 0.1, 1)
118 | self.assertAlmostEqual(obs.observations['amazing_proposal'].item(),
119 | obs.observations['amazing_target'].item(), 4)
120 | tuner = FCDPGTuner(model, target,
121 | context_distribution=context_distribution,
122 | n_gradient_steps=1,
123 | n_contexts_per_step=1,
124 | n_samples_per_context=1,
125 | scoring_size=1)
126 | with MockObserver(tuner) as obs:
127 | tuner.tune()
128 | self.assertTrue('amazing_proposal' not in obs.observations)
129 | self.assertTrue('amazing_target' not in obs.observations)
130 |
131 | class MockObserver(BaseTunerObserver):
132 | def __init__(self, tuner):
133 | super(MockObserver, self).__init__(tuner)
134 | self.observations = {}
135 | def on_metric_updated(self, k, v):
136 | self.observations[k] = v
137 |
--------------------------------------------------------------------------------
/disco/scorers/positive_scorer.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | import torch
6 | import numpy as np
7 | from functools import reduce
8 |
9 | from .scorer import Scorer
10 |
11 |
12 | class PositiveScorer(Scorer):
13 | """
14 | Scorer, but limited to positive values
15 | """
16 | def __init__(self, scoring_function):
17 | super().__init__(scoring_function)
18 |
19 | def __mul__(self, ot):
20 | """enables the use of the multiplication sign (*)
21 | to compose positive scorers"""
22 |
23 | return Product(self, ot)
24 |
25 | def log_score(self, samples, context):
26 | """returns the log-scores for the samples
27 | given the context by taking the log of their scores
28 |
29 | Parameters
30 | ----------
31 | samples : list(Sample)
32 | list of samples to log-score
33 | context: text
34 | context that the samples relate to
35 |
36 | Returns
37 | -------
38 | tensor of log-scores for the samples"""
39 |
40 | return torch.log(self.score(samples, context=context))
41 |
42 | def log_score_batch(self, samples, contexts):
43 | """
44 | Returns the batched log-scores for samples against multiple contexts.
45 |
46 | Parameters
47 | ----------
48 | samples : list
49 | A flat list of sample objects.
50 | contexts : list of str
51 | A list of contextual text strings.
52 | n_repeats : int
53 | The number of samples associated with each context.
54 |
55 | Returns
56 | -------
57 | torch.Tensor
58 | A tensor of log-scores with shape `(num_contexts, n_repeats)`.
59 | """
60 | scores = self.score_batch(samples, contexts=contexts)
61 | return torch.log(scores)
62 |
63 | def score(self, samples, context):
64 | """relies on the instance's scoring function
65 | to compute the scores of the samples given the context
66 |
67 | Parameters
68 | ----------
69 | samples : list(Sample)
70 | list of samples to score
71 | context: text
72 | context that the samples relate to
73 |
74 | Returns
75 | -------
76 | tensor of scores for the samples"""
77 |
78 | return self.scoring_function(samples, context)
79 |
80 | def score_batch(self, samples, contexts):
81 | """
82 | Computes scores for a batch of samples against multiple contexts.
83 |
84 | This default implementation works by iterating through the contexts,
85 | slicing the corresponding samples, and calling the single-context
86 | `scoring_function` for each.
87 |
88 | Parameters
89 | ----------
90 | samples : list
91 | A flat list of sample objects.
92 | contexts : list of str
93 | A list of contextual text strings.
94 |
95 | Returns
96 | -------
97 | torch.Tensor
98 | A tensor of scores with shape `(num_contexts, n_samples_per_context)`.
99 | """
100 | n_samples_per_context = len(samples) // len(contexts)
101 | all_scores = []
102 | for i, context in enumerate(contexts):
103 | # Slice the flat list of samples to get the ones for the current context
104 | context_samples = samples[i]
105 |
106 | if not context_samples:
107 | continue
108 |
109 | # Call the existing single-context scoring function
110 | context_scores = self.scoring_function(context_samples, context)
111 | all_scores.append(context_scores)
112 |
113 | if not all_scores:
114 | return torch.empty(len(contexts), 0)
115 |
116 | # Stack the list of 1D tensors into a single 2D tensor
117 | return torch.stack(all_scores, dim=0)
118 |
119 |
120 | class Product(PositiveScorer):
121 | """
122 | Utility class to compose scorers on the product of their scores
123 | """
124 |
125 | def __init__(self, *scorers):
126 | self.scorers = scorers
127 |
128 | def log_score(self, samples, context):
129 | """computes the product of the log-scores,
130 | hence adds the log-scores from the individual scorers
131 |
132 | Parameters
133 | ----------
134 | samples : list(Sample)
135 | list of samples to log-score
136 | context: text
137 | context used for the samples
138 |
139 | Returns:
140 | --------
141 | list of log-scores for the samples
142 | """
143 | try:
144 | device = self.scorers[0].device
145 | except AttributeError:
146 | device = "cpu"
147 |
148 | log_scores = [s.log_score(samples, context=context).to(device) for s in self.scorers]
149 | return reduce(lambda x,y: x+y, log_scores)
150 |
151 | def log_score_batch(self, samples, contexts):
152 | """
153 | Computes the product of scores in a batched manner by adding the
154 | batched log-scores from individual scorers.
155 |
156 | Parameters
157 | ----------
158 | samples : list
159 | A flat list of sample objects.
160 | contexts : list of str
161 | A list of contextual text strings.
162 |
163 | Returns
164 | -------
165 | torch.Tensor
166 | A tensor of the summed log-scores with shape `(num_contexts, n_samples_per_context)`.
167 | """
168 | try:
169 | device = self.scorers[0].device
170 | except AttributeError:
171 | device = "cpu"
172 |
173 | # Call the batched method on each scorer
174 | log_scores = [
175 | s.log_score_batch(samples, contexts=contexts).to(device)
176 | for s in self.scorers
177 | ]
178 |
179 | # Sum the resulting tensors element-wise
180 | return reduce(lambda x, y: x + y, log_scores)
181 |
182 | def __str__(self):
183 | scorers_str = ", ".join((str(scorer) for scorer in self.scorers))
184 | return f"Product({scorers_str})"
--------------------------------------------------------------------------------
/disco/samplers/accumulation_sampler.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | from . import Sampler
6 | import torch
7 | from tqdm.autonotebook import trange
8 |
9 |
10 | class AccumulationSampler(Sampler):
11 | """
12 | Utility class to accumulate samples, up to a total size
13 | """
14 |
15 | def __init__(self, distribution, total_size=512):
16 | """
17 | Parameters
18 | ----------
19 | distribution: distribution
20 | distribution to sample from
21 | total_size: int
22 | total number of samples
23 | """
24 |
25 | self.distribution = distribution
26 | self.total_size = total_size
27 |
28 | def sample(self, sampling_size=32, context=""):
29 | """accumulates batches of samples from the distribution
30 |
31 | Parameters
32 | ----------
33 | sampling_size: int
34 | number of requested samples per individual sampling
35 | context: text
36 | contextual text for which to sample
37 |
38 | Returns
39 | -------
40 | a tuple of accumulated samples and scores
41 | """
42 | with trange(
43 | self.total_size,
44 | desc=f"sampling from {type(self.distribution).__name__}",
45 | position=1,
46 | leave=False
47 | ) as t:
48 | remaining = self.total_size
49 | samples, log_scores = list(), torch.empty([0])
50 | while remaining > 0:
51 | more_samples, more_log_scores = self.distribution.sample(context=context, sampling_size=sampling_size)
52 | length = min(remaining, len(more_samples))
53 | more_samples, more_log_scores = more_samples[:length], more_log_scores[:length]
54 | samples, log_scores = (
55 | samples + more_samples,
56 | torch.cat((log_scores, more_log_scores))
57 | ) if samples else (more_samples, more_log_scores)
58 | remaining -= len(more_samples)
59 | t.update(len(more_samples))
60 |
61 | return (samples, log_scores)
62 |
63 | def sample_batch(self, contexts, sampling_size=32, output_scores=True):
64 | """
65 | Accumulates batches of samples for a list of contexts simultaneously.
66 |
67 | This method repeatedly calls the distribution's `sample_batch` method
68 | until `total_size` samples are collected for each context.
69 |
70 | Parameters
71 | ----------
72 | contexts: list of str
73 | A list of contextual texts for which to sample.
74 | sampling_size: int
75 | The number of samples to request for each context per batch call.
76 |
77 | Returns
78 | -------
79 | tuple of (list of lists, torch.Tensor)
80 | - A list of lists, where the outer list corresponds to the input
81 | contexts and each inner list contains `total_size` sample objects.
82 | - A tensor of log scores with shape `(num_contexts, total_size)`.
83 | """
84 | if not isinstance(contexts, list) or not contexts:
85 | raise ValueError("contexts must be a non-empty list of strings.")
86 |
87 | num_contexts = len(contexts)
88 |
89 | # Initialize containers for accumulating results for each context
90 | accumulated_samples = [[] for _ in range(num_contexts)]
91 | if output_scores:
92 | # Store score tensors from each batch call in a list for each context
93 | accumulated_scores_parts = [[] for _ in range(num_contexts)]
94 |
95 | with trange(
96 | self.total_size * num_contexts,
97 | desc=f"Batch sampling for {num_contexts} contexts from {type(self.distribution).__name__}",
98 | position=1,
99 | leave=False
100 | ) as t:
101 | collected_count = 0
102 | while collected_count < self.total_size:
103 | # Determine how many more samples are needed to reach the goal.
104 | # This prevents over-sampling in the final iteration.
105 | remaining = self.total_size - collected_count
106 | current_batch_size = min(remaining, sampling_size)
107 |
108 | # Call the batched sampling method on the distribution
109 | ret = self.distribution.sample_batch(
110 | contexts=contexts,
111 | sampling_size=current_batch_size,
112 | output_scores=output_scores
113 | )
114 | if output_scores:
115 | more_samples_nested, more_log_scores = ret
116 | else:
117 | more_samples_nested = ret
118 | # `more_log_scores` has shape: (num_contexts, current_batch_size)
119 |
120 | # Distribute the flat list of samples and batched scores to their
121 | # respective context accumulators.
122 | for i in range(num_contexts):
123 | start_idx = i * current_batch_size
124 | end_idx = start_idx + current_batch_size
125 |
126 | # Add samples for the i-th context
127 | accumulated_samples[i].extend(more_samples_nested[i])
128 |
129 | if output_scores:
130 | # Append the score tensor for the i-th context
131 | accumulated_scores_parts[i].append(more_log_scores[i])
132 |
133 | collected_count += current_batch_size
134 | t.update(current_batch_size * num_contexts)
135 |
136 | if output_scores:
137 | # Finalize the scores by concatenating the collected tensor parts for each context
138 | # and then stacking them into a single (num_contexts, total_size) tensor.
139 | final_scores = torch.stack(
140 | [torch.cat(parts) for parts in accumulated_scores_parts],
141 | dim=0
142 | )
143 |
144 | if output_scores:
145 | return (accumulated_samples, final_scores)
146 | else:
147 | return accumulated_samples
148 |
--------------------------------------------------------------------------------
/tests/base_distribution_test.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | import unittest
6 | import torch
7 | import numpy as np
8 | import random
9 |
10 | from disco.distributions import BaseDistribution
11 | from disco.distributions.single_context_distribution import SingleContextDistribution
12 | from disco.scorers import BooleanScorer
13 | from disco.scorers.positive_scorer import PositiveScorer
14 | from disco.scorers.positive_scorer import Product
15 | from disco.scorers.exponential_scorer import ExponentialScorer
16 |
17 |
18 | class DummyDistribution(BaseDistribution):
19 |
20 | def __init__(self, low=0):
21 | self.low = low
22 |
23 | def _max(self, high):
24 | return max(self.low + 8, high)
25 |
26 | def log_score(self, numbers, context=None):
27 | context = 64 if context is None else context
28 | high = self._max(context)
29 |
30 | numbers = torch.tensor(numbers)
31 | probability = 1 / (high - self.low)
32 | probabilities = torch.where(
33 | torch.logical_and(self.low <= numbers, numbers < high), probability, 0.0
34 | )
35 | return torch.log(probabilities)
36 |
37 | def sample(self, sampling_size=1, context=None):
38 | context = 64 if context is None else context
39 | high = self._max(context)
40 |
41 | numbers = torch.randint(self.low, high, (sampling_size,))
42 |
43 | return (
44 | numbers.tolist(),
45 | self.log_score(numbers),
46 | )
47 |
48 | even = BooleanScorer(lambda x, _: 0 == x % 2)
49 | divisible_by_3 = BooleanScorer(lambda x, _: 0 == x % 3)
50 | odds = [i * 2 + 1 for i in range(10)]
51 | evens = [i * 2 for i in range(10)]
52 | first_20_integers = odds + evens
53 |
54 | class BaseDistributionTest(unittest.TestCase):
55 |
56 | def test_constrain_features_should_passed_as_a_list(self):
57 | reference = DummyDistribution()
58 | with self.assertRaises(TypeError) as cm:
59 | _ = reference.constrain(even)
60 | err = cm.exception
61 | self.assertEqual(str(err), "features should be passed as a list.")
62 |
63 | def test_constrain_moments_should_passed_as_a_list_when_any(self):
64 | reference = DummyDistribution()
65 | with self.assertRaises(TypeError) as cm:
66 | _ = reference.constrain([even], 0.5)
67 | err = cm.exception
68 | self.assertEqual(str(err), "moments should be passed as a list.")
69 |
70 | def test_constrain_features_should_match_moments(self):
71 | reference = DummyDistribution()
72 | with self.assertRaises(TypeError) as cm:
73 | _ = reference.constrain([even, divisible_by_3], [0.5])
74 | err = cm.exception
75 | self.assertEqual(str(err), "there should be as many as many moments as there are features.")
76 |
77 | def test_constrain_pointwisely_with_a_feature(self):
78 | reference = DummyDistribution()
79 | target = reference.constrain([even])
80 | self.assertEqual(Product, type(target),
81 | "constrain(...) should return a Product.")
82 | self.assertTrue(all([-np.inf == s for s in target.log_score(odds, None)]),
83 | "a base distribution should be constrainable with a single pointwise feature to build an EBM.")
84 | self.assertTrue(all([-np.inf < s < 0 for s in target.log_score(evens, None)]),
85 | "a base distribution should be constrainable with a single pointwise feature to build an EBM.")
86 |
87 | def test_constrain_pointwisely_with_multiple_features(self):
88 | reference = DummyDistribution()
89 | target = reference.constrain([even, divisible_by_3])
90 | self.assertEqual(4, len([s for s in target.log_score(first_20_integers, None) if -np.inf < s < 0]),
91 | "a base distribution should be constrainable with multiple pointwise features to build an EBM.")
92 |
93 | def test_constrain_is_like_sugar_when_no_moment(self):
94 | reference = DummyDistribution()
95 | target_constrain = reference.constrain([even, divisible_by_3])
96 | target_sugar = reference * even * divisible_by_3
97 | some_integers, _ = reference.sample(20, 16)
98 | self.assertTrue(
99 | all([c == s for c, s in zip(target_constrain.log_score(some_integers, None), target_sugar.log_score(some_integers, None))]),
100 | "an EBM built from constrain(…) should be equivalent to the one built with the product sign sugar.")
101 |
102 | def test_constrain_distributionally_with_a_feature(self):
103 | reference = DummyDistribution()
104 | target = reference.constrain([even], [0.8], context_distribution=SingleContextDistribution(None))
105 | self.assertEqual(Product, type(target),
106 | "constrain(...) should return a Product.")
107 |
108 | def test_constrain_distributionally_with_multiple_features(self):
109 | reference = DummyDistribution()
110 | target = reference.constrain([even, divisible_by_3], [0.8, 0.5], context_distribution=SingleContextDistribution(None))
111 | self.assertEqual(Product, type(target),
112 | "constrain(...) should return a Product.")
113 |
114 | def test_constrain_without_moments(self):
115 | bs = BooleanScorer(lambda x, c: True)
116 | distribution = BaseDistribution(lambda x, c: random.random())
117 | self.assertEqual(BooleanScorer, type(distribution.constrain([bs]).scorers[-1]),
118 | "when no moment is specified a feature should be composed directly in a product.")
119 | for s in distribution.constrain([bs, bs, bs]).scorers[1:]:
120 | self.assertEqual(BooleanScorer, type(s),
121 | "when no moments are specified multiple features should appear as is in a product.")
122 |
123 | def test_constrain_with_boolean_scorers_and_moments_set_to_one(self):
124 | bs = BooleanScorer(lambda x, c: True)
125 | distribution = BaseDistribution(lambda x, c: 1)
126 | self.assertEqual(BooleanScorer, type(distribution.constrain([bs], [1]).scorers[-1]),
127 | "when only a boolean scorer is specified, and its moment is one, this feature should appear as is in a product.")
128 | bs = BooleanScorer(lambda x, c: True)
129 | for s in distribution.constrain([bs, bs, bs], [1, 1.0, "1"]).scorers[1:]:
130 | self.assertEqual(BooleanScorer, type(s),
131 | "when only boolean scorer are specified, and their moments are one, these feature should appear as is in a product.")
132 |
133 | def test_constrain_with_boolean_scorers_but_with_moments_not_set_to_one_returns_an_exponential_scorer(self):
134 | bs = BooleanScorer(lambda x, c: True)
135 | distribution = DummyDistribution()
136 | self.assertEqual(ExponentialScorer, type(distribution.constrain([bs, bs], [1, 0.5], context_distribution=SingleContextDistribution(None)).scorers[-1]),
137 | "when only boolean scorers are specified, but their moments are not one, features should be composed in an exponential scorer.")
138 |
139 | def test_constrain_returns_an_exponential_scorer(self):
140 | bs = BooleanScorer(lambda x, c: True)
141 | ps = PositiveScorer(lambda x, c: random.random())
142 | distribution = DummyDistribution()
143 | self.assertEqual(ExponentialScorer, type(distribution.constrain([bs, ps], [0.2, 0.5], context_distribution=SingleContextDistribution(None)).scorers[-1]),
144 | "features should be composed by default in an exponential scorer.")
145 |
146 | if __name__ == '__main__':
147 | unittest.main()
--------------------------------------------------------------------------------
/disco/samplers/quasi_rejection_sampler.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | import torch
6 |
7 | from . import Sampler
8 | from .accumulation_sampler import AccumulationSampler
9 | from disco.metrics import KL
10 | from disco.utils.device import get_device
11 | from disco.utils.helpers import batchify
12 |
13 |
14 | class QuasiRejectionSampler(Sampler):
15 | """
16 | Quasi Rejection-Sampling class
17 | """
18 |
19 | def __init__(self, target, proposal, beta=1):
20 | """
21 | Parameters
22 | ----------
23 | target: distribution
24 | Energy-based model to (log-)score the samples
25 | proposal: distribution
26 | distribution to generate the samples
27 | beta: float
28 | coefficient to control the sampling
29 | """
30 |
31 | super(QuasiRejectionSampler, self).__init__(target, proposal)
32 | self.beta = beta
33 | self.n_samples = 0
34 | self.n_accepted_samples = 0
35 |
36 | def sample(self, sampling_size=32, context=''):
37 | """Generates samples according to the QRS algorithm
38 |
39 | Parameters
40 | ----------
41 | sampling_size: int
42 | number of requested samples when sampling
43 | context: text
44 | contextual text for which to sample
45 |
46 | Returns
47 | -------
48 | tuple of accepted samples and their log-scores
49 | """
50 |
51 | samples, proposal_log_scores = self.proposal.sample(sampling_size=sampling_size, context=context)
52 | self.n_samples += len(samples)
53 |
54 | device = get_device(proposal_log_scores)
55 |
56 | target_log_scores = self.target.log_score(samples=samples, context=context).to(device)
57 |
58 | rs = torch.clamp(
59 | torch.exp(target_log_scores - proposal_log_scores) / self.beta,
60 | min=0.0, max=1.0
61 | )
62 |
63 | us = torch.rand(len(rs)).to(device)
64 | accepted_samples = [x for k, x in zip(us < rs, samples) if k]
65 | self.n_accepted_samples += len(accepted_samples)
66 | accepted_log_scores = torch.tensor([s for k, s in zip(us < rs, proposal_log_scores) if k]).to(device)
67 |
68 | return accepted_samples, accepted_log_scores
69 |
70 | def get_acceptance_rate(self):
71 | """Computes the currently observed acceptance rate, that is the number
72 | of accepted samples over the total sampled ones
73 |
74 | Returns
75 | -------
76 | acceptance rate as float between 0 and 1"""
77 |
78 | return self.n_accepted_samples / self.n_samples
79 |
80 |
81 | class QuasiRejectionSamplerEstimator:
82 | """
83 | Provides routines to compute estimates of useful metrics related to
84 | QuasiRejectionSampler (QRS)
85 | """
86 |
87 | def __init__(self, target, proposal, n_estimation_samples=10000,
88 | sampling_size=32, context=''):
89 | """
90 | Parameters
91 | ----------
92 | target: distribution
93 | Energy-based model to (log-)score the samples
94 | proposal: distribution
95 | distribution to generate the samples
96 | n_estimation_samples: integer
97 | number of samples to use for computing estimates
98 | sampling_size: integer
99 | number of samples that are concurrently obtained from the proposal
100 | context: text
101 | context to condition the proposal and target
102 | """
103 | sampler = AccumulationSampler(proposal, total_size=n_estimation_samples)
104 |
105 | self.samples, self.proposal_log_scores = \
106 | sampler.sample(sampling_size=sampling_size, context=context)
107 |
108 | self.target_log_scores = batchify(target.log_score,
109 | sampling_size, samples=self.samples, context=context)
110 |
111 | def acceptance_rate_at_beta(self, beta):
112 | """
113 | Estimate the acceptance rate that QRS has for a given beta parameter.
114 |
115 | Parameters
116 | ----------
117 | beta: float
118 | the value of beta for which we want the a.r. estimated
119 |
120 | Returns
121 | -------
122 | the estimated acceptance rate for the given beta
123 | """
124 | vals = self._compute_intermediary_values(beta)
125 |
126 | return vals['qrs_Z'].item() / beta
127 |
128 | def divergence_at_beta(self, beta, divergence=KL):
129 | """
130 | Estimate the divergence to the target distribution that QRS has
131 | for a given beta parameter.
132 |
133 | Parameters
134 | ----------
135 | beta: float
136 | the value of beta for which we want the a.r. estimated
137 | divergence: divergence
138 | the divergence that we want to compute (see :py:mod:`metrics`)
139 |
140 | Returns
141 | -------
142 | the estimated value of the chosen divergence for the given beta
143 | """
144 | vals = self._compute_intermediary_values(beta)
145 |
146 | return divergence.divergence(
147 | vals['target_log_scores'],
148 | vals['normalized_qrs_log_scores'],
149 | vals['Z'],
150 | vals['proposal_log_scores']).item()
151 |
152 | def feature_moment_at_beta(self, beta, feature):
153 | """
154 | Estimate the first moment (expected value) of a feature when using
155 | QRS with a given beta parameter.
156 |
157 | Parameters
158 | ----------
159 | beta: float
160 | the value of beta for which we want the a.r. estimated
161 | feature: scorer
162 | the feature whose moment we want to compute
163 |
164 | Returns
165 | -------
166 | the estimated feature moment at the given beta
167 | """
168 | vals = self._compute_intermediary_values(beta)
169 |
170 | feature_scores = batchify(feature.score,
171 | sampling_size, samples=self.estimation_samples, context=context)
172 |
173 | return torch.mean(vals['qrs_importance_ratios'] * feature_scores).item()
174 |
175 | def _compute_intermediary_values(self, beta):
176 | """
177 | Computes intermediary values used in QRS estimates
178 | """
179 | with torch.no_grad():
180 | target_importance_ratios = torch.exp(self.target_log_scores -
181 | self.proposal_log_scores)
182 | Z = torch.mean(target_importance_ratios)
183 | log_beta = torch.log(torch.tensor(beta)) if beta > 0 else float("-inf")
184 | qrs_log_scores = torch.minimum(self.target_log_scores,
185 | self.proposal_log_scores + log_beta)
186 | qrs_Z = torch.mean(torch.exp(qrs_log_scores - self.proposal_log_scores))
187 | normalized_qrs_log_scores = qrs_log_scores - torch.log(qrs_Z)
188 | qrs_importance_ratios = torch.exp(normalized_qrs_log_scores -
189 | self.proposal_log_scores)
190 |
191 | return {
192 | 'target_log_scores': self.target_log_scores,
193 | 'proposal_log_scores': self.proposal_log_scores,
194 | 'qrs_log_scores': qrs_log_scores,
195 | 'normalized_qrs_log_scores': normalized_qrs_log_scores,
196 | 'target_importance_ratios': target_importance_ratios,
197 | 'Z': Z,
198 | 'qrs_Z': qrs_Z,
199 | 'qrs_importance_ratios': qrs_importance_ratios}
200 |
--------------------------------------------------------------------------------
/tutorials/4.conditional_tuning.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "metadata": {},
7 | "source": [
8 | "disco \n",
9 | "Copyright (C) 2022-present NAVER Corp. \n",
10 | "Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license "
11 | ]
12 | },
13 | {
14 | "attachments": {},
15 | "cell_type": "markdown",
16 | "metadata": {},
17 | "source": [
18 | "# Conditional Tuning with CDPG"
19 | ]
20 | },
21 | {
22 | "attachments": {},
23 | "cell_type": "markdown",
24 | "metadata": {},
25 | "source": [
26 | "In the conditional case the context is no longer fixed and we rely on a more generic tuner, a `CDPGTuner`. In most cases we favor a seq2seq model, and our features make use of the context and the sample."
27 | ]
28 | },
29 | {
30 | "attachments": {},
31 | "cell_type": "markdown",
32 | "metadata": {},
33 | "source": [
34 | "For our experiment, we're going to summarize news article with T5, making sure the model does not hallucinate organizations."
35 | ]
36 | },
37 | {
38 | "attachments": {},
39 | "cell_type": "markdown",
40 | "metadata": {},
41 | "source": [
42 | "## Expressing Preferences"
43 | ]
44 | },
45 | {
46 | "attachments": {},
47 | "cell_type": "markdown",
48 | "metadata": {},
49 | "source": [
50 | "Using spaCy, we can extract organization names from a text."
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "execution_count": null,
56 | "metadata": {},
57 | "outputs": [],
58 | "source": [
59 | "import spacy"
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "execution_count": null,
65 | "metadata": {},
66 | "outputs": [],
67 | "source": [
68 | "nlp = spacy.load(\"en_core_web_sm\")"
69 | ]
70 | },
71 | {
72 | "cell_type": "code",
73 | "execution_count": null,
74 | "metadata": {},
75 | "outputs": [],
76 | "source": [
77 | "def organizations(text):\n",
78 | " \"\"\"returns a set of organizations from a text\"\"\"\n",
79 | " doc = nlp(text)\n",
80 | " return set(ent.text for ent in doc.ents if \"ORG\" == ent.label_)"
81 | ]
82 | },
83 | {
84 | "attachments": {},
85 | "cell_type": "markdown",
86 | "metadata": {},
87 | "source": [
88 | "Now that we can obtain a set of organizations from a text, we can build a scorer: we want to make sure that a sample only includes the organizations mentioned in the context, that we're going to summarize —in other words we don't want to have hallucinated organizations."
89 | ]
90 | },
91 | {
92 | "cell_type": "code",
93 | "execution_count": null,
94 | "metadata": {},
95 | "outputs": [],
96 | "source": [
97 | "from disco.scorers.boolean_scorer import BooleanScorer"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": null,
103 | "metadata": {},
104 | "outputs": [],
105 | "source": [
106 | "organization_scorer = BooleanScorer(lambda s, c: all({o in organizations(c) for o in organizations(s.text)}))"
107 | ]
108 | },
109 | {
110 | "attachments": {},
111 | "cell_type": "markdown",
112 | "metadata": {},
113 | "source": [
114 | "For this task, we're going to use a powerful seq2seq model from Transformers, in its \"base\" version: T5."
115 | ]
116 | },
117 | {
118 | "cell_type": "code",
119 | "execution_count": null,
120 | "metadata": {},
121 | "outputs": [],
122 | "source": [
123 | "from disco.distributions import LMDistribution\n",
124 | "from transformers import AutoModelForSeq2SeqLM"
125 | ]
126 | },
127 | {
128 | "cell_type": "code",
129 | "execution_count": null,
130 | "metadata": {},
131 | "outputs": [],
132 | "source": [
133 | "base = LMDistribution(model=\"t5-base\", tokenizer=\"t5-base\", auto=AutoModelForSeq2SeqLM)"
134 | ]
135 | },
136 | {
137 | "attachments": {},
138 | "cell_type": "markdown",
139 | "metadata": {},
140 | "source": [
141 | "And we simply state that we want all samples to respect our preferences."
142 | ]
143 | },
144 | {
145 | "cell_type": "code",
146 | "execution_count": null,
147 | "metadata": {},
148 | "outputs": [],
149 | "source": [
150 | "target = base * organization_scorer"
151 | ]
152 | },
153 | {
154 | "attachments": {},
155 | "cell_type": "markdown",
156 | "metadata": {},
157 | "source": [
158 | "## Tuning, Conditionally"
159 | ]
160 | },
161 | {
162 | "attachments": {},
163 | "cell_type": "markdown",
164 | "metadata": {},
165 | "source": [
166 | "We now want to tune a model in order to approximate this target distribution. For this we will need many contexts: we can use a DatasetContextDistribution to rely on a dataset from Hugging Face's Datasets repository, the CNN / Dailymail dataset. Let's see how this works."
167 | ]
168 | },
169 | {
170 | "cell_type": "code",
171 | "execution_count": null,
172 | "metadata": {},
173 | "outputs": [],
174 | "source": [
175 | "from disco.distributions.dataset_context_distribution import DatasetContextDistribution"
176 | ]
177 | },
178 | {
179 | "cell_type": "code",
180 | "execution_count": null,
181 | "metadata": {},
182 | "outputs": [],
183 | "source": [
184 | "dataset = DatasetContextDistribution(dataset=\"cnn_dailymail\", subset=\"1.0.0\", split=\"train\", key=\"article\")"
185 | ]
186 | },
187 | {
188 | "attachments": {},
189 | "cell_type": "markdown",
190 | "metadata": {},
191 | "source": [
192 | "Out of curiosity, we can sample a few articles and extract a set of organizations from the first one by doing:"
193 | ]
194 | },
195 | {
196 | "cell_type": "code",
197 | "execution_count": null,
198 | "metadata": {},
199 | "outputs": [],
200 | "source": [
201 | "articles, log_scores = dataset.sample(sampling_size=2**3)"
202 | ]
203 | },
204 | {
205 | "cell_type": "code",
206 | "execution_count": null,
207 | "metadata": {},
208 | "outputs": [],
209 | "source": [
210 | "articles[0]"
211 | ]
212 | },
213 | {
214 | "cell_type": "code",
215 | "execution_count": null,
216 | "metadata": {},
217 | "outputs": [],
218 | "source": [
219 | "organizations(articles[0])"
220 | ]
221 | },
222 | {
223 | "attachments": {},
224 | "cell_type": "markdown",
225 | "metadata": {},
226 | "source": [
227 | "We're using the online scheme, sampling directly from the model we'll be tuning —it's also very possible to rely on the offline scheme, see the [Tuning notebook](./3.tuning_DPG.ipynb)."
228 | ]
229 | },
230 | {
231 | "cell_type": "code",
232 | "execution_count": null,
233 | "metadata": {},
234 | "outputs": [],
235 | "source": [
236 | "model = LMDistribution(model=\"t5-base\", tokenizer=\"t5-base\", auto=AutoModelForSeq2SeqLM, length=256, freeze=False, )"
237 | ]
238 | },
239 | {
240 | "attachments": {},
241 | "cell_type": "markdown",
242 | "metadata": {},
243 | "source": [
244 | "\n",
245 | "\n",
246 | "We can now instantiate a tuner. We're going:\n",
247 | " * to tune model to approximate target getting our samples from the model itself;\n",
248 | " * to use a context distribution to fetch articles from the CNN / Dailymail —all prepended with the task \"summarize :\" to control T5."
249 | ]
250 | },
251 | {
252 | "cell_type": "code",
253 | "execution_count": null,
254 | "metadata": {},
255 | "outputs": [],
256 | "source": [
257 | "from disco.tuners import CDPGTuner"
258 | ]
259 | },
260 | {
261 | "cell_type": "code",
262 | "execution_count": null,
263 | "metadata": {},
264 | "outputs": [],
265 | "source": [
266 | "tuner = CDPGTuner(model, target,\n",
267 | " context_distribution=DatasetContextDistribution(\n",
268 | " dataset=\"cnn_dailymail\", subset=\"1.0.0\", split=\"train\", key=\"article\", prefix=\"summarize: \"),\n",
269 | " n_gradient_steps=1000,\n",
270 | " n_samples_per_context=2**8,\n",
271 | " sampling_size=2**5,\n",
272 | " scoring_size=2**5)"
273 | ]
274 | },
275 | {
276 | "attachments": {},
277 | "cell_type": "markdown",
278 | "metadata": {},
279 | "source": [
280 | "Of course we want to monitor the progress so we use a logger."
281 | ]
282 | },
283 | {
284 | "cell_type": "code",
285 | "execution_count": null,
286 | "metadata": {},
287 | "outputs": [],
288 | "source": [
289 | "from disco.tuners.loggers.console import ConsoleLogger"
290 | ]
291 | },
292 | {
293 | "cell_type": "code",
294 | "execution_count": null,
295 | "metadata": {},
296 | "outputs": [],
297 | "source": [
298 | "ConsoleLogger(tuner)"
299 | ]
300 | },
301 | {
302 | "attachments": {},
303 | "cell_type": "markdown",
304 | "metadata": {},
305 | "source": [
306 | "Note that to instead / also use a `NeptuneLogger` we can simply uncomment the following cell, assuming we've actually setup to use the service."
307 | ]
308 | },
309 | {
310 | "cell_type": "code",
311 | "execution_count": null,
312 | "metadata": {},
313 | "outputs": [],
314 | "source": [
315 | "# from disco.tuners.loggers.neptune import NeptuneLogger\n",
316 | "# import os\n",
317 | "# NEPTUNE_API_TOKEN = os.environ[\"NEPTUNE_API_TOKEN\"]\n",
318 | "# NeptuneLogger(tuner,\n",
319 | "# project=\"disco\", api_token=NEPTUNE_API_TOKEN\n",
320 | "# )"
321 | ]
322 | },
323 | {
324 | "attachments": {},
325 | "cell_type": "markdown",
326 | "metadata": {},
327 | "source": [
328 | "Let's dance!"
329 | ]
330 | },
331 | {
332 | "cell_type": "code",
333 | "execution_count": null,
334 | "metadata": {},
335 | "outputs": [],
336 | "source": [
337 | "tuner.tune()"
338 | ]
339 | },
340 | {
341 | "cell_type": "code",
342 | "execution_count": null,
343 | "metadata": {},
344 | "outputs": [],
345 | "source": []
346 | }
347 | ],
348 | "metadata": {
349 | "kernelspec": {
350 | "display_name": "Python 3.10.6 (conda)",
351 | "language": "python",
352 | "name": "python3"
353 | },
354 | "language_info": {
355 | "codemirror_mode": {
356 | "name": "ipython",
357 | "version": 3
358 | },
359 | "file_extension": ".py",
360 | "mimetype": "text/x-python",
361 | "name": "python",
362 | "nbconvert_exporter": "python",
363 | "pygments_lexer": "ipython3",
364 | "version": "3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:38:29) [Clang 13.0.1 ]"
365 | },
366 | "vscode": {
367 | "interpreter": {
368 | "hash": "babb4baf4e80bd80b9852210fc5469c0783907e52a560ed7247caef52808358d"
369 | }
370 | }
371 | },
372 | "nbformat": 4,
373 | "nbformat_minor": 2
374 | }
375 |
--------------------------------------------------------------------------------
/tutorials/3.tuning_DPG.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "metadata": {},
7 | "source": [
8 | "disco \n",
9 | "Copyright (C) 2022-present NAVER Corp. \n",
10 | "Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license "
11 | ]
12 | },
13 | {
14 | "cell_type": "markdown",
15 | "metadata": {},
16 | "source": [
17 | "# Tuning with DPG"
18 | ]
19 | },
20 | {
21 | "cell_type": "markdown",
22 | "metadata": {},
23 | "source": [
24 | "Once we have expressed our preferences on the generated sequences, through an Energy-Based Model (EBM), we cannot directly sample from it. What we can do is approximate it by fine-tuning a model. \n",
25 | "Let's first see the case of classic, unconditional, ie with a fixed context, DPG."
26 | ]
27 | },
28 | {
29 | "cell_type": "markdown",
30 | "metadata": {},
31 | "source": [
32 | "## Expressing Preferences"
33 | ]
34 | },
35 | {
36 | "cell_type": "markdown",
37 | "metadata": {},
38 | "source": [
39 | "Let's stick with our _amazing_ use case: we want the word to appear in our samples —see the [Expressing Preference](./2.expressing_preferences.ipynb) notebook for the detailed explanations."
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": null,
45 | "metadata": {},
46 | "outputs": [],
47 | "source": [
48 | "from disco.scorers import BooleanScorer"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": null,
54 | "metadata": {},
55 | "outputs": [],
56 | "source": [
57 | "import re\n",
58 | "\n",
59 | "is_amazing = lambda s, c: bool(re.search(r\"\\bamazing\\b\", s.text))\n",
60 | "amazing_scorer = BooleanScorer(is_amazing)"
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": null,
66 | "metadata": {},
67 | "outputs": [],
68 | "source": [
69 | "from disco.distributions import LMDistribution"
70 | ]
71 | },
72 | {
73 | "cell_type": "markdown",
74 | "metadata": {},
75 | "source": [
76 | "for a pointwise constraint:"
77 | ]
78 | },
79 | {
80 | "cell_type": "code",
81 | "execution_count": null,
82 | "metadata": {},
83 | "outputs": [],
84 | "source": [
85 | "base = LMDistribution()\n",
86 | "pw_target = base * amazing_scorer"
87 | ]
88 | },
89 | {
90 | "cell_type": "markdown",
91 | "metadata": {},
92 | "source": [
93 | "for a distributional one:"
94 | ]
95 | },
96 | {
97 | "cell_type": "code",
98 | "execution_count": null,
99 | "metadata": {},
100 | "outputs": [],
101 | "source": [
102 | "from disco.distributions.single_context_distribution import SingleContextDistribution"
103 | ]
104 | },
105 | {
106 | "cell_type": "code",
107 | "execution_count": null,
108 | "metadata": {},
109 | "outputs": [],
110 | "source": [
111 | "incipit = \"It was a cold and stormy night\""
112 | ]
113 | },
114 | {
115 | "cell_type": "code",
116 | "execution_count": null,
117 | "metadata": {},
118 | "outputs": [],
119 | "source": [
120 | "dc_target = base.constrain([amazing_scorer], [1/2],\n",
121 | " n_samples=2**10,\n",
122 | " context_distribution=SingleContextDistribution(incipit))"
123 | ]
124 | },
125 | {
126 | "cell_type": "markdown",
127 | "metadata": {},
128 | "source": [
129 | "## Tuning"
130 | ]
131 | },
132 | {
133 | "cell_type": "markdown",
134 | "metadata": {},
135 | "source": [
136 | "We then instantiate the model we want to tune —we'll tune the \"network\" inside the distribution."
137 | ]
138 | },
139 | {
140 | "cell_type": "code",
141 | "execution_count": null,
142 | "metadata": {},
143 | "outputs": [],
144 | "source": [
145 | "model = LMDistribution(freeze=False)"
146 | ]
147 | },
148 | {
149 | "cell_type": "markdown",
150 | "metadata": {},
151 | "source": [
152 | "Let's check the initial rate for our constraint."
153 | ]
154 | },
155 | {
156 | "cell_type": "code",
157 | "execution_count": null,
158 | "metadata": {},
159 | "outputs": [],
160 | "source": [
161 | "from disco.samplers import AccumulationSampler"
162 | ]
163 | },
164 | {
165 | "cell_type": "code",
166 | "execution_count": null,
167 | "metadata": {},
168 | "outputs": [],
169 | "source": [
170 | "sampler = AccumulationSampler(model, total_size=2**9)\n",
171 | "samples, log_scores = sampler.sample(context=incipit)"
172 | ]
173 | },
174 | {
175 | "cell_type": "code",
176 | "execution_count": null,
177 | "metadata": {},
178 | "outputs": [],
179 | "source": [
180 | "sum([is_amazing(s, _) for s in samples]) / len(samples)"
181 | ]
182 | },
183 | {
184 | "cell_type": "markdown",
185 | "metadata": {},
186 | "source": [
187 | "### Offline"
188 | ]
189 | },
190 | {
191 | "cell_type": "markdown",
192 | "metadata": {},
193 | "source": [
194 | "In the offline scheme, we use a companion proposal distribution to sample from, and update that proposal, eventually, during the tuning."
195 | ]
196 | },
197 | {
198 | "cell_type": "code",
199 | "execution_count": null,
200 | "metadata": {},
201 | "outputs": [],
202 | "source": [
203 | "proposal = LMDistribution()"
204 | ]
205 | },
206 | {
207 | "cell_type": "markdown",
208 | "metadata": {},
209 | "source": [
210 | "We can now instantiate a tuner. We're going:\n",
211 | " * to tune model to approximate dc_target getting our samples from proposal;\n",
212 | " * to use a fixed incipit for the context;\n",
213 | " * to check the divergence every `divergence_evaluation_interval` gradient steps, when we'll also eventually update the proposal."
214 | ]
215 | },
216 | {
217 | "cell_type": "code",
218 | "execution_count": null,
219 | "metadata": {},
220 | "outputs": [],
221 | "source": [
222 | "from disco.tuners import DPGTuner"
223 | ]
224 | },
225 | {
226 | "cell_type": "code",
227 | "execution_count": null,
228 | "metadata": {},
229 | "outputs": [],
230 | "source": [
231 | "tuner = DPGTuner(model, dc_target, proposal,\n",
232 | " context=incipit,\n",
233 | " n_gradient_steps=1000,\n",
234 | " n_samples_per_context=2**8,\n",
235 | " sampling_size=2**5,\n",
236 | " scoring_size=2**5,\n",
237 | " divergence_evaluation_interval=2**2,\n",
238 | " n_kl_samples=2**10)"
239 | ]
240 | },
241 | {
242 | "cell_type": "markdown",
243 | "metadata": {},
244 | "source": [
245 | "There are loggers we can use to monitor the tuning. They are built on the observer patterns so it's easy to add more specific ones —although beyond the simple `ConsoleLoger` disco provides loggers for Neptune, Weight & Biases, ..."
246 | ]
247 | },
248 | {
249 | "cell_type": "code",
250 | "execution_count": null,
251 | "metadata": {},
252 | "outputs": [],
253 | "source": [
254 | "from disco.tuners.loggers.console import ConsoleLogger"
255 | ]
256 | },
257 | {
258 | "cell_type": "code",
259 | "execution_count": null,
260 | "metadata": {},
261 | "outputs": [],
262 | "source": [
263 | "ConsoleLogger(tuner)"
264 | ]
265 | },
266 | {
267 | "cell_type": "markdown",
268 | "metadata": {},
269 | "source": [
270 | "Let's dance!"
271 | ]
272 | },
273 | {
274 | "cell_type": "code",
275 | "execution_count": null,
276 | "metadata": {
277 | "scrolled": false
278 | },
279 | "outputs": [],
280 | "source": [
281 | "tuner.tune()"
282 | ]
283 | },
284 | {
285 | "cell_type": "markdown",
286 | "metadata": {},
287 | "source": [
288 | "Are we doing better?"
289 | ]
290 | },
291 | {
292 | "cell_type": "code",
293 | "execution_count": null,
294 | "metadata": {},
295 | "outputs": [],
296 | "source": [
297 | "sampler = AccumulationSampler(model, total_size=512)\n",
298 | "samples, log_scores = sampler.sample(context=incipit)"
299 | ]
300 | },
301 | {
302 | "cell_type": "code",
303 | "execution_count": null,
304 | "metadata": {},
305 | "outputs": [],
306 | "source": [
307 | "sum([is_amazing(s, _) for s in samples]) / len(samples)"
308 | ]
309 | },
310 | {
311 | "cell_type": "markdown",
312 | "metadata": {},
313 | "source": [
314 | "### Online tuning"
315 | ]
316 | },
317 | {
318 | "cell_type": "markdown",
319 | "metadata": {},
320 | "source": [
321 | "In the online scheme, the model being tuned is also the one providing the samples, so we don't need a proposal."
322 | ]
323 | },
324 | {
325 | "cell_type": "code",
326 | "execution_count": null,
327 | "metadata": {},
328 | "outputs": [],
329 | "source": [
330 | "model = LMDistribution(freeze=False)"
331 | ]
332 | },
333 | {
334 | "attachments": {},
335 | "cell_type": "markdown",
336 | "metadata": {},
337 | "source": [
338 | "_Note that, for an actual tuning, you might want to move the networks to GPUs first, for example with:_\n",
339 | "```\n",
340 | "model.to(\"cuda\")\n",
341 | "dc_target.scorers[0].to(\"cuda\")\n",
342 | "```"
343 | ]
344 | },
345 | {
346 | "cell_type": "code",
347 | "execution_count": null,
348 | "metadata": {},
349 | "outputs": [],
350 | "source": [
351 | "tuner = DPGTuner(model, dc_target,\n",
352 | " context=incipit,\n",
353 | " n_gradient_steps=100,\n",
354 | " n_samples_per_context=2**8,\n",
355 | " sampling_size=2**5,\n",
356 | " scoring_size=2**5,\n",
357 | " divergence_evaluation_interval=1)"
358 | ]
359 | },
360 | {
361 | "attachments": {},
362 | "cell_type": "markdown",
363 | "metadata": {},
364 | "source": [
365 | "_Again, for an actual tuning, you might want to initiate logging, for example with:_\n",
366 | "```\n",
367 | "from disco.tuners.loggers.wandb import WandBLogger\n",
368 | "logger = WandBLogger(tuner, \"my_project\", \"my_run\")\n",
369 | "```"
370 | ]
371 | },
372 | {
373 | "cell_type": "code",
374 | "execution_count": null,
375 | "metadata": {},
376 | "outputs": [],
377 | "source": [
378 | "tuner.tune()"
379 | ]
380 | },
381 | {
382 | "cell_type": "code",
383 | "execution_count": null,
384 | "metadata": {},
385 | "outputs": [],
386 | "source": [
387 | "sampler = AccumulationSampler(model, total_size=2**9)\n",
388 | "samples, log_scores = sampler.sample(context=incipit)"
389 | ]
390 | },
391 | {
392 | "cell_type": "code",
393 | "execution_count": null,
394 | "metadata": {},
395 | "outputs": [],
396 | "source": [
397 | "sum([is_amazing(s, _) for s in samples]) / len(samples)"
398 | ]
399 | },
400 | {
401 | "cell_type": "code",
402 | "execution_count": null,
403 | "metadata": {},
404 | "outputs": [],
405 | "source": []
406 | }
407 | ],
408 | "metadata": {
409 | "kernelspec": {
410 | "display_name": "Python 3",
411 | "language": "python",
412 | "name": "python3"
413 | },
414 | "language_info": {
415 | "codemirror_mode": {
416 | "name": "ipython",
417 | "version": 3
418 | },
419 | "file_extension": ".py",
420 | "mimetype": "text/x-python",
421 | "name": "python",
422 | "nbconvert_exporter": "python",
423 | "pygments_lexer": "ipython3",
424 | "version": "3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:38:29) [Clang 13.0.1 ]"
425 | },
426 | "vscode": {
427 | "interpreter": {
428 | "hash": "babb4baf4e80bd80b9852210fc5469c0783907e52a560ed7247caef52808358d"
429 | }
430 | }
431 | },
432 | "nbformat": 4,
433 | "nbformat_minor": 2
434 | }
435 |
--------------------------------------------------------------------------------
/tests/lm_distribution_test.py:
--------------------------------------------------------------------------------
1 | # disco
2 | # Copyright (C) 2022-present NAVER Corp.
3 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4 |
5 | import unittest
6 | import random, torch
7 | import numpy as np
8 | from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM
9 |
10 | from disco.distributions import LMDistribution
11 | from disco.distributions.lm_distribution import TextSample
12 |
13 | prefix = "It was a cold and stormy night"
14 |
15 | class LMDistributionTest(unittest.TestCase):
16 |
17 | def setUp(self):
18 | test_models = [("gpt2", AutoModelForCausalLM),
19 | ("facebook/bart-base", AutoModelForSeq2SeqLM)]
20 | self.test_distributions = {model: LMDistribution(model, auto=auto)
21 | for (model, auto) in test_models}
22 |
23 | def test_instantiate_a_default_distribution(self):
24 | distribution = LMDistribution()
25 | self.assertTrue(hasattr(distribution, "length"),
26 | "the distribution should have a length attribute.")
27 | self.assertEqual(40, distribution.length,
28 | "the default length should be 40.")
29 | self.assertTrue(hasattr(distribution, "params"),
30 | "the distribution should have a params attribute.")
31 | self.assertTrue(hasattr(distribution, "scorable"),
32 | "the distribution should have a scorable attribute.")
33 | self.assertEqual(True, distribution.scorable,
34 | "the distribution should be scorable by default.")
35 |
36 | def test_sample_a_continuation_from_a_default_distribution(self):
37 | distribution = LMDistribution()
38 | samples, log_scores = distribution.sample(context=prefix)
39 | self.assertTrue(isinstance(samples, list), "the samples should be returned as a list.")
40 | from disco.distributions.lm_distribution import TextSample
41 | for sample in samples:
42 | self.assertTrue(isinstance(sample, TextSample), "each text should be a textSample.")
43 | self.assertTrue(isinstance(log_scores, torch.Tensor), "the log-scores should be returned as a tensor.")
44 | for log_score in log_scores:
45 | self.assertTrue(isinstance(log_score.item(), float), "each log-score should be a float.")
46 |
47 | def test_sample_multiple_continuations_for_a_prefix(self):
48 | distribution = LMDistribution()
49 | sampling_size = 8
50 | samples, log_scores = distribution.sample(sampling_size=sampling_size)
51 | self.assertEqual(sampling_size, len(samples),
52 | "the number of returned samples should be {}.".format(sampling_size))
53 | self.assertEqual(sampling_size, len(log_scores),
54 | "the number of returned log_scores should be {}.".format(sampling_size))
55 |
56 | def test_sample_a_continuation_with_temperature_close_to_zero(self):
57 | distribution = LMDistribution(temperature=0.001)
58 | samples, log_scores = distribution.sample(sampling_size=2, context=prefix)
59 | self.assertEqual(samples[0].text, samples[1].text,
60 | "samples should not vary when temperature is (almost) zero.")
61 | self.assertEqual(log_scores[0], log_scores[1],
62 | "log_scores should not vary when temperature is (almost) zero.")
63 |
64 | def test_score_sequences(self):
65 | prompt = "It was a cold and stormy night"
66 | distribution = LMDistribution()
67 | texts = [
68 | " and the streets were empty and quiet.",
69 | "; the rain fell in torrents"
70 | ]
71 | tokenized_texts = distribution.tokenizer(texts, return_tensors="pt", add_special_tokens=True, padding=True)
72 | from disco.distributions.lm_distribution import TextSample
73 | samples = [TextSample(s, t) for (s, t) in zip(tokenized_texts["input_ids"], texts)]
74 | log_scores = distribution.log_score(samples, context=prefix)
75 | self.assertTrue(isinstance(log_scores, torch.Tensor), "the log-scores should be returned as a tensor.")
76 | self.assertEqual(len(samples), len(log_scores), "there should be as many log-scores as there are sequences.")
77 | for log_score in log_scores:
78 | self.assertTrue(isinstance(log_score.item(), float), "each log-score should be a float.")
79 | self.assertLess(log_score, 0.0, "each log-score should be negative.")
80 |
81 | def test_score_sampled_sequences(self):
82 | for model, distribution in self.test_distributions.items():
83 | prompt = "It was a cold and stormy night"
84 | distribution = LMDistribution()
85 | samples, log_scores = distribution.sample(context=prefix)
86 | samples_log_scores_again = distribution.log_score(samples, context=prefix)
87 | self.assertLess((log_scores - samples_log_scores_again).abs().max(),
88 | 1e-3, "log-scores given at sampling time and the scoring of "
89 | "the same samples should match")
90 |
91 | def test_score_with_a_non_default_top_k_specified(self):
92 | distribution = LMDistribution(top_k=20)
93 | samples, _ = distribution.sample(context=prefix)
94 | self.assertRaises(AssertionError, distribution.log_score, samples, context=prefix)
95 |
96 | def test_score_with_a_non_default_top_p_specified(self):
97 | distribution = LMDistribution(top_p=0.92)
98 | samples, _ = distribution.sample(context=prefix)
99 | self.assertRaises(AssertionError, distribution.log_score, samples, context=prefix)
100 |
101 | def test_score_with_a_non_default_typical_p_specified(self):
102 | distribution = LMDistribution(typical_p=0.9)
103 | samples, _ = distribution.sample(context=prefix)
104 | self.assertRaises(AssertionError, distribution.log_score, samples, context=prefix)
105 |
106 | def test_score_with_a_non_default_temperature_specified(self):
107 | distribution = LMDistribution(temperature=0.7)
108 | samples, _ = distribution.sample(context=prefix)
109 | self.assertRaises(AssertionError, distribution.log_score, samples, context=prefix)
110 |
111 | def test_score_sequences_different_length(self):
112 | distribution = LMDistribution()
113 | texts = [
114 | " and the streets were empty and quiet.",
115 | " in Brest."
116 | ]
117 |
118 | from disco.distributions.lm_distribution import TextSample
119 | samples = [
120 | TextSample(distribution.tokenizer(t, return_tensors="pt", add_special_tokens=True)["input_ids"], t)\
121 | for t in texts
122 | ]
123 | self.assertRaises(AssertionError, distribution.log_score, samples, context=prefix)
124 |
125 | def test_ignore_padding_at_score(self):
126 | for model, distribution in self.test_distributions.items():
127 | text = "The streets were empty and quiet."
128 | context = "This is some context." if distribution.network.config.is_encoder_decoder else ""
129 |
130 | from disco.distributions.lm_distribution import TextSample
131 | sample = TextSample(distribution.tokenizer(text, return_tensors="pt", add_special_tokens=False)["input_ids"][0], text)
132 | eos_token_id = distribution.tokenizer.eos_token_id
133 | pad_token_id = distribution.tokenizer.pad_token_id
134 |
135 | eos_token = distribution.tokenizer.eos_token
136 | pad_token = distribution.tokenizer.pad_token
137 | completed_sample = TextSample(
138 | token_ids=torch.cat((sample.token_ids, torch.tensor([eos_token_id])), dim=0),
139 | text=text + eos_token)
140 |
141 | padded_sample = TextSample(
142 | token_ids=torch.cat((sample.token_ids, torch.tensor([eos_token_id, pad_token_id])), dim=0),
143 | text=text + pad_token + pad_token)
144 |
145 | log_scores = distribution.log_score([sample], context=context).item()
146 | completed_log_scores = distribution.log_score([completed_sample], context=context).item()
147 | padded_log_scores = distribution.log_score([padded_sample], context=context).item()
148 |
149 | self.assertNotEqual(0, log_scores)
150 | if completed_log_scores != float("-inf") and completed_log_scores != float("-inf"):
151 | self.assertNotAlmostEqual(log_scores, completed_log_scores, 4,
152 | f"The {model} scores must be different with and without and eos token.")
153 | self.assertAlmostEqual(completed_log_scores, padded_log_scores, 4,
154 | f"The {model} scores must not be different with and without a pad token after an eos token.")
155 |
156 | def test_ignore_padding_at_sample(self):
157 | distribution = LMDistribution()
158 | seed = 0
159 | torch.manual_seed(seed)
160 | np.random.seed(seed)
161 | random.seed(seed)
162 | samples, log_scores = distribution.sample()
163 | padded_sequence_id = 16 # found manually
164 | padded_sequence = samples[padded_sequence_id]
165 | padded_sequence_log_score = log_scores[padded_sequence_id]
166 | pad_token_id = distribution.tokenizer.pad_token_id
167 | self.assertTrue((padded_sequence.token_ids[1:] == pad_token_id).all())
168 | padded_sequence_relog_score = distribution.log_score([padded_sequence])
169 | self.assertLess(torch.abs(padded_sequence_log_score - padded_sequence_relog_score), 1e-2)
170 |
171 | def test_empty_sequence_score(self):
172 | for model, distribution in self.test_distributions.items():
173 | pad_token_id = distribution.tokenizer.pad_token_id
174 | empty_sequence1_tok = [pad_token_id]*1
175 | empty_sequence1_str = distribution.tokenizer.decode(empty_sequence1_tok)
176 | empty_sequence1 = TextSample(token_ids=torch.tensor(empty_sequence1_tok),
177 | text=empty_sequence1_str)
178 | empty_sequence2_tok = [pad_token_id]*20
179 | empty_sequence2_str = distribution.tokenizer.decode(empty_sequence2_tok)
180 | empty_sequence2 = TextSample(token_ids=torch.tensor(empty_sequence2_tok),
181 | text=empty_sequence2_str)
182 | if distribution.network.config.is_encoder_decoder:
183 | log_score1 = distribution.log_score([empty_sequence1], context='yada yada')
184 | log_score2 = distribution.log_score([empty_sequence2], context='yada yada')
185 | else:
186 | log_score1 = distribution.log_score([empty_sequence1])
187 | log_score2 = distribution.log_score([empty_sequence2])
188 | self.assertAlmostEqual(log_score1.item(), log_score2.item(), 4,
189 | f"The {model} score of all pad tokens should correspond to a single one.")
190 |
191 | def test_scoring_consistent_with_loss(self):
192 | for model, distribution in self.test_distributions.items():
193 | text = "The streets were empty and quiet."
194 | context = "This is an example." if distribution.network.config.is_encoder_decoder else ""
195 |
196 | from disco.distributions.lm_distribution import TextSample
197 | sample = TextSample(distribution.tokenizer(text, return_tensors="pt", add_special_tokens=True)["input_ids"][0], text)
198 | log_score = distribution.log_score([sample], context=context)
199 | if distribution.network.config.is_encoder_decoder:
200 | ctxt = distribution.tokenizer(context, return_tensors="pt", add_special_tokens=True)
201 | loss = distribution.network(input_ids=ctxt.input_ids,
202 | labels=sample.token_ids.unsqueeze(0)).loss
203 | else:
204 | loss = distribution.network(sample.token_ids, labels=sample.token_ids).loss
205 | self.assertLess(torch.abs(-log_score / len(sample.token_ids) - loss), 1e-2,
206 | f'The {model} score is not consistent with the loss reported by the forward function.')
207 |
208 | def test_sample_empty_sequence(self):
209 | distribution = LMDistribution("gpt2")
210 | pad_token_id = distribution.tokenizer.pad_token_id
211 | epsilon = 0.05
212 | # increase the likelihood of sampling pad_token_id
213 | distribution.network.lm_head.weight.data[pad_token_id, :] += epsilon
214 | seed = 1
215 | torch.manual_seed(seed)
216 | np.random.seed(seed)
217 | random.seed(seed)
218 | samples, log_scores = distribution.sample()
219 | self.assertTrue(any((s.token_ids == pad_token_id).all() for s in samples))
220 | new_log_scores = distribution.log_score(samples)
221 | self.assertLess(torch.abs(log_scores - new_log_scores).max(), 1e-4)
222 |
223 | def test_freeze_parameters_by_default(self):
224 | distribution = LMDistribution()
225 | self.assertTrue(all([not p.requires_grad for p in distribution.network.parameters()]))
226 |
227 | def test_unfreeze_parameters_on_demand(self):
228 | distribution = LMDistribution()
229 | distribution.freeze(False)
230 | self.assertTrue(all([p.requires_grad for p in distribution.network.parameters()]))
231 |
232 | def test_unfreeze_all_parameters_with_parameter(self):
233 | distribution = LMDistribution(freeze=False)
234 | self.assertTrue(all([p.requires_grad for p in distribution.network.parameters()]))
235 |
236 | def test_freeze_all_parameters_on_demand(self):
237 | distribution = LMDistribution(freeze=False)
238 | distribution.freeze()
239 | self.assertTrue(all([not p.requires_grad for p in distribution.network.parameters()]))
240 |
241 | def test_clone(self):
242 | distribution = LMDistribution()
243 | distribution2 = distribution.clone()
244 | self.assertEqual(distribution.network.config, distribution2.network.config)
245 | self.assertNotEqual(distribution.network, distribution2.network)
246 |
247 | def test_bart_consistency(self):
248 | model = LMDistribution(
249 | model="facebook/bart-large-cnn",
250 | auto=AutoModelForSeq2SeqLM,
251 | length=128)
252 |
253 | context="""Self-trained autonomous agents developed using machine learning are showing great promise in a variety of control settings,
254 | perhaps most remarkably in applications involving autonomous vehicles. The main challenge associated with self-learned agents in the form
255 | of deep neural networks, is their black-box nature: it is impossible for humans to interpret deep neural networks. Therefore, humans cannot
256 | directly interpret the actions of deep neural network based agents, or foresee their robustness in different scenarios. In this work,
257 | we demonstrate a method for probing which concepts self-learning agents internalise in the course of their training. For demonstration,
258 | we use a chess playing agent in a fast and light environment developed specifically to be suitable for research groups without access to
259 | enormous computational resources or machine learning models."""
260 | torch.manual_seed(1)
261 | samples, log_scores = model.sample(context=context, sampling_size=1, sum=False)
262 | ls_log_scores = model.log_score(samples, context=context, sum=False)
263 | sample_idx = 0
264 | self.assertAlmostEqual(log_scores.sum().item(), ls_log_scores.sum().item(), 4)
265 |
266 | if __name__ == '__main__':
267 | unittest.main()
268 |
--------------------------------------------------------------------------------
/README.MD:
--------------------------------------------------------------------------------
1 | # 🕺🏽 disco: A Toolkit for Distributional Control of Generative Models
2 |
3 | The 🕺🏽 **disco** toolkit allows to control language models and other generative systems to match human preferences while avoiding catastrophic forgetting.
4 |
5 | To achieve this, **disco** decouples the problem of expressing _what_ properties the model should have from _how_ to actually get the desired model as separate steps.
6 |
7 | **Step 1: ⚓ We express how the target distribution *should* be**
8 |
9 | First, we define some feature over the generated samples that matters to us. It can be anything we can compute. For example, on a language model it can be as simple as whether the generated text contains a certain word or as complex as the compilability of some generated piece of code. Importantly, there is no need for the feature to be differentiable.
10 |
11 | Then, we can express our preferences on the target distribution by deciding how prevalent the feature should be. For example, we might want to ask that a certain word appears 50% of the time when sampling from the model; or that 100% of the generated code is compilable. The resulting target distribution is expressed as an energy-based model or EBM, which is an unnormalized probability distribution that respects the desired moments while avoiding catastrophic forgetting, as a result of having minimal KL divergence to the original model.
12 |
13 | The resulting representation of the target distribution can *score* samples, but cannot directly be used to *generate* them.
14 |
15 | **Step 2: 🎯 Approximate the target distribution**
16 |
17 | To generate samples from the target distribution we can tune a model to approximate it. We do this by minimizing the divergence to the target distribution. While techniques such as reinforcement learning from human feedback (RLHF) are restricted to using one kind of divergence only (specifically, reverse KL divergence), **disco** is more general, allowing the use of the full class of f-divergences, including both forward and reverse KL divergence, Jensen-Shannon, and total variation distance.
18 |
19 | **Step 3: 💬 Generate content that matches the preferences**
20 |
21 | The resulting model can generate samples directly from a close approximation of the target distribution. Furthermore, it can be used jointly with Quasi-Rejection Sampling (QRS), a Monte Carlo sampling technique that allows the generation of samples that are even more representative of the target distribution.
22 | Alternatively, it is then possible to use decoding methods such as nucleus sampling, top-k sampling, or beam search, which would return samples from a further updated target distribution.
23 |
24 | See the references below for more theoretical and technical details.
25 |
26 | ## Installation
27 |
28 | ### Standard installation
29 |
30 | The easiest way to install **disco** is to rely on pip, asking for the ```disco-generation``` package:
31 |
32 | ```
33 | pip install disco-generation
34 | ```
35 |
36 | Note that the toolkit:
37 | - depends on PyTorch,
38 | - uses HuggingFace's Transformers library for the generic handling of language models, as well as the generation of samples from them.
39 |
40 | ### Toolkit Developers
41 |
42 | If we plan to extend the toolkit we will need to clone and to install it as a local package.
43 | From the toolkit top folder, once we've git-cloned the repository and activated our development environment, we simply do:
44 | ```
45 | pip install -e .
46 | ```
47 |
48 | ## Quick introduction
49 |
50 |
51 | ### Distributions
52 |
53 | The generative model that we want to tune must be wrapped by a `Distribution` object. For example, for a (causal or seq2seq) language model compatible with the 🤗 Hugging Face interface use an `LMDistribution`.
54 |
55 | A valid `Distribution` must have the following two methods:
56 | - `.sample(context)` that given an optional `context` on which the distribution can be conditioned, returns a list of samples from the underlying distribution and a tensor with their corresponding log-probabilities;
57 | - `.log_score(samples, context)` that given a list of samples and the `context` on which to condition the distribution, returns their corresponding log-probabilities.
58 |
59 | ```python
60 | from disco.distributions import LMDistribution
61 | distribution = LMDistribution()
62 |
63 | incipit = "It was a cold and stormy night"
64 | samples, log_scores = distribution.sample(context=incipit)
65 |
66 | distribution.log_score(samples, context=incipit)
67 | ```
68 |
69 | `LMDistribution` generate samples, with the `TextSample` type, which are named tuples with both a `text` and `token_ids` fields.
70 |
71 | From now on, after this initial example, imports will be skipped for clarity.
72 |
73 | ### Features
74 |
75 | Features are represented by an object with the method
76 |
77 | - `.score(samples, context)` which given a list of samples and an eventual context returns a tensor of real-valued scores.
78 |
79 | A convenient way to define one is using the `Scorer` class, which accepts a function or a lambda abstraction that takes sample and a context, and vectorizes it. For example, we can compute the effective length of a GPT-2 text sample by finding the eos token:
80 |
81 | ```python
82 | sequence_length = Scorer(lambda s, c: s.text.index("<|endoftext|>"))
83 | ```
84 |
85 | Where `s` is the sample (assumed to be a `TextSample`) and `c` is an eventual context.
86 |
87 | #### Boolean Features
88 |
89 | An important class of features are *boolean* features. While general features can only be used to define *distributional* constraints, boolean features can also be used to define *pointwise* constraints, see below. To define one, we can use the `BooleanScorer` helper class, which takes a function as an argument. For example, we can score the presence of the string "amazing", as follows:
90 |
91 | ```python
92 | amazing = BooleanScorer(lambda s, c: "amazing" in s.text)
93 | ```
94 |
95 | The ```False```/```True``` results from the lambda are casted to `0.0`/`1.0` float values so that they can be used in the EBM definition.
96 |
97 | `BooleanScorer` belongs to the more general `PositiveScorer` class, which can be used to construct EBMs. The main properties of a `PostiveScorer` are that first, it returns positive scorers, and second that it provides the method:
98 |
99 | - `.log_score(samples, context)` that given a list of samples and the `context` on which to condition the distribution, returns their corresponding log-probabilities.
100 |
101 | As a consequence, we can see that a ```Distribution``` is also a ```PositiveScorer``` that is able to sample as well.
102 |
103 |
104 | ### Controlling Generation
105 |
106 | #### Expressing preferences through an EBM
107 |
108 | We express preferences over the distribution by defining target moments for specific features. This results in a target distribution that matches the desired moments while minimizing the KL divergence to the original distribution. In other words, it incorporates the preferences while avoiding catastrophic forgetting. This distribution is represented as an EBM, which can be used to score samples, in other words it is a `PositiveScorer`, but cannot be used to sample, we'll see how to sample below.
109 |
110 | We can express either *pointwise* or *distributional* constraints on a distribution and compose them at will. The former expresses a (boolean) property that must apply to *all* sequences, whereas the latter represents properties at the distributional level.
111 |
112 | To obtain the target distribution that incorporates our constraints, we use the `constraint` method of the corresponding `Distribution`. This method takes a list of features and their corresponding target moments.
113 |
114 | For example, we can define an EBM with a *pointwise* constraint requiring that all our samples must include "amazing" by setting the target moment to `1` on a `BooleanFeature`:
115 |
116 | ```python
117 | target_ebm = base.constrain([amazing], [1])
118 | ```
119 |
120 | Or we can ask for a _distributional_ constraint requiring that _half_ of the samples include "amazing":
121 |
122 | ```python
123 | target_ebm = base.constrain([amazing], [1/2])
124 | ```
125 |
126 |
127 | #### Approximating the target EBM
128 |
129 |
130 | Given an EBM target distribution, we now want to train a model to approximate it so that we can use it to generate samples. In the _unconditional_ case, namely when there is a single fixed context used in generation, then we can use a `Tuner`, more specifically a ```DPGTuner```, as follows.
131 |
132 |
133 | ```python
134 | target_ebm = base.constrain([amazing], [1])
135 |
136 | model = LMDistribution(freeze=False)
137 | incipit = "It was a cold and stormy night"
138 |
139 | tuner = DPGTuner(model, target_ebm, context=incipit)
140 | tuner.tune()
141 | ```
142 |
143 | And we can sample _amazing_ sequences from the tuned model.
144 | ```python
145 | samples, log_scores = model.sample(context=incipit)
146 | for s in samples:
147 | print(incipit + s.text)
148 | ```
149 |
150 | ##### Tuning parameters
151 |
152 | Important parameters of the `Tuner` include:
153 |
154 | - `n_gradient_steps`: number of total gradient steps in the full tuning process;
155 | - `n_samples_per_context`: total number of samples used in performing a gradient step (aka batch size);
156 | - `scoring_size`: number of samples sent in a batch to the `.score` function. This parameter affects training speed or helps solve GPU memory errors, but does not affect final results;
157 | - `sampling_size`: number of samples obtained from a single call to the `.sample` function. This parameter affects training speed or helps solve GPU memory errors, but does not affect final results;
158 | - `features`: list of pairs (`name`, `feature`) so that the `feature` moments will be computed by importance sampling (and reported using the key given by `name`);
159 | - `track_divergence_from_base`: set to True to track the reverse KL divergence from the original model —this requires an additional round of samples' scoring).
160 |
161 | #### Logging
162 |
163 | The Tuner reports a number of metrics that are useful to monitor the training progress. A number of `Logger` classes are provided to keep track of these metrics. Basic logging is provided though the console, as follows:
164 |
165 | ```python
166 | console_logger = ConsoleLogger(tuner)
167 | ```
168 |
169 | However, more detailed statistics can be kept trhough a JSON/WandB/Neptune loggers:
170 |
171 | ```python
172 | project = "example_project"
173 | name = "run_01"
174 | json_logger = JSONLogger(tuner, project, name)
175 | neptune_logger = NeptuneLogger(tuner, project, name)
176 | wandb_logger = WandBLogger(tuner, project, name)
177 | ```
178 |
179 | where `project` and `name` refer to the project and run name, respectively.
180 |
181 | ##### Logged Metrics
182 |
183 | Loggers store a number of metrics about the training process. Here we list a few of the most relevant ones:
184 |
185 | - `kl_target_model` and `kl_target_proposal`: estimates of the forward KL divergence to the target EBM from the tuned model and the proposal distribution, respectively. In the case of using online training, the two are equivalent with the only caveat that `kl_target_model` is computed —this is the metric being optimized, and not the value reported as `loss`;
186 | - `kl_model_base` and `kl_proposal_base`: estimates of the reverse KL divergence to the original model of the tuned model and the proposal distribution, respectively —only reported if `track_divergence_from_base` is set to True;
187 | - Feature moments: Estimate of the features' moments for those features specified with the `features` parameter at the Tuner's construction time.
188 |
189 | ### Controlled Conditional Generation
190 |
191 | The _conditional_ case is superficially very similar, with an extra step needed to instantiate a `ContextDistribution`, which allows to sample contexts that can then be used to condition the model. Furthermore, we use the more general ```CDPGTuner``` class.
192 |
193 | Assuming we have a file of incipits, one per line, in a `data/incipits.txt` file, we could do:
194 |
195 | ```python
196 | target_ebm = base.constrain([amazing], [1])
197 |
198 | model = LMDistribution(freeze=False)
199 |
200 | tuner = CDPGTuner(model, target_ebm,
201 | context_distribution=ContextDistribution("data/incipits.txt"),
202 | n_contexts_per_step=2**3)
203 | tuner.tune()
204 | ```
205 |
206 | Note that while we have used a decoder-only model here for illustrative purposes, the real power of the CDPGTuner is that it allows to control _seq2seq models_ such as those used in NMT, summarization, etc... Please refer to the dedicated [tutorial notebook](tutorials/4.conditional_tuning.ipynb) for an example of how to control an actual conditional model.
207 |
208 |
209 | ### Improving the approximation through minimizing other f-divergences
210 |
211 | The Tuner classes train the model by minimizing its divergence from the distribution induced by the target ebm. While in the original DPG and CDPG algorithms this divergence was always the KL divergence, [recent work](https://arxiv.org/abs/2302.08215) has generalized them to the wider class of [f-divergences](https://en.wikipedia.org/wiki/F-divergence). To pick the loss to minimize, use the corresponding `FDPGTuner` or `FCDPGTuner` class, depending on whether you are in the unconditional or conditional case, and choose the divergence to minimize through the `loss` parameter. Some of the possible losses are:
212 |
213 | - `KLLoss()`: KL divergence
214 | - `ReverseKLLoss()`: KL divergence reversing the order of the arguments
215 | - `TVLoss()`: Total Variation Distance
216 | - `JSLoss()`: Jensen-Shannon divergence.
217 |
218 | Each of these divergences strikes a different balance between level of alignment and diversity. KL exhibits lower alignment than other losses, but higher diversity. On the other hand, ReverseKL tends to produce hight alignment at the cost of lower diversity. Jensen-Shannon strikes a good balance between the two, making it a good default choice.
219 |
220 | As an example,
221 |
222 | ```python
223 | target_ebm = base.constrain([amazing], [1])
224 |
225 | model = LMDistribution(freeze=False)
226 |
227 | tuner = FCDPGTuner(model, target_ebm, loss=JSLoss(),
228 | context_distribution=ContextDistribution("data/incipits.txt"),
229 | n_contexts_per_step=2**3)
230 | tuner.tune()
231 | ```
232 |
233 | ### Reinforcement Learning from Human Feedback (RLHF)
234 |
235 | RLHF is a popular paradigm for aligning language models to preferences. While RLHF is commonly known in the form of a reward maximization algorithm, [recent work](https://arxiv.org/abs/2206.00761) has shown that it is equivalent to a distribution approximation problem which can be easily handled by **disco**. Specifically, given a reward function `r(x)` and the regularization parameter `beta`, the following code optimizes the same objective as RLHF:
236 |
237 | ```python
238 | target_ebm = base * ExponentialScorer([r], [1./beta])
239 |
240 | model = base.clone()
241 |
242 |
243 | tuner = FCDPGTuner(model, target_ebm, loss=ReverseKLLoss())
244 | ```
245 |
246 | In other words, RLHF optimizes the *reverse* KL to the above-defined target EBM. Interestingly, this opens new opportunities as the divergence to be minimized could be now chosen to be any other, as explored [here](https://arxiv.org/abs/2302.08215).
247 |
248 |
249 | #### Monte-Carlo sampling to improve the approximation
250 |
251 | After the tuning is done, `model` is now a better approximation to the target EBM, but it is not guaranteed to perfectly match this distribution. While further training can improve the situation, another alternative is using [quasi-rejection sampling (QRS)](https://disco.europe.naverlabs.com/QRS/), a Monte-Carlo sampling technique that allows to trade-off sampling efficiency for a higher fidelity to the target distribution —a higher value of `beta` yields a better fidelity although at a higher computational cost.
252 |
253 | ```python
254 | beta=0.5
255 | sampler = QuasiRejectionSampler(target_ebm, model, beta=beta)
256 | samples, log_scores = sampler.sample(sampling_size=2**7)
257 | ```
258 |
259 | #### In summary
260 |
261 | To put some of this (distributional constraint, tuning in the unconditional case and using QRS) together:
262 |
263 | ```python
264 | base = LMDistribution()
265 | target_ebm = base.constrain([amazing], [1/2])
266 |
267 | model = LMDistribution(freeze=False)
268 |
269 | tuner = DPGTuner(model, target_ebm)
270 | tuner.tune()
271 |
272 | beta=0.5
273 | sampler = QuasiRejectionSampler(target_ebm, model, beta=beta)
274 | samples, log_scores = sampler.sample(context=incipit, sampling_size=2**7)
275 | ```
276 |
277 | ### Going further
278 |
279 | A few things to keep in mind while reading the following paragraphs showing the principles of **disco**:
280 | 1. this is only an introduction, skipping some details and relying on toyish use cases;
281 | 1. the notebooks in the tutorials folder go in more depth, on more use cases;
282 | 1. the focus here and in most notebooks is on natural language, but again the toolkit can be used to control distributions over sequences such as code or chess moves, or even other data types, as long as they respect the basic assumptions of a disco `Distribution` object.
283 |
284 | ## References
285 |
286 | The **disco** toolkit implements the theoretical framework presented in the following works:
287 | - A Distributional Approach to Controlled Text Generation, Khalifa et al., 2021, , ICLR;
288 | - An approximate sampler for energy-based models with divergence diagnostics, Eikema et al., 2022, , TMLR;
289 | - Energy-Based Models for Code Generation under Compilability Constraints, Korbak et al., 2021, , ACL (Workshop on Natural Language Processing for Programming);
290 | - Controlling Conditional Language Models without Catastrophic Forgetting, Korbak et al., 2022, , ICML;
291 | - On Reinforcement Learning and Distribution Matching for Fine-Tuning Language Models with no Catastrophic Forgetting, Korbak et al., 2022, , NeurIPS;
292 | - Aligning Language Models with Preferences through f-divergence Minimization, Go et al., 2023, https://arxiv.org/abs/2302.08215, ICML.
293 |
294 | To cite **disco**, please use:
295 | ```
296 | @inproceedings{kruszewski-etal-2023-disco,
297 | title = "disco: a toolkit for Distributional Control of Generative Models",
298 | author = "Kruszewski, Germ{\'a}n and
299 | Rozen, Jos and
300 | Dymetman, Marc",
301 | booktitle = "Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 3: System Demonstrations)",
302 | month = jul,
303 | year = "2023",
304 | address = "Toronto, Canada",
305 | publisher = "Association for Computational Linguistics",
306 | url = "https://aclanthology.org/2023.acl-demo.14",
307 | pages = "144--160",
308 | abstract = "Pre-trained language models and other generative models have revolutionized NLP and beyond. However, these models tend to reproduce undesirable biases present in their training data. Also, they may overlook patterns that are important but challenging to capture. To address these limitations, researchers have introduced distributional control techniques. These techniques, not limited to language, allow controlling the prevalence (i.e. expectations) of any features of interest in the model{'}s outputs. Despite their potential, the widespread adoption of these techniques has been hindered by the difficulty in adapting the complex, disconnected code. Here, we present disco, an open-source Python library that brings these techniques to the broader public",
309 | }
310 | ```
311 |
312 | ## License
313 |
314 | See [LICENSE](LICENSE) file.
315 |
--------------------------------------------------------------------------------
/tutorials/1.quick_introduction.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "metadata": {},
7 | "source": [
8 | "disco \n",
9 | "Copyright (C) 2022-present NAVER Corp. \n",
10 | "Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license "
11 | ]
12 | },
13 | {
14 | "attachments": {},
15 | "cell_type": "markdown",
16 | "metadata": {},
17 | "source": [
18 | "# Quick introduction"
19 | ]
20 | },
21 | {
22 | "attachments": {},
23 | "cell_type": "markdown",
24 | "metadata": {},
25 | "source": [
26 | "Note that this is only an introduction, from the [README](../README.MD), skipping some details and relying on toyish use cases: \n",
27 | "see the other notebooks in the tutorial folder for more depth, and more use cases."
28 | ]
29 | },
30 | {
31 | "attachments": {},
32 | "cell_type": "markdown",
33 | "metadata": {},
34 | "source": [
35 | "## Distributions"
36 | ]
37 | },
38 | {
39 | "attachments": {},
40 | "cell_type": "markdown",
41 | "metadata": {},
42 | "source": [
43 | "The generative model that you want to tune must be wrapped by a `Distribution` object. For example, for a (causal or seq2seq) language model compatible with the 🤗 Hugging Face interface use an `LMDistribution`."
44 | ]
45 | },
46 | {
47 | "attachments": {},
48 | "cell_type": "markdown",
49 | "metadata": {},
50 | "source": [
51 | "A valid `Distribution` must have the following two methods:\n",
52 | "- `.sample(context)` that given an optional `context` on which the distribution can be conditioned, returns a list of samples from the underlying distribution and a tensor with their corresponding log-probabilities. \n",
53 | "- `.log_score(samples, context)` that given a list of samples and the `context` on which to condition the distribution, returns their corresponding log-probabilities."
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": null,
59 | "metadata": {},
60 | "outputs": [],
61 | "source": [
62 | "from disco.distributions import LMDistribution\n",
63 | "\n",
64 | "distribution = LMDistribution()"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": null,
70 | "metadata": {},
71 | "outputs": [],
72 | "source": [
73 | "incipit = \"It was a cold and stormy night\"\n",
74 | "samples, log_scores = distribution.sample(context=incipit)"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": null,
80 | "metadata": {},
81 | "outputs": [],
82 | "source": [
83 | "distribution.log_score(samples, context=incipit)"
84 | ]
85 | },
86 | {
87 | "attachments": {},
88 | "cell_type": "markdown",
89 | "metadata": {},
90 | "source": [
91 | "`LMDistribution` generate samples, with the `TextSample` type, which are named tuples with both a `text` and `token_ids` fields."
92 | ]
93 | },
94 | {
95 | "attachments": {},
96 | "cell_type": "markdown",
97 | "metadata": {},
98 | "source": [
99 | "### Features"
100 | ]
101 | },
102 | {
103 | "attachments": {},
104 | "cell_type": "markdown",
105 | "metadata": {},
106 | "source": [
107 | "Features are represented by an object with the method\n",
108 | "- `.score(samples, context)` which given a list of samples and an eventual context returns a tensor of real-valued scores."
109 | ]
110 | },
111 | {
112 | "attachments": {},
113 | "cell_type": "markdown",
114 | "metadata": {},
115 | "source": [
116 | "A convenient way to define one is using the `Scorer` class, which accepts a function, or a lambda definition, that takes sample and a context, and vectorizes it. For example, we can compute the effective length of a GPT-2 text sample by finding the eos token:"
117 | ]
118 | },
119 | {
120 | "cell_type": "code",
121 | "execution_count": null,
122 | "metadata": {},
123 | "outputs": [],
124 | "source": [
125 | "from disco.scorers.scorer import Scorer\n",
126 | "\n",
127 | "sequence_length = Scorer(lambda s, c: s.text.index(\"<|endoftext|>\"))"
128 | ]
129 | },
130 | {
131 | "attachments": {},
132 | "cell_type": "markdown",
133 | "metadata": {},
134 | "source": [
135 | "where `s` is the sample (assumed to be a `TextSample`) and `c` is an eventual context."
136 | ]
137 | },
138 | {
139 | "attachments": {},
140 | "cell_type": "markdown",
141 | "metadata": {},
142 | "source": [
143 | "#### Boolean Features"
144 | ]
145 | },
146 | {
147 | "attachments": {},
148 | "cell_type": "markdown",
149 | "metadata": {},
150 | "source": [
151 | "An important class of features are *boolean* features. While general features can only be used to define *distributional* constraints, boolean features can also be used to define *pointwise* constraints, see below. To define one, we can use the `BooleanScorer` helper class, which takes a function as an argument. \n",
152 | "For example, we can score the presence of the string \"amazing\", as follows:"
153 | ]
154 | },
155 | {
156 | "cell_type": "code",
157 | "execution_count": null,
158 | "metadata": {},
159 | "outputs": [],
160 | "source": [
161 | "from disco.scorers.boolean_scorer import BooleanScorer"
162 | ]
163 | },
164 | {
165 | "cell_type": "code",
166 | "execution_count": null,
167 | "metadata": {},
168 | "outputs": [],
169 | "source": [
170 | "amazing = BooleanScorer(lambda s, c: \"amazing\" in s.text)"
171 | ]
172 | },
173 | {
174 | "attachments": {},
175 | "cell_type": "markdown",
176 | "metadata": {},
177 | "source": [
178 | "The ```False```/```True``` results from the lambda are casted to `0.0`/`1.0` float values so that they can be used in the EBM definition. "
179 | ]
180 | },
181 | {
182 | "attachments": {},
183 | "cell_type": "markdown",
184 | "metadata": {},
185 | "source": [
186 | "`BooleanScorer` belongs to the more general class of `PositiveScorer`s, which can be used to construct EBMs. The main properties of a `PostiveScorer` is that first, it returns positive scorers, and second that it provides the method\n",
187 | " \n",
188 | " - `.log_score(samples, context)` that given a list of samples and the `context` on which to condition the distribution, returns their corresponding log-probabilities."
189 | ]
190 | },
191 | {
192 | "attachments": {},
193 | "cell_type": "markdown",
194 | "metadata": {},
195 | "source": [
196 | "As a consequence, we can see that a ```Distribution``` is also a ```PositiveScorer``` that is able to sample as well."
197 | ]
198 | },
199 | {
200 | "attachments": {},
201 | "cell_type": "markdown",
202 | "metadata": {},
203 | "source": [
204 | "## Controlling Generation"
205 | ]
206 | },
207 | {
208 | "attachments": {},
209 | "cell_type": "markdown",
210 | "metadata": {},
211 | "source": [
212 | "### Expressing preferences in an EBM"
213 | ]
214 | },
215 | {
216 | "attachments": {},
217 | "cell_type": "markdown",
218 | "metadata": {},
219 | "source": [
220 | "We express preferences over the distribution by defining target moments for specific features. This results in a target distribution that matches the desired moments while minimizing the KL divergence to the original distribution. In other words, it incorporates the preferences while avoiding catastrophic forgetting. This distribution is represented as an EBM, which can be used to score samples, in other words it is a `PositiveScorer`, but cannot be used to sample, we'll see how to sample below."
221 | ]
222 | },
223 | {
224 | "attachments": {},
225 | "cell_type": "markdown",
226 | "metadata": {},
227 | "source": [
228 | "We can express either *pointwise* or *distributional* constraints on a distribution and compose them at will. The former expresses a (boolean) property that must apply to *all* sequences, whereas the latter represents properties at the distributional level. "
229 | ]
230 | },
231 | {
232 | "attachments": {},
233 | "cell_type": "markdown",
234 | "metadata": {},
235 | "source": [
236 | "To obtain the target distribution that incorporates our constraints, we use the `constraint` method of the corresponding `Distribution`. This method takes a list of features and their corresponding target moments."
237 | ]
238 | },
239 | {
240 | "attachments": {},
241 | "cell_type": "markdown",
242 | "metadata": {},
243 | "source": [
244 | "For example, we can define an EBM with a *pointwise* constraint requiring that all our samples must include \"amazing\" by setting the target moment to `1` on a `BooleanFeature`:"
245 | ]
246 | },
247 | {
248 | "cell_type": "code",
249 | "execution_count": null,
250 | "metadata": {},
251 | "outputs": [],
252 | "source": [
253 | "from disco.distributions.lm_distribution import LMDistribution\n",
254 | "\n",
255 | "base = LMDistribution()\n",
256 | "ebm = base.constrain([amazing], [1])\n",
257 | "# ebm = base.constrain([amazing]) # would also work\n",
258 | "# ebm = base * amazing # as well, using a product notation "
259 | ]
260 | },
261 | {
262 | "attachments": {},
263 | "cell_type": "markdown",
264 | "metadata": {},
265 | "source": [
266 | "Or we can ask for a _distributional_ constraint requiring that _half_ of the samples include \"amazing\":"
267 | ]
268 | },
269 | {
270 | "cell_type": "code",
271 | "execution_count": null,
272 | "metadata": {},
273 | "outputs": [],
274 | "source": [
275 | "import os\n",
276 | "# disabling parallelism to avoid deadlocks (see warning from HF's tokenizer)\n",
277 | "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\""
278 | ]
279 | },
280 | {
281 | "cell_type": "code",
282 | "execution_count": null,
283 | "metadata": {},
284 | "outputs": [],
285 | "source": [
286 | "ebm = base.constrain([amazing], [1/2])"
287 | ]
288 | },
289 | {
290 | "attachments": {},
291 | "cell_type": "markdown",
292 | "metadata": {},
293 | "source": [
294 | "### Approximating the target EBM"
295 | ]
296 | },
297 | {
298 | "attachments": {},
299 | "cell_type": "markdown",
300 | "metadata": {},
301 | "source": [
302 | "Given an EBM target distribution, we now want to train a model to approximate it so that we can use it to generate samples. In the _unconditional_ case, namely when there is a single fixed context used in generation, then we can use a `Tuner`, more specifically a ```DPGTuner```, as follows."
303 | ]
304 | },
305 | {
306 | "cell_type": "code",
307 | "execution_count": null,
308 | "metadata": {},
309 | "outputs": [],
310 | "source": [
311 | "from disco.tuners.dpg_tuner import DPGTuner"
312 | ]
313 | },
314 | {
315 | "cell_type": "code",
316 | "execution_count": null,
317 | "metadata": {},
318 | "outputs": [],
319 | "source": [
320 | "target_ebm = base.constrain([amazing], [1])\n",
321 | "\n",
322 | "model = LMDistribution(freeze=False)\n",
323 | "incipit = \"It was a cold and stormy night\"\n",
324 | "\n",
325 | "tuner = DPGTuner(model, target_ebm, context=incipit, n_gradient_steps=4)\n",
326 | "tuner.tune()"
327 | ]
328 | },
329 | {
330 | "attachments": {},
331 | "cell_type": "markdown",
332 | "metadata": {},
333 | "source": [
334 | "And we can sample _amazing_ sequences from the tuned model."
335 | ]
336 | },
337 | {
338 | "cell_type": "code",
339 | "execution_count": null,
340 | "metadata": {},
341 | "outputs": [],
342 | "source": [
343 | "samples, log_scores = model.sample(context=incipit)\n",
344 | "for s in samples:\n",
345 | " print(incipit + s.text)"
346 | ]
347 | },
348 | {
349 | "attachments": {},
350 | "cell_type": "markdown",
351 | "metadata": {},
352 | "source": [
353 | "#### Tuning parameters"
354 | ]
355 | },
356 | {
357 | "attachments": {},
358 | "cell_type": "markdown",
359 | "metadata": {},
360 | "source": [
361 | "Important parameters of the `Tuner` include:\n",
362 | "- `n_gradient_steps`: number of total gradient steps in the full tuning process;\n",
363 | "- `n_samples_per_context`: total number of samples used in performing a gradient step (aka batch size);\n",
364 | "- `scoring_size`: number of samples sent in a batch to the `.score` function. This parameter affects training speed or helps solve GPU memory errors, but does not affect final results;\n",
365 | "- `sampling_size`: number of samples obtained from a single call to the `.sample` function. This parameter affects training speed or helps solve GPU memory errors, but does not affect final results;\n",
366 | "- `features`: list of pairs (`name`, `feature`) so that the `feature` moments will be computed by importance sampling (and reported using the key given by `name`);\n",
367 | "- `track_divergence_from_base`: set to True to track the reverse KL divergence from the original model —this requires an additional round of samples' scoring."
368 | ]
369 | },
370 | {
371 | "attachments": {},
372 | "cell_type": "markdown",
373 | "metadata": {},
374 | "source": [
375 | "### Logging"
376 | ]
377 | },
378 | {
379 | "attachments": {},
380 | "cell_type": "markdown",
381 | "metadata": {},
382 | "source": [
383 | "The Tuner reports a number of metrics that are useful to monitor the training progress. A number of `Logger` classes are provided to keep track of these metrics. Basic logging is provided though the console, as follows:"
384 | ]
385 | },
386 | {
387 | "cell_type": "code",
388 | "execution_count": null,
389 | "metadata": {},
390 | "outputs": [],
391 | "source": [
392 | "from disco.tuners.loggers.console import ConsoleLogger"
393 | ]
394 | },
395 | {
396 | "attachments": {},
397 | "cell_type": "markdown",
398 | "metadata": {},
399 | "source": [
400 | "console_logger = ConsoleLogger(tuner)"
401 | ]
402 | },
403 | {
404 | "attachments": {},
405 | "cell_type": "markdown",
406 | "metadata": {},
407 | "source": [
408 | "However, more detailed statistics can be kept through a JSON/WandB/Neptune loggers:"
409 | ]
410 | },
411 | {
412 | "cell_type": "code",
413 | "execution_count": null,
414 | "metadata": {},
415 | "outputs": [],
416 | "source": [
417 | "from disco.tuners.loggers.json import JSONLogger\n",
418 | "from disco.tuners.loggers.neptune import NeptuneLogger\n",
419 | "from disco.tuners.loggers.wandb import WandBLogger\n"
420 | ]
421 | },
422 | {
423 | "cell_type": "code",
424 | "execution_count": null,
425 | "metadata": {},
426 | "outputs": [],
427 | "source": [
428 | "project = \"example_project\"\n",
429 | "name = \"run_01\"\n",
430 | "json_logger = JSONLogger(tuner, project, name)\n",
431 | "neptune_logger = NeptuneLogger(tuner, project, name)\n",
432 | "wandb_logger = WandBLogger(tuner, project, name)"
433 | ]
434 | },
435 | {
436 | "attachments": {},
437 | "cell_type": "markdown",
438 | "metadata": {},
439 | "source": [
440 | "where `project` and `name` refer to the project and run name, respectively."
441 | ]
442 | },
443 | {
444 | "attachments": {},
445 | "cell_type": "markdown",
446 | "metadata": {},
447 | "source": [
448 | "#### Logged Metrics"
449 | ]
450 | },
451 | {
452 | "attachments": {},
453 | "cell_type": "markdown",
454 | "metadata": {},
455 | "source": [
456 | "Loggers store a number of metrics about the training process. Here we list a few of the most relevant ones:\n",
457 | "\n",
458 | "- `kl_target_model` and `kl_target_proposal`: estimates of the forward KL divergence to the target EBM from the tuned model and the proposal distribution, respectively. In the case of using online training, the two are equivalent with the only caveat that `kl_target_model` is computed —this is the metric being optimized, and not the value reported as `loss`;\n",
459 | "- `kl_model_base` and `kl_proposal_base`: estimates of the reverse KL divergence to the original model of the tuned model and the proposal distribution, respectively —only reported if `track_divergence_from_base` is set to True;\n",
460 | "- Feature moments: estimate of the features' moments for those features specified with the `features` parameter at the Tuner's construction time."
461 | ]
462 | },
463 | {
464 | "attachments": {},
465 | "cell_type": "markdown",
466 | "metadata": {},
467 | "source": [
468 | "## Controlled Conditional Generation"
469 | ]
470 | },
471 | {
472 | "attachments": {},
473 | "cell_type": "markdown",
474 | "metadata": {},
475 | "source": [
476 | "The _conditional_ case is superficially very similar, with an extra step needed to instantiate a `ContextDistribution`, which allows to sample contexts that can then be used to condition the model. Furthermore, we use the more general ```CDPGTuner``` class."
477 | ]
478 | },
479 | {
480 | "attachments": {},
481 | "cell_type": "markdown",
482 | "metadata": {},
483 | "source": [
484 | "Assuming we have a file of incipits, one per line, in a `data/incipits.txt` file, we could do:"
485 | ]
486 | },
487 | {
488 | "cell_type": "code",
489 | "execution_count": null,
490 | "metadata": {},
491 | "outputs": [],
492 | "source": [
493 | "from disco.tuners.cdpg_tuner import CDPGTuner\n",
494 | "from disco.distributions.context_distribution import ContextDistribution"
495 | ]
496 | },
497 | {
498 | "cell_type": "code",
499 | "execution_count": null,
500 | "metadata": {},
501 | "outputs": [],
502 | "source": [
503 | "target_ebm = base.constrain([amazing], [1])\n",
504 | "\n",
505 | "model = LMDistribution(freeze=False)\n",
506 | "\n",
507 | "tuner = CDPGTuner(model, target_ebm, n_gradient_steps=4, # use a much higher value for actual tuning\n",
508 | " context_distribution=ContextDistribution(\"data/incipits.txt\"), n_contexts_per_step=2**3)\n",
509 | "tuner.tune()"
510 | ]
511 | },
512 | {
513 | "attachments": {},
514 | "cell_type": "markdown",
515 | "metadata": {},
516 | "source": [
517 | "Note that while we have used a decoder-only model here for illustrative purposes, the real power of the CDPGTuner is that it allows to control _seq2seq models_ such as those used in NMT, summarization, etc... Please refer to the dedicated [tutorial notebook](tutorials/4.conditional_tuning.ipynb) for an example of how to control an actual conditional model."
518 | ]
519 | },
520 | {
521 | "attachments": {},
522 | "cell_type": "markdown",
523 | "metadata": {},
524 | "source": [
525 | "### Monte-Carlo sampling to improve the approximation"
526 | ]
527 | },
528 | {
529 | "attachments": {},
530 | "cell_type": "markdown",
531 | "metadata": {},
532 | "source": [
533 | "After the tuning is done, `model` is now a better approximation to the target EBM, but it is not guaranteed to perfectly match this distribution. While further training can improve the situation, another alternative is using [quasi-rejection sampling (QRS)](https://disco.europe.naverlabs.com/QRS/), a Monte-Carlo sampling technique that allows to trade-off sampling efficiency for a higher fidelity to the target distribution —a higher value of `beta` yields a better fidelity although at a higher computational cost."
534 | ]
535 | },
536 | {
537 | "cell_type": "code",
538 | "execution_count": null,
539 | "metadata": {},
540 | "outputs": [],
541 | "source": [
542 | "from disco.samplers.quasi_rejection_sampler import QuasiRejectionSampler"
543 | ]
544 | },
545 | {
546 | "cell_type": "code",
547 | "execution_count": null,
548 | "metadata": {},
549 | "outputs": [],
550 | "source": [
551 | "beta=0.5\n",
552 | "sampler = QuasiRejectionSampler(target_ebm, model, beta=beta)\n",
553 | "samples, log_scores = sampler.sample(sampling_size=2**7)"
554 | ]
555 | },
556 | {
557 | "attachments": {},
558 | "cell_type": "markdown",
559 | "metadata": {},
560 | "source": [
561 | "### In summary"
562 | ]
563 | },
564 | {
565 | "attachments": {},
566 | "cell_type": "markdown",
567 | "metadata": {},
568 | "source": [
569 | "To put some of this (distributional constraint, tuning in the unconditional case and using QRS) together:"
570 | ]
571 | },
572 | {
573 | "cell_type": "code",
574 | "execution_count": null,
575 | "metadata": {},
576 | "outputs": [],
577 | "source": [
578 | "base = LMDistribution()\n",
579 | "target_ebm = base.constrain([amazing], [1/2])\n",
580 | "\n",
581 | "model = LMDistribution()\n",
582 | "\n",
583 | "tuner = DPGTuner(model, target_ebm)\n",
584 | "tuner.tune()\n",
585 | "\n",
586 | "beta=0.5\n",
587 | "sampler = QuasiRejectionSampler(target_ebm, model, beta=beta)\n",
588 | "samples, log_scores = sampler.sample(context=incipit, sampling_size=2**7)"
589 | ]
590 | },
591 | {
592 | "cell_type": "code",
593 | "execution_count": null,
594 | "metadata": {},
595 | "outputs": [],
596 | "source": []
597 | }
598 | ],
599 | "metadata": {
600 | "kernelspec": {
601 | "display_name": "Python 3.10.6 (conda)",
602 | "language": "python",
603 | "name": "python3"
604 | },
605 | "language_info": {
606 | "codemirror_mode": {
607 | "name": "ipython",
608 | "version": 3
609 | },
610 | "file_extension": ".py",
611 | "mimetype": "text/x-python",
612 | "name": "python",
613 | "nbconvert_exporter": "python",
614 | "pygments_lexer": "ipython3",
615 | "version": "3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:38:29) [Clang 13.0.1 ]"
616 | },
617 | "orig_nbformat": 4,
618 | "vscode": {
619 | "interpreter": {
620 | "hash": "babb4baf4e80bd80b9852210fc5469c0783907e52a560ed7247caef52808358d"
621 | }
622 | }
623 | },
624 | "nbformat": 4,
625 | "nbformat_minor": 2
626 | }
627 |
--------------------------------------------------------------------------------