├── loss_landscapes
├── contrib
│ ├── __init__.py
│ ├── connecting_paths.py
│ └── trajectories.py
├── model_interface
│ ├── __init__.py
│ ├── model_wrapper.py
│ └── model_parameters.py
├── metrics
│ ├── __init__.py
│ ├── metric.py
│ ├── rl_metrics.py
│ └── sl_metrics.py
├── __init__.py
└── main.py
├── requirements.txt
├── MANIFEST.in
├── img
├── loss-contour.png
├── loss-contour-3d.png
└── loss-landscape.png
├── .gitignore
├── setup.cfg
├── LICENCE.txt
├── setup.py
└── README.md
/loss_landscapes/contrib/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/loss_landscapes/model_interface/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | matplotlib
3 | tqdm
4 | torch
5 | torchvision
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | Include the README
2 | include *.md
3 |
4 | # Include the license file
5 | include LICENSE.txt
--------------------------------------------------------------------------------
/img/loss-contour.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/marcellodebernardi/loss-landscapes/HEAD/img/loss-contour.png
--------------------------------------------------------------------------------
/img/loss-contour-3d.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/marcellodebernardi/loss-landscapes/HEAD/img/loss-contour-3d.png
--------------------------------------------------------------------------------
/img/loss-landscape.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/marcellodebernardi/loss-landscapes/HEAD/img/loss-landscape.png
--------------------------------------------------------------------------------
/loss_landscapes/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from loss_landscapes.metrics.metric import Metric, MetricPipeline
2 | from loss_landscapes.metrics.rl_metrics import ExpectedReturnMetric
3 | from loss_landscapes.metrics.sl_metrics import Loss, LossGradient, LossPerturbations
4 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # environments
2 | env/
3 | ENV/
4 | Env/
5 |
6 | # data
7 | data/
8 |
9 | # debugging files
10 | tests/paste.txt
11 |
12 | # jupyter notebook checkpoints
13 | .ipynb_checkpoints/
14 |
15 | # pip
16 | loss_landscapes.egg-info/
17 |
18 | # dist
19 | dist/
20 | build/
21 |
22 | # tests
23 | tests/
24 |
--------------------------------------------------------------------------------
/loss_landscapes/__init__.py:
--------------------------------------------------------------------------------
1 | from loss_landscapes.main import point
2 | from loss_landscapes.main import linear_interpolation
3 | from loss_landscapes.main import random_line
4 | from loss_landscapes.main import planar_interpolation
5 | from loss_landscapes.main import random_plane
6 | from loss_landscapes.model_interface.model_wrapper import ModelWrapper, GeneralModelWrapper
7 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | # This includes the license file(s) in the wheel.
3 | license_files = LICENSE.txt
4 |
5 | [bdist_wheel]
6 |
7 |
8 | # support. Removing this line (or setting universal to 0) will prevent
9 |
10 | # bdist_wheel from trying to make a universal wheel. For more see:
11 |
12 | # https://packaging.python.org/guides/distributing-packages-using-setuptools/#wheels
13 |
14 | universal=0
--------------------------------------------------------------------------------
/loss_landscapes/metrics/metric.py:
--------------------------------------------------------------------------------
1 | """ Base classes for model evaluation metrics. """
2 |
3 | from abc import ABC, abstractmethod
4 | from loss_landscapes.model_interface.model_wrapper import ModelWrapper
5 |
6 |
7 | class Metric(ABC):
8 | """ A quantity that can be computed given a model or an agent. """
9 |
10 | def __init__(self):
11 | super().__init__()
12 |
13 | @abstractmethod
14 | def __call__(self, model_wrapper: ModelWrapper):
15 | pass
16 |
17 |
18 | class MetricPipeline(Metric):
19 | """ A sequence of metrics to be computed in order, given a model or an agent. """
20 |
21 | def __init__(self, metrics: list):
22 | super().__init__()
23 | self.metrics = metrics
24 |
25 | def __call__(self, model_wrapper: ModelWrapper) -> tuple:
26 | return tuple([metric(model_wrapper) for metric in self.metrics])
27 |
--------------------------------------------------------------------------------
/LICENCE.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) [2019] [Marcello De Bernardi]
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/loss_landscapes/metrics/rl_metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.autograd
3 | from loss_landscapes.metrics.metric import Metric
4 |
5 |
6 | class ExpectedReturnMetric(Metric):
7 | def __init__(self, gym_environment, n_episodes):
8 | super().__init__()
9 | self.gym_environment = gym_environment
10 | self.n_episodes = n_episodes
11 |
12 | def __call__(self, agent):
13 | returns = []
14 |
15 | # compute total return for each episode
16 | for episode in range(self.n_episodes):
17 | episode_return = 0
18 | obs, reward, done, _ = self.gym_environment.step(
19 | agent(torch.from_numpy(self.gym_environment.reset()).float())
20 | )
21 | episode_return += reward
22 |
23 | while not done:
24 | obs, reward, done, info = self.gym_environment.step(
25 | agent(torch.from_numpy(obs).float())
26 | )
27 | episode_return += reward
28 | returns.append(episode_return)
29 |
30 | # return average of episode returns
31 | return sum(returns) / len(returns)
32 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 | from os import path
3 |
4 | # Get the long description from the README file
5 | with open(path.join(path.abspath(path.dirname(__file__)), 'README.md'), encoding='utf-8') as f:
6 | long_description = f.read()
7 |
8 | setup(
9 | name='loss_landscapes',
10 | version='3.0.7',
11 | packages=find_packages(exclude='tests'),
12 | url='https://github.com/marcellodebernardi/loss-landscapes',
13 | license='MIT',
14 | author='Marcello De Bernardi',
15 | author_email='marcello.debernardi@stcatz.ox.ac.uk',
16 | description='A library for approximating loss landscapes in low-dimensional parameter subspaces',
17 | long_description=long_description,
18 | long_description_content_type='text/markdown',
19 | python_requires='>=3.5',
20 | install_requires=['numpy'],
21 | classifiers=[
22 | 'Development Status :: 4 - Beta',
23 | 'Intended Audience :: Developers',
24 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
25 | 'License :: OSI Approved :: MIT License',
26 | 'Programming Language :: Python :: 3.5',
27 | 'Programming Language :: Python :: 3.6',
28 | 'Programming Language :: Python :: 3.7',
29 | ],
30 | )
31 |
--------------------------------------------------------------------------------
/loss_landscapes/model_interface/model_wrapper.py:
--------------------------------------------------------------------------------
1 | """ Class used to define interface to complex models """
2 |
3 | import abc
4 | import itertools
5 | import torch.nn
6 | from loss_landscapes.model_interface.model_parameters import ModelParameters
7 |
8 |
9 | class ModelWrapper(abc.ABC):
10 | def __init__(self, modules: list):
11 | self.modules = modules
12 |
13 | def get_modules(self) -> list:
14 | return self.modules
15 |
16 | def get_module_parameters(self) -> ModelParameters:
17 | return ModelParameters([p for module in self.modules for p in module.parameters()])
18 |
19 | def train(self, mode=True) -> 'ModelWrapper':
20 | for module in self.modules:
21 | module.train(mode)
22 | return self
23 |
24 | def eval(self) -> 'ModelWrapper':
25 | return self.train(False)
26 |
27 | def requires_grad_(self, requires_grad=True) -> 'ModelWrapper':
28 | for module in self.modules:
29 | for p in module.parameters():
30 | p.requires_grad = requires_grad
31 | return self
32 |
33 | def zero_grad(self) -> 'ModelWrapper':
34 | for module in self.modules:
35 | for p in module.parameters():
36 | if p.grad is not None:
37 | p.grad.detach_()
38 | p.grad.zero_()
39 | return self
40 |
41 | def parameters(self):
42 | return itertools.chain([module.parameters() for module in self.modules])
43 |
44 | def named_parameters(self):
45 | return itertools.chain([module.named_parameters() for module in self.modules])
46 |
47 | @abc.abstractmethod
48 | def forward(self, x):
49 | pass
50 |
51 |
52 | class SimpleModelWrapper(ModelWrapper):
53 | def __init__(self, model: torch.nn.Module):
54 | super().__init__([model])
55 |
56 | def forward(self, x):
57 | return self.modules[0](x)
58 |
59 |
60 | class GeneralModelWrapper(ModelWrapper):
61 | def __init__(self, model, modules: list, forward_fn):
62 | super().__init__(modules)
63 | self.model = model
64 | self.forward_fn = forward_fn
65 |
66 | def forward(self, x):
67 | return self.forward_fn(self.model, x)
68 |
69 |
70 | def wrap_model(model):
71 | if isinstance(model, ModelWrapper):
72 | return model.requires_grad_(False)
73 | elif isinstance(model, torch.nn.Module):
74 | return SimpleModelWrapper(model).requires_grad_(False)
75 | else:
76 | raise ValueError('Only models of type torch.nn.modules.module.Module can be passed without a wrapper.')
77 |
--------------------------------------------------------------------------------
/loss_landscapes/contrib/connecting_paths.py:
--------------------------------------------------------------------------------
1 | """
2 | This module exposes functions for loss landscape operations which are more complex than simply
3 | computing the loss at different points in parameter space. This includes things such as Kolsbjerg
4 | et al.'s Automated Nudged Elastic Band algorithm.
5 | """
6 |
7 |
8 | import abc
9 | import copy
10 | import numpy as np
11 | from loss_landscapes.model_interface.model_interface import wrap_model
12 |
13 |
14 | class _ParametricCurve(abc.ABC):
15 | """ A _ParametricCurve is used in the Garipov path search algorithm. """
16 | # todo
17 |
18 |
19 | class _PolygonChain(_ParametricCurve):
20 | """ A _ParametricCurve consisting of consecutive line segments. """
21 | # todo
22 | pass
23 |
24 |
25 | class _BezierCurve(_ParametricCurve):
26 | """
27 | A Bezier curve is a parametric curve defined by a set of control points, including
28 | a start point and an end-point. The order of the curve refers to the number of control
29 | points excluding the start point: for example, an order 1 (linear) Bezier curve is
30 | defined by 2 points, an order 2 (quadratic) Bezier curve is defined by 3 points, and
31 | so on.
32 |
33 | In this library, each point is a neural network model with a specific value assignment
34 | to the model parameters.
35 | """
36 | def __init__(self, model_start, model_end, order=2):
37 | """
38 | Define a Bezier curve between a start point and an end point. The order of the
39 | curve refers to the number of control points, excluding the start point. The default
40 | order of 1, for example, results in no further control points being added after
41 | the given start and end points.
42 |
43 | :param model_start: point defining start of curve
44 | :param model_end: point defining end of curve
45 | :param order: number of control points, excluding start point
46 | """
47 | super().__init__()
48 | if order != 2:
49 | raise NotImplementedError('Currently only order 2 bezier curves are supported.')
50 |
51 | self.model_start_wrapper = wrap_model(copy.deepcopy(model_start))
52 | self.model_end_wrapper = wrap_model(copy.deepcopy(model_end))
53 | self.order = order
54 | self.control_points = []
55 |
56 | # add intermediate control points
57 | if order > 1:
58 | start_parameters = self.model_start_wrapper.get_parameter_tensor()
59 | end_parameters = self.model_end_wrapper.get_parameter_tensor()
60 | direction = (end_parameters - start_parameters) / order
61 |
62 | for i in range(1, order):
63 | model_template_wrapper = copy.deepcopy(self.model_start_wrapper)
64 | model_template_wrapper.set_parameter_tensor(start_parameters + (direction * i))
65 | self.control_points.append(model_template_wrapper)
66 |
67 | def fit(self):
68 | # todo
69 | raise NotImplementedError()
70 |
71 |
72 | def auto_neb() -> np.ndarray:
73 | """ Automatic Nudged Elastic Band algorithm, as used in https://arxiv.org/abs/1803.00885 """
74 | # todo return list of points in parameter space to represent trajectory
75 | # todo figure out how to return points as coordinates in 2D
76 | raise NotImplementedError()
77 |
78 |
79 | def garipov_curve_search(model_a, model_b, curve_type='polygon_chain') -> np.ndarray:
80 | """
81 | We refer by 'Garipov curve search' to the algorithm proposed by Garipov et al (2018) for
82 | finding low-loss paths between two arbitrary minima in a loss landscape. The core idea
83 | of the method is to define a parametric curve in the model's parameter space connecting
84 | one minima to the other, and then minimizing the expected loss along this curve by
85 | modifying its parameterization. For details, see https://arxiv.org/abs/1802.10026
86 |
87 | This is an alternative to the auto_neb algorithm.
88 | """
89 | model_a_wrapper = wrap_model(model_a)
90 | model_b_wrapper = wrap_model(model_b)
91 |
92 | point_a = model_a_wrapper.get_parameter_tensor()
93 | point_b = model_b_wrapper.get_parameter_tensor()
94 |
95 | # todo
96 | if curve_type == 'polygon_chain':
97 | raise NotImplementedError('Not implemented yet.')
98 | elif curve_type == 'bezier_curve':
99 | raise NotImplementedError('Not implemented yet.')
100 | else:
101 | raise AttributeError('Curve type is not polygon_chain or bezier_curve.')
102 |
--------------------------------------------------------------------------------
/loss_landscapes/contrib/trajectories.py:
--------------------------------------------------------------------------------
1 | """
2 | Classes and functions for tracking a model's optimization trajectory and computing
3 | a low-dimensional approximation of the trajectory.
4 | """
5 |
6 |
7 | from abc import ABC, abstractmethod
8 | from datetime import datetime
9 | import numpy as np
10 | from loss_landscapes.model_interface.model_interface import wrap_model
11 |
12 |
13 | class TrajectoryTracker(ABC):
14 | """
15 | A TrajectoryTracker facilitates tracking the optimization trajectory of a
16 | DL/RL model. Trajectory trackers provide facilities for storing model parameters
17 | as well as for retrieving and operating on stored parameters.
18 | """
19 |
20 | @abstractmethod
21 | def __getitem__(self, timestep) -> np.ndarray:
22 | """
23 | Returns the position of the model from the given training timestep as a numpy array.
24 | :param timestep: training step of parameters to retrieve
25 | :return: numpy array
26 | """
27 | pass
28 |
29 | @abstractmethod
30 | def get_item(self, timestep) -> np.ndarray:
31 | """
32 | Returns the position of the model from the given training timestep as a numpy array.
33 | :param timestep: training step of parameters to retrieve
34 | :return: numpy array
35 | """
36 | pass
37 |
38 | @abstractmethod
39 | def get_trajectory(self) -> list:
40 | """
41 | Returns a reference to the currently stored trajectory.
42 | :return: numpy array
43 | """
44 | pass
45 |
46 | @abstractmethod
47 | def save_position(self, model):
48 | """
49 | Appends the current model parameterization to the stored training trajectory.
50 | :param model: model object with current state of interest
51 | :return: N/A
52 | """
53 | pass
54 |
55 |
56 | class FullTrajectoryTracker(TrajectoryTracker):
57 | """
58 | A FullTrajectoryTracker is a tracker which stores a history of points in the tracked
59 | model's original parameter space, and can be used to perform a variety of computations
60 | on the trajectory. The tracker spills data into storage rather than keeping everything
61 | in main memory.
62 | """
63 | def __init__(self, model, agent_interface=None, directory='./', experiment_name=None):
64 | super().__init__()
65 | self.dir = directory + (experiment_name if experiment_name is not None else str(datetime.now()) + '/')
66 | self.next_idx = 0
67 | self.save_position(model)
68 | self.agent_interface = agent_interface
69 |
70 | def __getitem__(self, timestep) -> np.ndarray:
71 | if not (1 <= timestep < self.next_idx):
72 | raise IndexError('Given timestep does not exist.')
73 | return np.load(self.dir + str(timestep) + '.npy')
74 |
75 | def get_item(self, timestep) -> np.ndarray:
76 | return self.__getitem__(timestep)
77 |
78 | def save_position(self, model):
79 | np.save(self.dir + str(self.next_idx) + '.npy', wrap_model(model, self.agent_interface).get_parameter_tensor(deepcopy=True).as_numpy())
80 | self.next_idx += 1
81 |
82 | def get_trajectory(self) -> list:
83 | """
84 | WARNING: be aware that full trajectory tracking requires N * M memory, where N is the
85 | number of iterations tracked and M is the size of the model. The amount of memory used
86 | by the trajectory tracker can easily become very large.
87 | :return: list of numpy arrays
88 | """
89 | return [self[idx] for idx in range(self.next_idx)]
90 |
91 |
92 | class ProjectingTrajectoryTracker(TrajectoryTracker):
93 | """
94 | A ProjectingTrajectoryTracker is a tracker which applies dimensionality reduction to
95 | all model parameterizations upon storage. This is particularly appropriate for large
96 | models, where storing a history of points in the model's parameter space would be
97 | unfeasible in terms of memory.
98 | """
99 | def __init__(self, model, agent_interface=None, n_bases=2):
100 | super().__init__()
101 | self.trajectory = []
102 | self.agent_interface = agent_interface
103 |
104 | n = wrap_model(model, agent_interface).get_parameter_tensor().numel()
105 | self.A = np.column_stack(
106 | [np.random.normal(size=n) for _ in range(n_bases)]
107 | )
108 |
109 | def __getitem__(self, timestep) -> np.ndarray:
110 | return self.trajectory[timestep]
111 |
112 | def get_item(self, timestep) -> np.ndarray:
113 | return self.__getitem__(timestep)
114 |
115 | def get_trajectory(self) -> list:
116 | return self.trajectory
117 |
118 | def save_position(self, model):
119 | # we solve the equation Ax = b using least squares, where A is the matrix of basis vectors
120 | b = wrap_model(model, self.agent_interface).get_parameter_tensor().as_numpy()
121 | self.trajectory.append(np.linalg.lstsq(self.A, b, rcond=None)[0])
122 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # loss-landscapes
2 |
3 | `loss-landscapes` is a PyTorch library for approximating neural network loss functions, and other related metrics,
4 | in low-dimensional subspaces of the model's parameter space. The library makes the production of visualizations
5 | such as those seen in [Visualizing the Loss Landscape of Neural Nets](https://arxiv.org/abs/1712.09913v3) much
6 | easier, aiding the analysis of the geometry of neural network loss landscapes.
7 |
8 | This library does not provide plotting facilities, letting the user define how the data should be plotted. Other
9 | deep learning frameworks are not supported, though a TensorFlow version, `loss-landscapes-tf`, is planned for
10 | a future release.
11 |
12 | **NOTE: this library is in early development. Bugs are virtually a certainty, and the API is volatile. Do not use
13 | this library in production code. For prototyping and research, always use the newest version of the library.**
14 |
15 |
16 | ## 1. What is a Loss Landscape?
17 | Let `L : Parameters -> Real Numbers` be a loss function, which maps a point in the model parameter space to a
18 | real number. For a neural network with `n` parameters, the loss function `L` takes an `n`-dimensional input. We
19 | can define the loss landscape as the set of all `n+1`-dimensional points `(param, L(param))`, for all points
20 | `param` in the parameter space. For example, the image below, reproduced from the paper by Li et al (2018), link
21 | above, provides a visual representation of what a loss function over a two-dimensional parameter space might look
22 | like:
23 |
24 |

25 |
26 | Of course, real machine learning models have a number of parameters much greater than 2, so the parameter space of
27 | the model is virtually never two-dimensional. Because we can't print visualizations in more than two dimensions,
28 | we cannot hope to visualize the "true" shape of the loss landscape. Instead, a number of techniques
29 | exist for reducing the parameter space to one or two dimensions, ranging from dimensionality reduction techniques
30 | like PCA, to restricting ourselves to a particular subspace of the overall parameter space. For more details,
31 | read Li et al's paper.
32 |
33 |
34 | ## 2. Base Example: Supervised Loss in Parameter Subspaces
35 | The simplest use case for `loss-landscapes` is to estimate the value of a supervised loss function in a subspace
36 | of a neural network's parameter space. The subspace in question may be a point, a line, or a plane (these subspaces
37 | can be meaningfully visualized). Suppose the user has trained a supervised learning model, of type `torch.nn.Module`,
38 | on a dataset consisting of samples `X` and labels `y`, by minimizing some loss function. The user now wishes to
39 | produce a surface plot alike to the one in section 1.
40 |
41 | This is accomplished as follows:
42 |
43 | ````python
44 | metric = Loss(loss_function, X, y)
45 | landscape = random_plane(model, metric, normalize='filter')
46 | ````
47 |
48 | As seen in the example above, the two core concepts in `loss-landscapes` are _metrics_ and _parameter subspaces_. The
49 | latter define the section of parameter space to be considered, while the former define what quantity is evaluated at
50 | each considered point in parameter space, and how it is computed. In the example above, we define a `Loss` metric
51 | over data `X` and labels `y`, and instruct `loss_landscape` to evaluate it in a randomly generated planar subspace.
52 |
53 | This would return a 2-dimensional array of loss values, which the user can plot in any desirable way. Example
54 | visualizations the user might use this type of data for are shown below.
55 |
56 | 
57 |
58 | 
59 |
60 | Check the `examples` directory for `jupyter` notebooks with more in-depth examples of what is possible.
61 |
62 |
63 | ## 3. Metrics and Custom Metrics
64 | The `loss-landscapes` library can compute any quantity of interest at a collection of points in a parameter subspace,
65 | not just loss. This is accomplished using a `Metric`: a callable object which applies a pre-determined function,
66 | such as a cross entropy loss with a specific set of inputs and outputs, to the model. The `loss_landscapes.model_metrics`
67 | package contains a number of metrics that cover common use cases, such as `Loss` (evaluates a loss
68 | function), `LossGradient` (evaluates the gradient of the loss w.r.t. the model parameters),
69 | `PrincipalCurvatureEvaluator` (evaluates the principal curvatures of the loss function), and more.
70 |
71 | Furthermore, the user can add custom metrics by subclassing `Metric`. As an example, consider the library
72 | implementation of `Loss`, for `torch` models:
73 |
74 | ````python
75 | class Metric(abc.ABC):
76 | """ A quantity that can be computed given a model or an agent. """
77 |
78 | def __init__(self):
79 | super().__init__()
80 |
81 | @abc.abstractmethod
82 | def __call__(self, model_wrapper: ModelWrapper):
83 | pass
84 |
85 |
86 | class Loss(Metric):
87 | """ Computes a specified loss function over specified input-output pairs. """
88 | def __init__(self, loss_fn, inputs: torch.Tensor, target: torch.Tensor):
89 | super().__init__()
90 | self.loss_fn = loss_fn
91 | self.inputs = inputs
92 | self.target = target
93 |
94 | def __call__(self, model_wrapper: ModelWrapper) -> float:
95 | return self.loss_fn(model_wrapper.forward(self.inputs), self.target).item()
96 | ````
97 |
98 | The user may create custom `Metric`s in a similar manner. One complication is that the `Metric` class'
99 | `__call__` method is designed to take as input a `ModelWrapper` rather than a model. This class is internal
100 | to the library and exists to facilitate the handling of the myriad of different models a user may pass as
101 | inputs to a function such as `loss_landscapes.planar_interpolation()`. It is sufficient for the user to know
102 | that a `ModelWrapper` is a callable object that can be used to call the model on a given input (see the `call_fn`
103 | argument of the `ModelInterface` class in the next section). The class also provides a `get_model()` method
104 | that exposes a reference to the underlying model, should the user wish to carry out more complicated operations
105 | on it.
106 |
107 | In summary, the `Metric` abstraction adds a great degree of flexibility. An metric defines what quantity
108 | dependent on model parameters the user is interested in evaluating, and how to evaluate it. The user could define,
109 | for example, a metric that computes an estimate of the expected return of a reinforcement learning agent.
110 |
111 |
112 | ## 4. More Complex Models
113 | In the general case of a simple supervised learning model, as in the sections above, client code calls functions
114 | such as `loss_landscapes.linear_interpolation` and passes as argument a PyTorch module of type `torch.nn.Module`.
115 |
116 | For more complex cases, such as when the user wants to evaluate the loss landscape as a function of a subset of
117 | the model parameters, or the expected return landscape for a RL agent, the user must specify to the `loss-landscapes`
118 | library how to interface with the model (or the agent, on a more general level). This is accomplished using a
119 | `ModelWrapper` object, which hides the implementation details of the model or agent. For general use, the library
120 | supplies the `GeneralModelWrapper` in the `loss_landscapes.model_interface.model_wrapper` module.
121 |
122 | Assume the user wishes to estimate the expected return of some RL agent which provides an `agent.act(observation)`
123 | method for action selection. Then, the example from section 2 becomes as follows:
124 |
125 | ````python
126 | metric = ExpectedReturnMetric(env, n_samples)
127 | agent_wrapper = GeneralModelWrapper(agent, [agent.q_function, agent.policy], lambda agent, x: agent.act(x))
128 | landscape = random_plane(agent_wrapper, metric, normalize='filter')
129 | ````
130 |
131 |
132 |
133 | ## 5. WIP: Connecting Paths, Saddle Points, and Trajectory Tracking
134 | A number of features are currently under development, but as of yet incomplete.
135 |
136 | A number of papers in recent years have shown that loss landscapes of neural networks are dominated by a
137 | proliferation of saddle points, that good solutions are better described as large low-loss plateaus than as
138 | "well-bottom" points, and that for sufficiently high-dimensional networks, a low-loss path in parameter space can
139 | be found between almost any arbitrary pair of minima. In the future, the `loss-landscapes` library will feature
140 | implementations of algorithms for finding such low-loss connecting paths in the loss landscape, as well as tools to
141 | facilitate the study of saddle points.
142 |
143 | Some sort of trajectory tracking features are also under consideration, though at the time it's unclear what this
144 | should actually mean, as the optimization trajectory is implicitly tracked by the user's training loop. Any metric
145 | along the optimization trajectory can be tracked with libraries such as [ignite](https://github.com/pytorch/ignite)
146 | for PyTorch.
147 |
148 |
149 | ## 6. Support for Other DL Libraries
150 | The `loss-landscapes` library was initially designed to be agnostic to the DL framework in use. However, with the
151 | increasing number of use cases to cover it became obvious that maintaining the original library-agnostic design
152 | was adding too much complexity to the code.
153 |
154 | A TensorFlow version, `loss-landscapes-tf`, is planned for the future.
155 |
156 |
157 | ## 7. Installation and Use
158 | The package is available on PyPI. Install using `pip install loss-landscapes`. To use the library, import as follows:
159 |
160 | ````python
161 | import loss_landscapes
162 | import loss_landscapes.metrics
163 | ````
--------------------------------------------------------------------------------
/loss_landscapes/metrics/sl_metrics.py:
--------------------------------------------------------------------------------
1 | """
2 | A library of pre-written evaluation functions for PyTorch loss functions.
3 |
4 | The classes and functions in this module cover common loss landscape evaluations. In particular,
5 | computing the loss, the gradient of the loss (w.r.t. model parameters) and Hessian of the loss
6 | (w.r.t. model parameters) for some supervised learning loss is easily accomplished.
7 | """
8 |
9 |
10 | import numpy as np
11 | import torch
12 | import torch.autograd
13 | from loss_landscapes.metrics.metric import Metric
14 | from loss_landscapes.model_interface.model_parameters import rand_u_like
15 | from loss_landscapes.model_interface.model_wrapper import ModelWrapper
16 |
17 |
18 | class Loss(Metric):
19 | """ Computes a specified loss function over specified input-output pairs. """
20 | def __init__(self, loss_fn, inputs: torch.Tensor, target: torch.Tensor):
21 | super().__init__()
22 | self.loss_fn = loss_fn
23 | self.inputs = inputs
24 | self.target = target
25 |
26 | def __call__(self, model_wrapper: ModelWrapper) -> float:
27 | return self.loss_fn(model_wrapper.forward(self.inputs), self.target).item()
28 |
29 |
30 | class LossGradient(Metric):
31 | """ Computes the gradient of a specified loss function w.r.t. the model parameters
32 | over specified input-output pairs. """
33 | def __init__(self, loss_fn, inputs: torch.Tensor, target: torch.Tensor):
34 | super().__init__()
35 | self.loss_fn = loss_fn
36 | self.inputs = inputs
37 | self.target = target
38 |
39 | def __call__(self, model_wrapper: ModelWrapper) -> np.ndarray:
40 | loss = self.loss_fn(model_wrapper.forward(self.inputs), self.target)
41 | gradient = torch.autograd.grad(loss, model_wrapper.named_parameters()).detach().numpy()
42 | model_wrapper.zero_grad()
43 | return gradient
44 |
45 |
46 | class LossPerturbations(Metric):
47 | """ Computes random perturbations in the loss value along a sample or random directions.
48 | These perturbations can be used to reason probabilistically about the curvature of a
49 | point on the loss landscape, as demonstrated in the paper by Schuurmans et al
50 | (https://arxiv.org/abs/1811.11214)."""
51 | def __init__(self, loss_fn, inputs: torch.Tensor, target: torch.Tensor, n_directions, alpha):
52 | super().__init__()
53 | self.loss_fn = loss_fn
54 | self.inputs = inputs
55 | self.target = target
56 | self.n_directions = n_directions
57 | self.alpha = alpha
58 |
59 | def __call__(self, model_wrapper: ModelWrapper) -> np.ndarray:
60 | # start point and directions
61 | start_point = model_wrapper.get_module_parameters()
62 | start_loss = self.loss_fn(model_wrapper.forward(self.inputs), self.target).item()
63 |
64 | # compute start loss and perturbed losses
65 | results = []
66 | for idx in range(self.n_directions):
67 | direction = rand_u_like(start_point)
68 | start_point.add_(direction)
69 |
70 | loss = self.loss_fn(model_wrapper.forward(self.inputs), self.target).item()
71 | results.append(loss - start_loss)
72 |
73 | start_point.sub_(direction)
74 |
75 | return np.array(results)
76 |
77 |
78 | # noinspection DuplicatedCode
79 | # class GradientPredictivenessEvaluator(Metric):
80 | # """
81 | # Computes the L2 norm of the distance between loss gradients at consecutive
82 | # iterations. We consider a gradient to be predictive if a move in the direction
83 | # of the gradient results in a similar gradient at the next step; that is, the
84 | # gradients of the loss change smoothly along the optimization trajectory.
85 | #
86 | # This evaluator is inspired by experiments ran by Santurkar et al (2018), for
87 | # details see https://arxiv.org/abs/1805.11604
88 | # """
89 | # def __init__(self, supervised_loss_fn, inputs, target):
90 | # super().__init__(None, None, None)
91 | # self.gradient_evaluator = GradientEvaluator(supervised_loss_fn, inputs, target)
92 | # self.previous_gradient = None
93 | #
94 | # def __call__(self, model) -> float:
95 | # if self.previous_gradient is None:
96 | # self.previous_gradient = self.gradient_evaluator(model)
97 | # return 0.0
98 | # else:
99 | # current_grad = self.gradient_evaluator(model)
100 | # previous_grad = self.previous_gradient
101 | # self.previous_gradient = current_grad
102 | # # return l2 distance of current and previous gradients
103 | # return np.linalg.norm(current_grad - previous_grad, ord=2)
104 | #
105 | #
106 | # class BetaSmoothnessEvaluator(Metric):
107 | # """
108 | # Computes the "beta-smoothness" of the gradients, as characterized by
109 | # Santurkar et al (2018). The beta-smoothness of a function at any given point
110 | # is the ratio of the magnitude of the change in its gradients, over the magnitude
111 | # of the change in input. In the case of loss landscapes, it is the ratio of the
112 | # magnitude of the change in loss gradients over the magnitude of the change in
113 | # parameters. In general, we call a function f beta-smooth if
114 | #
115 | # |f'(x) - f'(y)| < beta|x - y|
116 | #
117 | # i.e. if there exists an upper bound beta on the ratio between change in gradients
118 | # and change in input. Santurkar et al call "effective beta-smoothness" the maximum
119 | # encountered ratio along some optimization trajectory.
120 | #
121 | # This evaluator is inspired by experiments ran by Santurkar et al (2018), for
122 | # details see https://arxiv.org/abs/1805.11604
123 | # """
124 | #
125 | # def __init__(self, supervised_loss_fn, inputs, target):
126 | # super().__init__(None, None, None)
127 | # self.gradient_evaluator = GradientEvaluator(supervised_loss_fn, inputs, target)
128 | # self.previous_gradient = None
129 | # self.previous_parameters = None
130 | #
131 | # def __call__(self, model):
132 | # if self.previous_parameters is None:
133 | # self.previous_gradient = self.gradient_evaluator(model)
134 | # self.previous_parameters = TorchModelWrapper(model).get_parameter_tensor().numpy()
135 | # return 0.0
136 | # else:
137 | # current_grad = self.gradient_evaluator(model)
138 | # current_p = TorchModelWrapper(model).get_parameter_tensor().numpy()
139 | # previous_grad = self.previous_gradient
140 | # previous_p = self.previous_parameters
141 | #
142 | # self.previous_gradient = current_grad
143 | # self.previous_parameters = current_p
144 | # # return l2 distance of current and previous gradients
145 | # return np.linalg.norm(current_grad - previous_grad, ord=2) / np.linalg.norm(current_p - previous_p, ord=2)
146 |
147 |
148 | # todo - these are complicated by the fact that hessian matrix is of size O(n^2) in the number of NN params
149 | # ideally there would be a way to compute the eigenvalues incrementally, without computing the whole hessian
150 | # matrix first.
151 |
152 | # class HessianEvaluator(SupervisedTorchEvaluator):
153 | # """
154 | # Computes the Hessian of a specified loss function w.r.t. the model
155 | # parameters over specified input-output pairs.
156 | # """
157 | # def __init__(self, supervised_loss_fn, inputs, target):
158 | # super().__init__(supervised_loss_fn, inputs, target)
159 | #
160 | # def __call__(self, model) -> np.ndarray:
161 | # loss = self.loss_fn(model(self.inputs), self.target)
162 | # gradient = torch.autograd.grad(loss, [p for _, p in model.named_parameters()], create_graph=True)
163 | # gradient = torch.cat(tuple([p.view(-1) for p in gradient]))
164 | # numel = sum([param.numel() for param in gradient])
165 | #
166 | # # for computing higher-order gradients, see https://github.com/pytorch/pytorch/releases/tag/v0.2.0
167 | # hessian = torch.zeros(size=(numel, numel))
168 | #
169 | # for derivative, idx in enumerate(gradient, 0):
170 | # hessian[idx] = torch.autograd.grad(torch.tensor(derivative), [p.view(-1) for _, p in model.named_parameters()])
171 | #
172 | # return hessian.detach().numpy()
173 | #
174 | #
175 | # class PrincipalCurvaturesEvaluator(SupervisedTorchEvaluator):
176 | # """
177 | # Computes the principal curvatures of a specified loss function over
178 | # specified input-output pairs. The principal curvatures are the
179 | # eigenvalues of the Hessian matrix.
180 | # """
181 | # def __init__(self, supervised_loss_fn, inputs, target):
182 | # super().__init__(None, None, None)
183 | # self.hessian_evaluator = HessianEvaluator(supervised_loss_fn, inputs, target)
184 | #
185 | # def __call__(self, model) -> np.ndarray:
186 | # return np.linalg.eigvals(self.hessian_evaluator(model))
187 | #
188 | #
189 | # class CurvaturePositivityEvaluator(SupervisedTorchEvaluator):
190 | # """
191 | # Computes the extent of the positivity of a loss function's curvature at a
192 | # specific point in parameter space. The extent of positivity is measured as
193 | # the fraction of dimensions with positive curvature. Optionally, dimensions
194 | # can be weighted by the magnitude of their curvature.
195 | #
196 | # Inspired by a related metric in the paper by Li et al,
197 | # http://papers.nips.cc/paper/7875-visualizing-the-loss-landscape-of-neural-nets.
198 | # """
199 | # def __init__(self, supervised_loss_fn, inputs, target, weighted=False):
200 | # super().__init__(None, None, None)
201 | # self.curvatures_evaluator = PrincipalCurvaturesEvaluator(supervised_loss_fn, inputs, target)
202 | # self.weighted = weighted
203 | #
204 | # def __call__(self, model) -> np.ndarray:
205 | # curvatures = self.curvatures_evaluator(model)
206 | # # ratio of sum of all positive curvatures over sum of all negative curvatures
207 | # if self.weighted:
208 | # positive_total = curvatures[(curvatures >= 0)].sum()
209 | # negative_total = np.abs(curvatures[(curvatures < 0)].sum())
210 | # return positive_total / negative_total
211 | # # fraction of dimensions with positive curvature
212 | # else:
213 | # return np.array((curvatures >= 0).sum() / curvatures.size())
214 |
--------------------------------------------------------------------------------
/loss_landscapes/model_interface/model_parameters.py:
--------------------------------------------------------------------------------
1 | """
2 | Basic linear algebra operations as defined on the parameter sets of entire models.
3 |
4 | We can think of these list as a single vectors consisting of all the individual
5 | parameter values. The functions in this module implement basic linear algebra
6 | operations on such lists.
7 |
8 | The operations defined in the module follow the PyTorch convention of appending
9 | the '__' suffix to the name of in-place operations.
10 | """
11 |
12 | import copy
13 | import math
14 | import numpy as np
15 | import torch
16 | import torch.nn
17 |
18 |
19 | class ModelParameters:
20 | """
21 | A ModelParameters object is an abstract view of a model's optimizable parameters as a tensor. This class
22 | enables the parameters of models of the same 'shape' (architecture) to be operated on as if they were 'real'
23 | tensors. A ModelParameters object cannot be converted to a true tensor as it is potentially irregularly
24 | shaped.
25 | """
26 |
27 | def __init__(self, parameters: list):
28 | if not isinstance(parameters, list) and all(isinstance(p, torch.Tensor) for p in parameters):
29 | raise AttributeError('Argument to ModelParameter is not a list of torch.Tensor objects.')
30 |
31 | self.parameters = parameters
32 |
33 | def __len__(self) -> int:
34 | """
35 | Returns the number of model layers within the parameter tensor.
36 | :return: number of layer tensors
37 | """
38 | return len(self.parameters)
39 |
40 | def numel(self) -> int:
41 | """
42 | Returns the number of elements (i.e. individual parameters) within the tensor.
43 | Note that this refers to individual parameters, not layers.
44 | :return: number of elements in tensor
45 | """
46 | return sum(p.numel() for p in self.parameters)
47 |
48 | def __getitem__(self, index) -> torch.nn.Parameter:
49 | """
50 | Returns the tensor of the layer at the given index.
51 | :param index: layer index
52 | :return: tensor of layer
53 | """
54 | return self.parameters[index]
55 |
56 | def __eq__(self, other: 'ModelParameters') -> bool:
57 | """
58 | Compares this parameter tensor for equality with the argument tensor, using the == operator.
59 | :param other: the object to compare to
60 | :return: true if equal
61 | """
62 | if not isinstance(other, ModelParameters) or len(self) != len(other):
63 | return False
64 | else:
65 | return all(torch.equal(p_self, p_other) for p_self, p_other in zip(self.parameters, other.parameters))
66 |
67 | def __add__(self, other: 'ModelParameters') -> 'ModelParameters':
68 | """
69 | Constructively returns the result of addition between this tensor and another.
70 | :param other: other to add
71 | :return: self + other
72 | """
73 | return ModelParameters([self[idx] + other[idx] for idx in range(len(self))])
74 |
75 | def __radd__(self, other: 'ModelParameters') -> 'ModelParameters':
76 | """
77 | Constructively returns the result of addition between this tensor and another.
78 | :param other: model parameters to add
79 | :return: other + self
80 | """
81 | return self.__add__(other)
82 |
83 | def add_(self, other: 'ModelParameters'):
84 | """
85 | In-place addition between this tensor and another.
86 | :param other: model parameters to add
87 | :return: none
88 | """
89 | for idx in range(len(self)):
90 | self.parameters[idx] += other[idx]
91 |
92 | def __sub__(self, other: 'ModelParameters') -> 'ModelParameters':
93 | """
94 | Constructively returns the result of subtracting another tensor from this one.
95 | :param other: model parameters to subtract
96 | :return: self - other
97 | """
98 | return ModelParameters([self[idx] - other[idx] for idx in range(len(self))])
99 |
100 | def __rsub__(self, other: 'ModelParameters') -> 'ModelParameters':
101 | """
102 | Constructively returns the result of subtracting this tensor from another one.
103 | :param other: other to subtract from
104 | :return: other - self
105 | """
106 | return self.__sub__(other)
107 |
108 | def sub_(self, vector: 'ModelParameters'):
109 | """
110 | In-place subtraction of another tensor from this one.
111 | :param vector: other to subtract
112 | :return: none
113 | """
114 | for idx in range(len(self)):
115 | self.parameters[idx] -= vector[idx]
116 |
117 | def __mul__(self, scalar) -> 'ModelParameters':
118 | """
119 | Constructively returns the result of multiplying this tensor by a scalar.
120 | :param scalar: scalar to multiply by
121 | :return: self * scalar
122 | """
123 | return ModelParameters([self[idx] * scalar for idx in range(len(self))])
124 |
125 | def __rmul__(self, scalar) -> 'ModelParameters':
126 | """
127 | Constructively returns the result of multiplying this tensor by a scalar.
128 | :param scalar: scalar to multiply by
129 | :return: scalar * self
130 | """
131 | return self.__mul__(scalar)
132 |
133 | def mul_(self, scalar):
134 | """
135 | In-place multiplication of this tensor by a scalar.
136 | :param scalar: scalar to multiply by
137 | :return: none
138 | """
139 | for idx in range(len(self)):
140 | self.parameters[idx] *= scalar
141 |
142 | def __truediv__(self, scalar) -> 'ModelParameters':
143 | """
144 | Constructively returns the result of true-dividing this tensor by a scalar.
145 | :param scalar: scalar to divide by
146 | :return: scalar / self
147 | """
148 | return ModelParameters([self[idx] / scalar for idx in range(len(self))])
149 |
150 | def truediv_(self, scalar):
151 | """
152 | In-place true-division of this tensor by a scalar.
153 | :param scalar: scalar to divide by
154 | :return: none
155 | """
156 | for idx in range(len(self)):
157 | self.parameters[idx] /= scalar
158 |
159 | def __floordiv__(self, scalar) -> 'ModelParameters':
160 | """
161 | Constructively returns the result of floor-dividing this tensor by a scalar.
162 | :param scalar: scalar to divide by
163 | :return: scalar // self
164 | """
165 | return ModelParameters([self[idx] // scalar for idx in range(len(self))])
166 |
167 | def floordiv_(self, scalar):
168 | """
169 | In-place floor-division of this tensor by a scalar.
170 | :param scalar: scalar to divide by
171 | :return: none
172 | """
173 | for idx in range(len(self)):
174 | self.parameters[idx] //= scalar
175 |
176 | def __matmul__(self, other: 'ModelParameters') -> 'ModelParameters':
177 | """
178 | Constructively returns the result of tensor-multiplication of this tensor by another tensor.
179 | :param other: other tensor
180 | :return: self @ tensor
181 | """
182 | raise NotImplementedError()
183 |
184 | def dot(self, other: 'ModelParameters') -> float:
185 | """
186 | Returns the vector dot product of this ModelParameters vector and the given other vector.
187 | :param other: other ModelParameters vector
188 | :return: dot product of self and other
189 | """
190 | param_products = []
191 | for idx in range(len(self.parameters)):
192 | param_products.append((self.parameters[idx] * other.parameters[idx]).sum().item())
193 | return sum(param_products)
194 |
195 | def model_normalize_(self, ref_point: 'ModelParameters', order=2):
196 | """
197 | In-place model-wise normalization of the tensor.
198 | :param ref_point: use this model's norm, if given
199 | :param order: norm order, e.g. 2 for L2 norm
200 | :return: none
201 | """
202 | for parameter in self.parameters:
203 | parameter *= (ref_point.model_norm(order) / self.model_norm())
204 |
205 | def layer_normalize_(self, ref_point: 'ModelParameters', order=2):
206 | """
207 | In-place layer-wise normalization of the tensor.
208 | :param ref_point: use this model's layer norms, if given
209 | :param order: norm order, e.g. 2 for L2 norm
210 | :return: none
211 | """
212 | # in-place normalize each parameter
213 | for layer_idx, parameter in enumerate(self.parameters, 0):
214 | parameter *= (ref_point.layer_norm(layer_idx, order) / self.layer_norm(layer_idx, order))
215 |
216 | def filter_normalize_(self, ref_point: 'ModelParameters', order=2):
217 | """
218 | In-place filter-wise normalization of the tensor.
219 | :param ref_point: use this model's filter norms, if given
220 | :param order: norm order, e.g. 2 for L2 norm
221 | :return: none
222 | """
223 | for l in range(len(self.parameters)):
224 | # normalize one-dimensional bias vectors
225 | if len(self.parameters[l].size()) == 1:
226 | self.parameters[l] *= (ref_point.parameters[l].norm(order) / self.parameters[l].norm(order))
227 | # normalize two-dimensional weight vectors
228 | for f in range(len(self.parameters[l])):
229 | self.parameters[l][f] *= ref_point.filter_norm((l, f), order) / (self.filter_norm((l, f), order))
230 |
231 | def model_norm(self, order=2) -> float:
232 | """
233 | Returns the model-wise L-norm of the tensor.
234 | :param order: norm order, e.g. 2 for L2 norm
235 | :return: L-norm of tensor
236 | """
237 | # L-n norm of model where we treat the model as a flat other
238 | return math.pow(sum([
239 | torch.pow(layer, order).sum().item()
240 | for layer in self.parameters
241 | ]), 1.0 / order)
242 |
243 | def layer_norm(self, index, order=2) -> float:
244 | """
245 | Returns a list of layer-wise L-norms of the tensor.
246 | :param order: norm order, e.g. 2 for L2 norm
247 | :param index: layer index
248 | :return: list of L-norms of layers
249 | """
250 | # L-n norms of layer where we treat each layer as a flat other
251 | return math.pow(torch.pow(self.parameters[index], order).sum().item(), 1.0 / order)
252 |
253 | def filter_norm(self, index, order=2) -> float:
254 | """
255 | Returns a 2D list of filter-wise L-norms of the tensor.
256 | :param order: norm order, e.g. 2 for L2 norm
257 | :param index: tuple with layer index and filter index
258 | :return: list of L-norms of filters
259 | """
260 | # L-n norm of each filter where we treat each layer as a flat other
261 | return math.pow(torch.pow(self.parameters[index[0]][index[1]], order).sum().item(), 1.0 / order)
262 |
263 | def as_numpy(self) -> np.ndarray:
264 | """
265 | Returns the tensor as a flat numpy array.
266 | :return: a numpy array
267 | """
268 | return np.concatenate([p.numpy().flatten() for p in self.parameters])
269 |
270 | def _get_parameters(self) -> list:
271 | """
272 | Returns a reference to the internal parameter data in whatever format used by the source model.
273 | :return: reference to internal parameter data
274 | """
275 | return self.parameters
276 |
277 |
278 | def rand_u_like(example_vector: ModelParameters) -> ModelParameters:
279 | """
280 | Create a new ModelParameters object of size and shape compatible with the given
281 | example vector, such that the values in the ModelParameter are uniformly distributed
282 | in the range [0,1].
283 | :param example_vector: defines by example the size and shape the new vector will have
284 | :return: new vector with uniformly distributed values
285 | """
286 | new_vector = []
287 |
288 | for param in example_vector:
289 | new_vector.append(torch.rand(size=param.size(), dtype=example_vector[0].dtype))
290 |
291 | return ModelParameters(new_vector)
292 |
293 |
294 | def rand_n_like(example_vector: ModelParameters) -> ModelParameters:
295 | """
296 | Create a new ModelParameters object of size and shape compatible with the given
297 | example vector, such that the values in the ModelParameter are normally distributed
298 | as N(0,1).
299 | :param example_vector: defines by example the size and shape the new vector will have
300 | :return: new vector with normally distributed values
301 | """
302 | new_vector = []
303 |
304 | for param in example_vector:
305 | new_vector.append(torch.randn(size=param.size(), dtype=example_vector[0].dtype))
306 |
307 | return ModelParameters(new_vector)
308 |
309 |
310 | def orthogonal_to(vector: ModelParameters) -> ModelParameters:
311 | """
312 | Create a new ModelParameters object of size and shape compatible with the given
313 | example vector, such that the two vectors are very nearly orthogonal.
314 | :param vector: original vector
315 | :return: new vector that is very nearly orthogonal to original vector
316 | """
317 | new_vector = rand_u_like(vector)
318 | new_vector = new_vector - new_vector.dot(vector) * vector / math.pow(vector.model_norm(2), 2)
319 | return new_vector
320 |
321 |
322 | def add(vector_a: ModelParameters, vector_b: ModelParameters) -> ModelParameters:
323 | return vector_a + vector_b
324 |
325 |
326 | def sub(vector_a: ModelParameters, vector_b: ModelParameters) -> ModelParameters:
327 | return vector_a - vector_b
328 |
329 |
330 | def mul(vector: ModelParameters, scalar) -> ModelParameters:
331 | return vector * scalar
332 |
333 |
334 | def truediv(vector: ModelParameters, scalar) -> ModelParameters:
335 | return vector / scalar
336 |
337 |
338 | def floordiv(vector: ModelParameters, scalar) -> ModelParameters:
339 | return vector // scalar
340 |
341 |
342 | def filter_normalize(tensor, order=2) -> ModelParameters:
343 | new_tensor = copy.deepcopy(tensor)
344 | new_tensor.filter_normalize_(order)
345 | return new_tensor
346 |
347 |
348 | def layer_normalize(tensor, order) -> ModelParameters:
349 | new_tensor = copy.deepcopy(tensor)
350 | new_tensor.layer_normalize_(order)
351 | return new_tensor
352 |
353 |
354 | def model_normalize(tensor, order) -> ModelParameters:
355 | new_tensor = copy.deepcopy(tensor)
356 | new_tensor.model_normalize_(order)
357 | return new_tensor
358 |
--------------------------------------------------------------------------------
/loss_landscapes/main.py:
--------------------------------------------------------------------------------
1 | """
2 | Functions for approximating loss/return landscapes in one and two dimensions.
3 | """
4 |
5 | import copy
6 | import typing
7 | import torch.nn
8 | import numpy as np
9 | from loss_landscapes.model_interface.model_wrapper import ModelWrapper, wrap_model
10 | from loss_landscapes.model_interface.model_parameters import rand_u_like, orthogonal_to
11 | from loss_landscapes.metrics.metric import Metric
12 |
13 |
14 | # noinspection DuplicatedCode
15 | def point(model: typing.Union[torch.nn.Module, ModelWrapper], metric: Metric) -> tuple:
16 | """
17 | Returns the computed value of the evaluation function applied to the model
18 | or agent at a specific point in parameter space.
19 |
20 | The Metric supplied has to be a subclass of the loss_landscapes.metrics.Metric
21 | class, and must specify a procedure whereby the model passed to it is evaluated on the
22 | task of interest, returning the resulting quantity (such as loss, loss gradient, etc).
23 |
24 | The model supplied can be either a torch.nn.Module model, or a ModelWrapper from the
25 | loss_landscapes library for more complex cases.
26 |
27 | :param model: the model or model wrapper defining the point in parameter space
28 | :param metric: Metric object used to evaluate model
29 | :return: quantity specified by Metric at point in parameter space
30 | """
31 | return metric(wrap_model(model))
32 |
33 |
34 | # noinspection DuplicatedCode
35 | def linear_interpolation(model_start: typing.Union[torch.nn.Module, ModelWrapper],
36 | model_end: typing.Union[torch.nn.Module, ModelWrapper],
37 | metric: Metric, steps=100, deepcopy_model=False) -> np.ndarray:
38 | """
39 | Returns the computed value of the evaluation function applied to the model or
40 | agent along a linear subspace of the parameter space defined by two end points.
41 | The models supplied can be either torch.nn.Module models, or ModelWrapper objects
42 | from the loss_landscapes library for more complex cases.
43 |
44 | That is, given two models, for both of which the model's parameters define a
45 | vertex in parameter space, the evaluation is computed at the given number of steps
46 | along the straight line connecting the two vertices. A common choice is to
47 | use the weights before training and the weights after convergence as the start
48 | and end points of the line, thus obtaining a view of the "straight line" in
49 | parameter space from the initialization to some minima. There is no guarantee
50 | that the model followed this path during optimization. In fact, it is highly
51 | unlikely to have done so, unless the optimization problem is convex.
52 |
53 | Note that a simple linear interpolation can produce misleading approximations
54 | of the loss landscape due to the scale invariance of neural networks. The sharpness/
55 | flatness of minima or maxima is affected by the scale of the neural network weights.
56 | For more details, see `https://arxiv.org/abs/1712.09913v3`. It is recommended to
57 | use random_line() with filter normalization instead.
58 |
59 | The Metric supplied has to be a subclass of the loss_landscapes.metrics.Metric class,
60 | and must specify a procedure whereby the model passed to it is evaluated on the
61 | task of interest, returning the resulting quantity (such as loss, loss gradient, etc).
62 |
63 | :param model_start: the model defining the start point of the line in parameter space
64 | :param model_end: the model defining the end point of the line in parameter space
65 | :param metric: list of function of form evaluation_f(model), used to evaluate model loss
66 | :param steps: at how many steps from start to end the model is evaluated
67 | :param deepcopy_model: indicates whether the method will deepcopy the model(s) to avoid aliasing
68 | :return: 1-d array of loss values along the line connecting start and end models
69 | """
70 | # create wrappers from deep copies to avoid aliasing if desired
71 | model_start_wrapper = wrap_model(copy.deepcopy(model_start) if deepcopy_model else model_start)
72 | end_model_wrapper = wrap_model(copy.deepcopy(model_end) if deepcopy_model else model_end)
73 |
74 | start_point = model_start_wrapper.get_module_parameters()
75 | end_point = end_model_wrapper.get_module_parameters()
76 | direction = (end_point - start_point) / steps
77 |
78 | data_values = []
79 | for i in range(steps):
80 | # add a step along the line to the model parameters, then evaluate
81 | start_point.add_(direction)
82 | data_values.append(metric(model_start_wrapper))
83 |
84 | return np.array(data_values)
85 |
86 |
87 | # noinspection DuplicatedCode
88 | def random_line(model_start: typing.Union[torch.nn.Module, ModelWrapper], metric: Metric, distance=0.1, steps=100,
89 | normalization='filter', deepcopy_model=False) -> np.ndarray:
90 | """
91 | Returns the computed value of the evaluation function applied to the model or agent along a
92 | linear subspace of the parameter space defined by a start point and a randomly sampled direction.
93 | The models supplied can be either torch.nn.Module models, or ModelWrapper objects
94 | from the loss_landscapes library for more complex cases.
95 |
96 | That is, given a neural network model, whose parameters define a point in parameter
97 | space, and a distance, the evaluation is computed at 'steps' points along a random
98 | direction, from the start point up to the maximum distance from the start point.
99 |
100 | Note that the dimensionality of the model parameters has an impact on the expected
101 | length of a uniformly sampled other in parameter space. That is, the more parameters
102 | a model has, the longer the distance in the random other's direction should be,
103 | in order to see meaningful change in individual parameters. Normalizing the
104 | direction other according to the model's current parameter values, which is supported
105 | through the 'normalization' parameter, helps reduce the impact of the distance
106 | parameter. In future releases, the distance parameter will refer to the maximum change
107 | in an individual parameter, rather than the length of the random direction other.
108 |
109 | Note also that a simple line approximation can produce misleading views
110 | of the loss landscape due to the scale invariance of neural networks. The sharpness or
111 | flatness of minima or maxima is affected by the scale of the neural network weights.
112 | For more details, see `https://arxiv.org/abs/1712.09913v3`. It is recommended to
113 | normalize the direction, preferably with the 'filter' option.
114 |
115 | The Metric supplied has to be a subclass of the loss_landscapes.metrics.Metric class,
116 | and must specify a procedure whereby the model passed to it is evaluated on the
117 | task of interest, returning the resulting quantity (such as loss, loss gradient, etc).
118 |
119 | :param model_start: model to be evaluated, whose current parameters represent the start point
120 | :param metric: function of form evaluation_f(model), used to evaluate model loss
121 | :param distance: maximum distance in parameter space from the start point
122 | :param steps: at how many steps from start to end the model is evaluated
123 | :param normalization: normalization of direction other, must be one of 'filter', 'layer', 'model'
124 | :param deepcopy_model: indicates whether the method will deepcopy the model(s) to avoid aliasing
125 | :return: 1-d array of loss values along the randomly sampled direction
126 | """
127 | # create wrappers from deep copies to avoid aliasing if desired
128 | model_start_wrapper = wrap_model(copy.deepcopy(model_start) if deepcopy_model else model_start)
129 |
130 | # obtain start point in parameter space and random direction
131 | # random direction is randomly sampled, then normalized, and finally scaled by distance/steps
132 | start_point = model_start_wrapper.get_module_parameters()
133 | direction = rand_u_like(start_point)
134 |
135 | if normalization == 'model':
136 | direction.model_normalize_(start_point)
137 | elif normalization == 'layer':
138 | direction.layer_normalize_(start_point)
139 | elif normalization == 'filter':
140 | direction.filter_normalize_(start_point)
141 | elif normalization is None:
142 | pass
143 | else:
144 | raise AttributeError('Unsupported normalization argument. Supported values are model, layer, and filter')
145 |
146 | direction.mul_(((start_point.model_norm() * distance) / steps) / direction.model_norm())
147 |
148 | data_values = []
149 | for i in range(steps):
150 | # add a step along the line to the model parameters, then evaluate
151 | start_point.add_(direction)
152 | data_values.append(metric(model_start_wrapper))
153 |
154 | return np.array(data_values)
155 |
156 |
157 | # noinspection DuplicatedCode
158 | def planar_interpolation(model_start: typing.Union[torch.nn.Module, ModelWrapper],
159 | model_end_one: typing.Union[torch.nn.Module, ModelWrapper],
160 | model_end_two: typing.Union[torch.nn.Module, ModelWrapper],
161 | metric: Metric, steps=20, deepcopy_model=False) -> np.ndarray:
162 | """
163 | Returns the computed value of the evaluation function applied to the model or agent along
164 | a planar subspace of the parameter space defined by a start point and two end points.
165 | The models supplied can be either torch.nn.Module models, or ModelWrapper objects
166 | from the loss_landscapes library for more complex cases.
167 |
168 | That is, given two models, for both of which the model's parameters define a
169 | vertex in parameter space, the loss is computed at the given number of steps
170 | along the straight line connecting the two vertices. A common choice is to
171 | use the weights before training and the weights after convergence as the start
172 | and end points of the line, thus obtaining a view of the "straight line" in
173 | paramater space from the initialization to some minima. There is no guarantee
174 | that the model followed this path during optimization. In fact, it is highly
175 | unlikely to have done so, unless the optimization problem is convex.
176 |
177 | That is, given three neural network models, 'model_start', 'model_end_one', and
178 | 'model_end_two', each of which defines a point in parameter space, the loss is
179 | computed at 'steps' * 'steps' points along the plane defined by the start vertex
180 | and the two vectors (end_one - start) and (end_two - start), up to the maximum
181 | distance in both directions. A common choice would be for two of the points to be
182 | the model after initialization, and the model after convergence. The third point
183 | could be another randomly initialized model, since in a high-dimensional space
184 | randomly sampled directions are most likely to be orthogonal.
185 |
186 | The Metric supplied has to be a subclass of the loss_landscapes.metrics.Metric class,
187 | and must specify a procedure whereby the model passed to it is evaluated on the
188 | task of interest, returning the resulting quantity (such as loss, loss gradient, etc).
189 |
190 | :param model_start: the model defining the origin point of the plane in parameter space
191 | :param model_end_one: the model representing the end point of the first direction defining the plane
192 | :param model_end_two: the model representing the end point of the second direction defining the plane
193 | :param metric: function of form evaluation_f(model), used to evaluate model loss
194 | :param steps: at how many steps from start to end the model is evaluated
195 | :param deepcopy_model: indicates whether the method will deepcopy the model(s) to avoid aliasing
196 | :return: 1-d array of loss values along the line connecting start and end models
197 | """
198 | model_start_wrapper = wrap_model(copy.deepcopy(model_start) if deepcopy_model else model_start)
199 | model_end_one_wrapper = wrap_model(copy.deepcopy(model_end_one) if deepcopy_model else model_end_one)
200 | model_end_two_wrapper = wrap_model(copy.deepcopy(model_end_two) if deepcopy_model else model_end_two)
201 |
202 | # compute direction vectors
203 | start_point = model_start_wrapper.get_module_parameters()
204 | dir_one = (model_end_one_wrapper.get_module_parameters() - start_point) / steps
205 | dir_two = (model_end_two_wrapper.get_module_parameters() - start_point) / steps
206 |
207 | data_matrix = []
208 | # evaluate loss in grid of (steps * steps) points, where each column signifies one step
209 | # along dir_one and each row signifies one step along dir_two. The implementation is again
210 | # a little convoluted to avoid constructive operations. Fundamentally we generate the matrix
211 | # [[start_point + (dir_one * i) + (dir_two * j) for j in range(steps)] for i in range(steps].
212 | for i in range(steps):
213 | data_column = []
214 |
215 | for j in range(steps):
216 | # for every other column, reverse the order in which the column is generated
217 | # so you can easily use in-place operations to move along dir_two
218 | if i % 2 == 0:
219 | start_point.add_(dir_two)
220 | data_column.append(metric(model_start_wrapper))
221 | else:
222 | start_point.sub_(dir_two)
223 | data_column.insert(0, metric(model_start_wrapper))
224 |
225 | data_matrix.append(data_column)
226 | start_point.add_(dir_one)
227 |
228 | return np.array(data_matrix)
229 |
230 |
231 | # noinspection DuplicatedCode
232 | def random_plane(model: typing.Union[torch.nn.Module, ModelWrapper], metric: Metric, distance=1, steps=20,
233 | normalization='filter', deepcopy_model=False) -> np.ndarray:
234 | """
235 | Returns the computed value of the evaluation function applied to the model or agent along a planar
236 | subspace of the parameter space defined by a start point and two randomly sampled directions.
237 | The models supplied can be either torch.nn.Module models, or ModelWrapper objects
238 | from the loss_landscapes library for more complex cases.
239 |
240 | That is, given a neural network model, whose parameters define a point in parameter
241 | space, and a distance, the loss is computed at 'steps' * 'steps' points along the
242 | plane defined by the two random directions, from the start point up to the maximum
243 | distance in both directions.
244 |
245 | Note that the dimensionality of the model parameters has an impact on the expected
246 | length of a uniformly sampled other in parameter space. That is, the more parameters
247 | a model has, the longer the distance in the random other's direction should be,
248 | in order to see meaningful change in individual parameters. Normalizing the
249 | direction other according to the model's current parameter values, which is supported
250 | through the 'normalization' parameter, helps reduce the impact of the distance
251 | parameter. In future releases, the distance parameter will refer to the maximum change
252 | in an individual parameter, rather than the length of the random direction other.
253 |
254 | Note also that a simple planar approximation with randomly sampled directions can produce
255 | misleading approximations of the loss landscape due to the scale invariance of neural
256 | networks. The sharpness/flatness of minima or maxima is affected by the scale of the neural
257 | network weights. For more details, see `https://arxiv.org/abs/1712.09913v3`. It is
258 | recommended to normalize the directions, preferably with the 'filter' option.
259 |
260 | The Metric supplied has to be a subclass of the loss_landscapes.metrics.Metric class,
261 | and must specify a procedure whereby the model passed to it is evaluated on the
262 | task of interest, returning the resulting quantity (such as loss, loss gradient, etc).
263 |
264 | :param model: the model defining the origin point of the plane in parameter space
265 | :param metric: function of form evaluation_f(model), used to evaluate model loss
266 | :param distance: maximum distance in parameter space from the start point
267 | :param steps: at how many steps from start to end the model is evaluated
268 | :param normalization: normalization of direction vectors, must be one of 'filter', 'layer', 'model'
269 | :param deepcopy_model: indicates whether the method will deepcopy the model(s) to avoid aliasing
270 | :return: 1-d array of loss values along the line connecting start and end models
271 | """
272 | model_start_wrapper = wrap_model(copy.deepcopy(model) if deepcopy_model else model)
273 |
274 | start_point = model_start_wrapper.get_module_parameters()
275 | dir_one = rand_u_like(start_point)
276 | dir_two = orthogonal_to(dir_one)
277 |
278 | if normalization == 'model':
279 | dir_one.model_normalize_(start_point)
280 | dir_two.model_normalize_(start_point)
281 | elif normalization == 'layer':
282 | dir_one.layer_normalize_(start_point)
283 | dir_two.layer_normalize_(start_point)
284 | elif normalization == 'filter':
285 | dir_one.filter_normalize_(start_point)
286 | dir_two.filter_normalize_(start_point)
287 | elif normalization is None:
288 | pass
289 | else:
290 | raise AttributeError('Unsupported normalization argument. Supported values are model, layer, and filter')
291 |
292 | # scale to match steps and total distance
293 | dir_one.mul_(((start_point.model_norm() * distance) / steps) / dir_one.model_norm())
294 | dir_two.mul_(((start_point.model_norm() * distance) / steps) / dir_two.model_norm())
295 | # Move start point so that original start params will be in the center of the plot
296 | dir_one.mul_(steps / 2)
297 | dir_two.mul_(steps / 2)
298 | start_point.sub_(dir_one)
299 | start_point.sub_(dir_two)
300 | dir_one.truediv_(steps / 2)
301 | dir_two.truediv_(steps / 2)
302 |
303 | data_matrix = []
304 | # evaluate loss in grid of (steps * steps) points, where each column signifies one step
305 | # along dir_one and each row signifies one step along dir_two. The implementation is again
306 | # a little convoluted to avoid constructive operations. Fundamentally we generate the matrix
307 | # [[start_point + (dir_one * i) + (dir_two * j) for j in range(steps)] for i in range(steps].
308 | for i in range(steps):
309 | data_column = []
310 |
311 | for j in range(steps):
312 | # for every other column, reverse the order in which the column is generated
313 | # so you can easily use in-place operations to move along dir_two
314 | if i % 2 == 0:
315 | start_point.add_(dir_two)
316 | data_column.append(metric(model_start_wrapper))
317 | else:
318 | start_point.sub_(dir_two)
319 | data_column.insert(0, metric(model_start_wrapper))
320 |
321 | data_matrix.append(data_column)
322 | start_point.add_(dir_one)
323 |
324 | return np.array(data_matrix)
325 |
326 |
327 | # todo add hypersphere function
328 |
--------------------------------------------------------------------------------