├── configs ├── opt │ ├── cvx.toml │ ├── SGD.toml │ └── grouprmsprop.toml ├── mat-cmpl │ ├── 2000.toml │ ├── 5000.toml │ ├── gen_gt.toml │ ├── gen_obs.toml │ └── run.toml ├── mat-sensing │ ├── 2000.toml │ ├── 5000.toml │ ├── gen_obs.toml │ └── run.toml ├── dynamics.toml └── ml-100k.toml ├── .gitignore ├── datasets └── ml-100k │ └── ml-100k.pkl ├── lunzi ├── requirements.txt ├── typing.py ├── serialization.py ├── dummy.py ├── README.md ├── __init__.py ├── sampler.py ├── debug.py ├── file_storage.py ├── base_flags.py ├── injector.py ├── experiment.py └── dataset.py ├── requirements.txt ├── gen_gt.py ├── gen_obs.py ├── scripts └── run.rb ├── opt.py ├── README.md └── main.py /configs/opt/cvx.toml: -------------------------------------------------------------------------------- 1 | optimizer = "cvxpy" -------------------------------------------------------------------------------- /configs/opt/SGD.toml: -------------------------------------------------------------------------------- 1 | optimizer = "SGD" 2 | lr = 0.01 3 | -------------------------------------------------------------------------------- /configs/mat-cmpl/2000.toml: -------------------------------------------------------------------------------- 1 | obs_path = "datasets/mat-cmpl/2000.pt" 2 | -------------------------------------------------------------------------------- /configs/mat-cmpl/5000.toml: -------------------------------------------------------------------------------- 1 | obs_path = "datasets/mat-cmpl/5000.pt" 2 | -------------------------------------------------------------------------------- /configs/mat-sensing/2000.toml: -------------------------------------------------------------------------------- 1 | obs_path = "datasets/mat-sensing/2000.pt" 2 | -------------------------------------------------------------------------------- /configs/mat-sensing/5000.toml: -------------------------------------------------------------------------------- 1 | obs_path = "datasets/mat-sensing/5000.pt" 2 | -------------------------------------------------------------------------------- /configs/opt/grouprmsprop.toml: -------------------------------------------------------------------------------- 1 | optimizer = "GroupRMSprop" 2 | lr = 0.001 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | __pycache__ 3 | 4 | .DS_Store 5 | datasets/* 6 | !datasets/ml-100k 7 | -------------------------------------------------------------------------------- /configs/mat-cmpl/gen_gt.toml: -------------------------------------------------------------------------------- 1 | n = 100 2 | rank = 5 3 | gt_path = "datasets/mat-cmpl/gt.pt" 4 | -------------------------------------------------------------------------------- /configs/mat-cmpl/gen_obs.toml: -------------------------------------------------------------------------------- 1 | n = 100 2 | gt_path = 'datasets/mat-cmpl/gt.pt' 3 | problem = 'matrix-completion' 4 | -------------------------------------------------------------------------------- /configs/mat-sensing/gen_obs.toml: -------------------------------------------------------------------------------- 1 | n = 100 2 | gt_path = 'datasets/mat-sensing/gt.pt' 3 | problem = 'matrix-sensing' 4 | -------------------------------------------------------------------------------- /configs/dynamics.toml: -------------------------------------------------------------------------------- 1 | n_iters = 200000 2 | n_singulars_save = 10 3 | n_dev_iters = 100 4 | initialization = "identity" 5 | -------------------------------------------------------------------------------- /configs/mat-cmpl/run.toml: -------------------------------------------------------------------------------- 1 | problem = "matrix-completion" 2 | gt_path = "datasets/mat-cmpl/gt.pt" 3 | shape = [100, 100] 4 | -------------------------------------------------------------------------------- /configs/mat-sensing/run.toml: -------------------------------------------------------------------------------- 1 | problem = "matrix-sensing" 2 | gt_path = "datasets/mat-sensing/gt.pt" 3 | shape = [100, 100] 4 | -------------------------------------------------------------------------------- /datasets/ml-100k/ml-100k.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/roosephu/deep_matrix_factorization/HEAD/datasets/ml-100k/ml-100k.pkl -------------------------------------------------------------------------------- /lunzi/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | wrapt 3 | coloredlogs 4 | toml 5 | fasteners 6 | ipdb # optional 7 | tensorboardX # optional 8 | -------------------------------------------------------------------------------- /configs/ml-100k.toml: -------------------------------------------------------------------------------- 1 | problem = "ml-100k" 2 | obs_path = "datasets/ml-100k/ml-100k.pkl" 3 | n_train_samples = 100_000 4 | shape = [1682, 943] 5 | -------------------------------------------------------------------------------- /lunzi/typing.py: -------------------------------------------------------------------------------- 1 | 2 | from .base_flags import BaseFLAGS 3 | from .experiment import Logger, SummaryWriter 4 | from .dataset import BaseDataset, Dataset, ExtendableDataset 5 | from .file_storage import FileStorage 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.1.0 2 | numpy 3 | ipdb 4 | gitpython 5 | cvxpy 6 | cvxopt 7 | 8 | coloredlogs 9 | toml==0.10.0 10 | tensorboard==1.14.0 11 | tensorboardX==1.7 12 | wrapt 13 | 14 | tinydb 15 | matplotlib 16 | 17 | fasteners 18 | -------------------------------------------------------------------------------- /lunzi/serialization.py: -------------------------------------------------------------------------------- 1 | from typing import Union, IO, Any 2 | import numpy as np 3 | 4 | 5 | def save(obj: Any, file: Union[str, IO]): 6 | np.save(file, obj) 7 | 8 | 9 | def load(file: Union[str, IO]): 10 | return np.load(file)[()] 11 | 12 | -------------------------------------------------------------------------------- /lunzi/dummy.py: -------------------------------------------------------------------------------- 1 | class DummyClass: 2 | def __getattr__(self, item): 3 | if item in self.__dict__: 4 | return self.__dict__[item] 5 | return self 6 | 7 | def __call__(self, *args, **kwargs): 8 | pass 9 | 10 | 11 | dummy = DummyClass() 12 | -------------------------------------------------------------------------------- /lunzi/README.md: -------------------------------------------------------------------------------- 1 | Writing a library is like SGD: Every time you'd like to write some code 2 | according to current requirements. However, after some time you probably don't 3 | need it any more and the code gets deprecated. After along time, finally you've 4 | found the central need and therefore maintain its core. -------------------------------------------------------------------------------- /lunzi/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.0.1' 2 | 3 | from typing import Union 4 | 5 | from .dummy import dummy 6 | from .file_storage import FileStorage 7 | from .experiment import init, close, main, get_logger, SummaryWriter, Logger 8 | from .injector import inject 9 | 10 | 11 | log: Logger = get_logger('lunzi') 12 | fs: FileStorage = FileStorage() 13 | writer: Union[dummy, SummaryWriter] = dummy 14 | features = [] 15 | info = { 16 | 'lunzi': { 17 | 'features': features, 18 | '__version__': __version__, 19 | } 20 | } 21 | 22 | from .serialization import save, load 23 | from .dataset import BaseDataset, Dataset, ExtendableDataset 24 | from .sampler import BaseSampler, BatchSampler 25 | from . import debug 26 | from .typing import * 27 | -------------------------------------------------------------------------------- /gen_gt.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import lunzi as lz 4 | import torch 5 | 6 | 7 | class FLAGS(lz.BaseFLAGS): 8 | n = 100 9 | rank = 5 10 | symmetric = False 11 | gt_path = '' 12 | 13 | 14 | @lz.main(FLAGS) 15 | @FLAGS.inject 16 | def main(n, rank, gt_path, symmetric): 17 | r = rank 18 | U = np.random.randn(n, r).astype(np.float32) 19 | if symmetric: 20 | V = U 21 | else: 22 | V = np.random.randn(n, r).astype(np.float32) 23 | w_gt = U.dot(V.T) / np.sqrt(r) 24 | w_gt = w_gt / np.linalg.norm(w_gt, 'fro') * n 25 | 26 | oracle_sv = np.linalg.svd(w_gt, compute_uv=False) 27 | lz.log.info("singular values = %s, Fro(w) = %.3f", oracle_sv[:r], np.linalg.norm(w_gt, ord='fro')) 28 | torch.save(torch.from_numpy(w_gt), gt_path) 29 | 30 | 31 | if __name__ == '__main__': 32 | main() 33 | -------------------------------------------------------------------------------- /lunzi/sampler.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import numpy as np 3 | 4 | 5 | class BaseSampler(abc.ABC): 6 | @abc.abstractmethod 7 | def __iter__(self): 8 | pass 9 | 10 | def __len__(self): 11 | raise NotImplementedError 12 | 13 | 14 | class BatchSampler(BaseSampler): 15 | def __init__(self, dataset, batch_size, shuffle=True): 16 | self.dataset = dataset 17 | self.batch_size = batch_size 18 | self.shuffle = shuffle 19 | 20 | def _iterator(self): 21 | indices = np.arange(self.dataset.size, dtype=np.int32) 22 | if self.shuffle: 23 | np.random.shuffle(indices) 24 | index = 0 25 | while index < self.dataset.size: 26 | end = min(index + self.batch_size, self.dataset.size) 27 | yield self.dataset[indices[index:end]] 28 | index = end 29 | 30 | def __iter__(self): 31 | return self._iterator() 32 | 33 | def __len__(self): 34 | return self.dataset.size // self.batch_size 35 | 36 | -------------------------------------------------------------------------------- /lunzi/debug.py: -------------------------------------------------------------------------------- 1 | from lunzi import log 2 | 3 | skips = { 4 | # 'lunzi', 5 | # 'lunzi.*', 6 | 'lunzi.injector', 7 | 'lunzi.base_flags', 8 | 'lunzi.experiment', 9 | 'ipdb.*', 'pdb', 10 | 'numpy', 'numpy.*', 11 | 'torch', 'torch.*', 12 | 'tensorflow', 'tensorflow.*', 13 | } 14 | 15 | 16 | def _monkey_patch(): 17 | try: 18 | import ipdb 19 | except ImportError: 20 | log.critical(f'skip patching `ipdb`: `ipdb` not found.') 21 | return 22 | 23 | import os 24 | env_var = 'PYTHONBREAKPOINT' 25 | if env_var in os.environ: 26 | log.critical(f'skip patching `ipdb`: environment variable `{env_var}` has been set.') 27 | return 28 | os.environ[env_var] = 'ipdb.set_trace' 29 | 30 | old_init_pdb = ipdb.__main__._init_pdb 31 | 32 | def _init_pdb(*args, **kwargs): 33 | p = old_init_pdb(*args, **kwargs) 34 | p.skip = skips 35 | return p 36 | 37 | ipdb.__main__._init_pdb = _init_pdb 38 | log.critical(f'`ipdb` patched...') 39 | 40 | 41 | _monkey_patch() 42 | -------------------------------------------------------------------------------- /gen_obs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | import lunzi as lz 5 | 6 | 7 | class FLAGS(lz.BaseFLAGS): 8 | n = 100 9 | gt_path = '' 10 | obs_path = '' 11 | problem = '' 12 | n_train_samples = 0 13 | 14 | @classmethod 15 | def finalize(cls): 16 | if cls.problem == 'matrix-sensing': 17 | cls.obs_path = f'datasets/mat-sensing/{cls.n_train_samples}.pt' 18 | elif cls.problem == 'matrix-completion': 19 | cls.obs_path = f'datasets/mat-cmpl/{cls.n_train_samples}.pt' 20 | 21 | 22 | @lz.main(FLAGS) 23 | @FLAGS.inject 24 | def main(n, problem, n_train_samples, gt_path, obs_path, _log): 25 | w_gt = torch.load(gt_path) 26 | 27 | with torch.no_grad(): 28 | if problem == 'matrix-completion': 29 | indices = torch.multinomial(torch.ones(n * n), n_train_samples, replacement=False) 30 | us, vs = indices // n, indices % n 31 | ys_ = w_gt[us, vs] 32 | assert 0.8 <= ys_.pow(2).mean().sqrt() <= 1.2 33 | torch.save([(us, vs), ys_], obs_path) 34 | elif problem == 'matrix-sensing': 35 | xs = torch.randn(n_train_samples, n, n) / n 36 | ys_ = (xs * w_gt).sum(dim=-1).sum(dim=-1) 37 | assert 0.8 <= ys_.pow(2).mean().sqrt() <= 1.2 38 | torch.save([xs, ys_], obs_path) 39 | else: 40 | raise ValueError(f'unexpected problem: {problem}') 41 | _log.warning('[%s] Saved %d samples to %s', problem, n_train_samples, obs_path) 42 | 43 | 44 | if __name__ == '__main__': 45 | main() 46 | -------------------------------------------------------------------------------- /scripts/run.rb: -------------------------------------------------------------------------------- 1 | #!/usr/bin/ruby 2 | # gem install colorize --user 3 | require 'optparse' 4 | require 'date' 5 | require 'colorize' 6 | 7 | options = { dry: false, template: '', name: '', n_jobs: 1, base_log_dir: '~/logs' } 8 | grid = [] 9 | OptionParser.new do |opts| 10 | opts.on("--name NAME", "name") { |name| options[:name] = name } 11 | opts.on("--n_jobs N_JOBS", "number of jobs") { |n_jobs| options[:n_jobs] = n_jobs } 12 | opts.on("--dry") { options[:dry] = true } 13 | opts.on("--base_log_dir BASE_LOG_DIR") { |base_log_dir| options[:base_log_dir] = base_log_dir } 14 | opts.on("--replace PARAM", "param") { |param| 15 | key, values = param.split '=', 2 16 | values = values.split ',' 17 | grid << values.map { |value| [key, value] } 18 | } 19 | opts.on("--template TEMPLATE", "template") { |template| options[:template] = template } 20 | end.parse! 21 | 22 | puts options 23 | 24 | Dir.instance_eval do 25 | timestamp = DateTime.now.strftime '%Y%m%d-%H%M%S' 26 | raise "name can't be empty" if options[:name] == '' 27 | log_dir = "#{options[:base_log_dir]}/#{options[:name]}-#{timestamp}" 28 | puts log_dir.yellow 29 | grid = [['LOGDIR', log_dir]].product(*grid) 30 | 31 | grid.each { |args| 32 | cmd = options[:template].dup 33 | args.each do |(old, new)| 34 | cmd.gsub! /\b#{old}\b/, new # an unsafe regexp constructor... 35 | end 36 | puts cmd 37 | if not options[:dry] then 38 | (1..options[:n_jobs].to_i).each do |i| 39 | pid = spawn cmd, :out => "/dev/null", :err => "/dev/null" 40 | Process.wait pid 41 | sleep 0.05 42 | end 43 | end 44 | } 45 | 46 | end 47 | -------------------------------------------------------------------------------- /opt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.optimizer import Optimizer 3 | 4 | 5 | class GroupRMSprop(Optimizer): 6 | """A different version of RMSprop optimizer with a global learning rate adjusting. 7 | """ 8 | 9 | def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-6): 10 | if not 0.0 <= lr: 11 | raise ValueError("Invalid learning rate: {}".format(lr)) 12 | if not 0.0 <= eps: 13 | raise ValueError("Invalid epsilon value: {}".format(eps)) 14 | if not 0.0 <= alpha: 15 | raise ValueError("Invalid alpha value: {}".format(alpha)) 16 | 17 | defaults = dict(lr=lr, alpha=alpha, eps=eps, adjusted_lr=lr) 18 | super().__init__(params, defaults) 19 | 20 | def __setstate__(self, state): 21 | super().__setstate__(state) 22 | 23 | def step(self, closure=None): 24 | """Performs a single optimization step. 25 | 26 | Arguments: 27 | closure (callable, optional): A closure that reevaluates the model 28 | and returns the loss. 29 | """ 30 | loss = None 31 | if closure is not None: 32 | loss = closure() 33 | 34 | for group in self.param_groups: 35 | state = self.state 36 | 37 | # State initialization 38 | if len(state) == 0: 39 | state['step'] = 0 40 | state['square_avg'] = torch.tensor(0.) 41 | 42 | square_avg = state['square_avg'] 43 | alpha = group['alpha'] 44 | square_avg.mul_(alpha) 45 | 46 | state['step'] += 1 47 | 48 | for p in group['params']: 49 | if p.grad is None: 50 | continue 51 | grad = p.grad.data 52 | if grad.is_sparse: 53 | raise RuntimeError('GroupRMSprop does not support sparse gradients') 54 | 55 | square_avg.add_((1 - alpha) * grad.pow(2).sum().cpu().float()) 56 | 57 | avg = square_avg.div(1 - alpha**state['step']).sqrt_().add_(group['eps']) 58 | lr = group['lr'] / avg 59 | group['adjusted_lr'] = lr 60 | 61 | for p in group['params']: 62 | if p.grad is None: 63 | continue 64 | grad = p.grad.data 65 | p.data.add_(-lr.to(grad.device) * grad) 66 | 67 | return loss 68 | -------------------------------------------------------------------------------- /lunzi/file_storage.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from pathlib import Path 3 | from zipfile import ZipFile 4 | 5 | import numpy as np 6 | import fasteners 7 | import toml 8 | 9 | 10 | class FileStorage: 11 | log_dir: Optional[Path] 12 | exp_dir: Optional[Path] 13 | 14 | def __init__(self): 15 | self.log_dir = None 16 | self.exp_dir = None 17 | self.run_id = None 18 | self._exp = {} 19 | 20 | def init(self, exp_dir): 21 | self.exp_dir = Path(exp_dir) 22 | with fasteners.InterProcessLock(self.exp_dir / '.lock'): 23 | self._exp = self._read_exp_status() 24 | self.run_id = str(self._exp['id']) 25 | self._exp['id'] += 1 26 | 27 | self.log_dir = Path(exp_dir).expanduser() / self.run_id 28 | self.log_dir.mkdir(parents=True, exist_ok=True) 29 | self.git_backup() 30 | 31 | toml.dump(self._exp, open(self.exp_dir / '.status.toml', 'w')) 32 | 33 | def _read_exp_status(self): 34 | status_path = self.exp_dir / '.status.toml' 35 | if status_path.exists(): 36 | return toml.load(status_path) 37 | else: 38 | run_id = 0 39 | while (self.exp_dir / str(run_id)).exists(): 40 | run_id += 1 41 | return {'id': run_id} 42 | 43 | def git_backup(self): 44 | try: 45 | from git import Repo 46 | from git.exc import InvalidGitRepositoryError 47 | except ImportError as e: 48 | print(f"Can't import `git`: {e}") 49 | return 50 | 51 | try: 52 | repo = Repo('.') 53 | pkg = ZipFile(self.log_dir / 'source.zip', 'w') 54 | 55 | for file_name in repo.git.ls_files().split(): 56 | pkg.write(file_name) 57 | 58 | except InvalidGitRepositoryError as e: 59 | print(f"Can't use git to backup files: {e}") 60 | except FileNotFoundError as e: 61 | print(f"Can't find file {e}. Did you delete a file and forget to `git add .`") 62 | 63 | def resolve(self, file_name: str): 64 | if '$LOGDIR' in file_name: 65 | file_name = file_name.replace('$LOGDIR', str(self.log_dir)) 66 | return Path(file_name).expanduser() 67 | 68 | def save(self, file_name: str, array: np.ndarray): 69 | resolved = self.resolve(file_name) 70 | if resolved: 71 | np.save(str(resolved), array) 72 | 73 | def load(self, file_name: str): 74 | return np.load(self.resolve(file_name)) 75 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Implicit Regularization in Deep Matrix Factorization 2 | 3 | Code for [ 4 | Implicit Regularization in Deep Matrix Factorization](https://arxiv.org/abs/1905.13655). 5 | 6 | ## Installation 7 | 8 | Please ues Python 3.7 for running this code. 9 | 10 | ```bash 11 | pip install -r requirements.txt 12 | ``` 13 | 14 | ## Dataset Generation 15 | 16 | Here is the example for generating the inputs for matrix completion with n = 100, rank = 5 and 2k samples. 17 | 18 | ```bash 19 | mkdir -p datasets/mat-cmpl 20 | python gen_gt.py --config configs/mat-cmpl/gen_gt.toml 21 | python gen_obs.py --config configs/mat-cmpl/gen_obs.toml --set n_train_samples 2000 22 | ``` 23 | 24 | ## Experiments 25 | 26 | If you just want to run one experiment, use the following command as an example. 27 | 28 | ```bash 29 | python main.py --print_config --log_dir /tmp/exp1 \ 30 | --config configs/mat-cmpl/run.toml \ 31 | --config configs/mat-cmpl/2000.toml \ 32 | --config configs/opt/grouprmsprop.toml \ 33 | --set depth 2 34 | ``` 35 | 36 | For nuclear norm minimization: 37 | 38 | ```bash 39 | python main.py --print_config --log_dir /tmp/exp2 \ 40 | --config configs/mat-cmpl/run.toml \ 41 | --config configs/mat-cmpl/2000.toml \ 42 | --config configs/opt/cvx.toml 43 | ``` 44 | 45 | For dynamics of gradient descent (Figure 3): 46 | 47 | ```bash 48 | python main.py --log_dir /tmp --print_config \ 49 | --config configs/ml-100k.toml \ 50 | --config configs/opt/SGD.toml \ 51 | --config configs/dynamics.toml \ 52 | --set depth 2 53 | ``` 54 | 55 | 56 | The results will be saved at `/tmp/ID`, where `ID` is a different number for each run and startsfrom 0. 57 | 58 | To run multiple experiments sequentially, you can use `./scripts/run.rb` (please make sure Ruby is installed and `gem install colorize --user`). The code will log into `~/logs` by default. 59 | 60 | ```bash 61 | ./scripts/run.rb --n_jobs 3 --name mat-cmpl \ 62 | --template 'python main.py --print_config --log_dir LOGDIR --config configs/mat-cmpl/run.toml --config configs/mat-cmpl/SAMPLES.toml --config configs/opt/grouprmsprop.toml --set depth DEPTH --set lr LR --set init_scale SCALE' \ 63 | --replace LR=0.001,0.0003 \ 64 | --replace DEPTH=2,3,4 \ 65 | --replace SCALE=1.e-3,1.e-4,1.e-5,1.e-6 \ 66 | --replace SAMPLES=2000,5000 67 | ``` 68 | 69 | For multiple experiments on nuclear norm minimization: 70 | 71 | ```bash 72 | ./scripts/run.rb --n_jobs 1 --name mat-cmpl-cvx \ 73 | --template 'python main.py --print_config --log_dir LOGDIR --config configs/mat-cmpl/run.toml --config configs/mat-cmpl/SAMPLES.toml --config configs/opt/cvx.toml' \ 74 | --replace SAMPLES=2000,5000 75 | ``` 76 | 77 | # Plotting 78 | 79 | We use the Jupyter notebook `plot.ipynb` to generate our figures. 80 | 81 | Please modify 4-th cell to load all results. The directories are the corresponding `--log_dir` option, e.g., `/tmp/exp1` in the first example. 82 | -------------------------------------------------------------------------------- /lunzi/base_flags.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List, Any 2 | 3 | from .injector import inject, ParamInjector 4 | 5 | 6 | class MetaFLAGS(type): 7 | _frozen = False 8 | seed: int 9 | 10 | def __setattr__(cls, key: str, value: Any): 11 | # assert not cls._frozen, 'Modifying frozen FLAGS.' 12 | super().__setattr__(key, value) 13 | 14 | def __getitem__(cls, item: str): 15 | return cls.__dict__[item] 16 | 17 | def add(cls, key: str, value: Any, overwrite=False, overwrite_false=False): 18 | if key not in cls or overwrite or not getattr(cls, key) and overwrite_false: 19 | setattr(cls, key, value) 20 | 21 | def __iter__(cls): 22 | for key, value in cls.__dict__.items(): 23 | if not key.startswith('_') and not isinstance(value, classmethod): 24 | if isinstance(value, MetaFLAGS): 25 | value = dict(value) 26 | yield key, value 27 | 28 | def as_dict(cls): 29 | return dict(cls) 30 | 31 | def freeze(cls): 32 | for key, value in cls.__dict__.items(): 33 | if not key.startswith('_') and isinstance(value, MetaFLAGS): 34 | value.freeze() 35 | cls.finalize() 36 | cls._frozen = True 37 | 38 | def _injector(cls, _, parameters): 39 | claims = [] 40 | for param in parameters.keys() & cls.__dict__.keys(): 41 | annotation = cls.__annotations__.get(param, None) 42 | claims.append(ParamInjector(param, lambda *_, param_=param: cls.__dict__[param_], annotation)) 43 | return claims 44 | 45 | def finalize(cls): 46 | pass 47 | 48 | @property 49 | def inject(cls): 50 | """ 51 | Generate a new `inject` instance, in case `fn.__injectors__` 52 | is changed. 53 | """ 54 | return inject(cls._injector) 55 | 56 | 57 | class BaseFLAGS(metaclass=MetaFLAGS): 58 | pass 59 | 60 | 61 | def merge(lhs: Union[MetaFLAGS, dict], rhs: dict): 62 | # import ipdb; ipdb.set_trace() 63 | for key in rhs: 64 | keys = lhs if isinstance(lhs, dict) else lhs.__dict__ 65 | assert key in keys, f"Can't find key `{key}`" 66 | if isinstance(lhs[key], (MetaFLAGS, dict)) and isinstance(rhs[key], dict): 67 | merge(lhs[key], rhs[key]) 68 | else: 69 | if isinstance(lhs, dict): 70 | lhs[key] = rhs[key] 71 | else: 72 | setattr(lhs, key, rhs[key]) 73 | 74 | 75 | def set_value(cls: Union[MetaFLAGS, dict], path: List[str], value: Any): 76 | key, *rest = path 77 | keys = cls if isinstance(cls, dict) else cls.__dict__ 78 | assert key in keys, f"Can't find key `{key}`" 79 | if not rest: 80 | if isinstance(cls, dict): 81 | cls[key] = value 82 | else: 83 | setattr(cls, key, value) 84 | else: 85 | assert isinstance(cls[key], (MetaFLAGS, dict)) 86 | set_value(cls[key], rest, value) 87 | -------------------------------------------------------------------------------- /lunzi/injector.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Callable, Any, Tuple, List, Optional 2 | from inspect import signature, Parameter 3 | from dataclasses import dataclass 4 | 5 | import wrapt 6 | 7 | Injector = Callable[[Callable, dict], List[Tuple[str, Callable]]] 8 | 9 | 10 | @dataclass 11 | class ParamInjector: 12 | key: str 13 | getter: Callable 14 | annotation: Optional[type] 15 | cache: bool = False 16 | 17 | 18 | class DefaultInjector: 19 | _params: Dict[str, ParamInjector] = {} 20 | _cache: Dict[Tuple[str, Callable], Any] = {} 21 | 22 | @staticmethod 23 | def register(tag: str, injector: ParamInjector): 24 | DefaultInjector._params[tag] = injector 25 | 26 | @staticmethod 27 | def inject(fn, parameters: dict) -> List[ParamInjector]: 28 | claims = [] 29 | for param in DefaultInjector._params.keys() & parameters.keys(): 30 | injector = DefaultInjector._params[param] 31 | 32 | key = (param, fn) 33 | if injector.cache: 34 | if key not in DefaultInjector._cache: 35 | injection = injector.getter(fn) 36 | DefaultInjector._cache[key] = injection 37 | else: 38 | injection = DefaultInjector._cache[key] 39 | getter = lambda *_, injection_=injection: injection_ 40 | else: 41 | getter = injector.getter 42 | claims.append(ParamInjector(param, getter, injector.annotation)) 43 | return claims 44 | 45 | 46 | # By default, all injections starting with _ are ignored, unless we run it from ours. 47 | _default_injectors = [DefaultInjector.inject] 48 | 49 | 50 | def inject(*injectors: Callable): 51 | injectors = list(injectors) 52 | 53 | def decorate(fn: Callable): 54 | sig = signature(fn) 55 | parameters = sig.parameters 56 | assigners = {} 57 | annotations = fn.__annotations__.copy() 58 | 59 | new_params: Dict[str, Parameter] = parameters.copy() 60 | for injector in injectors + _default_injectors: 61 | for param_injector in injector(fn, parameters): 62 | key = param_injector.key 63 | if key not in assigners: 64 | assigners[key] = param_injector.getter 65 | if param_injector.annotation: 66 | annotations[key] = param_injector.annotation 67 | # we enforce these parameters to be keyword-only parameters. 68 | # to change the default value, you have to explicitly specify it. 69 | new_params[key] = new_params[key].replace(kind=Parameter.KEYWORD_ONLY) 70 | 71 | def adapter(): pass 72 | adapter.__annotations__ = annotations 73 | adapter.__signature__ = sig.replace(parameters=new_params.values()) 74 | 75 | @wrapt.decorator(adapter=adapter) 76 | def injecting(wrapped, instance, args, kwargs): 77 | # easier for pdb... only need to step over one line 78 | new_kwargs = {key: getter() for key, getter in assigners.items()} 79 | new_kwargs.update(kwargs) 80 | return wrapped(*args, **new_kwargs) 81 | 82 | injected_fn = injecting(fn) 83 | injected_fn.__unwrapped__ = fn 84 | injected_fn.more = lambda *extra: inject(*extra, *injectors) 85 | injected_fn.injectors = injectors 86 | return injected_fn 87 | return decorate 88 | -------------------------------------------------------------------------------- /lunzi/experiment.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from logging import getLogger, FileHandler, Logger 3 | 4 | try: 5 | from tensorboardX import SummaryWriter 6 | except ImportError: 7 | SummaryWriter = None 8 | 9 | import lunzi as lz 10 | from .base_flags import MetaFLAGS, merge, set_value 11 | from .file_storage import FileStorage 12 | 13 | 14 | def add_file_handler(logger: Logger, file_path: str): 15 | import coloredlogs 16 | file_handler = FileHandler(file_path) 17 | file_handler.setFormatter(coloredlogs.BasicFormatter(fmt='%(asctime)s - %(filename)s:%(lineno)d - %(message)s')) 18 | logger.addHandler(file_handler) 19 | 20 | 21 | def get_logger(name: str) -> Logger: 22 | import coloredlogs 23 | 24 | logger = getLogger(name) 25 | coloredlogs.install( 26 | logger=logger, 27 | milliseconds=True, 28 | fmt='%(asctime)s - %(filename)s:%(lineno)d - %(message)s', 29 | field_styles={**coloredlogs.DEFAULT_FIELD_STYLES, 'filename': {'color': 'cyan'}}, 30 | ) 31 | 32 | return logger 33 | 34 | 35 | def set_random_seed(seed: int): 36 | import random 37 | random.seed(seed) 38 | 39 | np.random.seed(seed) 40 | 41 | try: 42 | import tensorflow 43 | tensorflow.set_random_seed(seed) 44 | except ImportError: 45 | pass 46 | 47 | try: 48 | import torch 49 | torch.manual_seed(seed) 50 | if torch.cuda.is_available(): 51 | torch.cuda.manual_seed_all(seed) 52 | except ImportError: 53 | pass 54 | 55 | 56 | def parse_string(s: str): 57 | try: 58 | import ast 59 | return ast.literal_eval(s) 60 | except (ValueError, SyntaxError): 61 | return s 62 | 63 | 64 | def init(root: MetaFLAGS, doc: str = ''): 65 | if 'seed' not in root: 66 | import os 67 | root.add('seed', int.from_bytes(os.urandom(3), 'little')) 68 | args, unknown = parse(root, doc) 69 | 70 | seed = root.seed 71 | root.freeze() 72 | log_dir = args.log_dir 73 | 74 | set_random_seed(seed) 75 | 76 | if log_dir is not None: 77 | lz.fs.init(log_dir) 78 | lz.writer = SummaryWriter(logdir=str(lz.fs.log_dir)) 79 | dump(root) 80 | add_file_handler(lz.log, lz.fs.resolve('$LOGDIR/out.log')) 81 | lz.log.warning(f'log_dir = {str(lz.fs.log_dir)}') 82 | else: 83 | lz.log.critical('no log_dir provided') 84 | if unknown: 85 | lz.log.critical('unknown arguments: %s', unknown) 86 | 87 | if args.print_config: 88 | import toml 89 | print('----- FLAGS begin -----') 90 | print(toml.dumps(root.as_dict())) 91 | print('----- FLAGS end -----') 92 | 93 | 94 | def dump(root: MetaFLAGS): 95 | import toml 96 | 97 | with open(lz.fs.resolve('$LOGDIR/config.toml'), 'w') as f: 98 | toml.dump(root.as_dict(), f) 99 | 100 | lz.info['log_dir'] = lz.fs.log_dir 101 | with open(lz.fs.resolve('$LOGDIR/meta.toml'), 'w') as f: 102 | toml.dump(lz.info, f) 103 | 104 | 105 | def set_default_injector(): 106 | from .injector import ParamInjector, DefaultInjector 107 | injectors = [ 108 | ParamInjector('_seed', lambda *_: np.random.randint(1, 10 ** 9), int, cache=True), 109 | ParamInjector('_rng', lambda *_: np.random.RandomState(np.random.randint(0, 2 ** 32 - 1)), 110 | np.random.RandomState, cache=True), 111 | ParamInjector('_fs', lambda *_: lz.fs, FileStorage), 112 | ParamInjector('_writer', lambda *_: lz.writer, SummaryWriter), 113 | ParamInjector('_log', lambda *_: lz.log, Logger), 114 | ParamInjector('_info', lambda *_: lz.info, dict), 115 | ] 116 | 117 | for injector in injectors: 118 | DefaultInjector.register(injector.key, injector) 119 | 120 | 121 | set_default_injector() 122 | 123 | 124 | def parse(root: MetaFLAGS, doc=''): 125 | import toml 126 | import argparse 127 | from pathlib import Path 128 | 129 | parser = argparse.ArgumentParser(description=doc) 130 | parser.add_argument('-c', '--config', help='configuration file (TOML)', action='append', metavar='FILE') 131 | parser.add_argument('-s', '--set', help='additional options', nargs=2, action='append', metavar=('PATH', 'VALUE')) 132 | parser.add_argument('--print_config', help='print configs', action='store_true') 133 | parser.add_argument('--log_dir', help='the directory to logs', default='/tmp') 134 | 135 | args, unknown = parser.parse_known_args() 136 | if args.config: 137 | for config in args.config: 138 | merge(root, toml.load(open(Path(config).expanduser()))) 139 | if args.set: 140 | for path, value in args.set: 141 | set_value(root, path.split('.'), parse_string(value)) 142 | 143 | return args, unknown 144 | 145 | 146 | def close(): 147 | lz.writer.close() 148 | 149 | 150 | def main(root: MetaFLAGS, doc: str = ''): 151 | def decorate(fn): 152 | def decorated(): 153 | init(root, doc) 154 | fn() 155 | close() 156 | return decorated 157 | return decorate 158 | -------------------------------------------------------------------------------- /lunzi/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numpy.lib.recfunctions 3 | from collections import namedtuple 4 | 5 | 6 | class BaseDataset(object): 7 | def __getitem__(self, item): 8 | raise NotImplementedError 9 | 10 | def apply(self, transforms): 11 | return TransformedDataset(self, transforms) 12 | 13 | @property 14 | def size(self): 15 | raise NotImplementedError 16 | 17 | 18 | class Dataset(np.recarray, BaseDataset): 19 | def __init__(self, dtype, size): 20 | super().__init__() 21 | self.resize(size) 22 | 23 | @staticmethod 24 | def fromdict(**arrays): 25 | array_list = [np.asarray(x) for x in arrays.values()] 26 | rank = 0 27 | for rank in range(len(array_list[0].shape) + 1): 28 | ok = True 29 | for array in array_list: 30 | if len(array.shape) <= rank or array.shape[rank] != array_list[0].shape[rank]: 31 | ok = False 32 | break 33 | if not ok: 34 | break 35 | 36 | dtypes = [] 37 | for name, array in zip(arrays.keys(), array_list): 38 | dtypes.append((name, (array.dtype, array.shape[rank:]))) 39 | return np.rec.fromarrays(array_list, dtypes).view(Dataset) 40 | 41 | def __new__(cls, dtype, size): 42 | return np.recarray.__new__(cls, size, dtype=dtype) 43 | 44 | def to_dict(self): 45 | return {name: self[name] for name in self.dtype.names} 46 | 47 | def to_torch(self): 48 | import torch 49 | return {name: torch.tensor(self[name].copy()) for name in self.dtype.names} 50 | 51 | @property 52 | def size(self): 53 | return len(self) 54 | 55 | def sample(self, size, indices=None): 56 | if indices is None: 57 | indices = np.random.randint(0, self.size, size=size) 58 | return self[indices] 59 | 60 | def append_fields(self, names, data, dtype=None): 61 | if isinstance(names, str) and names in self.dtype.names: 62 | self[names] = data 63 | return self 64 | return np.lib.recfunctions.append_fields(self, names, data, dtype, usemask=False).view(Dataset) 65 | 66 | def drop_fields(self, names): 67 | return np.lib.recfunctions.drop_fields(self, names, usemask=False).view(Dataset) 68 | 69 | def iterator(self, batch_size, shuffle=True): 70 | indices = np.arange(self.size, dtype=np.int32) 71 | if shuffle: 72 | np.random.shuffle(indices) 73 | index = 0 74 | while index < self.size: 75 | end = min(index + batch_size, self.size) 76 | yield self[indices[index:end]] 77 | index = end 78 | 79 | def sample_iterator(self, batch_size, n_iters=0): 80 | while True: 81 | yield self.sample(batch_size) 82 | n_iters = max(n_iters - 1, -1) 83 | if n_iters == 0: 84 | break 85 | 86 | 87 | class TransformedDataset(BaseDataset): 88 | def __init__(self, dataset: BaseDataset, transform): 89 | if isinstance(dataset, TransformedDataset): 90 | self._dataset = dataset._dataset 91 | self._transforms = dataset._transforms + [transform] 92 | else: 93 | self._dataset = dataset 94 | self._transforms = [transform] 95 | 96 | def __getitem__(self, items): 97 | items = self._dataset[items] 98 | for transform in self._transforms: 99 | items = transform(items) 100 | return items 101 | 102 | @property 103 | def size(self): 104 | return self._dataset.size 105 | 106 | 107 | class ExtendableDataset(Dataset): 108 | """ 109 | Overallocation can be supported, by making examinations before 110 | each `append` and `extend`. 111 | """ 112 | 113 | def __init__(self, dtype, max_size, verbose=False): 114 | super().__init__(dtype, max_size) 115 | self.max_size = max_size 116 | self._index = 0 117 | self._buf_size = 0 118 | self._len = 0 119 | self._buf_size = max_size 120 | 121 | def __new__(cls, dtype, max_size): 122 | return np.recarray.__new__(cls, max_size, dtype=dtype) 123 | 124 | @property 125 | def size(self): 126 | return self._len 127 | 128 | def reserve(self, size): 129 | cur_size = max(self._buf_size, 1) 130 | while cur_size < size: 131 | cur_size *= 2 132 | if cur_size != self._buf_size: 133 | self.resize(cur_size) 134 | 135 | def clear(self): 136 | self._index = 0 137 | self._len = 0 138 | return self 139 | 140 | def append(self, item): 141 | self[self._index] = item 142 | self._index = (self._index + 1) % self.max_size 143 | self._len = min(self._len + 1, self.max_size) 144 | return self 145 | 146 | def extend(self, items): 147 | n_new = len(items) 148 | if n_new > self.max_size: 149 | items = items[-self.max_size:] 150 | n_new = self.max_size 151 | 152 | n_tail = self.max_size - self._index 153 | if n_new <= n_tail: 154 | self[self._index:self._index + n_new] = items 155 | else: 156 | n_head = n_new - n_tail 157 | self[self._index:] = items[:n_tail] 158 | self[:n_head] = items[n_tail:] 159 | 160 | self._index = (self._index + n_new) % self.max_size 161 | self._len = min(self._len + n_new, self.max_size) 162 | return self 163 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import cvxpy as cvx 6 | 7 | import lunzi as lz 8 | from lunzi.typing import * 9 | from opt import GroupRMSprop 10 | 11 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 12 | 13 | 14 | class FLAGS(lz.BaseFLAGS): 15 | problem = '' 16 | gt_path = '' 17 | obs_path = '' 18 | 19 | depth = 1 20 | n_train_samples = 0 21 | n_iters = 1000000 22 | n_dev_iters = max(1, n_iters // 1000) 23 | init_scale = 0.001 # average magnitude of entries 24 | shape = [0, 0] 25 | n_singulars_save = 0 26 | 27 | optimizer = 'GroupRMSprop' 28 | initialization = 'gaussian' # `orthogonal` or `identity` or `gaussian` 29 | lr = 0.01 30 | train_thres = 1.e-6 31 | 32 | hidden_sizes = [] 33 | 34 | @classmethod 35 | def finalize(cls): 36 | assert cls.problem 37 | cls.add('hidden_sizes', [cls.shape[0]] + [cls.shape[1]] * cls.depth, overwrite_false=True) 38 | 39 | 40 | def get_e2e(model): 41 | weight = None 42 | for fc in model.children(): 43 | assert isinstance(fc, nn.Linear) and fc.bias is None 44 | if weight is None: 45 | weight = fc.weight.t() 46 | else: 47 | weight = fc(weight) 48 | 49 | return weight 50 | 51 | 52 | @FLAGS.inject 53 | def init_model(model, *, hidden_sizes, initialization, init_scale, _log): 54 | depth = len(hidden_sizes) - 1 55 | 56 | if initialization == 'orthogonal': 57 | scale = (init_scale * np.sqrt(hidden_sizes[0]))**(1. / depth) 58 | matrices = [] 59 | for param in model.parameters(): 60 | nn.init.orthogonal_(param) 61 | param.data.mul_(scale) 62 | matrices.append(param.data.cpu().numpy()) 63 | for a, b in zip(matrices, matrices[1:]): 64 | assert np.allclose(a.dot(a.T), b.T.dot(b), atol=1e-6) 65 | elif initialization == 'identity': 66 | scale = init_scale**(1. / depth) 67 | for param in model.parameters(): 68 | nn.init.eye_(param) 69 | param.data.mul_(scale) 70 | elif initialization == 'gaussian': 71 | n = hidden_sizes[0] 72 | assert hidden_sizes[0] == hidden_sizes[-1] 73 | scale = init_scale**(1. / depth) * n**(-0.5) 74 | for param in model.parameters(): 75 | nn.init.normal_(param, std=scale) 76 | e2e = get_e2e(model).detach().cpu().numpy() 77 | e2e_fro = np.linalg.norm(e2e, 'fro') 78 | desired_fro = FLAGS.init_scale * np.sqrt(n) 79 | _log.info(f"[check] e2e fro norm: {e2e_fro:.6e}, desired = {desired_fro:.6e}") 80 | assert 0.8 <= e2e_fro / desired_fro <= 1.2 81 | elif initialization == 'uniform': 82 | n = hidden_sizes[0] 83 | assert hidden_sizes[0] == hidden_sizes[-1] 84 | scale = np.sqrt(3.) * init_scale**(1. / depth) * n**(-0.5) 85 | for param in model.parameters(): 86 | nn.init.uniform_(param, a=-scale, b=scale) 87 | e2e = get_e2e(model).detach().cpu().numpy() 88 | e2e_fro = np.linalg.norm(e2e, 'fro') 89 | desired_fro = FLAGS.init_scale * np.sqrt(n) 90 | _log.info(f"[check] e2e fro norm: {e2e_fro:.6e}, desired = {desired_fro:.6e}") 91 | assert 0.8 <= e2e_fro / desired_fro <= 1.2 92 | else: 93 | assert 0 94 | 95 | 96 | class BaseProblem: 97 | def get_d_e2e(self, e2e): 98 | pass 99 | 100 | def get_train_loss(self, e2e): 101 | pass 102 | 103 | def get_test_loss(self, e2e): 104 | pass 105 | 106 | def get_cvx_opt_constraints(self, x) -> list: 107 | pass 108 | 109 | 110 | @FLAGS.inject 111 | def cvx_opt(prob: BaseProblem, *, shape, _log: Logger, _writer: SummaryWriter, _fs: FileStorage): 112 | x = cvx.Variable(shape=shape) 113 | 114 | objective = cvx.Minimize(cvx.norm(x, 'nuc')) 115 | constraints = prob.get_cvx_opt_constraints(x) 116 | 117 | problem = cvx.Problem(objective, constraints) 118 | problem.solve(solver=cvx.SCS, verbose=True, use_indirect=False) 119 | e2e = torch.from_numpy(x.value).float() 120 | 121 | train_loss = prob.get_train_loss(e2e) 122 | test_loss = prob.get_test_loss(e2e) 123 | 124 | nuc_norm = e2e.norm('nuc') 125 | _log.info(f"train loss = {train_loss.item():.3e}, " 126 | f"test error = {test_loss.item():.3e}, " 127 | f"nuc_norm = {nuc_norm.item():.3f}") 128 | _writer.add_scalar('loss/train', train_loss.item()) 129 | _writer.add_scalar('loss/test', test_loss.item()) 130 | _writer.add_scalar('nuc_norm', nuc_norm.item()) 131 | 132 | torch.save(e2e, _fs.resolve('$LOGDIR/nuclear.npy')) 133 | 134 | 135 | class MatrixCompletion(BaseProblem): 136 | ys: torch.Tensor 137 | 138 | @FLAGS.inject 139 | def __init__(self, *, gt_path, obs_path): 140 | self.w_gt = torch.load(gt_path, map_location=device) 141 | (self.us, self.vs), self.ys_ = torch.load(obs_path, map_location=device) 142 | 143 | def get_train_loss(self, e2e): 144 | self.ys = e2e[self.us, self.vs] 145 | return (self.ys - self.ys_).pow(2).mean() 146 | 147 | def get_test_loss(self, e2e): 148 | return (self.w_gt - e2e).view(-1).pow(2).mean() 149 | 150 | @FLAGS.inject 151 | def get_d_e2e(self, e2e, shape): 152 | d_e2e = torch.zeros(shape, device=device) 153 | d_e2e[self.us, self.vs] = self.ys - self.ys_ 154 | d_e2e = d_e2e / len(self.ys_) 155 | return d_e2e 156 | 157 | @FLAGS.inject 158 | def get_cvx_opt_constraints(self, x, shape): 159 | A = np.zeros(shape) 160 | mask = np.zeros(shape) 161 | A[self.us, self.vs] = self.ys_ 162 | mask[self.us, self.vs] = 1 163 | eps = 1.e-3 164 | constraints = [cvx.abs(cvx.multiply(x - A, mask)) <= eps] 165 | return constraints 166 | 167 | 168 | class MatrixCompletionOld(MatrixCompletion): 169 | @FLAGS.inject 170 | def __init__(self, *, obs_path): 171 | self.w_gt, (self.us, self.vs), self.ys_ = torch.load(obs_path, map_location=device) 172 | 173 | 174 | class MatrixSensing(BaseProblem): 175 | ys: torch.Tensor 176 | 177 | @FLAGS.inject 178 | def __init__(self, *, gt_path, obs_path): 179 | self.w_gt = torch.load(gt_path, map_location=device) 180 | self.xs, self.ys_ = torch.load(obs_path, map_location=device) 181 | 182 | def get_train_loss(self, e2e): 183 | self.ys = (self.xs * e2e).sum(dim=-1).sum(dim=-1) 184 | return (self.ys - self.ys_).pow(2).mean() 185 | 186 | def get_test_loss(self, e2e): 187 | return (self.w_gt - e2e).view(-1).pow(2).mean() 188 | 189 | @FLAGS.inject 190 | def get_d_e2e(self, e2e, shape): 191 | d_e2e = self.xs.view(-1, *shape) * (self.ys - self.ys_).view(len(self.xs), 1, 1) 192 | d_e2e = d_e2e.sum(0) 193 | return d_e2e 194 | 195 | def get_cvx_opt_constraints(self, X): 196 | eps = 1.e-3 197 | constraints = [] 198 | for x, y_ in zip(self.xs, self.ys_): 199 | constraints.append(cvx.abs(cvx.sum(cvx.multiply(X, x)) - y_) <= eps) 200 | return constraints 201 | 202 | 203 | class MovieLens100k(BaseProblem): 204 | ys: torch.Tensor 205 | 206 | @FLAGS.inject 207 | def __init__(self, *, obs_path, n_train_samples): 208 | (self.us, self.vs), ys_ = torch.load(obs_path, map_location=device) 209 | # self.ys_ = (ys_ - ys_.mean()) / ys_.std() 210 | self.ys_ = ys_ 211 | self.n_train_samples = n_train_samples 212 | 213 | def get_train_loss(self, e2e): 214 | self.ys = e2e[self.us[:self.n_train_samples], self.vs[:self.n_train_samples]] 215 | return (self.ys - self.ys_[:self.n_train_samples]).pow(2).mean() 216 | 217 | def get_test_loss(self, e2e): 218 | ys = e2e[self.us[self.n_train_samples:], self.vs[self.n_train_samples:]] 219 | return (ys - self.ys_[self.n_train_samples:]).pow(2).mean() 220 | 221 | @FLAGS.inject 222 | def get_d_e2e(self, e2e, *, shape): 223 | d_e2e = torch.zeros(shape, device=device) 224 | d_e2e[self.us[:self.n_train_samples], self.vs[:self.n_train_samples]] = \ 225 | self.ys - self.ys_[:self.n_train_samples] 226 | d_e2e = d_e2e / len(self.ys_) 227 | return d_e2e 228 | 229 | @FLAGS.inject 230 | def get_cvx_opt_constraints(self, x, *, shape): 231 | A = np.zeros(shape) 232 | mask = np.zeros(shape) 233 | A[self.us[:self.n_train_samples], self.vs[:self.n_train_samples]] = self.ys_[:self.n_train_samples] 234 | mask[self.us[:self.n_train_samples], self.vs[:self.n_train_samples]] = 1 235 | eps = 1.e-3 236 | constraints = [cvx.abs(cvx.multiply(x - A, mask)) <= eps] 237 | return constraints 238 | 239 | 240 | @lz.main(FLAGS) 241 | @FLAGS.inject 242 | def main(*, depth, hidden_sizes, n_iters, problem, train_thres, _seed, _log, _writer, _info, _fs): 243 | prob: BaseProblem 244 | if problem == 'matrix-completion': 245 | prob = MatrixCompletion() 246 | elif problem == 'matrix-sensing': 247 | prob = MatrixSensing() 248 | elif problem == 'ml-100k': 249 | prob = MovieLens100k() 250 | else: 251 | raise ValueError 252 | 253 | layers = zip(hidden_sizes, hidden_sizes[1:]) 254 | model = nn.Sequential(*[nn.Linear(f_in, f_out, bias=False) for (f_in, f_out) in layers]).to(device) 255 | _log.info(model) 256 | 257 | if FLAGS.optimizer == 'SGD': 258 | optimizer = optim.SGD(model.parameters(), FLAGS.lr) 259 | elif FLAGS.optimizer == 'GroupRMSprop': 260 | optimizer = GroupRMSprop(model.parameters(), FLAGS.lr, eps=1e-4) 261 | elif FLAGS.optimizer == 'Adam': 262 | optimizer = optim.Adam(model.parameters(), FLAGS.lr) 263 | elif FLAGS.optimizer == 'cvxpy': 264 | cvx_opt(prob) 265 | return 266 | else: 267 | raise ValueError 268 | 269 | init_model(model) 270 | 271 | loss = None 272 | for T in range(n_iters): 273 | e2e = get_e2e(model) 274 | 275 | loss = prob.get_train_loss(e2e) 276 | 277 | params_norm = 0 278 | for param in model.parameters(): 279 | params_norm = params_norm + param.pow(2).sum() 280 | optimizer.zero_grad() 281 | loss.backward() 282 | 283 | with torch.no_grad(): 284 | test_loss = prob.get_test_loss(e2e) 285 | 286 | if T % FLAGS.n_dev_iters == 0 or loss.item() <= train_thres: 287 | 288 | U, singular_values, V = e2e.svd() # U D V^T = e2e 289 | schatten_norm = singular_values.pow(2. / depth).sum() 290 | 291 | d_e2e = prob.get_d_e2e(e2e) 292 | full = U.t().mm(d_e2e).mm(V).abs() # we only need the magnitude. 293 | n, m = full.shape 294 | 295 | diag = full.diag() 296 | mask = torch.ones_like(full, dtype=torch.int) 297 | mask[np.arange(min(n, m)), np.arange(min(n, m))] = 0 298 | off_diag = full.masked_select(mask > 0) 299 | _writer.add_scalar('diag/mean', diag.mean().item(), global_step=T) 300 | _writer.add_scalar('diag/std', diag.std().item(), global_step=T) 301 | _writer.add_scalar('off_diag/mean', off_diag.mean().item(), global_step=T) 302 | _writer.add_scalar('off_diag/std', off_diag.std().item(), global_step=T) 303 | 304 | grads = [param.grad.cpu().data.numpy().reshape(-1) for param in model.parameters()] 305 | grads = np.concatenate(grads) 306 | avg_grads_norm = np.sqrt(np.mean(grads**2)) 307 | avg_param_norm = np.sqrt(params_norm.item() / len(grads)) 308 | 309 | if isinstance(optimizer, GroupRMSprop): 310 | adjusted_lr = optimizer.param_groups[0]['adjusted_lr'] 311 | else: 312 | adjusted_lr = optimizer.param_groups[0]['lr'] 313 | _log.info(f"Iter #{T}: train = {loss.item():.3e}, test = {test_loss.item():.3e}, " 314 | f"Schatten norm = {schatten_norm:.3e}, " 315 | f"grad: {avg_grads_norm:.3e}, " 316 | f"lr = {adjusted_lr:.3f}") 317 | 318 | _writer.add_scalar('loss/train', loss.item(), global_step=T) 319 | _writer.add_scalar('loss/test', test_loss, global_step=T) 320 | _writer.add_scalar('Schatten_norm', schatten_norm, global_step=T) 321 | _writer.add_scalar('norm/grads', avg_grads_norm, global_step=T) 322 | _writer.add_scalar('norm/params', avg_param_norm, global_step=T) 323 | 324 | for i in range(FLAGS.n_singulars_save): 325 | _writer.add_scalar(f'singular_values/{i}', singular_values[i], global_step=T) 326 | 327 | torch.save(e2e, _fs.resolve("$LOGDIR/final.npy")) 328 | if loss.item() <= train_thres: 329 | break 330 | optimizer.step() 331 | 332 | _log.info(f"train loss = {loss.item()}. test loss = {test_loss.item()}") 333 | 334 | 335 | if __name__ == '__main__': 336 | main() 337 | --------------------------------------------------------------------------------