├── dataset ├── __init__.py ├── data │ ├── __init__.py │ ├── dataset │ │ ├── __init__.py │ │ ├── base_dataset.py │ │ ├── prompt_dataset.py │ │ └── stream_prompt_dataset.py │ └── json_data.py ├── utils │ ├── __init__.py │ ├── accelerate_utils.py │ ├── multipreprocess_utils.py │ ├── deepspeed_utils.py │ ├── print_utils.py │ └── io_utils.py ├── prompt_maker │ ├── __init__.py │ ├── base_prompt_maker.py │ ├── translation_prompt.py │ ├── custom_prompt_maker.py │ ├── openorca_prompt_maker.py │ ├── translate_prompt_maker.py │ ├── scienceqa_prompt_maker.py │ ├── custom_prompt_maker_inference.py │ ├── alpaca_prompt_maker.py │ └── contrastive_translate_prompt_maker.py ├── debug.ipynb ├── utils.py └── alpaca.py ├── nngeometry ├── __init__.py ├── generator │ ├── __init__.py │ ├── dummy.py │ ├── jacobian │ │ ├── grads_conv.py │ │ └── grads.py │ ├── lm_jacobian │ │ ├── grads_conv.py │ │ └── grads.py │ └── para_lm_jacobian │ │ ├── grads_conv.py │ │ └── grads.py ├── maths.py ├── object │ ├── __init__.py │ ├── map.py │ ├── fspace.py │ ├── vector.py │ └── lm_vector.py ├── llama_layercollection.py ├── utils.py ├── layers.py ├── metrics.py ├── lm_metrics_para.py ├── lm_metrics.py └── layercollection.py ├── requirements.txt ├── .gitignore ├── hessian.sh ├── if_score.sh ├── README.md ├── kfac_mapper.py ├── kfac_launcher.py ├── query_loss_launcher.py └── query_loss_mapper.py /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset/prompt_maker/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /nngeometry/__init__.py: -------------------------------------------------------------------------------- 1 | print('ok') -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | transformers 3 | datasets -------------------------------------------------------------------------------- /nngeometry/generator/__init__.py: -------------------------------------------------------------------------------- 1 | from .jacobian import Jacobian 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | 3 | **/__pycache__ 4 | 5 | *.pyc 6 | 7 | .DS_Store 8 | 9 | **/.DS_Store 10 | -------------------------------------------------------------------------------- /dataset/data/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .prompt_dataset import DynamicPromptDataset, COAIDynamicPromptDataset 2 | from .stream_prompt_dataset import StreamDynamicPromptDataset, COAIStreamDynamicPromptDataset -------------------------------------------------------------------------------- /nngeometry/maths.py: -------------------------------------------------------------------------------- 1 | def kronecker(A, B): 2 | sA = A.size() 3 | sB = B.size() 4 | return (A.view(sA[0], 1, sA[1], 1) * B.view(1, sB[0], 1, sB[1])) \ 5 | .contiguous().view(sA[0] * sB[0], sA[1] * sB[1]) 6 | -------------------------------------------------------------------------------- /dataset/utils/accelerate_utils.py: -------------------------------------------------------------------------------- 1 | from tqdm.auto import tqdm as original_tqdm 2 | from functools import partial 3 | 4 | 5 | def make_tqdm(accelerator, list_data): 6 | tqdm = partial(original_tqdm, disable=not accelerator.is_local_main_process, position=0) 7 | return tqdm(list_data) -------------------------------------------------------------------------------- /dataset/data/dataset/base_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import logging 3 | from torch.utils.data import Dataset 4 | 5 | 6 | class BaseDataset(Dataset): 7 | 8 | def __init__(self, *args, **kargs): 9 | logging.info(f"initiate dataset: {BaseDataset.__name__}") 10 | 11 | 12 | 13 | if __name__ == '__main__': 14 | BaseDataset() 15 | -------------------------------------------------------------------------------- /nngeometry/generator/dummy.py: -------------------------------------------------------------------------------- 1 | class DummyGenerator: 2 | """ 3 | This dummy generator is used for pickled objects 4 | """ 5 | 6 | def __init__(self, layer_collection, device): 7 | self.layer_collection = layer_collection 8 | self.device = device 9 | 10 | def get_device(self): 11 | return self.device 12 | -------------------------------------------------------------------------------- /hessian.sh: -------------------------------------------------------------------------------- 1 | # -g: GPU device for reduce 2 | # -n: Number of GPUs (Parallelism) 3 | # -d: Dataset 4 | # -m: Model 5 | # -k: Tokenizer 6 | # -o: Output file 7 | 8 | python3 kfac_launcher.py \ 9 | -g cuda:0 \ 10 | -n 8 \ 11 | -d demo.json \ 12 | -m ./checkpoint \ 13 | -t 1 \ 14 | -k bigscience/bloom-560m \ 15 | -o hessian.pkl \ 16 | 17 | -------------------------------------------------------------------------------- /nngeometry/object/__init__.py: -------------------------------------------------------------------------------- 1 | from .pspace import (PMatDense, PMatBlockDiag, PMatDiag, 2 | PMatLowRank, PMatImplicit, 3 | PMatKFAC, PMatEKFAC, PMatQuasiDiag, PMatAbstract) 4 | from .vector import (PVector, FVector) 5 | # from .lm_vector import (LMPVector) 6 | from .fspace import (FMatDense,) 7 | from .map import (PushForwardDense, PushForwardImplicit, 8 | PullBackDense) 9 | -------------------------------------------------------------------------------- /dataset/utils/multipreprocess_utils.py: -------------------------------------------------------------------------------- 1 | from itertools import repeat 2 | 3 | 4 | def starmap_with_kwargs(pool, fn, args_iter=None, kwargs_iter=None): 5 | args_for_starmap = zip(repeat(fn), args_iter, kwargs_iter) 6 | return pool.starmap(apply_args_and_kwargs, args_for_starmap) 7 | 8 | 9 | def apply_args_and_kwargs(fn, args, kwargs): 10 | if args is None: 11 | return fn(**kwargs) 12 | else: 13 | return fn(*args, **kwargs) -------------------------------------------------------------------------------- /if_score.sh: -------------------------------------------------------------------------------- 1 | # -k: Hessian file 2 | # -n: Number of GPUs (Parallelism) 3 | # -m: Model 4 | # -t: Tokenizer 5 | # -o: Output file 6 | # -d: Candidate dataset 7 | # -q: Seed dataset 8 | # -bq: Batch size of seed dataset 9 | # -lmd: Lambda in Hessian inverse 10 | 11 | 12 | 13 | python3 query_loss_launcher.py \ 14 | -n 8 -k hessian.pkl \ 15 | -m ./checkpoint \ 16 | -t bigscience/bloom-560m \ 17 | -o score_results.json \ 18 | -d demo.json \ 19 | -q demo.json \ 20 | -bq 2 -lmd 0.5 --full-score 1 -------------------------------------------------------------------------------- /dataset/prompt_maker/base_prompt_maker.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | class BasePromptMaker(): 4 | def __init__(self, *args, **kargs): 5 | pass 6 | 7 | def get_input(self, data_point: Dict[str, str], **kargs) -> str: 8 | raise NotImplementedError() 9 | 10 | def get_target(self, data_point: Dict[str, str], **kargs) -> str: 11 | raise NotImplementedError() 12 | 13 | def get_full(self, data_point: Dict[str, str], **kargs) -> str: 14 | raise NotImplementedError() 15 | -------------------------------------------------------------------------------- /dataset/prompt_maker/translation_prompt.py: -------------------------------------------------------------------------------- 1 | from .base_prompt_maker import BasePromptMaker 2 | from typing import Dict, List 3 | import random 4 | 5 | class TranslationPromptMaker(BasePromptMaker): 6 | def __init__(self, data_point): 7 | self.data_point = data_point 8 | 9 | def get_input(self) -> str: 10 | prompts = self.data_point['translation'] 11 | if len(prompts) > 1: 12 | prompt = random.choice(prompts) 13 | elif len(prompts) == 1: 14 | prompt = prompts[0] 15 | else: 16 | raise Exception 17 | return prompt 18 | 19 | -------------------------------------------------------------------------------- /dataset/prompt_maker/custom_prompt_maker.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from .base_prompt_maker import BasePromptMaker 4 | import random 5 | 6 | 7 | class PromptMaker(BasePromptMaker): 8 | 9 | def get_input(self, data_point, **kargs) -> str: 10 | return f''' 11 | {data_point['system_prompt']} 12 | 13 | {data_point['question']} 14 | 15 | ''' 16 | 17 | def get_target(self, data_point, **kargs) -> str: 18 | target = data_point["response"] 19 | return target 20 | 21 | def get_full(self, data_point, **kargs) -> str: 22 | text = self.get_input(data_point) + self.get_target(data_point) 23 | return text -------------------------------------------------------------------------------- /dataset/prompt_maker/openorca_prompt_maker.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from .base_prompt_maker import BasePromptMaker 4 | import random 5 | 6 | 7 | class PromptMaker(BasePromptMaker): 8 | 9 | def get_input(self, data_point, **kargs) -> str: 10 | return f''' 11 | {data_point['system_prompt']} 12 | 13 | {data_point['question']} 14 | 15 | ''' 16 | 17 | def get_target(self, data_point, **kargs) -> str: 18 | target = data_point["response"] 19 | return target 20 | 21 | def get_full(self, data_point, **kargs) -> str: 22 | text = self.get_input(data_point) + self.get_target(data_point) 23 | return text -------------------------------------------------------------------------------- /dataset/prompt_maker/translate_prompt_maker.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from .base_prompt_maker import BasePromptMaker 4 | import random 5 | 6 | 7 | class PromptMaker(BasePromptMaker): 8 | 9 | def get_input(self, data_point, **kargs) -> str: 10 | if data_point['trg_lang']: 11 | trg_lang=data_point['trg_lang'] 12 | res = f"""Translate the following text into {trg_lang} 13 | 14 | Text: 15 | \"{data_point["src_text"]}\" 16 | 17 | """ 18 | # print(res) 19 | # res='test' 20 | return res 21 | 22 | def get_target(self, data_point, **kargs) -> str: 23 | # target = 'test' 24 | target = data_point["trg_text"] 25 | return target 26 | 27 | def get_full(self, data_point, **kargs) -> str: 28 | text = self.get_input(data_point) + self.get_target(data_point) 29 | return text -------------------------------------------------------------------------------- /dataset/utils/deepspeed_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | deepspeed_filedict = { 4 | "ds_fp16_zero3_offload": "/opt/tiger/llama/finetune/alpaca-lora/codes/config/ds_fp16_zero3_offload.json", 5 | "ds_bf16_zero3_offload": "/opt/tiger/llama/finetune/alpaca-lora/codes/config/ds_bf16_zero3_offload.json", 6 | "ds_int8_zero3_offload": "/opt/tiger/llama/finetune/alpaca-lora/codes/config/ds_int8_zero3_offload.json", 7 | "ds_fp16_zero2": "/opt/tiger/llama/finetune/alpaca-lora/codes/config/ds_fp16_zero2.json", 8 | "ds_int8_zero2": "/opt/tiger/llama/finetune/alpaca-lora/codes/config/ds_int8_zero2.json", 9 | } 10 | 11 | 12 | def get_deepspeed_config_or_file(deepspeed_file_or_dict: str): 13 | deepspeed_config = deepspeed_filedict.get(deepspeed_file_or_dict, deepspeed_file_or_dict) 14 | logging.info(f"load deepspeed file from {deepspeed_config}") 15 | return deepspeed_config -------------------------------------------------------------------------------- /dataset/utils/print_utils.py: -------------------------------------------------------------------------------- 1 | from deepspeed.utils import logger, log_dist, instrument_w_nvtx 2 | import torch 3 | 4 | 5 | def print_rank_0(message): 6 | """If distributed is initialized, print only on rank 0.""" 7 | log_dist(message, ranks=[0]) 8 | 9 | 10 | def print_model_parameters(model): 11 | total_trainable_param, total_nontrainable_param = 0, 0 12 | for name, param in model.named_parameters(): 13 | if param.requires_grad: 14 | total_trainable_param += param.numel() 15 | print_rank_0(f"{name},\t{param.data.shape}") 16 | else: 17 | total_nontrainable_param += param.numel() 18 | print_rank_0(f"total_nontrainable_param = {total_nontrainable_param}") 19 | print_rank_0(f"total_trainable_param = {total_trainable_param}") 20 | 21 | 22 | def print_model(model): 23 | print_rank_0(f"\nmodel is = \n\n{model}") 24 | -------------------------------------------------------------------------------- /dataset/prompt_maker/scienceqa_prompt_maker.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from .base_prompt_maker import BasePromptMaker 4 | import random 5 | 6 | 7 | class PromptMaker(BasePromptMaker): 8 | 9 | def get_input(self, data_point, **kargs) -> str: 10 | choise='' 11 | if data_point['choices']: 12 | choise=data_point['choices'] 13 | choise='\n'.join([c+'.' for c in choise]) 14 | res = f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. 15 | 16 | ### Instruction: 17 | {data_point["question"]} 18 | 19 | {choise} 20 | 21 | ### Response: 22 | """ 23 | # print(res) 24 | # res='test' 25 | return res 26 | 27 | def get_target(self, data_point, **kargs) -> str: 28 | # target = 'test' 29 | target = data_point["solution"] 30 | return target 31 | 32 | def get_full(self, data_point, **kargs) -> str: 33 | text = self.get_input(data_point) + self.get_target(data_point) 34 | return text -------------------------------------------------------------------------------- /dataset/prompt_maker/custom_prompt_maker_inference.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from .base_prompt_maker import BasePromptMaker 4 | import random 5 | 6 | 7 | class PromptMaker(BasePromptMaker): 8 | 9 | def get_input(self, data_point, **kargs) -> str: 10 | choise='' 11 | if data_point['choices']: 12 | choise=data_point['choices'] 13 | choise='\n'.join([c+'.' for c in choise]) 14 | res = f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. 15 | 16 | ### Instruction: 17 | {data_point["question"]} 18 | 19 | {choise} 20 | 21 | ### Response: 22 | """ 23 | # print(res) 24 | # res='test' 25 | return res 26 | 27 | def get_target(self, data_point, **kargs) -> str: 28 | # target = 'test' 29 | target = data_point["solution"] 30 | return target 31 | 32 | def get_full(self, data_point, **kargs) -> str: 33 | text = self.get_input(data_point) + self.get_target(data_point) 34 | return text -------------------------------------------------------------------------------- /dataset/prompt_maker/alpaca_prompt_maker.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from .base_prompt_maker import BasePromptMaker 4 | 5 | class PromptMaker(BasePromptMaker): 6 | 7 | def get_input(self, data_point, **kargs) -> str: 8 | if data_point["input"]: 9 | return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. 10 | 11 | ### Instruction: 12 | {data_point["instruction"]} 13 | 14 | ### Input: 15 | {data_point["input"]} 16 | 17 | ### Response: 18 | """ 19 | else: 20 | return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. 21 | 22 | ### Instruction: 23 | {data_point["instruction"]} 24 | 25 | ### Response: 26 | """ 27 | 28 | def get_target(self, data_point, **kargs) -> str: 29 | target = data_point["output"] 30 | return target 31 | 32 | def get_full(self, data_point, **kargs) -> str: 33 | text = self.get_input(data_point) + self.get_target(data_point) 34 | return text -------------------------------------------------------------------------------- /dataset/utils/io_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Union, List, Iterable 3 | from glob import glob 4 | from itertools import chain 5 | 6 | 7 | def load_json( 8 | filenames: Union[str, List[str]], 9 | # return_iter: bool = False 10 | ) -> Union[Iterable, List[str]]: 11 | if isinstance(filenames, str): 12 | return json.load(open(filenames, 'r')) 13 | else: 14 | return list(chain(*[json.load(open(filename, 'r')) for filename in filenames])) 15 | 16 | 17 | def grob_paths( 18 | paths: str 19 | ) -> List[str]: 20 | if paths.startswith("\"") and paths.endswith("\""): 21 | paths = paths[1:-1] 22 | if isinstance(paths, List): 23 | pass 24 | elif isinstance(paths, str): 25 | paths = paths.split(",") 26 | else: 27 | raise ValueError(f"paths should be str or list of str, paths = {paths}") 28 | 29 | gather_paths = [] 30 | for p in paths: 31 | gather_paths.extend(glob(p)) 32 | return gather_paths 33 | 34 | 35 | if __name__ == '__main__': 36 | print(grob_paths("./*")) 37 | print(load_json(grob_paths("../../data/*.json"))) -------------------------------------------------------------------------------- /dataset/prompt_maker/contrastive_translate_prompt_maker.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from .base_prompt_maker import BasePromptMaker 4 | import random 5 | import json 6 | 7 | class PromptMaker(BasePromptMaker): 8 | 9 | def get_input(self, data_point, **kargs) -> str: 10 | if data_point['trg_lang']: 11 | trg_lang=data_point['trg_lang'] 12 | res = f"""Translate the following text into {trg_lang} 13 | 14 | Text: 15 | \"{data_point["src_text"]}\" 16 | 17 | """ 18 | # print(res) 19 | # res='test' 20 | return res 21 | 22 | def get_target(self, data_point, **kargs) -> str: 23 | # target = 'test' 24 | target = data_point["trg_text"] 25 | return target 26 | 27 | def get_constrastive_target(self, data_point, path_contrastive_label, **kargs) -> str: 28 | # target = 'test' 29 | target = data_point["trg_text"] 30 | # print(path_contrastive_label) 31 | c_text = json.loads(open(path_contrastive_label).read()) 32 | c_target_text=None 33 | for c in c_text: 34 | if c['src_text'] == data_point['src_text']: 35 | c_target_text = c['trg_text'] 36 | break 37 | if c_target_text is not None: 38 | return c_target_text 39 | else: 40 | raise ValueError(f"Cannot match src_text: {data_point['src_text']}") 41 | 42 | def get_full(self, data_point, **kargs) -> str: 43 | text = self.get_input(data_point) + self.get_target(data_point) 44 | return text -------------------------------------------------------------------------------- /nngeometry/llama_layercollection.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from collections import OrderedDict 3 | from functools import reduce 4 | import operator 5 | from .layercollection import LayerCollection 6 | from transformers import LlamaModel, LlamaForCausalLM, LlamaTokenizer, LlamaPreTrainedModel 7 | 8 | class LLamaLayerCollection(LayerCollection): 9 | def __init__(self, layers=None): 10 | super().__init__(layers) 11 | 12 | # override 13 | def from_model(model, ignore_unsupported_layers=False, ignore_layers=[]): 14 | # print('test') 15 | lc = LayerCollection() 16 | for layer, mod in model.named_modules(): 17 | # print(layer, type(layer)) 18 | flag=False 19 | for l in ignore_layers: 20 | if l in layer: 21 | flag=True 22 | break 23 | if flag: continue 24 | mod_class = mod.__class__.__name__ 25 | if mod_class in LayerCollection._known_modules: 26 | lc.add_layer('%s.%s' % (layer, str(mod)), 27 | LayerCollection._module_to_layer(mod)) 28 | elif not ignore_unsupported_layers: 29 | if len(list(mod.children())) == 0 and len(list(mod.parameters())) > 0: 30 | raise Exception('I do not know what to do with layer ' + str(mod)) 31 | 32 | return lc 33 | 34 | if __name__=='__main__': 35 | model = LlamaForCausalLM.from_pretrained("/home/yarn/Influence/model/llama_2_7b/llama-2-7b-hf") 36 | lc=LLamaLayerCollection.from_model(model, True) 37 | print(lc.layers) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # G-DIG: Towards Gradient-based DIverse and hiGh-quality Instruction Data Selection for Machine Translations 2 | 3 | In this implementation, we use model bigscience/bloom-560m for demonstration. And we use a demo dataset demo.json for demonstration. There are three steps in our proposed high-quality selection method in G-DIG: 4 | 5 | 1. Fine-tune the target model with candidate data $\mathcal{D}_{raw}$ using huggingface compatible model. 6 | 7 | ``` 8 | Save checkpoint to ./checkpoint, which will be used for calculate the Hessian matrix and scoring. 9 | ``` 10 | 11 | 2. Compute the Hessian matrix. 12 | 13 | ``` 14 | ./hessian.sh 15 | 16 | # In this script, demo.json should be replaced by the training data you used to finetune the LLM. 17 | ``` 18 | 19 | 3. run influence function to compute data score. 20 | 21 | ``` 22 | ./if_score.sh 23 | 24 | # In this script, -d demo.json corresponds to the candidate dataset and -q demo.json corresponds to the seed dataset. 25 | ``` 26 | 27 | Finally, use the data score according to Equation (4) in the paper to select high-quality data. 28 | 29 | 30 | ## Data 31 | 32 | We release our selected data (EN->ZH and DE->EN) 33 | 34 | 1. DE->EN training data of size from 1k to 64k available [here](https://drive.google.com/file/d/1aklM0Q7BV14tVZF8isQdqe_PbpEwHEv9/view) 35 | 2. ZH->EN training data to be continue. 36 | 37 | ## Citation 38 | 39 | If this repo was useful to you, please consider citing 40 | 41 | ``` 42 | @article{pan2024g, 43 | title={G-DIG: Towards Gradient-based DIverse and hiGh-quality Instruction Data Selection for Machine Translation}, 44 | author={Pan, Xingyuan and Huang, Luyang and Kang, Liyan and Liu, Zhicheng and Lu, Yu and Cheng, Shanbo}, 45 | journal={arXiv preprint arXiv:2405.12915}, 46 | year={2024} 47 | } 48 | ``` 49 | -------------------------------------------------------------------------------- /nngeometry/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from nngeometry.object.vector import PVector 4 | 5 | 6 | def display_correl(M, axis): 7 | 8 | M = M.get_dense_tensor() 9 | diag = torch.diag(M) 10 | dM = (diag + diag.mean() / 100) **.5 11 | correl = torch.abs(M) / dM.unsqueeze(0) / dM.unsqueeze(1) 12 | 13 | axis.imshow(correl.cpu()) 14 | 15 | 16 | def grad(output, vec, *args, **kwargs): 17 | """ 18 | Computes the gradient of `output` with respect to the `PVector` `vec` 19 | 20 | ..warning This function only works when internally your `vec` has been 21 | created from leaf nodes in the graph (e.g. model parameters) 22 | 23 | :param output: The scalar quantity to be differentiated 24 | :param vec: a `PVector` 25 | :return: a `PVector` of gradients of `output` w.r.t `vec` 26 | """ 27 | if vec.dict_repr is not None: 28 | # map all parameters to a list 29 | params = [] 30 | pos = [] 31 | lenghts = [] 32 | current_pos = 0 33 | for k in vec.dict_repr.keys(): 34 | p = vec.dict_repr[k] 35 | params += list(p) 36 | pos.append(current_pos) 37 | lenghts.append(len(p)) 38 | current_pos = current_pos + len(p) 39 | 40 | grad_list = torch.autograd.grad(output, params, *args, **kwargs) 41 | dict_repr_grad = dict() 42 | 43 | for k, p, l in zip(vec.dict_repr.keys(), pos, lenghts): 44 | if l == 1: 45 | dict_repr_grad[k] = (grad_list[p],) 46 | elif l == 2: 47 | dict_repr_grad[k] = (grad_list[p], grad_list[p+1]) 48 | 49 | return PVector(vec.layer_collection, 50 | dict_repr=dict_repr_grad) 51 | else: 52 | raise RuntimeError('grad only works with the vector is created ' + 53 | 'from leaf nodes in the computation graph') 54 | -------------------------------------------------------------------------------- /nngeometry/object/map.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from abc import ABC, abstractmethod 3 | from .vector import FVector, PVector 4 | 5 | 6 | class AbstractPushForward(ABC): 7 | 8 | @abstractmethod 9 | def __init__(self, generator): 10 | return NotImplementedError 11 | 12 | 13 | class PushForwardDense(AbstractPushForward): 14 | def __init__(self, generator, data=None, examples=None): 15 | self.generator = generator 16 | if data is not None: 17 | self.data = data 18 | else: 19 | self.data = generator.get_jacobian(examples) 20 | 21 | def get_dense_tensor(self): 22 | return self.data 23 | 24 | def mv(self, v): 25 | v_flat = torch.mv(self.data.view(-1, self.data.size(-1)), 26 | v.get_flat_representation()) 27 | v_flat = v_flat.view(self.data.size(0), self.data.size(1)) 28 | return FVector(vector_repr=v_flat) 29 | 30 | 31 | class PushForwardImplicit(AbstractPushForward): 32 | def __init__(self, generator, data=None, examples=None): 33 | self.generator = generator 34 | self.examples = examples 35 | assert data is None 36 | 37 | def mv(self, v): 38 | return self.generator.implicit_Jv(v, self.examples) 39 | 40 | 41 | class PullBackAbstract(ABC): 42 | 43 | @abstractmethod 44 | def __init__(self, generator): 45 | return NotImplementedError 46 | 47 | 48 | class PullBackDense(PullBackAbstract): 49 | def __init__(self, generator, data=None, examples=None): 50 | self.generator = generator 51 | if data is not None: 52 | self.data = data 53 | else: 54 | self.data = generator.get_jacobian(examples) 55 | 56 | def get_dense_tensor(self): 57 | return self.data 58 | 59 | def mv(self, v): 60 | v_flat = torch.mv(self.data.view(-1, self.data.size(-1)).t(), 61 | v.get_flat_representation().view(-1)) 62 | return PVector(self.generator.layer_collection, vector_repr=v_flat) 63 | -------------------------------------------------------------------------------- /dataset/debug.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 6, 6 | "metadata": { 7 | "tags": [] 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import pandas as pd\n", 12 | "import json" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "metadata": { 19 | "tags": [] 20 | }, 21 | "outputs": [], 22 | "source": [] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 7, 27 | "metadata": { 28 | "tags": [] 29 | }, 30 | "outputs": [], 31 | "source": [ 32 | "df = pd.read_parquet('/home/yarn/Influence/datasets/raw/alpaca.parquet')\n", 33 | "records = json.loads(df.to_json(orient = \"records\"))" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 11, 39 | "metadata": { 40 | "tags": [] 41 | }, 42 | "outputs": [ 43 | { 44 | "data": { 45 | "text/plain": [ 46 | "{'instruction': 'Give three tips for staying healthy.',\n", 47 | " 'input': '',\n", 48 | " 'output': '1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \\n2. Exercise regularly to keep your body active and strong. \\n3. Get enough sleep and maintain a consistent sleep schedule.',\n", 49 | " 'text': 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\\n\\n### Instruction:\\nGive three tips for staying healthy.\\n\\n### Response:\\n1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \\n2. Exercise regularly to keep your body active and strong. \\n3. Get enough sleep and maintain a consistent sleep schedule.'}" 50 | ] 51 | }, 52 | "execution_count": 11, 53 | "metadata": {}, 54 | "output_type": "execute_result" 55 | } 56 | ], 57 | "source": [ 58 | "records[0]" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [] 67 | } 68 | ], 69 | "metadata": { 70 | "kernelspec": { 71 | "name": "1526533" 72 | }, 73 | "language_info": { 74 | "codemirror_mode": { 75 | "name": "ipython", 76 | "version": 3 77 | }, 78 | "file_extension": ".py", 79 | "mimetype": "text/x-python", 80 | "name": "python", 81 | "nbconvert_exporter": "python", 82 | "pygments_lexer": "ipython3", 83 | "version": "3.7.3" 84 | } 85 | }, 86 | "nbformat": 4, 87 | "nbformat_minor": 4 88 | } 89 | -------------------------------------------------------------------------------- /nngeometry/object/fspace.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from abc import ABC, abstractmethod 3 | from .vector import FVector, PVector 4 | 5 | 6 | class FMatAbstract(ABC): 7 | 8 | @abstractmethod 9 | def __init__(self, generator): 10 | return NotImplementedError 11 | 12 | 13 | class FMatDense(FMatAbstract): 14 | def __init__(self, generator, data=None, examples=None): 15 | self.generator = generator 16 | if data is not None: 17 | self.data = data 18 | else: 19 | self.data = generator.get_gram_matrix(examples) 20 | 21 | def compute_eigendecomposition(self, impl='eigh'): 22 | s = self.data.size() 23 | M = self.data.view(s[0] * s[1], s[2] * s[3]) 24 | if impl == 'eigh': 25 | self.evals, self.evecs = torch.linalg.eigh(M) 26 | elif impl == 'svd': 27 | _, self.evals, self.evecs = torch.svd(M, some=False) 28 | else: 29 | raise NotImplementedError 30 | 31 | def mv(self, v): 32 | # TODO: test 33 | v_flat = torch.mv(self.data, v.get_flat_representation()) 34 | return FVector(vector_repr=v_flat) 35 | 36 | def vTMv(self, v): 37 | v_flat = v.get_flat_representation().view(-1) 38 | sd = self.data.size() 39 | return torch.dot(v_flat, 40 | torch.mv(self.data.view(sd[0]*sd[1], sd[2]*sd[3]), 41 | v_flat)) 42 | 43 | def frobenius_norm(self): 44 | return torch.norm(self.data) 45 | 46 | def project_to_diag(self, v): 47 | # TODO: test 48 | return PVector(model=v.model, 49 | vector_repr=torch.mv(self.evecs.t(), 50 | v.get_flat_representation())) 51 | 52 | def project_from_diag(self, v): 53 | # TODO: test 54 | return PVector(model=v.model, 55 | vector_repr=torch.mv(self.evecs, 56 | v.get_flat_representation())) 57 | 58 | def get_eigendecomposition(self): 59 | # TODO: test 60 | return self.evals, self.evecs 61 | 62 | def size(self, *args): 63 | # TODO: test 64 | return self.data.size(*args) 65 | 66 | def trace(self): 67 | # TODO: test 68 | return torch.trace(self.data) 69 | 70 | def get_dense_tensor(self): 71 | return self.data 72 | 73 | def __add__(self, other): 74 | # TODO: test 75 | sum_data = self.data + other.data 76 | return FMatDense(generator=self.generator, 77 | data=sum_data) 78 | 79 | def __sub__(self, other): 80 | # TODO: test 81 | sub_data = self.data - other.data 82 | return FMatDense(generator=self.generator, 83 | data=sub_data) 84 | -------------------------------------------------------------------------------- /dataset/data/dataset/prompt_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union 2 | from collections.abc import Mapping 3 | from torch.utils.data import Dataset 4 | import torch 5 | import logging 6 | import json 7 | from copy import deepcopy 8 | import random 9 | from dataset.data.dataset.base_dataset import BaseDataset 10 | from dataset.utils.io_utils import load_json 11 | import os 12 | from dataset.utils.io_utils import grob_paths 13 | import torch 14 | 15 | 16 | InputDataClass = NewType("InputDataClass", Any) 17 | 18 | 19 | class DynamicPromptDataset(BaseDataset): 20 | """Dynamic prompt making dataset.""" 21 | 22 | def __init__(self, args, 23 | json_data: Union[os.PathLike, List[Dict]], 24 | static_transform: Callable = None, 25 | dynamic_transform: Callable = None, 26 | shuffle: bool = False, 27 | from_file: bool = False, 28 | ): 29 | """ 30 | Arguments: 31 | json_data (List): Path to the csv file with annotations. 32 | root_dir (string): Directory with all the images. 33 | static_transform (callable): Optional transform to be applied on a sample only once. 34 | dynamic_transform (callable, optional): Optional transform to be applied on a sample on the fly. 35 | """ 36 | if from_file: 37 | data = load_json(grob_paths(json_data)) 38 | else: 39 | data = deepcopy(json_data) 40 | 41 | if shuffle: 42 | random.shuffle(data) 43 | self.data = data 44 | 45 | if static_transform is not None: 46 | self.data = list(map(lambda t: static_transform(t), self.data)) 47 | 48 | self.dynamic_transform = dynamic_transform 49 | logging.info(f"data = {self.data}") 50 | 51 | def __len__(self): 52 | return len(self.data) 53 | 54 | def __getitem__(self, idx, to_tensor=True): 55 | if torch.is_tensor(idx): 56 | idx = idx.tolist() 57 | sample = self.data[idx] 58 | if self.dynamic_transform: 59 | sample = self.dynamic_transform(sample) 60 | 61 | if to_tensor: 62 | for k, v in sample.items(): 63 | sample[k] = torch.tensor(v) 64 | # print(sample) 65 | return sample 66 | 67 | 68 | class COAIDynamicPromptDataset(DynamicPromptDataset): 69 | def __getitem__(self, idx, to_tensor=True): 70 | sample = super().__getitem__(idx=idx, to_tensor=to_tensor) 71 | input_ids = sample["input_ids"][:-1] 72 | labels = sample["labels"][1:] 73 | attention_mask = sample["attention_mask"][1:] 74 | return {'input_ids': input_ids, 'attention_mask': attention_mask}, {'labels': labels} 75 | 76 | 77 | if __name__ == '__main__': 78 | data_file = "/opt/tiger/llama/finetune/alpaca-lora/codes/data/alpaca_data_cleaned.json" 79 | train_data = DynamicPromptDataset( 80 | json_data=data_file, 81 | dynamic_transform=lambda t: t, 82 | shuffle=True, 83 | from_file=True 84 | ) 85 | 86 | for data in train_data: 87 | print(data) -------------------------------------------------------------------------------- /dataset/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import logging 3 | import copy 4 | 5 | logger=logging.getLogger() 6 | 7 | 8 | def generate_and_tokenize_prompt( 9 | data_point, 10 | args = None, 11 | tokenizer = None, 12 | prompt_maker = None, 13 | use_prompt_labels = True, 14 | padding: bool = False, 15 | truncation: bool = True, 16 | verbose: bool = True, 17 | ignore_loss_idx: int = -100, 18 | ): 19 | assert prompt_maker is not None, "please provide prompt_maker" 20 | input_text = prompt_maker.get_input(data_point) 21 | target_text = prompt_maker.get_target(data_point) 22 | full_text = input_text + target_text 23 | 24 | user_prompt = tokenizer( 25 | input_text, 26 | truncation=True, 27 | max_length=args.max_length + 1, 28 | )["input_ids"][:-1] # no eos token 29 | 30 | # -------- 31 | user_prompt = tokenizer( 32 | input_text, 33 | truncation=True, 34 | max_length=args.max_length + 1, 35 | )["input_ids"] 36 | if user_prompt[-1]==tokenizer.eos_token_id: 37 | user_prompt=user_prompt[:-1] 38 | else: 39 | user_prompt=user_prompt[1:] 40 | # -------- 41 | len_user_prompt_tokens = len(user_prompt) 42 | len_user_prompt_tokens = min(len_user_prompt_tokens, args.max_length) 43 | 44 | full_tokens = tokenizer( 45 | full_text, 46 | truncation=truncation, 47 | max_length=args.max_length 48 | )["input_ids"] 49 | # --------- 50 | if full_tokens[-1] != tokenizer.eos_token_id: 51 | full_tokens=full_tokens+[tokenizer.eos_token_id] 52 | # --------- 53 | attention_mask = [1] * len(full_tokens) 54 | 55 | if args.use_prompt_loss: 56 | labels = copy.deepcopy(full_tokens) 57 | else: 58 | labels = [ignore_loss_idx] * len_user_prompt_tokens + full_tokens[len_user_prompt_tokens:] 59 | 60 | ## deal with padding 61 | if padding: 62 | padded_length = args.max_length - len(full_tokens) 63 | full_tokens.extend([tokenizer.pad_token_id] * padded_length) 64 | labels.extend([ignore_loss_idx] * padded_length) 65 | attention_mask = attention_mask + [0] * padded_length 66 | 67 | if verbose and (random.random() <= args.prob_data_display): 68 | logger.info(f"""### random data case: 69 | batch length = {len(full_tokens)} 70 | (P) prompt = {[input_text]} 71 | (PT) prompt_and_target = {[full_text]} 72 | (PT) tokenized = {full_tokens} 73 | (PT) attention_mask = {attention_mask} 74 | (PT) labels = {labels} 75 | """) 76 | # print(len(full_tokens)) 77 | ## deal with prompt or not (w.r.t. pretrain and instruction tuning) 78 | if use_prompt_labels: 79 | # This function masks out the labels for the input, 80 | # so that our loss is computed only on the response. 81 | return { 82 | "input_ids": full_tokens, 83 | "attention_mask": attention_mask, 84 | "labels": labels, 85 | } 86 | else: 87 | return { 88 | "input_ids": full_tokens, 89 | "attention_mask": attention_mask, 90 | } -------------------------------------------------------------------------------- /nngeometry/layers.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | from torch.nn import Linear, Conv2d, Module, init 3 | from torch.nn import functional as F 4 | from torch.nn.parameter import Parameter 5 | import torch 6 | 7 | class Cosine1d(Linear): 8 | """Computes the cosine similarity between rows of the weight matrix 9 | and the incoming data 10 | """ 11 | 12 | def __init__(self, in_features: int, out_features: int, eps=1e-05) -> None: 13 | super(Cosine1d, self).__init__(in_features=in_features, 14 | out_features=out_features, 15 | bias=False) 16 | self.eps = eps 17 | 18 | def forward(self, input: Tensor) -> Tensor: 19 | norm2_w = (self.weight**2).sum(dim=1, keepdim=True) + self.eps 20 | norm2_x = (input**2).sum(dim=1, keepdim=True) + self.eps 21 | return F.linear(input / torch.sqrt(norm2_x), 22 | self.weight / torch.sqrt(norm2_w)) 23 | 24 | 25 | class WeightNorm1d(Linear): 26 | """Computes an affine mapping of the incoming data using a weight matrix 27 | with rows normalized with norm 1 28 | """ 29 | 30 | def __init__(self, in_features: int, out_features: int, eps=1e-05) -> None: 31 | super(WeightNorm1d, self).__init__(in_features=in_features, 32 | out_features=out_features, 33 | bias=False) 34 | self.eps = eps 35 | 36 | def forward(self, input: Tensor) -> Tensor: 37 | norm2 = (self.weight**2).sum(dim=1, keepdim=True) + self.eps 38 | return F.linear(input, 39 | self.weight / torch.sqrt(norm2)) 40 | 41 | 42 | class WeightNorm2d(Conv2d): 43 | """Computes a 2d convolution using a kernel weight matrix 44 | with rows normalized with norm 1 45 | """ 46 | 47 | def __init__(self, *args, eps=1e-05, **kwargs) -> None: 48 | assert 'bias' not in kwargs or kwargs['bias'] is False 49 | super(WeightNorm2d, self).__init__(*args, bias=False, **kwargs) 50 | self.eps = eps 51 | 52 | def forward(self, input: Tensor) -> Tensor: 53 | norm2 = (self.weight**2).sum(dim=(1, 2, 3), keepdim=True) + self.eps 54 | return self._conv_forward(input, self.weight / torch.sqrt(norm2), 55 | None) 56 | 57 | 58 | class Affine1d(Module): 59 | """Computes the transformation out = weight * input + bias 60 | where * is the elementwise multiplication. This is similar to the 61 | scaling and translation given by parameters gamma and beta in batch norm 62 | 63 | """ 64 | def __init__(self, num_features: int, bias: bool = True, 65 | device=None, dtype=None) -> None: 66 | factory_kwargs = {'device': device, 'dtype': dtype} 67 | super(Affine1d, self).__init__() 68 | self.num_features = num_features 69 | self.weight = Parameter(torch.empty(num_features, **factory_kwargs)) 70 | if bias: 71 | self.bias = Parameter(torch.empty(num_features, **factory_kwargs)) 72 | else: 73 | self.register_parameter('bias', None) 74 | self.reset_parameters() 75 | 76 | def reset_parameters(self) -> None: 77 | init.ones_(self.weight) 78 | if self.bias is not None: 79 | init.zeros_(self.bias) 80 | 81 | def forward(self, input: Tensor) -> Tensor: 82 | if self.bias is not None: 83 | return input * self.weight.unsqueeze(0) + self.bias 84 | else: 85 | return input * self.weight.unsqueeze(0) 86 | 87 | def extra_repr(self) -> str: 88 | return 'num_features={}, bias={}'.format( 89 | self.num_features, self.bias is not None 90 | ) -------------------------------------------------------------------------------- /dataset/data/dataset/stream_prompt_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union 2 | from torch.utils.data import Dataset, IterableDataset 3 | from datasets import load_dataset 4 | import logging 5 | import os 6 | from dataset.utils.io_utils import grob_paths 7 | 8 | 9 | class StreamDynamicPromptDataset(IterableDataset): 10 | """ 11 | stream for large scale data 12 | """ 13 | 14 | def __init__(self, args, 15 | json_data: Union[os.PathLike, List[Dict]], 16 | static_transform: Callable = None, 17 | dynamic_transform: Callable = None, 18 | shuffle: bool = True, 19 | from_file: bool = True, 20 | ): 21 | """ 22 | json_data: json filenames splited by ",": i.e., /opt/tiger/json_data1/*,/opt/tiger/json_data2/* 23 | For shard/iterable dataset, static transformation is not supported. 24 | """ 25 | assert from_file == True, "for StreamDynamicPromptDataset, the json data should be from file" 26 | logging.warning(f"static_transform is deprecated for our StreamDynamicPromptDataset.") 27 | 28 | json_filenames = json_data 29 | if from_file: 30 | self.data_files = grob_paths(json_filenames) 31 | else: 32 | self.data_files = json_filenames 33 | 34 | self.dataiter = load_dataset( 35 | "json", 36 | data_files=self.data_files, 37 | split="train", 38 | streaming=True, 39 | keep_in_memory=True 40 | ) 41 | 42 | if shuffle: 43 | self.dataiter = self.dataiter.shuffle(buffer_size=vars(args).get("buffer_size", -1), seed=args.seed) 44 | 45 | if dynamic_transform: 46 | self.dataiter = self.dataiter.map(lambda t: dynamic_transform(t)) 47 | 48 | logging.info(f"loading from {self.n_files} file: {self.data_files}") 49 | 50 | @property 51 | def n_files(self): 52 | return len(self.data_files) 53 | 54 | def __iter__(self): 55 | return iter(self.dataiter) 56 | 57 | 58 | class COAIStreamDynamicPromptDataset(StreamDynamicPromptDataset): 59 | 60 | def __init__(self, args, 61 | json_data: Union[os.PathLike, List[Dict]], 62 | static_transform: Callable = None, 63 | dynamic_transform: Callable = None, 64 | shuffle: bool = True, 65 | from_file: bool = True, 66 | ): 67 | super().__init__( 68 | args=args, 69 | json_data=json_data, 70 | static_transform=static_transform, 71 | dynamic_transform=None, 72 | shuffle=shuffle, 73 | from_file=from_file 74 | ) 75 | if dynamic_transform: 76 | def coai_transform(t): 77 | sample = dynamic_transform(t) 78 | input_ids = sample["input_ids"][:-1] 79 | labels = sample["labels"][1:] 80 | attention_mask = sample["attention_mask"][1:] 81 | return {'input_ids': input_ids, 'attention_mask': attention_mask}, {'labels': labels} 82 | self.dataiter = self.dataiter.map(lambda t: coai_transform(t)) 83 | 84 | 85 | if __name__ == '__main__': 86 | data_file = "/opt/tiger/llama/finetune/alpaca-lora/codes/data/alpaca_data_cleaned.json" 87 | data_files = [data_file] 88 | for i in range(1000): 89 | os.system(f"cp {data_file} /opt/tiger/json_data/{i}.json") 90 | data_files.append(f"/opt/tiger/json_data/{i}.json") 91 | 92 | shuffled_iterable_dataset = StreamDynamicPromptDataset( 93 | data_files, 94 | shuffle=True, 95 | buffer_size=100, 96 | dynamic_transform=lambda t: {"instruction": "1" + t["instruction"]}, 97 | ) 98 | for i, example in enumerate(shuffled_iterable_dataset): # as fast as before 99 | print(example) 100 | -------------------------------------------------------------------------------- /kfac_mapper.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | import torch 3 | from nngeometry.object import PMatKFAC 4 | import nngeometry 5 | from nngeometry import lm_metrics, llama_layercollection 6 | from torch.utils.data import Subset, DataLoader 7 | from transformers import AutoTokenizer, AutoModelForCausalLM 8 | import argparse 9 | import pickle 10 | from dataset.data.json_data import get_json_train_valid_data, generate_and_tokenize_prompt 11 | from functools import partial 12 | from dataset.prompt_maker.contrastive_translate_prompt_maker import PromptMaker 13 | 14 | # bloom_ignore_lc=['atte', 'lm_head', '0', '1', '2', '3', '21', '22', '23'] 15 | bloom_ignore_lc=['atte', 'lm_head', 'dense_4h_to_h'] 16 | # llama_ignore_lc=['attn', 'lm_head', 'up', 'gate', '0', '1', '.2.', '3', '4', '9', '21', '22', '23'] 17 | # baichuan_ignore_lc=['attn', 'lm_head', 'up', 'gate', '0', '1', '.2.', '3', '4', '9', '21', '22', '23'] 18 | baichuan_ignore_lc=['.0.', '.1.', '.2.', '.4.', '.5.', '.7.', '.8.', '.10.', '.11.', '.13.', '.14.', '.16.', '.17.', '.19.', '.20.', '.22.', '.23.', '.25.', '.26.', '.28.', '.29.', '30', 'attn', 'up', 'gate', 'lm_head'] 19 | llama_ignore_lc=['.0.', '.1.', '.2.', '.4.', '.5.', '.7.', '.8.', '.10.', '.11.', '.13.', '.14.', '.16.', '.17.', '.19.', '.20.', '.22.', '.23.', '.25.', '.26.', '.28.', '.29.', '30', 'attn', 'up', 'gate', 'lm_head'] 20 | 21 | # Specify here which layer you do not want to compute the Fisher information matrix 22 | CUR_LC=bloom_ignore_lc 23 | 24 | def main(): 25 | parser = argparse.ArgumentParser( 26 | prog='ProgramName', 27 | description='What the program does', 28 | epilog='Text at the bottom of help') 29 | 30 | parser.add_argument('-v', '--device') # device e.g. 'cuda:1' 31 | parser.add_argument('-idstart', '--indexestart', type=int) 32 | parser.add_argument('-idend', '--indexesend', type=int) 33 | parser.add_argument('-d', '--data_path') 34 | parser.add_argument('-o', '--output') 35 | parser.add_argument('-m', '--model') 36 | parser.add_argument('-t', '--trials', type=int, default=100) 37 | parser.add_argument('-k', '--tokenizer', type=str) 38 | 39 | 40 | args = parser.parse_args() 41 | 42 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=True) 43 | model = AutoModelForCausalLM.from_pretrained(args.model, trust_remote_code=True) 44 | 45 | if CUR_LC==baichuan_ignore_lc: 46 | tokenizer.pad_token = tokenizer.eos_token 47 | 48 | device=torch.device(args.device) 49 | model.to(device) 50 | 51 | class tmpargs: 52 | max_length=256 53 | use_prompt_loss=False 54 | prob_data_display=0.1 55 | data_path=args.data_path 56 | valid_data_path=None 57 | use_large_data=False 58 | val_set_size=None 59 | micro_batch_size=1 60 | tokenizer=args.tokenizer 61 | seed=1 62 | train_data, val_data = get_json_train_valid_data( 63 | args=tmpargs, 64 | data_file=tmpargs.data_path, 65 | valid_data_file=tmpargs.valid_data_path, 66 | val_set_size=tmpargs.val_set_size, 67 | prompt_fn=partial(generate_and_tokenize_prompt, args=tmpargs, verbose=False, tokenizer=tokenizer, prompt_maker=PromptMaker(args=tmpargs)), 68 | ) 69 | trainset=train_data 70 | 71 | ids=list(range(args.indexestart, args.indexesend+1)) 72 | print(len(ids)) 73 | sub_dataset=torch.utils.data.Subset(trainset, ids) 74 | 75 | trainloader = DataLoader( 76 | sub_dataset, 77 | shuffle=True, 78 | collate_fn=transformers.DataCollatorForSeq2Seq(tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True), 79 | batch_size=1, 80 | pin_memory=False, 81 | drop_last=True 82 | ) 83 | 84 | lc=llama_layercollection.LLamaLayerCollection.from_model(model, True, CUR_LC) 85 | print(f'parameter num: {lc.numel()}') 86 | F_kfac = lm_metrics.FIM(model=model, 87 | loader=trainloader, 88 | representation=PMatKFAC, 89 | n_output=args.trials, 90 | variant='empirical_fisher', 91 | device=device, layer_collection=lc) 92 | 93 | with open(args.output, 'wb') as f: 94 | pickle.dump(F_kfac, f) 95 | 96 | if __name__=='__main__': 97 | main() 98 | -------------------------------------------------------------------------------- /kfac_launcher.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoTokenizer 3 | import argparse 4 | import pickle 5 | import numpy as np 6 | import subprocess 7 | import datetime 8 | import time 9 | from dataset.data.json_data import get_json_train_valid_data, generate_and_tokenize_prompt 10 | from functools import partial 11 | from dataset.prompt_maker.contrastive_translate_prompt_maker import PromptMaker 12 | 13 | def to_device(kfac, device='cpu'): 14 | for key in kfac.data.keys(): 15 | kfac.data[key]=(kfac.data[key][0].to(device), kfac.data[key][1].to(device)) 16 | return kfac 17 | 18 | def main(): 19 | parser = argparse.ArgumentParser( 20 | prog='ProgramName', 21 | description='What the program does', 22 | epilog='Text at the bottom of help') 23 | 24 | parser.add_argument('-g', '--device') # device e.g. 'cuda:1' 25 | parser.add_argument('-n', '--nsubsets', type=int) # option that takes a value 26 | parser.add_argument('-d', '--data_path') # on/off flag 27 | parser.add_argument('-o', '--output') # on/off flag 28 | parser.add_argument('-m', '--model', type=str, default='bigscience/bloom-560m') # on/off flag 29 | parser.add_argument('-t', '--trials', type=int, default=1) # on/off flag 30 | parser.add_argument('-k', '--tokenizer', type=str) # on/off flag 31 | 32 | args = parser.parse_args() 33 | 34 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=True) 35 | 36 | class tmpargs: 37 | max_length=256 38 | use_prompt_loss=False 39 | prob_data_display=0.1 40 | data_path=args.data_path 41 | valid_data_path=None 42 | use_large_data=False 43 | val_set_size=None 44 | micro_batch_size=4 45 | tokenizer=args.tokenizer 46 | seed=1 47 | train_data, val_data = get_json_train_valid_data( 48 | args=tmpargs, 49 | data_file=tmpargs.data_path, 50 | valid_data_file=tmpargs.valid_data_path, 51 | val_set_size=tmpargs.val_set_size, 52 | prompt_fn=partial(generate_and_tokenize_prompt, args=tmpargs, verbose=False, tokenizer=tokenizer, prompt_maker=PromptMaker(args=tmpargs)), 53 | ) 54 | trainset=train_data 55 | 56 | def split_array(arr, k): 57 | chunks = np.array_split(arr, k) 58 | indices = [] 59 | for chunk in chunks: 60 | start_index = arr.tolist().index(chunk[0]) 61 | end_index = start_index + len(chunk) - 1 62 | indices.append((start_index, end_index)) 63 | return indices 64 | indices=split_array(np.arange(len(trainset)), args.nsubsets) 65 | 66 | filenames=[str(datetime.datetime.timestamp(datetime.datetime.now()))+'.kfac' for i in range(args.nsubsets)] 67 | 68 | st=time.time() 69 | childlist=[] 70 | for idx, (fn, subset) in enumerate(zip(filenames, indices)): 71 | child = subprocess.Popen(['python3', 'kfac_mapper.py', '-v', f'cuda:{idx}', \ 72 | '-idstart', f'{subset[0]}', '-idend', f'{subset[1]}', \ 73 | '-o', f'{fn}', '-m', f'{args.model}', \ 74 | '-t', f'{args.trials}', '-d', f'{args.data_path}', '-k', f'{args.tokenizer}']) 75 | childlist.append(child) 76 | 77 | for child in childlist: 78 | print(f'Mapper process {child.pid} is running') 79 | 80 | while True: 81 | flag=True 82 | print('Checking status...') 83 | for child in childlist: 84 | if child.poll() is None: 85 | flag=False 86 | else: 87 | print(f'Mapper process {child.pid} finished') 88 | if flag: 89 | break 90 | time.sleep(5) 91 | 92 | elapse=time.time()-st 93 | print(f'Mapper processes all finished. Use {elapse}s.') 94 | print('Reducing...') 95 | 96 | 97 | device=torch.device(args.device) 98 | 99 | kfac_list=[] 100 | for fn in filenames: 101 | with open(fn, 'rb') as f: 102 | pF = pickle.load(f) 103 | kfac_list.append(pF) 104 | reduce=kfac_list[0] 105 | for kfac in kfac_list[1:]: 106 | for key in kfac.data.keys(): 107 | reduce.data[key][0].add_(kfac.data[key][0].to(device)) 108 | reduce.data[key][1].add_(kfac.data[key][1].to(device)) 109 | 110 | for key in reduce.data.keys(): 111 | reduce.data[key][0].div_(args.nsubsets) 112 | reduce.data[key][1].div_(args.nsubsets) 113 | 114 | reduce=to_device(reduce, 'cpu') 115 | with open(args.output, 'wb') as f: 116 | pickle.dump(reduce, f) 117 | 118 | print('Finish reduce') 119 | 120 | 121 | if __name__=='__main__': 122 | main() -------------------------------------------------------------------------------- /query_loss_launcher.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | from transformers import AutoTokenizer 4 | import argparse 5 | import numpy as np 6 | import subprocess 7 | import datetime 8 | import time 9 | from dataset.data.json_data import get_json_train_valid_data, generate_and_tokenize_prompt 10 | from functools import partial 11 | from dataset.prompt_maker.contrastive_translate_prompt_maker import PromptMaker 12 | 13 | def main(): 14 | parser = argparse.ArgumentParser( 15 | prog='ProgramName', 16 | description='What the program does', 17 | epilog='Text at the bottom of help') 18 | 19 | parser.add_argument('-n', '--nsubsets', type=int) 20 | parser.add_argument('-k', '--kfac') 21 | parser.add_argument('-m', '--model', type=str, default='bigscience/bloom-560m') 22 | parser.add_argument('-t', '--tokenizer') 23 | parser.add_argument('-l', '--limit', type=int, default=-1) 24 | parser.add_argument('-lq', '--limit_query', type=int, default=-1) 25 | parser.add_argument('-o', '--output', type=str, default='res.jsonl') 26 | parser.add_argument('-d', '--data_path') 27 | parser.add_argument('-q', '--query_path') 28 | parser.add_argument('-bq', '--batch_query', type=int, default=16) 29 | parser.add_argument('-lmd', '--lambdaa', type=float, default=0.5) 30 | parser.add_argument('--full-score', default=0, type=int) 31 | parser.add_argument('--ekfac', default=0, type=int) 32 | parser.add_argument('--start', default=-1, type=int) 33 | parser.add_argument('--end', default=-1, type=int) 34 | parser.add_argument('--layer', default='b', type=str) 35 | 36 | args = parser.parse_args() 37 | 38 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=True) 39 | 40 | class tmpargs: 41 | max_length=256 42 | use_prompt_loss=False 43 | prob_data_display=0.1 44 | data_path=args.data_path 45 | valid_data_path=None 46 | use_large_data=False 47 | val_set_size=None 48 | micro_batch_size=4 49 | tokenizer=args.tokenizer 50 | seed=1 51 | train_data, val_data = get_json_train_valid_data( 52 | args=tmpargs, 53 | data_file=tmpargs.data_path, 54 | valid_data_file=tmpargs.valid_data_path, 55 | val_set_size=tmpargs.val_set_size, 56 | prompt_fn=partial(generate_and_tokenize_prompt, args=tmpargs, verbose=False, tokenizer=tokenizer, prompt_maker=PromptMaker(args=tmpargs)), 57 | ) 58 | if args.limit>0: 59 | sub_dataset=torch.utils.data.Subset(train_data, range(args.limit)) 60 | elif args.start >= 0 and args.end >= 0: 61 | ids=list(range(args.start, args.end+1)) 62 | print(len(ids)) 63 | sub_dataset=torch.utils.data.Subset(train_data, ids) 64 | else: 65 | sub_dataset=train_data 66 | 67 | trainset=sub_dataset 68 | 69 | def split_array(arr, k): 70 | chunks = np.array_split(arr, k) 71 | indices = [] 72 | for chunk in chunks: 73 | start_index = arr.tolist().index(chunk[0]) 74 | end_index = start_index + len(chunk) - 1 75 | indices.append((start_index, end_index)) 76 | return indices 77 | indices=split_array(np.arange(len(trainset)), args.nsubsets) 78 | 79 | filenames=[str(datetime.datetime.timestamp(datetime.datetime.now()))+'.json' for i in range(args.nsubsets)] 80 | 81 | st=time.time() 82 | childlist=[] 83 | for idx, (fn, subset) in enumerate(zip(filenames, indices)): 84 | child = subprocess.Popen(['python3', 'query_loss_mapper.py', '-v', f'cuda:{idx}', \ 85 | '-idstart', f'{subset[0]}', '-idend', f'{subset[1]}', \ 86 | '-o', f'{fn}', '-m', f'{args.model}', '-l', f'{args.limit}', \ 87 | '-t', f'{args.tokenizer}', '-d', f'{args.data_path}', \ 88 | '-bq', f'{args.batch_query}', '-q', f'{args.query_path}', \ 89 | '-k', f'{args.kfac}', '-lq', f'{args.limit_query}', '-lmd', f'{args.lambdaa}', \ 90 | '--full-score', f'{args.full_score}', '--ekfac', f'{args.ekfac}', \ 91 | '--layer', f'{args.layer}']) 92 | childlist.append(child) 93 | 94 | for child in childlist: 95 | print(f'Mapper process {child.pid} is running') 96 | 97 | while True: 98 | flag=True 99 | print('Checking status...') 100 | for child in childlist: 101 | if child.poll() is None: 102 | flag=False 103 | else: 104 | print(f'Mapper process {child.pid} finished') 105 | if flag: 106 | break 107 | time.sleep(5) 108 | 109 | elapse=time.time()-st 110 | print(f'Mapper processes all finished. Use {elapse}s.') 111 | print('Reducing...') 112 | 113 | 114 | json_list=[] 115 | for fn in filenames: 116 | json_list.extend(json.loads(open(fn).read())) 117 | 118 | res_json=sorted(json_list, key=lambda x: x['score']) 119 | open(args.output, 'w').write(json.dumps(res_json, indent=4, ensure_ascii=False)) 120 | 121 | print('Finish reduce') 122 | 123 | 124 | if __name__=='__main__': 125 | main() -------------------------------------------------------------------------------- /dataset/alpaca.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from torch.utils.data import Dataset, DataLoader 3 | import os 4 | import json 5 | import pandas as pd 6 | from transformers import LlamaModel, LlamaForCausalLM, LlamaTokenizer, AutoTokenizer 7 | import torch 8 | from dataset.prompt_maker import alpaca_prompt_maker 9 | import transformers 10 | from dataset.data.json_data import get_json_train_valid_data, generate_and_tokenize_prompt 11 | from functools import partial 12 | 13 | 14 | class AlpacaDataset(): 15 | def __new__(self, max_length=256, use_prompt_loss=False, limit=64, 16 | prob_data_display=0.1, data_path='/mnt/bn/pxy/data/alpaca_data.json', valid_data_path=None, 17 | use_large_data=False, val_set_size=0, micro_batch_size=4, tokenizer="bigscience/bloom-560m"): 18 | class args: 19 | max_length=256 20 | use_prompt_loss=False 21 | prob_data_display=0.1 22 | data_path='/mnt/bn/pxy/data/alpaca_data.json' 23 | valid_data_path=None 24 | use_large_data=False 25 | val_set_size=None 26 | micro_batch_size=4 27 | tokenizer="YeungNLP/bloomz-396m-zh" 28 | seed=1 29 | 30 | 31 | prompt_maker=alpaca_prompt_maker.PromptMaker() 32 | tokenizer = AutoTokenizer.from_pretrained(tokenizer) 33 | train_data, val_data = get_json_train_valid_data( 34 | args=args, 35 | data_file=args.data_path, 36 | valid_data_file=args.valid_data_path, 37 | val_set_size=args.val_set_size, 38 | prompt_fn=partial(generate_and_tokenize_prompt, args=args, tokenizer=tokenizer, prompt_maker=prompt_maker, verbose=False, use_prompt_labels=True), 39 | ) 40 | return train_data #[:limit] 41 | 42 | 43 | class AlpacaDataset_(Dataset): 44 | def __init__(self, root = 'raw/', tokenizer=None, limit=None, embed_layer=None): 45 | super().__init__() 46 | 47 | path = os.path.join(root, 'alpaca.parquet') 48 | 49 | self.data = [] 50 | self.end_of_text_token = "<|endoftext|>" 51 | 52 | df = pd.read_parquet(path) 53 | records = json.loads(df.to_json(orient = "records")) 54 | 55 | NUM=len(records) 56 | if limit: 57 | NUM=limit 58 | class args: 59 | max_length=256 60 | use_prompt_loss=False 61 | prob_data_display=0.1 62 | 63 | for r in records[:NUM]: 64 | dp=generate_and_tokenize_prompt(r, args=args, prompt_maker=alpaca_prompt_maker.PromptMaker(), tokenizer=tokenizer, padding=True) 65 | for key in dp.keys(): dp[key]=torch.tensor(dp[key]).unsqueeze(0) 66 | self.data.append(dp) 67 | 68 | # if tokenizer: 69 | # self.tokenizer=tokenizer 70 | # tokenizer.pad_token = tokenizer.eos_token 71 | # self.data=self.tokenizer(self.data, return_tensors="pt", padding='longest')['input_ids'] 72 | # # print(self.tokenizer) 73 | # self.tokenized=self.data.to(next(embed_layer.parameters()).device) 74 | 75 | # if embed_layer: 76 | # self.data=embed_layer(self.data.to(next(embed_layer.parameters()).device)) 77 | 78 | def __len__(self): 79 | return len(self.data) 80 | 81 | def __getitem__(self, item): 82 | # if self.tokenizer: 83 | # generated = self.tokenizer.encode(self.data[item]) 84 | # context = torch.tensor([generated]) 85 | # else: 86 | # context=self.data[item] 87 | # return embed, tokens 88 | if hasattr(self, 'tokenized'): 89 | return self.data[item], self.tokenized[item] 90 | else: 91 | print(self.data[item]['input_ids'].shape) 92 | return self.data[item]['input_ids'], self.data[item]['labels'] 93 | 94 | if __name__=='__main__': 95 | tokenizer = AutoTokenizer.from_pretrained("YeungNLP/bloomz-396m-zh") 96 | # ds=AlpacaDataset(root='/mnt/bn/pxy/yarn/Influence/dataset/raw', tokenizer=tokenizer) 97 | # print(ds[0]) 98 | # print(len(ds)) 99 | td=AlpacaDataset() 100 | print(td) 101 | # print(td[0]) 102 | res=torch.utils.data.Subset(td, [1,2,3]) 103 | print(res) 104 | exit() 105 | class args: 106 | max_length=256 107 | use_prompt_loss=False 108 | prob_data_display=0.1 109 | data_path='/mnt/bn/pxy/data/alpaca_data.json' 110 | valid_data_path=None 111 | use_large_data=False 112 | val_set_size=0 113 | micro_batch_size=4 114 | 115 | prompt_maker=alpaca_prompt_maker.PromptMaker() 116 | train_data, val_data = get_json_train_valid_data( 117 | args=args, 118 | data_file=args.data_path, 119 | valid_data_file=args.valid_data_path, 120 | val_set_size=args.val_set_size, 121 | prompt_fn=partial(generate_and_tokenize_prompt, args=args, tokenizer=tokenizer, prompt_maker=prompt_maker, use_prompt_labels=True), 122 | ) 123 | print(type(train_data)) 124 | 125 | train_dataloader = DataLoader( 126 | train_data, 127 | shuffle=True, 128 | collate_fn=transformers.DataCollatorForSeq2Seq(tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True), 129 | batch_size=args.micro_batch_size, 130 | pin_memory=False, 131 | drop_last=True 132 | ) 133 | for d in train_dataloader: 134 | print(d) 135 | break -------------------------------------------------------------------------------- /nngeometry/generator/jacobian/grads_conv.py: -------------------------------------------------------------------------------- 1 | # Author(s): Gaspar Rochette 2 | # License: BSD 3 clause 3 | # These functions are borrowed from https://github.com/owkin/grad-cnns 4 | 5 | import numpy as np 6 | import torch 7 | from torch._C import unify_type_list 8 | import torch.nn.functional as F 9 | 10 | def conv_backward(input, grad_output, in_channels, out_channels, kernel_size, 11 | stride=1, dilation=1, padding=0, groups=1, nd=1): 12 | '''Computes per-example gradients for nn.Conv1d and nn.Conv2d layers. 13 | 14 | This function is used in the internal behaviour of bnn.Linear. 15 | ''' 16 | 17 | # Change format of stride from int to tuple if necessary. 18 | if isinstance(kernel_size, int): 19 | kernel_size = (kernel_size,) * nd 20 | if isinstance(stride, int): 21 | stride = (stride,) * nd 22 | if isinstance(dilation, int): 23 | dilation = (dilation,) * nd 24 | if isinstance(padding, int): 25 | padding = (padding,) * nd 26 | 27 | # Get some useful sizes 28 | batch_size = input.size(0) 29 | input_shape = input.size()[-nd:] 30 | output_shape = grad_output.size()[-nd:] 31 | 32 | # Reshape to extract groups from the convolutional layer 33 | # Channels are seen as an extra spatial dimension with kernel size 1 34 | input_conv = input.view(1, batch_size * groups, in_channels // groups, *input_shape) 35 | 36 | # Compute convolution between input and output; the batchsize is seen 37 | # as channels, taking advantage of the `groups` argument 38 | grad_output_conv = grad_output.view(-1, 1, 1, *output_shape) 39 | 40 | stride = (1, *stride) 41 | dilation = (1, *dilation) 42 | padding = (0, *padding) 43 | 44 | if nd == 1: 45 | convnd = F.conv2d 46 | s_ = np.s_[..., :kernel_size[0]] 47 | elif nd == 2: 48 | convnd = F.conv3d 49 | s_ = np.s_[..., :kernel_size[0], :kernel_size[1]] 50 | elif nd == 3: 51 | raise NotImplementedError('3d convolution is not available with current per-example gradient computation') 52 | 53 | conv = convnd( 54 | input_conv, grad_output_conv, 55 | groups=batch_size * groups, 56 | stride=dilation, 57 | dilation=stride, 58 | padding=padding 59 | ) 60 | 61 | # Because of rounding shapes when using non-default stride or dilation, 62 | # convolution result must be truncated to convolution kernel size 63 | conv = conv[s_] 64 | 65 | # Reshape weight gradient to correct shape 66 | new_shape = [batch_size, out_channels, in_channels // groups, *kernel_size] 67 | weight_bgrad = conv.view(*new_shape).contiguous() 68 | 69 | return weight_bgrad 70 | 71 | 72 | def conv1d_backward(*args, **kwargs): 73 | '''Computes per-example gradients for nn.Conv1d layers.''' 74 | return conv_backward(*args, nd=1, **kwargs) 75 | 76 | 77 | def conv2d_backward_using_conv(mod, x, gy): 78 | '''Computes per-example gradients for nn.Conv2d layers.''' 79 | return conv_backward(x, gy, nd=2, 80 | in_channels=mod.in_channels, 81 | out_channels=mod.out_channels, 82 | kernel_size=mod.kernel_size, 83 | stride=mod.stride, 84 | dilation=mod.dilation, 85 | padding=mod.padding, 86 | groups=mod.groups) 87 | 88 | 89 | def conv2d_backward_using_unfold(mod, x, gy): 90 | '''Computes per-example gradients for nn.Conv2d layers.''' 91 | ks = (mod.weight.size(2), mod.weight.size(3)) 92 | gy_s = gy.size() 93 | bs = gy_s[0] 94 | x_unfold = F.unfold(x, kernel_size=ks, stride=mod.stride, 95 | padding=mod.padding, dilation=mod.dilation) 96 | x_unfold_s = x_unfold.size() 97 | return torch.bmm(gy.view(bs, gy_s[1], -1), 98 | x_unfold.view(bs, x_unfold_s[1], -1).permute(0, 2, 1)) 99 | 100 | 101 | def conv2d_backward(*args, **kwargs): 102 | return _conv_grad_impl.get_impl()(*args, **kwargs) 103 | 104 | 105 | class ConvGradImplManager: 106 | 107 | def __init__(self): 108 | self._use_unfold = True 109 | 110 | def use_unfold(self, choice=True): 111 | self._use_unfold = choice 112 | 113 | def get_impl(self): 114 | if self._use_unfold: 115 | return conv2d_backward_using_unfold 116 | else: 117 | return conv2d_backward_using_conv 118 | 119 | 120 | _conv_grad_impl = ConvGradImplManager() 121 | 122 | class use_unfold_impl_for_convs: 123 | 124 | def __enter__(self): 125 | self.prev = _conv_grad_impl._use_unfold 126 | _conv_grad_impl.use_unfold(True) 127 | 128 | def __exit__(self, exc_type, exc_value, traceback): 129 | _conv_grad_impl._use_unfold = self.prev 130 | 131 | class use_conv_impl_for_convs: 132 | 133 | def __enter__(self): 134 | self.prev = _conv_grad_impl._use_unfold 135 | _conv_grad_impl.use_unfold(False) 136 | 137 | def __exit__(self, exc_type, exc_value, traceback): 138 | _conv_grad_impl._use_unfold = self.prev 139 | 140 | 141 | def convtranspose2d_backward(mod, x, gy): 142 | '''Computes per-example gradients for nn.ConvTranspose2d layers.''' 143 | bs = gy.size(0) 144 | s_i, s_o, k_h, k_w = mod.weight.size() 145 | x_unfold = unfold_transpose_conv2d(mod, x) 146 | 147 | x_perm = x_unfold.view(bs, s_i*k_w*k_h, -1).permute(0, 2, 1) 148 | o = torch.bmm(gy.view(bs, s_o, -1), x_perm) 149 | o = o.view(bs, s_o, s_i, k_h, k_w).permute(0, 2, 1, 3, 4) 150 | o = o.contiguous() 151 | return o 152 | 153 | 154 | def unfold_transpose_conv2d(mod, x): 155 | unfold_filter = _filter_bank.get(mod) 156 | return F.conv_transpose2d(x, unfold_filter, stride=mod.stride, padding=mod.padding, 157 | output_padding=mod.output_padding, groups=mod.in_channels, 158 | dilation=mod.dilation) 159 | 160 | class TransposeConv_Unfold_Filter_Bank: 161 | 162 | def __init__(self): 163 | self.filters = dict() 164 | 165 | def get(self, mod): 166 | if mod not in self.filters: 167 | self.filters[mod] = self._create_unfold_filter(mod) 168 | return self.filters[mod] 169 | 170 | def _create_unfold_filter(self, mod): 171 | kw, kh = mod.kernel_size 172 | unfold_filter = mod.weight.data.new(mod.in_channels, kw * kh, kw, kh) 173 | unfold_filter.fill_(0) 174 | for i in range(mod.in_channels): 175 | for j in range(kw): 176 | for k in range(kh): 177 | unfold_filter[i, k + kh*j, j, k] = 1 178 | return unfold_filter 179 | 180 | _filter_bank = TransposeConv_Unfold_Filter_Bank() -------------------------------------------------------------------------------- /nngeometry/generator/lm_jacobian/grads_conv.py: -------------------------------------------------------------------------------- 1 | # Author(s): Gaspar Rochette 2 | # License: BSD 3 clause 3 | # These functions are borrowed from https://github.com/owkin/grad-cnns 4 | 5 | import numpy as np 6 | import torch 7 | from torch._C import unify_type_list 8 | import torch.nn.functional as F 9 | 10 | def conv_backward(input, grad_output, in_channels, out_channels, kernel_size, 11 | stride=1, dilation=1, padding=0, groups=1, nd=1): 12 | '''Computes per-example gradients for nn.Conv1d and nn.Conv2d layers. 13 | 14 | This function is used in the internal behaviour of bnn.Linear. 15 | ''' 16 | 17 | # Change format of stride from int to tuple if necessary. 18 | if isinstance(kernel_size, int): 19 | kernel_size = (kernel_size,) * nd 20 | if isinstance(stride, int): 21 | stride = (stride,) * nd 22 | if isinstance(dilation, int): 23 | dilation = (dilation,) * nd 24 | if isinstance(padding, int): 25 | padding = (padding,) * nd 26 | 27 | # Get some useful sizes 28 | batch_size = input.size(0) 29 | input_shape = input.size()[-nd:] 30 | output_shape = grad_output.size()[-nd:] 31 | 32 | # Reshape to extract groups from the convolutional layer 33 | # Channels are seen as an extra spatial dimension with kernel size 1 34 | input_conv = input.view(1, batch_size * groups, in_channels // groups, *input_shape) 35 | 36 | # Compute convolution between input and output; the batchsize is seen 37 | # as channels, taking advantage of the `groups` argument 38 | grad_output_conv = grad_output.view(-1, 1, 1, *output_shape) 39 | 40 | stride = (1, *stride) 41 | dilation = (1, *dilation) 42 | padding = (0, *padding) 43 | 44 | if nd == 1: 45 | convnd = F.conv2d 46 | s_ = np.s_[..., :kernel_size[0]] 47 | elif nd == 2: 48 | convnd = F.conv3d 49 | s_ = np.s_[..., :kernel_size[0], :kernel_size[1]] 50 | elif nd == 3: 51 | raise NotImplementedError('3d convolution is not available with current per-example gradient computation') 52 | 53 | conv = convnd( 54 | input_conv, grad_output_conv, 55 | groups=batch_size * groups, 56 | stride=dilation, 57 | dilation=stride, 58 | padding=padding 59 | ) 60 | 61 | # Because of rounding shapes when using non-default stride or dilation, 62 | # convolution result must be truncated to convolution kernel size 63 | conv = conv[s_] 64 | 65 | # Reshape weight gradient to correct shape 66 | new_shape = [batch_size, out_channels, in_channels // groups, *kernel_size] 67 | weight_bgrad = conv.view(*new_shape).contiguous() 68 | 69 | return weight_bgrad 70 | 71 | 72 | def conv1d_backward(*args, **kwargs): 73 | '''Computes per-example gradients for nn.Conv1d layers.''' 74 | return conv_backward(*args, nd=1, **kwargs) 75 | 76 | 77 | def conv2d_backward_using_conv(mod, x, gy): 78 | '''Computes per-example gradients for nn.Conv2d layers.''' 79 | return conv_backward(x, gy, nd=2, 80 | in_channels=mod.in_channels, 81 | out_channels=mod.out_channels, 82 | kernel_size=mod.kernel_size, 83 | stride=mod.stride, 84 | dilation=mod.dilation, 85 | padding=mod.padding, 86 | groups=mod.groups) 87 | 88 | 89 | def conv2d_backward_using_unfold(mod, x, gy): 90 | '''Computes per-example gradients for nn.Conv2d layers.''' 91 | ks = (mod.weight.size(2), mod.weight.size(3)) 92 | gy_s = gy.size() 93 | bs = gy_s[0] 94 | x_unfold = F.unfold(x, kernel_size=ks, stride=mod.stride, 95 | padding=mod.padding, dilation=mod.dilation) 96 | x_unfold_s = x_unfold.size() 97 | return torch.bmm(gy.view(bs, gy_s[1], -1), 98 | x_unfold.view(bs, x_unfold_s[1], -1).permute(0, 2, 1)) 99 | 100 | 101 | def conv2d_backward(*args, **kwargs): 102 | return _conv_grad_impl.get_impl()(*args, **kwargs) 103 | 104 | 105 | class ConvGradImplManager: 106 | 107 | def __init__(self): 108 | self._use_unfold = True 109 | 110 | def use_unfold(self, choice=True): 111 | self._use_unfold = choice 112 | 113 | def get_impl(self): 114 | if self._use_unfold: 115 | return conv2d_backward_using_unfold 116 | else: 117 | return conv2d_backward_using_conv 118 | 119 | 120 | _conv_grad_impl = ConvGradImplManager() 121 | 122 | class use_unfold_impl_for_convs: 123 | 124 | def __enter__(self): 125 | self.prev = _conv_grad_impl._use_unfold 126 | _conv_grad_impl.use_unfold(True) 127 | 128 | def __exit__(self, exc_type, exc_value, traceback): 129 | _conv_grad_impl._use_unfold = self.prev 130 | 131 | class use_conv_impl_for_convs: 132 | 133 | def __enter__(self): 134 | self.prev = _conv_grad_impl._use_unfold 135 | _conv_grad_impl.use_unfold(False) 136 | 137 | def __exit__(self, exc_type, exc_value, traceback): 138 | _conv_grad_impl._use_unfold = self.prev 139 | 140 | 141 | def convtranspose2d_backward(mod, x, gy): 142 | '''Computes per-example gradients for nn.ConvTranspose2d layers.''' 143 | bs = gy.size(0) 144 | s_i, s_o, k_h, k_w = mod.weight.size() 145 | x_unfold = unfold_transpose_conv2d(mod, x) 146 | 147 | x_perm = x_unfold.view(bs, s_i*k_w*k_h, -1).permute(0, 2, 1) 148 | o = torch.bmm(gy.view(bs, s_o, -1), x_perm) 149 | o = o.view(bs, s_o, s_i, k_h, k_w).permute(0, 2, 1, 3, 4) 150 | o = o.contiguous() 151 | return o 152 | 153 | 154 | def unfold_transpose_conv2d(mod, x): 155 | unfold_filter = _filter_bank.get(mod) 156 | return F.conv_transpose2d(x, unfold_filter, stride=mod.stride, padding=mod.padding, 157 | output_padding=mod.output_padding, groups=mod.in_channels, 158 | dilation=mod.dilation) 159 | 160 | class TransposeConv_Unfold_Filter_Bank: 161 | 162 | def __init__(self): 163 | self.filters = dict() 164 | 165 | def get(self, mod): 166 | if mod not in self.filters: 167 | self.filters[mod] = self._create_unfold_filter(mod) 168 | return self.filters[mod] 169 | 170 | def _create_unfold_filter(self, mod): 171 | kw, kh = mod.kernel_size 172 | unfold_filter = mod.weight.data.new(mod.in_channels, kw * kh, kw, kh) 173 | unfold_filter.fill_(0) 174 | for i in range(mod.in_channels): 175 | for j in range(kw): 176 | for k in range(kh): 177 | unfold_filter[i, k + kh*j, j, k] = 1 178 | return unfold_filter 179 | 180 | _filter_bank = TransposeConv_Unfold_Filter_Bank() -------------------------------------------------------------------------------- /nngeometry/generator/para_lm_jacobian/grads_conv.py: -------------------------------------------------------------------------------- 1 | # Author(s): Gaspar Rochette 2 | # License: BSD 3 clause 3 | # These functions are borrowed from https://github.com/owkin/grad-cnns 4 | 5 | import numpy as np 6 | import torch 7 | from torch._C import unify_type_list 8 | import torch.nn.functional as F 9 | 10 | def conv_backward(input, grad_output, in_channels, out_channels, kernel_size, 11 | stride=1, dilation=1, padding=0, groups=1, nd=1): 12 | '''Computes per-example gradients for nn.Conv1d and nn.Conv2d layers. 13 | 14 | This function is used in the internal behaviour of bnn.Linear. 15 | ''' 16 | 17 | # Change format of stride from int to tuple if necessary. 18 | if isinstance(kernel_size, int): 19 | kernel_size = (kernel_size,) * nd 20 | if isinstance(stride, int): 21 | stride = (stride,) * nd 22 | if isinstance(dilation, int): 23 | dilation = (dilation,) * nd 24 | if isinstance(padding, int): 25 | padding = (padding,) * nd 26 | 27 | # Get some useful sizes 28 | batch_size = input.size(0) 29 | input_shape = input.size()[-nd:] 30 | output_shape = grad_output.size()[-nd:] 31 | 32 | # Reshape to extract groups from the convolutional layer 33 | # Channels are seen as an extra spatial dimension with kernel size 1 34 | input_conv = input.view(1, batch_size * groups, in_channels // groups, *input_shape) 35 | 36 | # Compute convolution between input and output; the batchsize is seen 37 | # as channels, taking advantage of the `groups` argument 38 | grad_output_conv = grad_output.view(-1, 1, 1, *output_shape) 39 | 40 | stride = (1, *stride) 41 | dilation = (1, *dilation) 42 | padding = (0, *padding) 43 | 44 | if nd == 1: 45 | convnd = F.conv2d 46 | s_ = np.s_[..., :kernel_size[0]] 47 | elif nd == 2: 48 | convnd = F.conv3d 49 | s_ = np.s_[..., :kernel_size[0], :kernel_size[1]] 50 | elif nd == 3: 51 | raise NotImplementedError('3d convolution is not available with current per-example gradient computation') 52 | 53 | conv = convnd( 54 | input_conv, grad_output_conv, 55 | groups=batch_size * groups, 56 | stride=dilation, 57 | dilation=stride, 58 | padding=padding 59 | ) 60 | 61 | # Because of rounding shapes when using non-default stride or dilation, 62 | # convolution result must be truncated to convolution kernel size 63 | conv = conv[s_] 64 | 65 | # Reshape weight gradient to correct shape 66 | new_shape = [batch_size, out_channels, in_channels // groups, *kernel_size] 67 | weight_bgrad = conv.view(*new_shape).contiguous() 68 | 69 | return weight_bgrad 70 | 71 | 72 | def conv1d_backward(*args, **kwargs): 73 | '''Computes per-example gradients for nn.Conv1d layers.''' 74 | return conv_backward(*args, nd=1, **kwargs) 75 | 76 | 77 | def conv2d_backward_using_conv(mod, x, gy): 78 | '''Computes per-example gradients for nn.Conv2d layers.''' 79 | return conv_backward(x, gy, nd=2, 80 | in_channels=mod.in_channels, 81 | out_channels=mod.out_channels, 82 | kernel_size=mod.kernel_size, 83 | stride=mod.stride, 84 | dilation=mod.dilation, 85 | padding=mod.padding, 86 | groups=mod.groups) 87 | 88 | 89 | def conv2d_backward_using_unfold(mod, x, gy): 90 | '''Computes per-example gradients for nn.Conv2d layers.''' 91 | ks = (mod.weight.size(2), mod.weight.size(3)) 92 | gy_s = gy.size() 93 | bs = gy_s[0] 94 | x_unfold = F.unfold(x, kernel_size=ks, stride=mod.stride, 95 | padding=mod.padding, dilation=mod.dilation) 96 | x_unfold_s = x_unfold.size() 97 | return torch.bmm(gy.view(bs, gy_s[1], -1), 98 | x_unfold.view(bs, x_unfold_s[1], -1).permute(0, 2, 1)) 99 | 100 | 101 | def conv2d_backward(*args, **kwargs): 102 | return _conv_grad_impl.get_impl()(*args, **kwargs) 103 | 104 | 105 | class ConvGradImplManager: 106 | 107 | def __init__(self): 108 | self._use_unfold = True 109 | 110 | def use_unfold(self, choice=True): 111 | self._use_unfold = choice 112 | 113 | def get_impl(self): 114 | if self._use_unfold: 115 | return conv2d_backward_using_unfold 116 | else: 117 | return conv2d_backward_using_conv 118 | 119 | 120 | _conv_grad_impl = ConvGradImplManager() 121 | 122 | class use_unfold_impl_for_convs: 123 | 124 | def __enter__(self): 125 | self.prev = _conv_grad_impl._use_unfold 126 | _conv_grad_impl.use_unfold(True) 127 | 128 | def __exit__(self, exc_type, exc_value, traceback): 129 | _conv_grad_impl._use_unfold = self.prev 130 | 131 | class use_conv_impl_for_convs: 132 | 133 | def __enter__(self): 134 | self.prev = _conv_grad_impl._use_unfold 135 | _conv_grad_impl.use_unfold(False) 136 | 137 | def __exit__(self, exc_type, exc_value, traceback): 138 | _conv_grad_impl._use_unfold = self.prev 139 | 140 | 141 | def convtranspose2d_backward(mod, x, gy): 142 | '''Computes per-example gradients for nn.ConvTranspose2d layers.''' 143 | bs = gy.size(0) 144 | s_i, s_o, k_h, k_w = mod.weight.size() 145 | x_unfold = unfold_transpose_conv2d(mod, x) 146 | 147 | x_perm = x_unfold.view(bs, s_i*k_w*k_h, -1).permute(0, 2, 1) 148 | o = torch.bmm(gy.view(bs, s_o, -1), x_perm) 149 | o = o.view(bs, s_o, s_i, k_h, k_w).permute(0, 2, 1, 3, 4) 150 | o = o.contiguous() 151 | return o 152 | 153 | 154 | def unfold_transpose_conv2d(mod, x): 155 | unfold_filter = _filter_bank.get(mod) 156 | return F.conv_transpose2d(x, unfold_filter, stride=mod.stride, padding=mod.padding, 157 | output_padding=mod.output_padding, groups=mod.in_channels, 158 | dilation=mod.dilation) 159 | 160 | class TransposeConv_Unfold_Filter_Bank: 161 | 162 | def __init__(self): 163 | self.filters = dict() 164 | 165 | def get(self, mod): 166 | if mod not in self.filters: 167 | self.filters[mod] = self._create_unfold_filter(mod) 168 | return self.filters[mod] 169 | 170 | def _create_unfold_filter(self, mod): 171 | kw, kh = mod.kernel_size 172 | unfold_filter = mod.weight.data.new(mod.in_channels, kw * kh, kw, kh) 173 | unfold_filter.fill_(0) 174 | for i in range(mod.in_channels): 175 | for j in range(kw): 176 | for k in range(kh): 177 | unfold_filter[i, k + kh*j, j, k] = 1 178 | return unfold_filter 179 | 180 | _filter_bank = TransposeConv_Unfold_Filter_Bank() -------------------------------------------------------------------------------- /nngeometry/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.functional import softmax 3 | from .generator.jacobian import Jacobian 4 | from .layercollection import LayerCollection 5 | 6 | 7 | def FIM_MonteCarlo(model, 8 | loader, 9 | representation, 10 | variant='classif_logits', 11 | trials=1, 12 | device='cpu', 13 | function=None, 14 | layer_collection=None): 15 | """ 16 | Helper that creates a matrix computing the Fisher Information 17 | Matrix using a Monte-Carlo estimate of y|x with `trials` samples per 18 | example 19 | 20 | Parameters 21 | ---------- 22 | model : torch.nn.Module 23 | The model that contains all parameters of the function 24 | loader : torch.utils.data.DataLoader 25 | DataLoader for computing expectation over the input space 26 | representation : class 27 | The parameter matrix representation that will be used to store 28 | the matrix 29 | variants : string 'classif_logits' or 'regression', optional 30 | (default='classif_logits') 31 | Variant to use depending on how you interpret your function. 32 | Possible choices are: 33 | - 'classif_logits' when using logits for classification 34 | - 'classif_logsoftmax' when using log_softmax values for classification 35 | - 'segmentation_logits' when using logits in a segmentation task 36 | trials : int, optional (default=1) 37 | Number of trials for Monte Carlo sampling 38 | device : string, optional (default='cpu') 39 | Target device for the returned matrix 40 | function : function, optional (default=None) 41 | An optional function if different from `model(input)`. If 42 | it is different from None, it will override the device 43 | parameter. 44 | layer_collection : layercollection.LayerCollection, optional 45 | (default=None) 46 | An optional layer collection 47 | 48 | """ 49 | 50 | if function is None: 51 | def function(*d): 52 | return model(d[0].to(device)) 53 | 54 | if layer_collection is None: 55 | layer_collection = LayerCollection.from_model(model) 56 | 57 | if variant == 'classif_logits': 58 | 59 | def fim_function(*d): 60 | log_softmax = torch.log_softmax(function(*d), dim=1) 61 | probabilities = torch.exp(log_softmax) 62 | sampled_targets = torch.multinomial(probabilities, trials, 63 | replacement=True) 64 | return trials ** -.5 * torch.gather(log_softmax, 1, 65 | sampled_targets) 66 | elif variant == 'classif_logsoftmax': 67 | 68 | def fim_function(*d): 69 | log_softmax = function(*d) 70 | probabilities = torch.exp(log_softmax) 71 | sampled_targets = torch.multinomial(probabilities, trials, 72 | replacement=True) 73 | return trials ** -.5 * torch.gather(log_softmax, 1, 74 | sampled_targets) 75 | elif variant == 'segmentation_logits': 76 | 77 | def fim_function(*d): 78 | log_softmax = torch.log_softmax(function(*d), dim=1) 79 | s_mb, s_c, s_h, s_w = log_softmax.size() 80 | log_softmax = log_softmax.permute(0, 2, 3, 1).contiguous() \ 81 | .view(s_mb * s_h * s_w, s_c) 82 | probabilities = torch.exp(log_softmax) 83 | sampled_indices = torch.multinomial(probabilities, trials, 84 | replacement=True) 85 | sampled_targets = torch.gather(log_softmax, 1, 86 | sampled_indices) 87 | sampled_targets = sampled_targets.view(s_mb, s_h * s_w, trials) \ 88 | .sum(dim=1) 89 | return trials ** -.5 * sampled_targets 90 | 91 | else: 92 | raise NotImplementedError 93 | 94 | generator = Jacobian(layer_collection=layer_collection, 95 | model=model, 96 | function=fim_function, 97 | n_output=trials) 98 | return representation(generator=generator, examples=loader) 99 | 100 | 101 | def FIM(model, 102 | loader, 103 | representation, 104 | n_output, 105 | variant='classif_logits', 106 | device='cpu', 107 | function=None, 108 | layer_collection=None): 109 | """ 110 | Helper that creates a matrix computing the Fisher Information 111 | Matrix using closed form expressions for the expectation y|x 112 | as described in (Pascanu and Bengio, 2013) 113 | 114 | Parameters 115 | ---------- 116 | model : torch.nn.Module 117 | The model that contains all parameters of the function 118 | loader : torch.utils.data.DataLoader 119 | DataLoader for computing expectation over the input space 120 | representation : class 121 | The parameter matrix representation that will be used to store 122 | the matrix 123 | n_output : int 124 | Number of outputs of the model 125 | variants : string 'classif_logits' or 'regression', optional 126 | (default='classif_logits') 127 | Variant to use depending on how you interpret your function. 128 | Possible choices are: 129 | - 'classif_logits' when using logits for classification 130 | - 'regression' when using a gaussian regression model 131 | device : string, optional (default='cpu') 132 | Target device for the returned matrix 133 | function : function, optional (default=None) 134 | An optional function if different from `model(input)`. If 135 | it is different from None, it will override the device 136 | parameter. 137 | layer_collection : layercollection.LayerCollection, optional 138 | (default=None) 139 | An optional layer collection 140 | """ 141 | 142 | if function is None: 143 | def function(*d): 144 | return model(d[0].to(device)) 145 | 146 | if layer_collection is None: 147 | layer_collection = LayerCollection.from_model(model) 148 | 149 | if variant == 'classif_logits': 150 | 151 | def function_fim(*d): 152 | log_probs = torch.log_softmax(function(*d), dim=1) 153 | probs = torch.exp(log_probs).detach() 154 | return (log_probs * probs**.5) 155 | 156 | elif variant == 'regression': 157 | 158 | def function_fim(*d): 159 | estimates = function(*d) 160 | return estimates 161 | else: 162 | raise NotImplementedError 163 | 164 | generator = Jacobian(layer_collection=layer_collection, 165 | model=model, 166 | function=function_fim, 167 | n_output=n_output) 168 | return representation(generator=generator, examples=loader) 169 | -------------------------------------------------------------------------------- /nngeometry/lm_metrics_para.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.functional import softmax 3 | # from .generator.jacobian import Jacobian 4 | from .generator.para_lm_jacobian import Jacobian 5 | from .layercollection import LayerCollection 6 | from torch import multiprocessing 7 | 8 | def FIM_MonteCarlo(model, 9 | loader, 10 | representation, 11 | variant='classif_logits', 12 | trials=1, 13 | device='cpu', 14 | function=None, 15 | layer_collection=None): 16 | """ 17 | Helper that creates a matrix computing the Fisher Information 18 | Matrix using a Monte-Carlo estimate of y|x with `trials` samples per 19 | example 20 | 21 | Parameters 22 | ---------- 23 | model : torch.nn.Module 24 | The model that contains all parameters of the function 25 | loader : torch.utils.data.DataLoader 26 | DataLoader for computing expectation over the input space 27 | representation : class 28 | The parameter matrix representation that will be used to store 29 | the matrix 30 | variants : string 'classif_logits' or 'regression', optional 31 | (default='classif_logits') 32 | Variant to use depending on how you interpret your function. 33 | Possible choices are: 34 | - 'classif_logits' when using logits for classification 35 | - 'classif_logsoftmax' when using log_softmax values for classification 36 | - 'segmentation_logits' when using logits in a segmentation task 37 | trials : int, optional (default=1) 38 | Number of trials for Monte Carlo sampling 39 | device : string, optional (default='cpu') 40 | Target device for the returned matrix 41 | function : function, optional (default=None) 42 | An optional function if different from `model(input)`. If 43 | it is different from None, it will override the device 44 | parameter. 45 | layer_collection : layercollection.LayerCollection, optional 46 | (default=None) 47 | An optional layer collection 48 | 49 | """ 50 | 51 | if function is None: 52 | def function(*d): 53 | return model(inputs_embeds=d[0].to(device)) 54 | 55 | if layer_collection is None: 56 | layer_collection = LayerCollection.from_model(model) 57 | 58 | if variant == 'classif_logits': 59 | 60 | def fim_function(*d): 61 | out=function(*d) 62 | # print(out.keys()) 63 | lgt=out['logits'].mean(1) 64 | # lgt=out.logits.mean(1) 65 | # print(f'lgt: {lgt}') 66 | log_softmax = torch.log_softmax(lgt, dim=1) 67 | probabilities = torch.exp(log_softmax) 68 | # print(f'log_softmax: {log_softmax}') 69 | # print(f'prob: {probabilities}') 70 | sampled_targets = torch.multinomial(probabilities, trials, 71 | replacement=True) 72 | # print(probabilities.shape) 73 | # st=[] 74 | # for i in range(probabilities.shape[1]): 75 | # sampled_targets = torch.multinomial(probabilities[:,i,:], trials, 76 | # replacement=True) 77 | # st.append(sampled_targets) 78 | # sampled_targets=torch.stack(st, dim=1) 79 | # print(sampled_targets.shape) 80 | # print(f'sampled_targets: {sampled_targets}') 81 | 82 | return trials ** -.5 * torch.gather(log_softmax, 1, 83 | sampled_targets) 84 | elif variant == 'classif_logsoftmax': 85 | 86 | def fim_function(*d): 87 | log_softmax = function(*d) 88 | probabilities = torch.exp(log_softmax) 89 | sampled_targets = torch.multinomial(probabilities, trials, 90 | replacement=True) 91 | return trials ** -.5 * torch.gather(log_softmax, 1, 92 | sampled_targets) 93 | elif variant == 'segmentation_logits': 94 | 95 | def fim_function(*d): 96 | log_softmax = torch.log_softmax(function(*d), dim=1) 97 | s_mb, s_c, s_h, s_w = log_softmax.size() 98 | log_softmax = log_softmax.permute(0, 2, 3, 1).contiguous() \ 99 | .view(s_mb * s_h * s_w, s_c) 100 | probabilities = torch.exp(log_softmax) 101 | sampled_indices = torch.multinomial(probabilities, trials, 102 | replacement=True) 103 | sampled_targets = torch.gather(log_softmax, 1, 104 | sampled_indices) 105 | sampled_targets = sampled_targets.view(s_mb, s_h * s_w, trials) \ 106 | .sum(dim=1) 107 | return trials ** -.5 * sampled_targets 108 | 109 | else: 110 | raise NotImplementedError 111 | 112 | generator = Jacobian(layer_collection=layer_collection, 113 | model=model, 114 | function=fim_function, 115 | n_output=trials) 116 | return representation(generator=generator, examples=loader) 117 | 118 | 119 | def FIM(model, 120 | loader, 121 | representation, 122 | n_output, 123 | variant='classif_logits', 124 | device='cpu', 125 | function=None, 126 | layer_collection=None): 127 | """ 128 | Helper that creates a matrix computing the Fisher Information 129 | Matrix using closed form expressions for the expectation y|x 130 | as described in (Pascanu and Bengio, 2013) 131 | 132 | Parameters 133 | ---------- 134 | model : torch.nn.Module 135 | The model that contains all parameters of the function 136 | loader : torch.utils.data.DataLoader 137 | DataLoader for computing expectation over the input space 138 | representation : class 139 | The parameter matrix representation that will be used to store 140 | the matrix 141 | n_output : int 142 | Number of outputs of the model 143 | variants : string 'classif_logits' or 'regression', optional 144 | (default='classif_logits') 145 | Variant to use depending on how you interpret your function. 146 | Possible choices are: 147 | - 'classif_logits' when using logits for classification 148 | - 'regression' when using a gaussian regression model 149 | device : string, optional (default='cpu') 150 | Target device for the returned matrix 151 | function : function, optional (default=None) 152 | An optional function if different from `model(input)`. If 153 | it is different from None, it will override the device 154 | parameter. 155 | layer_collection : layercollection.LayerCollection, optional 156 | (default=None) 157 | An optional layer collection 158 | """ 159 | 160 | if function is None: 161 | def function(*d): 162 | return model(inputs_embeds=d[0].to(device)) 163 | 164 | if layer_collection is None: 165 | layer_collection = LayerCollection.from_model(model) 166 | 167 | if variant == 'classif_logits': 168 | 169 | def function_fim(*d): 170 | lgt=function(*d).logits 171 | # print(lgt.shape) 172 | log_probs = torch.log_softmax(lgt, dim=2) 173 | probs = torch.exp(log_probs).detach() 174 | return (log_probs * probs**.5) 175 | 176 | elif variant == 'regression': 177 | 178 | def function_fim(*d): 179 | estimates = function(*d) 180 | return estimates 181 | else: 182 | raise NotImplementedError 183 | 184 | generator = Jacobian(layer_collection=layer_collection, 185 | model=model, 186 | function=function_fim, 187 | n_output=n_output) 188 | return representation(generator=generator, examples=loader) 189 | -------------------------------------------------------------------------------- /query_loss_mapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from nngeometry.object import PMatEKFAC 3 | from torch.utils.data import DataLoader 4 | from transformers import AutoTokenizer, AutoModelForCausalLM 5 | import argparse 6 | import pickle 7 | import json 8 | from nngeometry.object import lm_vector 9 | import transformers 10 | import time 11 | from dataset.data.json_data import get_json_train_valid_data, generate_and_tokenize_prompt 12 | from functools import partial 13 | from dataset.prompt_maker.translate_prompt_maker import PromptMaker 14 | from tqdm import tqdm 15 | from nngeometry.maths import kronecker 16 | 17 | def to_device(kfac, device): 18 | for key in kfac.data.keys(): 19 | kfac.data[key]=(kfac.data[key][0].to(device), kfac.data[key][1].to(device)) 20 | torch.cuda.empty_cache() 21 | torch.cuda.empty_cache() 22 | torch.cuda.empty_cache() 23 | torch.cuda.empty_cache() 24 | torch.cuda.empty_cache() 25 | return kfac 26 | 27 | def to_ekfac_and_device(kfac, device): 28 | for key in kfac.data.keys(): 29 | kfac.data[key]=(kfac.data[key][0].to(device), kfac.data[key][1].to(device)) 30 | # --------- 31 | evecs = dict() 32 | diags = dict() 33 | 34 | kfac_blocks = kfac.data 35 | for layer_id, layer in \ 36 | kfac.generator.layer_collection.layers.items(): 37 | a, g = kfac_blocks[layer_id] 38 | evals_a, evecs_a = torch.linalg.eigh(a) 39 | evals_g, evecs_g = torch.linalg.eigh(g) 40 | evecs[layer_id] = (evecs_a, evecs_g) 41 | diags[layer_id] = kronecker(evals_g.view(-1, 1), 42 | evals_a.view(-1, 1)) 43 | del a, g, kfac_blocks[layer_id] 44 | data = (evecs, diags) 45 | ekfac=PMatEKFAC(generator=kfac.generator, data=data) 46 | return ekfac 47 | 48 | # bloom_ignored_layers=['atte', 'lm_head', '0', '1', '2', '3', '21', '22', '23'] 49 | bloom_ignore_lc=['atte', 'lm_head', 'dense_4h_to_h'] 50 | # baichuan_ignore_lc=['attn', 'lm_head', 'up', 'gate', '0', '1', '.2.', '3', '4', '9', '21', '22', '23'] 51 | baichuan_ignore_lc=['.0.', '.1.', '.2.', '.4.', '.5.', '.7.', '.8.', '.10.', '.11.', '.13.', '.14.', '.16.', '.17.', '.19.', '.20.', '.22.', '.23.', '.25.', '.26.', '.28.', '.29.', '30', 'attn', 'up', 'gate', 'lm_head'] 52 | llama_ignore_lc=['.0.', '.1.', '.2.', '.4.', '.5.', '.7.', '.8.', '.10.', '.11.', '.13.', '.14.', '.16.', '.17.', '.19.', '.20.', '.22.', '.23.', '.25.', '.26.', '.28.', '.29.', '30', 'attn', 'up', 'gate', 'lm_head'] 53 | # llama_ignore_lc=['attn', 'lm_head', 'up', 'gate', '0', '1', '.2.', '3', '4', '9', '21', '22', '23'] 54 | CUR_LC=bloom_ignore_lc 55 | 56 | def main(): 57 | parser = argparse.ArgumentParser( 58 | prog='ProgramName', 59 | description='What the program does', 60 | epilog='Text at the bottom of help') 61 | 62 | parser.add_argument('-k', '--kfac') # kfac_dir 63 | parser.add_argument('-m', '--model') 64 | parser.add_argument('-t', '--tokenizer') 65 | parser.add_argument('-v', '--device') # device e.g. 'cuda:1' 66 | parser.add_argument('-l', '--limit', type=int, default=-1) 67 | parser.add_argument('-lq', '--limit_query', type=int, default=-1) 68 | parser.add_argument('-o', '--output', type=str, default='res.jsonl') 69 | parser.add_argument('-d', '--data_path') 70 | parser.add_argument('-q', '--query_path') 71 | parser.add_argument('-bq', '--batch_query', type=int, default=16) 72 | parser.add_argument('-idstart', '--indexestart', type=int) 73 | parser.add_argument('-idend', '--indexesend', type=int) 74 | parser.add_argument('-lmd', '--lambdaa', type=float, default=0.5) 75 | 76 | parser.add_argument('--full-score', default=0, type=int) 77 | parser.add_argument('--ekfac', default=0, type=int) 78 | parser.add_argument('--layer', default='b', type=str) 79 | 80 | 81 | args = parser.parse_args() 82 | print(args) 83 | 84 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=True) 85 | model = AutoModelForCausalLM.from_pretrained(args.model, trust_remote_code=True) 86 | 87 | 88 | model.to(args.device) 89 | model.zero_grad() 90 | with open(args.kfac, 'rb') as f: 91 | 92 | pF = pickle.load(f) 93 | if args.ekfac: 94 | pF=to_ekfac_and_device(pF, args.device) 95 | else: 96 | pF=to_device(pF, args.device) 97 | 98 | class tmpQueryargs: 99 | max_length=256 100 | use_prompt_loss=False 101 | prob_data_display=0.1 102 | data_path=args.query_path 103 | valid_data_path=None 104 | use_large_data=False 105 | val_set_size=None 106 | micro_batch_size=4 107 | tokenizer=args.tokenizer 108 | seed=1 109 | query_data, val_data = get_json_train_valid_data( 110 | args=tmpQueryargs, 111 | data_file=tmpQueryargs.data_path, 112 | valid_data_file=tmpQueryargs.valid_data_path, 113 | val_set_size=tmpQueryargs.val_set_size, 114 | prompt_fn=partial(generate_and_tokenize_prompt, args=tmpQueryargs, verbose=False, tokenizer=tokenizer, prompt_maker=PromptMaker(args=tmpQueryargs)), 115 | ) 116 | 117 | if args.limit_query>0: 118 | query_data=torch.utils.data.Subset(query_data, range(args.limit_query)) 119 | else: 120 | pass 121 | queryloader = DataLoader( 122 | query_data, 123 | shuffle=False, 124 | collate_fn=transformers.DataCollatorForSeq2Seq(tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True), 125 | batch_size=args.batch_query, 126 | pin_memory=False, 127 | drop_last=True 128 | ) 129 | 130 | d_ihvp=[] 131 | for q in tqdm(queryloader): 132 | model.zero_grad() 133 | inp=q['input_ids'].to(args.device) 134 | labels=q['labels'].to(args.device) 135 | 136 | loss=model(input_ids=inp, labels=labels).loss 137 | loss.backward() 138 | 139 | vec_query=lm_vector.PVector.from_model_grad(model, ignore_layers=CUR_LC) 140 | 141 | 142 | ihvp=pF.inverse(regul=args.lambdaa).mv(vec_query) 143 | ihvp=lm_vector.PVector(layer_collection=ihvp.layer_collection, 144 | vector_repr=ihvp.vector_repr, dict_repr=ihvp.dict_repr) 145 | ihvp.svd() 146 | d_ihvp.append(ihvp) 147 | 148 | 149 | class tmpargs: 150 | max_length=256 151 | use_prompt_loss=False 152 | prob_data_display=0.1 153 | data_path=args.data_path 154 | valid_data_path=None 155 | use_large_data=False 156 | val_set_size=None 157 | micro_batch_size=4 158 | tokenizer=args.tokenizer 159 | seed=1 160 | train_data, val_data = get_json_train_valid_data( 161 | args=tmpargs, 162 | data_file=tmpargs.data_path, 163 | valid_data_file=tmpargs.valid_data_path, 164 | val_set_size=tmpargs.val_set_size, 165 | prompt_fn=partial(generate_and_tokenize_prompt, args=tmpargs, verbose=False, tokenizer=tokenizer, prompt_maker=PromptMaker(args=tmpargs)), 166 | ) 167 | candidate_set=train_data 168 | 169 | ids=list(range(args.indexestart, args.indexesend+1)) 170 | print(len(ids)) 171 | sub_dataset=torch.utils.data.Subset(candidate_set, ids) 172 | 173 | trainloader = DataLoader( 174 | sub_dataset, 175 | shuffle=False, 176 | collate_fn=transformers.DataCollatorForSeq2Seq(tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True), 177 | batch_size=1, 178 | pin_memory=False, 179 | drop_last=True 180 | ) 181 | 182 | 183 | st=time.time() 184 | res_json=[] 185 | for data in tqdm(trainloader): 186 | inp=data['input_ids'].to(args.device) 187 | labels=data['labels'].to(args.device) 188 | 189 | model.zero_grad() 190 | loss=model(input_ids=inp, labels=labels).loss 191 | loss.backward() 192 | 193 | vec_candi=lm_vector.PVector.from_model_grad(model, CUR_LC) 194 | vec_candi.svd() 195 | 196 | score=0 197 | score_list=[] 198 | 199 | for ihvp in d_ihvp: 200 | tmp=-(vec_candi.dot_svd(ihvp)) 201 | score=score+tmp 202 | score_list.append(tmp.item()) 203 | 204 | 205 | score=score/len(d_ihvp) 206 | text=tokenizer.batch_decode(data['input_ids'], skip_special_tokens=True) 207 | if args.full_score: 208 | res_json.append({'score': score.item(), 'score_list': score_list, 'loss': loss.item(), 'text': text}) 209 | else: 210 | res_json.append({'score': score.item(), 'loss': loss.item(), 'text': text}) 211 | 212 | 213 | el=time.time()-st 214 | print(f'Elapse: {el}s.') 215 | 216 | res_json=sorted(res_json, key=lambda x: x['score']) 217 | 218 | open(args.output, 'w').write(json.dumps(res_json, indent=4, ensure_ascii=False)) 219 | 220 | 221 | if __name__=='__main__': 222 | main() 223 | -------------------------------------------------------------------------------- /nngeometry/lm_metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.functional import softmax 3 | # from .generator.jacobian import Jacobian 4 | from .generator.lm_jacobian import Jacobian 5 | from .layercollection import LayerCollection 6 | 7 | 8 | def FIM_MonteCarlo(model, 9 | loader, 10 | representation, 11 | variant='classif_logits', 12 | trials=1, 13 | device='cpu', 14 | function=None, 15 | layer_collection=None): 16 | """ 17 | Helper that creates a matrix computing the Fisher Information 18 | Matrix using a Monte-Carlo estimate of y|x with `trials` samples per 19 | example 20 | 21 | Parameters 22 | ---------- 23 | model : torch.nn.Module 24 | The model that contains all parameters of the function 25 | loader : torch.utils.data.DataLoader 26 | DataLoader for computing expectation over the input space 27 | representation : class 28 | The parameter matrix representation that will be used to store 29 | the matrix 30 | variants : string 'classif_logits' or 'regression', optional 31 | (default='classif_logits') 32 | Variant to use depending on how you interpret your function. 33 | Possible choices are: 34 | - 'classif_logits' when using logits for classification 35 | - 'classif_logsoftmax' when using log_softmax values for classification 36 | - 'segmentation_logits' when using logits in a segmentation task 37 | trials : int, optional (default=1) 38 | Number of trials for Monte Carlo sampling 39 | device : string, optional (default='cpu') 40 | Target device for the returned matrix 41 | function : function, optional (default=None) 42 | An optional function if different from `model(input)`. If 43 | it is different from None, it will override the device 44 | parameter. 45 | layer_collection : layercollection.LayerCollection, optional 46 | (default=None) 47 | An optional layer collection 48 | 49 | """ 50 | 51 | if function is None: 52 | def function(*d): 53 | return model(inputs_embeds=d[0].to(device)) 54 | 55 | if layer_collection is None: 56 | layer_collection = LayerCollection.from_model(model) 57 | 58 | if variant == 'classif_logits': 59 | 60 | def fim_function(*d): 61 | out=function(*d) 62 | # print(out.keys()) 63 | lgt=out['logits'] 64 | # lgt=out.logits.mean(1) 65 | # print(f'lgt: {lgt}') 66 | log_softmax = torch.log_softmax(lgt, dim=2) 67 | probabilities = torch.exp(log_softmax) 68 | # print(f'log_softmax: {log_softmax}') 69 | # print(f'prob: {probabilities}') 70 | # sampled_targets = torch.multinomial(probabilities, trials, 71 | # replacement=True) 72 | # print(probabilities.shape) 73 | st=[] 74 | for i in range(probabilities.shape[1]): 75 | sampled_targets = torch.multinomial(probabilities[:,i,:], trials, 76 | replacement=True) 77 | tmp = torch.gather(log_softmax[:, i, :], 1, sampled_targets) 78 | # print(f'tmp: {tmp.shape}') 79 | st.append(tmp) 80 | sampled_targets=torch.stack(st, dim=1) 81 | # print(f'sampled_targets: {sampled_targets.shape}') 82 | # print(f'sampled_targets: {sampled_targets}') 83 | res=trials ** -.5 * sampled_targets 84 | # print(res) 85 | return res 86 | 87 | elif variant == 'classif_logsoftmax': 88 | 89 | def fim_function(*d): 90 | log_softmax = function(*d) 91 | probabilities = torch.exp(log_softmax) 92 | sampled_targets = torch.multinomial(probabilities, trials, 93 | replacement=True) 94 | return trials ** -.5 * torch.gather(log_softmax, 1, 95 | sampled_targets) 96 | elif variant == 'segmentation_logits': 97 | 98 | def fim_function(*d): 99 | log_softmax = torch.log_softmax(function(*d), dim=1) 100 | s_mb, s_c, s_h, s_w = log_softmax.size() 101 | log_softmax = log_softmax.permute(0, 2, 3, 1).contiguous() \ 102 | .view(s_mb * s_h * s_w, s_c) 103 | probabilities = torch.exp(log_softmax) 104 | sampled_indices = torch.multinomial(probabilities, trials, 105 | replacement=True) 106 | sampled_targets = torch.gather(log_softmax, 1, 107 | sampled_indices) 108 | sampled_targets = sampled_targets.view(s_mb, s_h * s_w, trials) \ 109 | .sum(dim=1) 110 | return trials ** -.5 * sampled_targets 111 | 112 | else: 113 | raise NotImplementedError 114 | 115 | generator = Jacobian(layer_collection=layer_collection, 116 | model=model, 117 | function=fim_function, 118 | n_output=trials) 119 | return representation(generator=generator, examples=loader) 120 | 121 | 122 | def FIM(model, 123 | loader, 124 | representation, 125 | n_output, 126 | variant='classif_logits', 127 | device='cpu', 128 | function=None, 129 | layer_collection=None): 130 | """ 131 | Helper that creates a matrix computing the Fisher Information 132 | Matrix using closed form expressions for the expectation y|x 133 | as described in (Pascanu and Bengio, 2013) 134 | 135 | Parameters 136 | ---------- 137 | model : torch.nn.Module 138 | The model that contains all parameters of the function 139 | loader : torch.utils.data.DataLoader 140 | DataLoader for computing expectation over the input space 141 | representation : class 142 | The parameter matrix representation that will be used to store 143 | the matrix 144 | n_output : int 145 | Number of outputs of the model 146 | variants : string 'classif_logits' or 'regression', optional 147 | (default='classif_logits') 148 | Variant to use depending on how you interpret your function. 149 | Possible choices are: 150 | - 'classif_logits' when using logits for classification 151 | - 'regression' when using a gaussian regression model 152 | device : string, optional (default='cpu') 153 | Target device for the returned matrix 154 | function : function, optional (default=None) 155 | An optional function if different from `model(input)`. If 156 | it is different from None, it will override the device 157 | parameter. 158 | layer_collection : layercollection.LayerCollection, optional 159 | (default=None) 160 | An optional layer collection 161 | """ 162 | 163 | if function is None: 164 | def function(d): 165 | return model(input_ids=d.to(device)) 166 | # return model(inputs_embeds=d[0].to(device)) 167 | 168 | if layer_collection is None: 169 | layer_collection = LayerCollection.from_model(model) 170 | 171 | if variant == 'classif_logits': 172 | 173 | def function_fim(*d): 174 | lgt=function(*d).logits 175 | # print(lgt.shape) 176 | log_probs = torch.log_softmax(lgt, dim=2) 177 | probs = torch.exp(log_probs).detach() 178 | return (log_probs * probs**.5) 179 | 180 | elif variant == 'empirical_fisher': 181 | def function_fim(d): 182 | d=d.to(device) 183 | inp=d 184 | out=function(d) 185 | # print(out.keys()) 186 | lgt=out['logits'] 187 | # lgt=out.logits.mean(1) 188 | # print(f'lgt: {lgt}') 189 | log_softmax = torch.log_softmax(lgt, dim=2) 190 | probabilities = torch.exp(log_softmax) 191 | inp = inp.unsqueeze(2) 192 | # print(f'inp: {inp.shape}') 193 | # print(f'probabilities: {probabilities.shape}') 194 | sampled_targets=torch.gather(probabilities, 2, inp) 195 | # res=trials ** -.5 * sampled_targets 196 | res=sampled_targets 197 | # print(res.shape) 198 | return res 199 | 200 | elif variant == 'regression': 201 | 202 | def function_fim(*d): 203 | estimates = function(*d) 204 | return estimates 205 | else: 206 | raise NotImplementedError 207 | 208 | generator = Jacobian(layer_collection=layer_collection, 209 | model=model, 210 | function=function_fim, 211 | n_output=n_output) 212 | return representation(generator=generator, examples=loader) 213 | -------------------------------------------------------------------------------- /dataset/data/json_data.py: -------------------------------------------------------------------------------- 1 | 2 | from dataset.data.dataset import ( 3 | DynamicPromptDataset, 4 | StreamDynamicPromptDataset, 5 | COAIDynamicPromptDataset, 6 | COAIStreamDynamicPromptDataset 7 | ) 8 | import json 9 | import random 10 | from dataset.utils.io_utils import load_json, grob_paths 11 | import random 12 | import os 13 | import logging 14 | import copy 15 | 16 | logger = logging.getLogger(__name__) 17 | logger.setLevel(logging.INFO) 18 | 19 | 20 | def get_json_train_valid_data( 21 | args, 22 | val_set_size: int, 23 | data_file: str = "alpaca_data_cleaned.json", 24 | train_val_provider = None, 25 | prompt_fn = None, 26 | valid_data_file: str = None, 27 | for_coai: bool = False, 28 | ): 29 | assert prompt_fn is not None, "please provide prompt_fn" 30 | 31 | if train_val_provider is not None: 32 | train_val_provider(data_file, valid_data_file, val_set_size) 33 | 34 | if args.use_large_data: 35 | assert valid_data_file is not None, \ 36 | """ 37 | For large data, training data is iterable among large mounts of files, 38 | It is not convinient to fetch data as validation set that won't be 39 | resampled again as training data (from an iterable pipeline). 40 | So it is recommended to provide validation data yourself. 41 | """ 42 | train_dataloader_class = COAIStreamDynamicPromptDataset if for_coai else StreamDynamicPromptDataset 43 | else: 44 | train_dataloader_class = COAIDynamicPromptDataset if for_coai else DynamicPromptDataset 45 | valid_dataloader_class = COAIDynamicPromptDataset if for_coai else DynamicPromptDataset 46 | 47 | 48 | # if validation data is provided, just use it. 49 | if valid_data_file: 50 | train_data = train_dataloader_class( 51 | args=args, 52 | json_data=data_file, 53 | dynamic_transform=prompt_fn, 54 | shuffle=False, 55 | from_file=True) 56 | 57 | valid_data = valid_dataloader_class( 58 | args=args, 59 | json_data=valid_data_file, 60 | static_transform=prompt_fn, 61 | from_file=True, 62 | shuffle=False) 63 | 64 | # if validation data is not provided, produce pseudo validation data from training data. 65 | else: 66 | if val_set_size: 67 | raw_train_data = load_json(grob_paths(data_file)) 68 | random.seed(args.seed) 69 | random.shuffle(raw_train_data) 70 | 71 | train_data = train_dataloader_class( 72 | args=args, 73 | json_data=raw_train_data[:-val_set_size], 74 | dynamic_transform=prompt_fn, 75 | shuffle=True, 76 | from_file=False) 77 | 78 | valid_data = valid_dataloader_class( 79 | args=args, 80 | json_data=raw_train_data[-val_set_size:], 81 | static_transform=prompt_fn, 82 | from_file=False, 83 | shuffle=False) 84 | else: 85 | train_data = train_dataloader_class( 86 | args=args, 87 | json_data=data_file, 88 | dynamic_transform=prompt_fn, 89 | shuffle=False, 90 | from_file=True) 91 | valid_data = None 92 | 93 | return train_data, valid_data 94 | 95 | 96 | def generate_and_tokenize_prompt( 97 | data_point, 98 | args = None, 99 | tokenizer = None, 100 | prompt_maker = None, 101 | use_prompt_labels = True, 102 | padding: bool = False, 103 | truncation: bool = True, 104 | verbose: bool = True, 105 | ignore_loss_idx: int = -100, 106 | ): 107 | assert prompt_maker is not None, "please provide prompt_maker" 108 | input_text = prompt_maker.get_input(data_point) 109 | target_text = prompt_maker.get_target(data_point) 110 | full_text = input_text + target_text 111 | 112 | user_prompt = tokenizer( 113 | input_text, 114 | truncation=True, 115 | max_length=args.max_length + 1, 116 | )["input_ids"][:-1] # no eos token 117 | 118 | # -------- 119 | user_prompt = tokenizer( 120 | input_text, 121 | truncation=True, 122 | max_length=args.max_length + 1, 123 | )["input_ids"] 124 | if user_prompt[-1]==tokenizer.eos_token_id: 125 | user_prompt=user_prompt[:-1] 126 | else: 127 | # user_prompt=user_prompt 128 | pass 129 | # -------- 130 | len_user_prompt_tokens = len(user_prompt) 131 | len_user_prompt_tokens = min(len_user_prompt_tokens, args.max_length) 132 | 133 | full_tokens = tokenizer( 134 | full_text, 135 | truncation=truncation, 136 | max_length=args.max_length # 137 | )["input_ids"] 138 | # --------- 139 | if full_tokens[-1] != tokenizer.eos_token_id: 140 | full_tokens=full_tokens+[tokenizer.eos_token_id] 141 | else: 142 | full_tokens=full_tokens 143 | # --------- 144 | attention_mask = [1] * len(full_tokens) 145 | 146 | if args.use_prompt_loss: 147 | labels = copy.deepcopy(full_tokens) 148 | else: 149 | labels = [ignore_loss_idx] * len_user_prompt_tokens + full_tokens[len_user_prompt_tokens:] 150 | 151 | ## deal with padding 152 | if padding: 153 | padded_length = args.max_length - len(full_tokens) 154 | full_tokens.extend([tokenizer.pad_token_id] * padded_length) 155 | labels.extend([ignore_loss_idx] * padded_length) 156 | attention_mask = attention_mask + [0] * padded_length 157 | 158 | if verbose and (random.random() <= args.prob_data_display): 159 | logger.info(f"""### random data case: 160 | batch length = {len(full_tokens)} 161 | (P) prompt = {[input_text]} 162 | (PT) prompt_and_target = {[full_text]} 163 | (PT) tokenized = {full_tokens} 164 | (PT) attention_mask = {attention_mask} 165 | (PT) labels = {labels} 166 | """) 167 | 168 | ## deal with prompt or not (w.r.t. pretrain and instruction tuning) 169 | if use_prompt_labels: 170 | # This function masks out the labels for the input, 171 | # so that our loss is computed only on the response. 172 | return { 173 | "input_ids": full_tokens, 174 | "attention_mask": attention_mask, 175 | "labels": labels, 176 | } 177 | else: 178 | return { 179 | "input_ids": full_tokens, 180 | "attention_mask": attention_mask, 181 | } 182 | 183 | 184 | def generate_and_tokenize_prompt_with_contrastive_label( 185 | data_point, 186 | args = None, 187 | tokenizer = None, 188 | prompt_maker = None, 189 | use_prompt_labels = True, 190 | padding: bool = False, 191 | truncation: bool = True, 192 | verbose: bool = True, 193 | ignore_loss_idx: int = -100, 194 | path_contrastive_label = None, 195 | ): 196 | assert prompt_maker is not None, "please provide prompt_maker" 197 | assert path_contrastive_label is not None 198 | input_text = prompt_maker.get_input(data_point) 199 | target_text = prompt_maker.get_target(data_point) 200 | full_text = input_text + target_text 201 | # print(prompt_maker) 202 | c_target_text = prompt_maker.get_constrastive_target(data_point, path_contrastive_label) 203 | c_full_text = input_text + c_target_text 204 | 205 | user_prompt = tokenizer( 206 | input_text, 207 | truncation=True, 208 | max_length=args.max_length + 1, 209 | )["input_ids"][:-1] # no eos token 210 | 211 | # -------- 212 | # c_target_text_t = tokenizer( 213 | # c_target_text, 214 | # truncation=True, 215 | # max_length=args.max_length + 1, 216 | # )["input_ids"] 217 | # # print(c_target_text, c_target_text_t) 218 | # if not c_target_text_t: c_target_text_t=[tokenizer.eos_token_id] 219 | # if c_target_text_t[-1] != tokenizer.eos_token_id: 220 | # c_target_text_t=c_target_text_t+[tokenizer.eos_token_id] 221 | # else: 222 | # # c_target_text_t=c_target_text_t 223 | # pass 224 | # -------- 225 | user_prompt = tokenizer( 226 | input_text, 227 | truncation=True, 228 | max_length=args.max_length + 1, 229 | )["input_ids"] 230 | if user_prompt[-1]==tokenizer.eos_token_id: 231 | user_prompt=user_prompt[:-1] 232 | else: 233 | # user_prompt=user_prompt 234 | pass 235 | #--------- 236 | len_user_prompt_tokens = len(user_prompt) 237 | # len_c_target_text_tokens = len(c_target_text_t) 238 | len_user_prompt_tokens = min(len_user_prompt_tokens, args.max_length) 239 | 240 | full_tokens = tokenizer( 241 | full_text, 242 | truncation=truncation, 243 | max_length=args.max_length # 244 | )["input_ids"] 245 | # --------- 246 | if full_tokens[-1] != tokenizer.eos_token_id: 247 | full_tokens=full_tokens+[tokenizer.eos_token_id] 248 | else: 249 | full_tokens=full_tokens 250 | # --------- 251 | attention_mask = [1] * len(full_tokens) 252 | 253 | if args.use_prompt_loss: 254 | labels = copy.deepcopy(full_tokens) 255 | else: 256 | labels = [ignore_loss_idx] * len_user_prompt_tokens + full_tokens[len_user_prompt_tokens:] 257 | 258 | ## deal with padding 259 | if padding: 260 | padded_length = args.max_length - len(full_tokens) 261 | full_tokens.extend([tokenizer.pad_token_id] * padded_length) 262 | labels.extend([ignore_loss_idx] * padded_length) 263 | attention_mask = attention_mask + [0] * padded_length 264 | 265 | # if verbose and (random.random() <= args.prob_data_display): 266 | # logger.info(f"""### random data case: 267 | # batch length = {len(full_tokens)} 268 | # (P) prompt = {[input_text]} 269 | # (PT) prompt_and_target = {[full_text]} 270 | # (PT) tokenized = {full_tokens} 271 | # (PT) attention_mask = {attention_mask} 272 | # (PT) labels = {labels} 273 | # """) 274 | 275 | 276 | #-----------deal with c_labels------- 277 | 278 | c_full_tokens = tokenizer( 279 | c_full_text, 280 | truncation=truncation, 281 | max_length=args.max_length # 282 | )["input_ids"] 283 | if len(c_full_tokens)==args.max_length: 284 | pass 285 | else: 286 | if c_full_tokens[-1] != tokenizer.eos_token_id: 287 | c_full_tokens=c_full_tokens+[tokenizer.eos_token_id] 288 | else: 289 | pass 290 | # --------- 291 | # if c_full_tokens[-1] != tokenizer.eos_token_id: 292 | # c_full_tokens=c_full_tokens+[tokenizer.eos_token_id] 293 | # else: 294 | # c_full_tokens=c_full_tokens 295 | # --------- 296 | # attention_mask = [1] * len(c_full_tokens) 297 | 298 | if args.use_prompt_loss: 299 | c_labels = copy.deepcopy(c_full_tokens) 300 | else: 301 | # c_labels = [ignore_loss_idx] * (len(full_tokens)-len_c_target_text_tokens) + c_full_tokens[len_user_prompt_tokens:] 302 | c_labels = [ignore_loss_idx] * len_user_prompt_tokens + c_full_tokens[len_user_prompt_tokens:] 303 | 304 | ## deal with padding 305 | if padding: 306 | padded_length = args.max_length - len(c_full_tokens) 307 | c_full_tokens.extend([tokenizer.pad_token_id] * padded_length) 308 | c_labels.extend([ignore_loss_idx] * padded_length) 309 | # attention_mask = attention_mask + [0] * padded_length 310 | 311 | if verbose and (random.random() <= args.prob_data_display): 312 | logger.info(f"""### random data case: 313 | batch length = {len(full_tokens)} 314 | (P) prompt = {[input_text]} 315 | (PT) prompt_and_target = {[full_text]} 316 | (PT) cons_prompt_and_target = {[c_full_text]} 317 | (PT) tokenized = {full_tokens} 318 | (PT) attention_mask = {attention_mask} 319 | (PT) labels = {labels} 320 | (PT) contrastive labels = {c_labels} 321 | """) 322 | 323 | #------------------------------------ 324 | 325 | ## deal with prompt or not (w.r.t. pretrain and instruction tuning) 326 | if use_prompt_labels: 327 | # This function masks out the labels for the input, 328 | # so that our loss is computed only on the response. 329 | return { 330 | "input_ids": full_tokens, 331 | "attention_mask": attention_mask, 332 | "labels": labels, 333 | "c_labels": c_labels, 334 | } 335 | else: 336 | return { 337 | "input_ids": full_tokens, 338 | "attention_mask": attention_mask, 339 | } 340 | 341 | if __name__ == '__main__': 342 | from ..utils.prompt_maker.custum_prompt_maker import PromptMaker 343 | from transformers import LlamaTokenizer 344 | from functools import partial 345 | 346 | class temp_args(): 347 | max_length = 256 348 | model_id = "/mnt/bn/multilingual-translation/public/hf_models/llama-7b-hf/" 349 | data_file = "/opt/tiger/llama/finetune/alpaca-lora/codes/data/alpaca_data_cleaned.json" 350 | 351 | args = temp_args() 352 | tokenizer = LlamaTokenizer.from_pretrained(args.model_id, add_eos_token=True) 353 | tokenizer.pad_token_id = 0 # unk. we want this to be different from the eos token 354 | 355 | 356 | train_data = DynamicPromptDataset( 357 | json_data=args.data_file, 358 | dynamic_transform=partial(generate_and_tokenize_prompt, args=args, tokenizer=tokenizer, prompt_maker=PromptMaker()), 359 | shuffle=True, 360 | from_file=True 361 | ) 362 | 363 | for data in train_data: 364 | print(data) 365 | -------------------------------------------------------------------------------- /nngeometry/layercollection.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from collections import OrderedDict 3 | from functools import reduce 4 | import operator 5 | 6 | 7 | class LayerCollection: 8 | """ 9 | This class describes a set or subset of layers, that can be used 10 | in order to instantiate :class:`nngeometry.object.PVector` or 11 | :class:`nngeometry.object.PSpaceDense` objects 12 | 13 | :param layers: 14 | """ 15 | 16 | _known_modules = ['Linear', 'Conv2d', 'BatchNorm1d', 17 | 'BatchNorm2d', 'GroupNorm', 'WeightNorm1d', 18 | 'WeightNorm2d', 'Cosine1d', 'Affine1d', 19 | 'ConvTranspose2d'] 20 | 21 | def __init__(self, layers=None): 22 | if layers is None: 23 | self.layers = OrderedDict() 24 | self._numel = 0 25 | self.p_pos = dict() 26 | else: 27 | self.layers = layers 28 | raise NotImplementedError 29 | 30 | def from_model(model, ignore_unsupported_layers=False): 31 | """ 32 | Constructs a new LayerCollection object by using all parameters 33 | of the model passed as argument. 34 | 35 | :param model: The PyTorch model 36 | :type model: `nn.Module` 37 | :param ignore_unsupported_layers: If false, will raise an error 38 | when model contains layers that are not supported yet. If true, will 39 | silently ignore the layer 40 | :type ignore_unsupported_layers: bool 41 | """ 42 | lc = LayerCollection() 43 | for layer, mod in model.named_modules(): 44 | mod_class = mod.__class__.__name__ 45 | if mod_class in LayerCollection._known_modules: 46 | lc.add_layer('%s.%s' % (layer, str(mod)), 47 | LayerCollection._module_to_layer(mod)) 48 | elif not ignore_unsupported_layers: 49 | if len(list(mod.children())) == 0 and len(list(mod.parameters())) > 0: 50 | raise Exception('I do not know what to do with layer ' + str(mod)) 51 | 52 | return lc 53 | 54 | def get_layerid_module_maps(self, model): 55 | layerid_to_module = OrderedDict() 56 | module_to_layerid = OrderedDict() 57 | named_modules = {'%s.%s' % (l, str(m)): m 58 | for l, m in model.named_modules()} 59 | # print(f'name: {named_modules}') 60 | # print(f'key: {self.layers.keys()}') 61 | for layer in self.layers.keys(): 62 | # TODO 63 | # FSDP case 64 | # layer='_fsdp_wrapped_module._fpw_module.'+layer 65 | # layerid_to_module[layer] = named_modules['_fsdp_wrapped_module._fpw_module.'+layer] 66 | layerid_to_module[layer] = named_modules[layer] 67 | # module_to_layerid[named_modules['_fsdp_wrapped_module._fpw_module.'+layer]] = layer 68 | module_to_layerid[named_modules[layer]] = layer 69 | return layerid_to_module, module_to_layerid 70 | 71 | def add_layer(self, name, layer): 72 | self.layers[name] = layer 73 | self.p_pos[name] = self._numel 74 | self._numel += layer.numel() 75 | 76 | def add_layer_from_model(self, model, module): 77 | """ 78 | Add a layer by specifying the module corresponding 79 | to this layer (e.g. torch.nn.Linear or torch.nn.BatchNorm1d) 80 | 81 | :param model: The model defining the neural network 82 | :param module: The layer to be added 83 | """ 84 | if module.__class__.__name__ not in LayerCollection._known_modules: 85 | raise NotImplementedError 86 | for layer, mod in model.named_modules(): 87 | if mod is module: 88 | self.add_layer('%s.%s' % (layer, str(mod)), 89 | LayerCollection._module_to_layer(mod)) 90 | 91 | def _module_to_layer(mod): 92 | mod_class = mod.__class__.__name__ 93 | if mod_class == 'Linear': 94 | return LinearLayer(in_features=mod.in_features, 95 | out_features=mod.out_features, 96 | bias=(mod.bias is not None)) 97 | elif mod_class == 'Conv2d': 98 | return Conv2dLayer(in_channels=mod.in_channels, 99 | out_channels=mod.out_channels, 100 | kernel_size=mod.kernel_size, 101 | bias=(mod.bias is not None)) 102 | elif mod_class == 'ConvTranspose2d': 103 | return ConvTranspose2dLayer(in_channels=mod.in_channels, 104 | out_channels=mod.out_channels, 105 | kernel_size=mod.kernel_size, 106 | bias=(mod.bias is not None)) 107 | elif mod_class == 'BatchNorm1d': 108 | return BatchNorm1dLayer(num_features=mod.num_features) 109 | elif mod_class == 'BatchNorm2d': 110 | return BatchNorm2dLayer(num_features=mod.num_features) 111 | elif mod_class == 'GroupNorm': 112 | return GroupNormLayer(num_groups=mod.num_groups, 113 | num_channels=mod.num_channels) 114 | elif mod_class == 'WeightNorm1d': 115 | return WeightNorm1dLayer(in_features=mod.in_features, 116 | out_features=mod.out_features) 117 | elif mod_class == 'WeightNorm2d': 118 | return WeightNorm2dLayer(in_channels=mod.in_channels, 119 | out_channels=mod.out_channels, 120 | kernel_size=mod.kernel_size) 121 | elif mod_class == 'Cosine1d': 122 | return Cosine1dLayer(in_features=mod.in_features, 123 | out_features=mod.out_features) 124 | elif mod_class == 'Affine1d': 125 | return Affine1dLayer(num_features=mod.num_features, 126 | bias=(mod.bias is not None)) 127 | 128 | def numel(self): 129 | """ 130 | Total number of scalar parameters in this LayerCollection object 131 | 132 | :return: number of scalar parameters 133 | :rtype: int 134 | """ 135 | return self._numel 136 | 137 | def __getitem__(self, layer_id): 138 | return self.layers[layer_id] 139 | 140 | def parameters(self, layerid_to_module): 141 | for layer_id, layer in self.layers.items(): 142 | yield layerid_to_module[layer_id].weight 143 | if (isinstance(layer, BatchNorm1dLayer) or 144 | isinstance(layer, BatchNorm2dLayer)): 145 | yield layerid_to_module[layer_id].bias 146 | # otherwise it is a Linear or Conv2d with optional bias 147 | elif layer.bias: 148 | yield layerid_to_module[layer_id].bias 149 | 150 | def __eq__(self, other): 151 | for layer_id in set(self.layers.keys()).union(set(other.layers.keys())): 152 | if (layer_id not in other.layers.keys() 153 | or layer_id not in self.layers.keys() 154 | or self.layers[layer_id] != other.layers[layer_id]): 155 | return False 156 | return True 157 | 158 | 159 | class AbstractLayer(ABC): 160 | pass 161 | 162 | 163 | class Conv2dLayer(AbstractLayer): 164 | 165 | def __init__(self, in_channels, out_channels, kernel_size, bias=True): 166 | self.in_channels = in_channels 167 | self.out_channels = out_channels 168 | self.kernel_size = kernel_size 169 | self.weight = Parameter(out_channels, in_channels, kernel_size[0], 170 | kernel_size[1]) 171 | if bias: 172 | self.bias = Parameter(out_channels) 173 | else: 174 | self.bias = None 175 | 176 | def numel(self): 177 | if self.bias is not None: 178 | return self.weight.numel() + self.bias.numel() 179 | else: 180 | return self.weight.numel() 181 | 182 | def __eq__(self, other): 183 | return (self.in_channels == other.in_channels and 184 | self.out_channels == other.out_channels and 185 | self.kernel_size == other.kernel_size) 186 | 187 | 188 | class ConvTranspose2dLayer(AbstractLayer): 189 | 190 | def __init__(self, in_channels, out_channels, kernel_size, bias=True): 191 | self.in_channels = in_channels 192 | self.out_channels = out_channels 193 | self.kernel_size = kernel_size 194 | self.weight = Parameter(out_channels, in_channels, kernel_size[0], 195 | kernel_size[1]) 196 | if bias: 197 | self.bias = Parameter(out_channels) 198 | else: 199 | self.bias = None 200 | 201 | def numel(self): 202 | if self.bias is not None: 203 | return self.weight.numel() + self.bias.numel() 204 | else: 205 | return self.weight.numel() 206 | 207 | def __eq__(self, other): 208 | return (self.in_channels == other.in_channels and 209 | self.out_channels == other.out_channels and 210 | self.kernel_size == other.kernel_size) 211 | 212 | 213 | class LinearLayer(AbstractLayer): 214 | 215 | def __init__(self, in_features, out_features, bias=True): 216 | self.in_features = in_features 217 | self.out_features = out_features 218 | self.weight = Parameter(out_features, in_features) 219 | if bias: 220 | self.bias = Parameter(out_features) 221 | else: 222 | self.bias = None 223 | 224 | def numel(self): 225 | if self.bias is not None: 226 | return self.weight.numel() + self.bias.numel() 227 | else: 228 | return self.weight.numel() 229 | 230 | def __eq__(self, other): 231 | return (self.in_features == other.in_features and 232 | self.out_features == other.out_features) 233 | 234 | 235 | class BatchNorm1dLayer(AbstractLayer): 236 | 237 | def __init__(self, num_features): 238 | self.num_features = num_features 239 | self.weight = Parameter(num_features) 240 | self.bias = Parameter(num_features) 241 | 242 | def numel(self): 243 | return self.weight.numel() + self.bias.numel() 244 | 245 | def __eq__(self, other): 246 | return self.num_features == other.num_features 247 | 248 | 249 | class BatchNorm2dLayer(AbstractLayer): 250 | 251 | def __init__(self, num_features): 252 | self.num_features = num_features 253 | self.weight = Parameter(num_features) 254 | self.bias = Parameter(num_features) 255 | 256 | def numel(self): 257 | return self.weight.numel() + self.bias.numel() 258 | 259 | def __eq__(self, other): 260 | return self.num_features == other.num_features 261 | 262 | 263 | class GroupNormLayer(AbstractLayer): 264 | 265 | def __init__(self, num_groups, num_channels): 266 | self.num_channels = num_channels 267 | self.weight = Parameter(num_channels) 268 | self.bias = Parameter(num_channels) 269 | 270 | def numel(self): 271 | return self.weight.numel() + self.bias.numel() 272 | 273 | def __eq__(self, other): 274 | return self.num_channels == other.num_channels 275 | 276 | 277 | class WeightNorm1dLayer(AbstractLayer): 278 | 279 | def __init__(self, in_features, out_features): 280 | self.in_features = in_features 281 | self.out_features = out_features 282 | self.weight = Parameter(out_features, in_features) 283 | self.bias = None 284 | 285 | def numel(self): 286 | return self.weight.numel() 287 | 288 | def __eq__(self, other): 289 | return (self.in_features == other.in_features and 290 | self.out_features == other.out_features) 291 | 292 | 293 | class WeightNorm2dLayer(AbstractLayer): 294 | 295 | def __init__(self, in_channels, out_channels, kernel_size): 296 | self.in_channels = in_channels 297 | self.out_channels = out_channels 298 | self.kernel_size = kernel_size 299 | self.weight = Parameter(out_channels, in_channels, kernel_size[0], 300 | kernel_size[1]) 301 | self.bias = None 302 | 303 | def numel(self): 304 | return self.weight.numel() 305 | 306 | def __eq__(self, other): 307 | return (self.in_channels == other.in_channels and 308 | self.out_channels == other.out_channels and 309 | self.kernel_size == other.kernel_size) 310 | 311 | 312 | class Cosine1dLayer(AbstractLayer): 313 | 314 | def __init__(self, in_features, out_features): 315 | self.in_features = in_features 316 | self.out_features = out_features 317 | self.weight = Parameter(out_features, in_features) 318 | self.bias = None 319 | 320 | def numel(self): 321 | return self.weight.numel() 322 | 323 | def __eq__(self, other): 324 | return (self.in_features == other.in_features and 325 | self.out_features == other.out_features) 326 | 327 | 328 | class Affine1dLayer(AbstractLayer): 329 | 330 | def __init__(self, num_features, bias=True): 331 | self.num_features = num_features 332 | self.weight = Parameter(num_features) 333 | if bias: 334 | self.bias = Parameter(num_features) 335 | else: 336 | self.bias = None 337 | 338 | def numel(self): 339 | if self.bias is not None: 340 | return self.weight.numel() + self.bias.numel() 341 | else: 342 | return self.weight.numel() 343 | 344 | def __eq__(self, other): 345 | return self.num_features == other.num_features 346 | 347 | 348 | class Parameter(object): 349 | 350 | def __init__(self, *size): 351 | self.size = size 352 | 353 | def numel(self): 354 | return reduce(operator.mul, self.size, 1) 355 | -------------------------------------------------------------------------------- /nngeometry/object/vector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ..layercollection import LayerCollection 3 | 4 | 5 | def random_pvector_dict(layer_collection, device=None): 6 | """ 7 | Returns a random :class:`nngeometry.object.PVector` object using 8 | the structure defined by the `layer_collection` parameter, with 9 | each components drawn from a normal distribution with mean 0 and standard 10 | deviation 1. 11 | 12 | The returned `PVector` will internally use a dict representation. 13 | 14 | :param layer_collection: The :class:`nngeometry.layercollection.LayerCollection` 15 | describing the structure of the random pvector 16 | """ 17 | v_dict = dict() 18 | for layer_id, layer in layer_collection.layers.items(): 19 | if layer.bias is not None: 20 | v_dict[layer_id] = (torch.normal(0, 1, layer.weight.size, device=device), 21 | torch.normal(0, 1, layer.bias.size, device=device)) 22 | else: 23 | v_dict[layer_id] = (torch.normal(0, 1, layer.weight.size, device=device),) 24 | return PVector(layer_collection, dict_repr=v_dict) 25 | 26 | 27 | def random_pvector(layer_collection, device=None): 28 | """ 29 | Returns a random :class:`nngeometry.object.PVector` object using 30 | the structure defined by the `layer_collection` parameter, with 31 | each components drawn from a normal distribution with mean 0 and standard 32 | deviation 1. 33 | 34 | The returned `PVector` will internally use a flat representation. 35 | 36 | :param layer_collection: The :class:`nngeometry.layercollection.LayerCollection` 37 | describing the structure of the random pvector 38 | """ 39 | n_parameters = layer_collection.numel() 40 | random_v_flat = torch.normal(0, 1, (n_parameters,), 41 | device=device) 42 | return PVector(layer_collection=layer_collection, 43 | vector_repr=random_v_flat) 44 | 45 | 46 | def random_fvector(n_samples, n_output=1, device=None): 47 | random_v_flat = torch.normal(0, 1, (n_output, n_samples,), 48 | device=device) 49 | return FVector(vector_repr=random_v_flat) 50 | 51 | 52 | class PVector: 53 | """ 54 | A vector in parameter space 55 | 56 | :param: 57 | """ 58 | def __init__(self, layer_collection, vector_repr=None, 59 | dict_repr=None): 60 | self.layer_collection = layer_collection 61 | self.vector_repr = vector_repr 62 | self.dict_repr = dict_repr 63 | 64 | @staticmethod 65 | def from_model(model): 66 | """ 67 | Creates a PVector using the current values of the given 68 | model 69 | """ 70 | dict_repr = dict() 71 | layer_collection = LayerCollection.from_model(model) 72 | l_to_m, _ = layer_collection.get_layerid_module_maps(model) 73 | for layer_id, layer in layer_collection.layers.items(): 74 | mod = l_to_m[layer_id] 75 | if layer.bias is not None: 76 | dict_repr[layer_id] = (mod.weight, mod.bias) 77 | else: 78 | dict_repr[layer_id] = (mod.weight,) 79 | return PVector(layer_collection, dict_repr=dict_repr) 80 | 81 | def copy_to_model(self, model): 82 | """ 83 | Updates `model` parameter values with the current PVector 84 | 85 | Note. This is an inplace operation 86 | """ 87 | dict_repr = self.get_dict_representation() 88 | layer_collection = LayerCollection.from_model(model) 89 | l_to_m, _ = layer_collection.get_layerid_module_maps(model) 90 | for layer_id, layer in layer_collection.layers.items(): 91 | mod = l_to_m[layer_id] 92 | if layer.bias is not None: 93 | mod.bias.data.copy_(dict_repr[layer_id][1]) 94 | mod.weight.data.copy_(dict_repr[layer_id][0]) 95 | 96 | def add_to_model(self, model): 97 | """ 98 | Updates `model` parameter values by adding the current PVector 99 | 100 | Note. This is an inplace operation 101 | """ 102 | dict_repr = self.get_dict_representation() 103 | layer_collection = LayerCollection.from_model(model) 104 | l_to_m, _ = layer_collection.get_layerid_module_maps(model) 105 | for layer_id, layer in layer_collection.layers.items(): 106 | mod = l_to_m[layer_id] 107 | if layer.bias is not None: 108 | mod.bias.data.add_(dict_repr[layer_id][1]) 109 | mod.weight.data.add_(dict_repr[layer_id][0]) 110 | 111 | @staticmethod 112 | def from_model_grad(model): 113 | """ 114 | Creates a PVector using the current values of the `.grad` 115 | fields of parameters of the given model 116 | """ 117 | dict_repr = dict() 118 | layer_collection = LayerCollection.from_model(model) 119 | l_to_m, _ = layer_collection.get_layerid_module_maps(model) 120 | for layer_id, layer in layer_collection.layers.items(): 121 | mod = l_to_m[layer_id] 122 | if layer.bias is not None: 123 | dict_repr[layer_id] = (mod.weight.grad, mod.bias.grad) 124 | else: 125 | dict_repr[layer_id] = (mod.weight.grad,) 126 | return PVector(layer_collection, dict_repr=dict_repr) 127 | 128 | def clone(self): 129 | """ 130 | Returns a clone of the current object 131 | """ 132 | if self.dict_repr is not None: 133 | dict_clone = dict() 134 | for k, v in self.dict_repr.items(): 135 | if len(v) == 2: 136 | dict_clone[k] = (v[0].clone(), v[1].clone()) 137 | else: 138 | dict_clone[k] = (v[0].clone(),) 139 | return PVector(self.layer_collection, dict_repr=dict_clone) 140 | if self.vector_repr is not None: 141 | return PVector(self.layer_collection, 142 | vector_repr=self.vector_repr.clone()) 143 | 144 | def detach(self): 145 | """ 146 | Detachs the current PVector from the computation graph 147 | """ 148 | if self.dict_repr is not None: 149 | dict_detach = dict() 150 | for k, v in self.dict_repr.items(): 151 | if len(v) == 2: 152 | dict_detach[k] = (v[0].detach(), v[1].detach()) 153 | else: 154 | dict_detach[k] = (v[0].detach(),) 155 | return PVector(self.layer_collection, dict_repr=dict_detach) 156 | if self.vector_repr is not None: 157 | return PVector(self.layer_collection, 158 | vector_repr=self.vector_repr.detach()) 159 | 160 | def get_flat_representation(self): 161 | """ 162 | Returns a Pytorch 1d tensor of the flatten vector. 163 | 164 | .. warning:: 165 | The ordering in which the parameters are 166 | flattened can seem to be arbitrary. It is in fact 167 | the same ordering as specified by the ``layercollection.LayerCollection`` 168 | object. 169 | 170 | :return: a Pytorch Tensor 171 | """ 172 | if self.vector_repr is not None: 173 | return self.vector_repr 174 | elif self.dict_repr is not None: 175 | return self._dict_to_flat() 176 | else: 177 | return NotImplementedError 178 | 179 | def get_dict_representation(self): 180 | if self.dict_repr is not None: 181 | return self.dict_repr 182 | elif self.vector_repr is not None: 183 | return self._flat_to_dict() 184 | else: 185 | return NotImplementedError 186 | 187 | def _dict_to_flat(self): 188 | parts = [] 189 | for layer_id, layer in self.layer_collection.layers.items(): 190 | parts.append(self.dict_repr[layer_id][0].view(-1)) 191 | if len(self.dict_repr[layer_id]) > 1: 192 | parts.append(self.dict_repr[layer_id][1].view(-1)) 193 | return torch.cat(parts) 194 | 195 | def _flat_to_dict(self): 196 | dict_repr = dict() 197 | for layer_id, layer in self.layer_collection.layers.items(): 198 | start = self.layer_collection.p_pos[layer_id] 199 | w = self.vector_repr[start:start+layer.weight.numel()] \ 200 | .view(*layer.weight.size) 201 | start += layer.weight.numel() 202 | if layer.bias is not None: 203 | b = self.vector_repr[start:start+layer.bias.numel()] \ 204 | .view(*layer.bias.size) 205 | start += layer.bias.numel() 206 | dict_repr[layer_id] = (w, b) 207 | else: 208 | dict_repr[layer_id] = (w,) 209 | return dict_repr 210 | 211 | def norm(self, p=2): 212 | """ 213 | Computes the Lp norm of the PVector 214 | """ 215 | if self.dict_repr is not None: 216 | sum_p = 0 217 | for l_id, l in self.layer_collection.layers.items(): 218 | sum_p += (self.dict_repr[l_id][0]**p).sum() 219 | if l.bias is not None: 220 | sum_p += (self.dict_repr[l_id][1]**p).sum() 221 | return sum_p ** (1/p) 222 | else: 223 | return torch.norm(self.vector_repr, p=p) 224 | 225 | def __rmul__(self, x): 226 | # TODO: test 227 | # scalar multiplication 228 | if self.dict_repr is not None: 229 | v_dict = dict() 230 | for l_id, l in self.layer_collection.layers.items(): 231 | if l.bias: 232 | v_dict[l_id] = (x * self.dict_repr[l_id][0], 233 | x * self.dict_repr[l_id][1]) 234 | else: 235 | v_dict[l_id] = (x * self.dict_repr[l_id][0],) 236 | return PVector(self.layer_collection, dict_repr=v_dict) 237 | else: 238 | return PVector(self.layer_collection, 239 | vector_repr=x * self.vector_repr) 240 | 241 | def __add__(self, other): 242 | if self.dict_repr is not None and other.dict_repr is not None: 243 | v_dict = dict() 244 | for l_id, l in self.layer_collection.layers.items(): 245 | if l.bias is not None: 246 | v_dict[l_id] = (self.dict_repr[l_id][0] + 247 | other.dict_repr[l_id][0], 248 | self.dict_repr[l_id][1] + 249 | other.dict_repr[l_id][1]) 250 | else: 251 | v_dict[l_id] = (self.dict_repr[l_id][0] + 252 | other.dict_repr[l_id][0],) 253 | return PVector(self.layer_collection, dict_repr=v_dict) 254 | elif self.vector_repr is not None and other.vector_repr is not None: 255 | return PVector(self.layer_collection, 256 | vector_repr=self.vector_repr+other.vector_repr) 257 | else: 258 | return PVector(self.layer_collection, 259 | vector_repr=(self.get_flat_representation() + 260 | other.get_flat_representation())) 261 | 262 | def __sub__(self, other): 263 | if self.dict_repr is not None and other.dict_repr is not None: 264 | v_dict = dict() 265 | for l_id, l in self.layer_collection.layers.items(): 266 | if l.bias is not None: 267 | v_dict[l_id] = (self.dict_repr[l_id][0] - 268 | other.dict_repr[l_id][0], 269 | self.dict_repr[l_id][1] - 270 | other.dict_repr[l_id][1]) 271 | else: 272 | v_dict[l_id] = (self.dict_repr[l_id][0] - 273 | other.dict_repr[l_id][0],) 274 | return PVector(self.layer_collection, dict_repr=v_dict) 275 | elif self.vector_repr is not None and other.vector_repr is not None: 276 | return PVector(self.layer_collection, 277 | vector_repr=self.vector_repr-other.vector_repr) 278 | else: 279 | return PVector(self.layer_collection, 280 | vector_repr=(self.get_flat_representation() - 281 | other.get_flat_representation())) 282 | 283 | def dot(self, other): 284 | """ 285 | Computes the dot product between `self` and `other` 286 | 287 | :param other: The other `PVector` 288 | """ 289 | if self.vector_repr is not None or other.vector_repr is not None: 290 | return torch.dot(self.get_flat_representation(), 291 | other.get_flat_representation()) 292 | else: 293 | dot_ = 0 294 | for l_id, l in self.layer_collection.layers.items(): 295 | if l.bias is not None: 296 | dot_ += torch.dot(self.dict_repr[l_id][1], 297 | other.dict_repr[l_id][1]) 298 | dot_ += torch.dot(self.dict_repr[l_id][0].view(-1), 299 | other.dict_repr[l_id][0].view(-1)) 300 | return dot_ 301 | 302 | def size(self): 303 | """ 304 | The size of the PVector, or equivalently the number of 305 | parameters of the layer collection 306 | """ 307 | return (self.layer_collection.numel(), ) 308 | 309 | 310 | class FVector: 311 | """ 312 | A vector in function space 313 | """ 314 | def __init__(self, vector_repr=None): 315 | self.vector_repr = vector_repr 316 | 317 | def get_flat_representation(self): 318 | if self.vector_repr is not None: 319 | return self.vector_repr 320 | else: 321 | return NotImplementedError 322 | -------------------------------------------------------------------------------- /nngeometry/generator/jacobian/grads.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from nngeometry.layercollection import (Affine1dLayer, Cosine1dLayer, LinearLayer, Conv2dLayer, BatchNorm1dLayer, 3 | BatchNorm2dLayer, GroupNormLayer, WeightNorm1dLayer, 4 | WeightNorm2dLayer, ConvTranspose2dLayer) 5 | from .grads_conv import conv2d_backward, convtranspose2d_backward, unfold_transpose_conv2d 6 | 7 | import torch.nn.functional as F 8 | 9 | 10 | class JacobianFactory: 11 | @classmethod 12 | def diag(cls, buffer, mod, layer, x, gy): 13 | bs = x.size(0) 14 | buffer_flat = torch.zeros(bs, layer.numel(), device=buffer.device) 15 | cls.flat_grad(buffer_flat, mod, layer, x, gy) 16 | buffer.add_((buffer_flat**2).sum(dim=0)) 17 | 18 | @classmethod 19 | def trace(cls, buffer, mod, layer, x, gy): 20 | buffer_diag = torch.zeros(layer.numel(), device=buffer.device) 21 | cls.diag(buffer_diag, mod, layer, x, gy) 22 | buffer.add_(buffer_diag.sum()) 23 | 24 | @classmethod 25 | def layer_block(cls, buffer, mod, layer, x, gy): 26 | bs = x.size(0) 27 | buffer_flat = torch.zeros(bs, layer.numel(), device=buffer.device) 28 | cls.flat_grad(buffer_flat, mod, layer, x, gy) 29 | buffer.add_(torch.mm(buffer_flat.t(), buffer_flat)) 30 | 31 | @classmethod 32 | def kxy(cls, buffer, mod, layer, x_i, gy_i, x_o, gy_o): 33 | bs_i = x_i.size(0) 34 | bs_o = x_o.size(0) 35 | buffer_flat_i = torch.zeros(bs_i, layer.numel(), device=buffer.device) 36 | buffer_flat_o = torch.zeros(bs_o, layer.numel(), device=buffer.device) 37 | cls.flat_grad(buffer_flat_i, mod, layer, x_i, gy_i) 38 | cls.flat_grad(buffer_flat_o, mod, layer, x_o, gy_o) 39 | buffer.add_(torch.mm(buffer_flat_i, buffer_flat_o.t())) 40 | 41 | @classmethod 42 | def Jv(cls, buffer, mod, layer, x, gy, v, v_bias): 43 | bs = x.size(0) 44 | buffer_flat = torch.zeros(bs, layer.numel(), device=buffer.device) 45 | cls.flat_grad(buffer_flat, mod, layer, x, gy) 46 | v = v.view(-1) 47 | if v_bias is not None: 48 | v = torch.cat((v, v_bias)) 49 | buffer.add_(torch.mv(buffer_flat, v)) 50 | 51 | 52 | class LinearJacobianFactory(JacobianFactory): 53 | @classmethod 54 | def flat_grad(cls, buffer, mod, layer, x, gy): 55 | bs = x.size(0) 56 | w_numel = layer.weight.numel() 57 | buffer[:, :w_numel] \ 58 | .add_(torch.bmm(gy.unsqueeze(2), x.unsqueeze(1)).view(bs, -1)) 59 | if layer.bias is not None: 60 | buffer[:, w_numel:].add_(gy) 61 | 62 | @classmethod 63 | def diag(cls, buffer, mod, layer, x, gy): 64 | w_numel = layer.weight.numel() 65 | buffer[:w_numel].add_(torch.mm(gy.t()**2, x**2).view(-1)) 66 | if layer.bias is not None: 67 | buffer[w_numel:].add_((gy**2).sum(dim=0)) 68 | 69 | @classmethod 70 | def kxy(cls, buffer, mod, layer, x_i, gy_i, x_o, gy_o): 71 | buffer.add_(torch.mm(x_i, x_o.t()) * 72 | torch.mm(gy_i, gy_o.t())) 73 | if layer.bias is not None: 74 | buffer.add_(torch.mm(gy_i, gy_o.t())) 75 | 76 | @classmethod 77 | def Jv(cls, buffer, mod, layer, x, gy, v, v_bias): 78 | buffer.add_((torch.mm(x, v.t()) * gy).sum(dim=1)) 79 | if layer.bias is not None: 80 | buffer.add_(torch.mv(gy.contiguous(), v_bias)) 81 | 82 | @classmethod 83 | def trace(cls, buffer, mod, layer, x, gy): 84 | buffer.add_(torch.mm(gy.t()**2, x**2).sum()) 85 | if layer.bias is not None: 86 | buffer.add_((gy**2).sum()) 87 | 88 | @classmethod 89 | def kfac_xx(cls, buffer, mod, layer, x, gy): 90 | if layer.bias is not None: 91 | x = torch.cat([x, torch.ones_like(x[:, :1])], dim=1) 92 | buffer.add_(torch.mm(x.t(), x)) 93 | 94 | @classmethod 95 | def kfac_gg(cls, buffer, mod, layer, x, gy): 96 | # print(f'mod: {mod}, layer: {layer}') 97 | # print(f'gy: {gy.shape}') 98 | # print(f'x: {x.shape}') 99 | buffer.add_(torch.mm(gy.t(), gy)) 100 | 101 | @classmethod 102 | def kfe_diag(cls, buffer, mod, layer, x, gy, evecs_a, evecs_g): 103 | if layer.bias is not None: 104 | x = torch.cat([x, torch.ones_like(x[:, :1])], dim=1) 105 | gy_kfe = torch.mm(gy, evecs_g) 106 | x_kfe = torch.mm(x, evecs_a) 107 | buffer.add_(torch.mm(gy_kfe.t()**2, x_kfe**2).view(-1)) 108 | 109 | @classmethod 110 | def quasidiag(cls, buffer_diag, buffer_cross, mod, layer, x, gy): 111 | w_numel = layer.weight.numel() 112 | buffer_diag[:w_numel].add_(torch.mm(gy.t()**2, x**2).view(-1)) 113 | if layer.bias is not None: 114 | buffer_diag[w_numel:].add_((gy**2).sum(dim=0)) 115 | buffer_cross.add_(torch.mm(gy.t()**2, x)) 116 | 117 | 118 | class Conv2dJacobianFactory(JacobianFactory): 119 | @classmethod 120 | def flat_grad(cls, buffer, mod, layer, x, gy): 121 | bs = x.size(0) 122 | w_numel = layer.weight.numel() 123 | indiv_gw = conv2d_backward(mod, x, gy) 124 | buffer[:, :w_numel].add_(indiv_gw.view(bs, -1)) 125 | if layer.bias is not None: 126 | buffer[:, w_numel:].add_(gy.sum(dim=(2, 3))) 127 | 128 | @classmethod 129 | def Jv(cls, buffer, mod, layer, x, gy, v, v_bias): 130 | bs = x.size(0) 131 | gy2 = F.conv2d(x, v, stride=mod.stride, 132 | padding=mod.padding, dilation=mod.dilation) 133 | buffer.add_((gy * gy2).view(bs, -1).sum(dim=1)) 134 | if layer.bias is not None: 135 | buffer.add_(torch.mv(gy.sum(dim=(2, 3)), v_bias)) 136 | 137 | @classmethod 138 | def kfac_xx(cls, buffer, mod, layer, x, gy): 139 | ks = (mod.weight.size(2), mod.weight.size(3)) 140 | # A_tilda in KFC 141 | A_tilda = F.unfold(x, kernel_size=ks, stride=mod.stride, 142 | padding=mod.padding, dilation=mod.dilation) 143 | # A_tilda is bs * #locations x #parameters 144 | A_tilda = A_tilda.permute(0, 2, 1).contiguous() \ 145 | .view(-1, A_tilda.size(1)) 146 | if layer.bias is not None: 147 | A_tilda = torch.cat([A_tilda, 148 | torch.ones_like(A_tilda[:, :1])], 149 | dim=1) 150 | # Omega_hat in KFC 151 | buffer.add_(torch.mm(A_tilda.t(), A_tilda)) 152 | 153 | @classmethod 154 | def kfac_gg(cls, buffer, mod, layer, x, gy): 155 | spatial_locations = gy.size(2) * gy.size(3) 156 | os = gy.size(1) 157 | # DS_tilda in KFC 158 | DS_tilda = gy.permute(0, 2, 3, 1).contiguous().view(-1, os) 159 | buffer.add_(torch.mm(DS_tilda.t(), DS_tilda) / spatial_locations) 160 | 161 | @classmethod 162 | def kfe_diag(cls, buffer, mod, layer, x, gy, evecs_a, evecs_g): 163 | ks = (mod.weight.size(2), mod.weight.size(3)) 164 | gy_s = gy.size() 165 | bs = gy_s[0] 166 | # project x to kfe 167 | x_unfold = F.unfold(x, kernel_size=ks, stride=mod.stride, 168 | padding=mod.padding, dilation=mod.dilation) 169 | x_unfold_s = x_unfold.size() 170 | x_unfold = x_unfold.view(bs, x_unfold_s[1], -1).permute(0, 2, 1)\ 171 | .contiguous().view(-1, x_unfold_s[1]) 172 | if mod.bias is not None: 173 | x_unfold = torch.cat([x_unfold, 174 | torch.ones_like(x_unfold[:, :1])], dim=1) 175 | x_kfe = torch.mm(x_unfold, evecs_a) 176 | 177 | # project gy to kfe 178 | gy = gy.view(bs, gy_s[1], -1).permute(0, 2, 1).contiguous() 179 | gy_kfe = torch.mm(gy.view(-1, gy_s[1]), evecs_g) 180 | gy_kfe = gy_kfe.view(bs, -1, gy_s[1]).permute(0, 2, 1).contiguous() 181 | 182 | indiv_gw = torch.bmm(gy_kfe.view(bs, gy_s[1], -1), 183 | x_kfe.view(bs, -1, x_kfe.size(1))) 184 | buffer.add_((indiv_gw**2).sum(dim=0).view(-1)) 185 | 186 | @classmethod 187 | def quasidiag(cls, buffer_diag, buffer_cross, mod, layer, x, gy): 188 | w_numel = layer.weight.numel() 189 | indiv_gw = conv2d_backward(mod, x, gy) 190 | buffer_diag[:w_numel].add_((indiv_gw**2).sum(dim=0).view(-1)) 191 | if layer.bias is not None: 192 | gb_per_example = gy.sum(dim=(2, 3)) 193 | buffer_diag[w_numel:].add_((gb_per_example**2).sum(dim=0)) 194 | y = (gy * gb_per_example.unsqueeze(2).unsqueeze(3)) 195 | cross_this = F.conv2d(x.transpose(0, 1), 196 | y.transpose(0, 1), 197 | stride=mod.dilation, 198 | padding=mod.padding, 199 | dilation=mod.stride).transpose(0, 1) 200 | cross_this = cross_this[:, :, :mod.kernel_size[0], :mod.kernel_size[1]] 201 | buffer_cross.add_(cross_this) 202 | 203 | 204 | class ConvTranspose2dJacobianFactory(JacobianFactory): 205 | @classmethod 206 | def flat_grad(cls, buffer, mod, layer, x, gy): 207 | bs = x.size(0) 208 | w_numel = layer.weight.numel() 209 | indiv_gw = convtranspose2d_backward(mod, x, gy) 210 | buffer[:, :w_numel].add_(indiv_gw.view(bs, -1)) 211 | if layer.bias is not None: 212 | buffer[:, w_numel:].add_(gy.sum(dim=(2, 3))) 213 | 214 | 215 | def check_bn_training(mod): 216 | # check that BN layers are in eval mode 217 | if mod.training: 218 | raise NotImplementedError('NNGeometry\'s Jacobian generator can' + 219 | ' only handle BatchNorm in evaluation mode') 220 | 221 | 222 | class BatchNorm1dJacobianFactory(JacobianFactory): 223 | @classmethod 224 | def flat_grad(cls, buffer, mod, layer, x, gy): 225 | check_bn_training(mod) 226 | w_numel = layer.weight.numel() 227 | x_normalized = F.batch_norm(x, mod.running_mean, 228 | mod.running_var, 229 | None, None, mod.training, 230 | momentum=0.) 231 | buffer[:, :w_numel].add_(gy * x_normalized) 232 | if layer.bias is not None: 233 | buffer[:, w_numel:].add_(gy) 234 | 235 | 236 | class BatchNorm2dJacobianFactory(JacobianFactory): 237 | @classmethod 238 | def flat_grad(cls, buffer, mod, layer, x, gy): 239 | check_bn_training(mod) 240 | w_numel = layer.weight.numel() 241 | x_normalized = F.batch_norm(x, mod.running_mean, 242 | mod.running_var, 243 | None, None, mod.training, 244 | momentum=0.) 245 | buffer[:, :w_numel].add_((gy * x_normalized).sum(dim=(2, 3))) 246 | if layer.bias is not None: 247 | buffer[:, w_numel:].add_(gy.sum(dim=(2, 3))) 248 | 249 | 250 | class GroupNormJacobianFactory(JacobianFactory): 251 | @classmethod 252 | def flat_grad(cls, buffer, mod, layer, x, gy): 253 | w_numel = layer.weight.numel() 254 | x_normalized = F.group_norm(x, mod.num_groups, 255 | eps=mod.eps) 256 | buffer[:, :w_numel].add_((gy * x_normalized).sum(dim=(2, 3))) 257 | buffer[:, w_numel:].add_(gy.sum(dim=(2, 3))) 258 | 259 | 260 | class WeightNorm1dJacobianFactory(JacobianFactory): 261 | @classmethod 262 | def flat_grad(cls, buffer, mod, layer, x, gy): 263 | bs = x.size(0) 264 | norm2 = (mod.weight**2).sum(dim=1, keepdim=True) + mod.eps 265 | gw = torch.bmm(gy.unsqueeze(2) / torch.sqrt(norm2), 266 | x.unsqueeze(1)) 267 | wn2_out = F.linear(x, mod.weight / norm2**1.5) 268 | gw -= (gy * wn2_out).unsqueeze(2) * mod.weight.unsqueeze(0) 269 | buffer.add_(gw.view(bs, -1)) 270 | 271 | 272 | class WeightNorm2dJacobianFactory(JacobianFactory): 273 | @classmethod 274 | def flat_grad(cls, buffer, mod, layer, x, gy): 275 | bs = x.size(0) 276 | out_dim = mod.weight.size(0) 277 | norm2 = (mod.weight**2).sum(dim=(1, 2, 3)) + mod.eps 278 | gw = conv2d_backward(mod, x, gy / torch.sqrt(norm2).view(1, out_dim, 1, 1)) 279 | gw = gw.view(bs, out_dim, -1) 280 | wn2_out = F.conv2d(x, mod.weight / norm2.view(out_dim, 1, 1, 1)**1.5, None, 281 | stride=mod.stride, padding=mod.padding, dilation=mod.dilation) 282 | t = (gy * wn2_out).sum(dim=(2, 3)).view(bs, out_dim, 1) * mod.weight.view(1, out_dim, -1) 283 | gw -= (gy * wn2_out).sum(dim=(2, 3)).view(bs, out_dim, 1) * mod.weight.view(1, out_dim, -1) 284 | buffer.add_(gw.view(bs, -1)) 285 | 286 | 287 | class Cosine1dJacobianFactory(JacobianFactory): 288 | @classmethod 289 | def flat_grad(cls, buffer, mod, layer, x, gy): 290 | bs = x.size(0) 291 | norm2_w = (mod.weight**2).sum(dim=1, keepdim=True) + mod.eps 292 | norm2_x = (x**2).sum(dim=1, keepdim=True) + mod.eps 293 | x = x / torch.sqrt(norm2_x) 294 | gw = torch.bmm(gy.unsqueeze(2) / torch.sqrt(norm2_w), 295 | x.unsqueeze(1)) 296 | wn2_out = F.linear(x, mod.weight / norm2_w**1.5) 297 | gw -= (gy * wn2_out).unsqueeze(2) * mod.weight.unsqueeze(0) 298 | buffer.add_(gw.view(bs, -1)) 299 | 300 | 301 | class Affine1dJacobianFactory(JacobianFactory): 302 | @classmethod 303 | def flat_grad(cls, buffer, mod, layer, x, gy): 304 | w_numel = layer.weight.numel() 305 | buffer[:, :w_numel].add_(gy * x) 306 | if layer.bias is not None: 307 | buffer[:, w_numel:].add_(gy) 308 | 309 | 310 | FactoryMap = { 311 | LinearLayer: LinearJacobianFactory, 312 | Conv2dLayer: Conv2dJacobianFactory, 313 | ConvTranspose2dLayer: ConvTranspose2dJacobianFactory, 314 | BatchNorm1dLayer: BatchNorm1dJacobianFactory, 315 | BatchNorm2dLayer: BatchNorm2dJacobianFactory, 316 | GroupNormLayer: GroupNormJacobianFactory, 317 | WeightNorm1dLayer: WeightNorm1dJacobianFactory, 318 | WeightNorm2dLayer: WeightNorm2dJacobianFactory, 319 | Cosine1dLayer: Cosine1dJacobianFactory, 320 | Affine1dLayer: Affine1dJacobianFactory, 321 | } -------------------------------------------------------------------------------- /nngeometry/generator/lm_jacobian/grads.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from nngeometry.layercollection import (Affine1dLayer, Cosine1dLayer, LinearLayer, Conv2dLayer, BatchNorm1dLayer, 3 | BatchNorm2dLayer, GroupNormLayer, WeightNorm1dLayer, 4 | WeightNorm2dLayer, ConvTranspose2dLayer) 5 | from .grads_conv import conv2d_backward, convtranspose2d_backward, unfold_transpose_conv2d 6 | 7 | import torch.nn.functional as F 8 | 9 | 10 | class JacobianFactory: 11 | @classmethod 12 | def diag(cls, buffer, mod, layer, x, gy): 13 | bs = x.size(0) 14 | buffer_flat = torch.zeros(bs, layer.numel(), device=buffer.device) 15 | cls.flat_grad(buffer_flat, mod, layer, x, gy) 16 | buffer.add_((buffer_flat**2).sum(dim=0)) 17 | 18 | @classmethod 19 | def trace(cls, buffer, mod, layer, x, gy): 20 | buffer_diag = torch.zeros(layer.numel(), device=buffer.device) 21 | cls.diag(buffer_diag, mod, layer, x, gy) 22 | buffer.add_(buffer_diag.sum()) 23 | 24 | @classmethod 25 | def layer_block(cls, buffer, mod, layer, x, gy): 26 | bs = x.size(0) 27 | buffer_flat = torch.zeros(bs, layer.numel(), device=buffer.device) 28 | cls.flat_grad(buffer_flat, mod, layer, x, gy) 29 | buffer.add_(torch.mm(buffer_flat.t(), buffer_flat)) 30 | 31 | @classmethod 32 | def kxy(cls, buffer, mod, layer, x_i, gy_i, x_o, gy_o): 33 | bs_i = x_i.size(0) 34 | bs_o = x_o.size(0) 35 | buffer_flat_i = torch.zeros(bs_i, layer.numel(), device=buffer.device) 36 | buffer_flat_o = torch.zeros(bs_o, layer.numel(), device=buffer.device) 37 | cls.flat_grad(buffer_flat_i, mod, layer, x_i, gy_i) 38 | cls.flat_grad(buffer_flat_o, mod, layer, x_o, gy_o) 39 | buffer.add_(torch.mm(buffer_flat_i, buffer_flat_o.t())) 40 | 41 | @classmethod 42 | def Jv(cls, buffer, mod, layer, x, gy, v, v_bias): 43 | bs = x.size(0) 44 | buffer_flat = torch.zeros(bs, layer.numel(), device=buffer.device) 45 | cls.flat_grad(buffer_flat, mod, layer, x, gy) 46 | v = v.view(-1) 47 | if v_bias is not None: 48 | v = torch.cat((v, v_bias)) 49 | buffer.add_(torch.mv(buffer_flat, v)) 50 | 51 | 52 | class LinearJacobianFactory(JacobianFactory): 53 | @classmethod 54 | def flat_grad(cls, buffer, mod, layer, x, gy): 55 | bs = x.size(0) 56 | w_numel = layer.weight.numel() 57 | buffer[:, :w_numel] \ 58 | .add_(torch.bmm(gy.unsqueeze(2), x.unsqueeze(1)).view(bs, -1)) 59 | if layer.bias is not None: 60 | buffer[:, w_numel:].add_(gy) 61 | 62 | @classmethod 63 | def diag(cls, buffer, mod, layer, x, gy): 64 | w_numel = layer.weight.numel() 65 | buffer[:w_numel].add_(torch.mm(gy.t()**2, x**2).view(-1)) 66 | if layer.bias is not None: 67 | buffer[w_numel:].add_((gy**2).sum(dim=0)) 68 | 69 | @classmethod 70 | def kxy(cls, buffer, mod, layer, x_i, gy_i, x_o, gy_o): 71 | buffer.add_(torch.mm(x_i, x_o.t()) * 72 | torch.mm(gy_i, gy_o.t())) 73 | if layer.bias is not None: 74 | buffer.add_(torch.mm(gy_i, gy_o.t())) 75 | 76 | @classmethod 77 | def Jv(cls, buffer, mod, layer, x, gy, v, v_bias): 78 | buffer.add_((torch.mm(x, v.t()) * gy).sum(dim=1)) 79 | if layer.bias is not None: 80 | buffer.add_(torch.mv(gy.contiguous(), v_bias)) 81 | 82 | @classmethod 83 | def trace(cls, buffer, mod, layer, x, gy): 84 | buffer.add_(torch.mm(gy.t()**2, x**2).sum()) 85 | if layer.bias is not None: 86 | buffer.add_((gy**2).sum()) 87 | 88 | @classmethod 89 | def kfac_xx(cls, buffer, mod, layer, x, gy): 90 | if layer.bias is not None: 91 | x = torch.cat([x, torch.ones_like(x[:, :1])], dim=1) 92 | buffer.add_(torch.mm(x.t(), x)) 93 | 94 | @classmethod 95 | def kfac_gg(cls, buffer, mod, layer, x, gy): 96 | # print(f'mod: {mod}, layer: {layer}') 97 | # print(f'gy: {gy.shape}') 98 | # print(f'x: {x.shape}') 99 | buffer.add_(torch.mm(gy.t(), gy)) 100 | 101 | @classmethod 102 | def kfe_diag(cls, buffer, mod, layer, x, gy, evecs_a, evecs_g): 103 | if layer.bias is not None: 104 | x = torch.cat([x, torch.ones_like(x[:, :1])], dim=1) 105 | gy_kfe = torch.mm(gy, evecs_g) 106 | x_kfe = torch.mm(x, evecs_a) 107 | buffer.add_(torch.mm(gy_kfe.t()**2, x_kfe**2).view(-1)) 108 | 109 | @classmethod 110 | def quasidiag(cls, buffer_diag, buffer_cross, mod, layer, x, gy): 111 | w_numel = layer.weight.numel() 112 | buffer_diag[:w_numel].add_(torch.mm(gy.t()**2, x**2).view(-1)) 113 | if layer.bias is not None: 114 | buffer_diag[w_numel:].add_((gy**2).sum(dim=0)) 115 | buffer_cross.add_(torch.mm(gy.t()**2, x)) 116 | 117 | 118 | class Conv2dJacobianFactory(JacobianFactory): 119 | @classmethod 120 | def flat_grad(cls, buffer, mod, layer, x, gy): 121 | bs = x.size(0) 122 | w_numel = layer.weight.numel() 123 | indiv_gw = conv2d_backward(mod, x, gy) 124 | buffer[:, :w_numel].add_(indiv_gw.view(bs, -1)) 125 | if layer.bias is not None: 126 | buffer[:, w_numel:].add_(gy.sum(dim=(2, 3))) 127 | 128 | @classmethod 129 | def Jv(cls, buffer, mod, layer, x, gy, v, v_bias): 130 | bs = x.size(0) 131 | gy2 = F.conv2d(x, v, stride=mod.stride, 132 | padding=mod.padding, dilation=mod.dilation) 133 | buffer.add_((gy * gy2).view(bs, -1).sum(dim=1)) 134 | if layer.bias is not None: 135 | buffer.add_(torch.mv(gy.sum(dim=(2, 3)), v_bias)) 136 | 137 | @classmethod 138 | def kfac_xx(cls, buffer, mod, layer, x, gy): 139 | ks = (mod.weight.size(2), mod.weight.size(3)) 140 | # A_tilda in KFC 141 | A_tilda = F.unfold(x, kernel_size=ks, stride=mod.stride, 142 | padding=mod.padding, dilation=mod.dilation) 143 | # A_tilda is bs * #locations x #parameters 144 | A_tilda = A_tilda.permute(0, 2, 1).contiguous() \ 145 | .view(-1, A_tilda.size(1)) 146 | if layer.bias is not None: 147 | A_tilda = torch.cat([A_tilda, 148 | torch.ones_like(A_tilda[:, :1])], 149 | dim=1) 150 | # Omega_hat in KFC 151 | buffer.add_(torch.mm(A_tilda.t(), A_tilda)) 152 | 153 | @classmethod 154 | def kfac_gg(cls, buffer, mod, layer, x, gy): 155 | spatial_locations = gy.size(2) * gy.size(3) 156 | os = gy.size(1) 157 | # DS_tilda in KFC 158 | DS_tilda = gy.permute(0, 2, 3, 1).contiguous().view(-1, os) 159 | buffer.add_(torch.mm(DS_tilda.t(), DS_tilda) / spatial_locations) 160 | 161 | @classmethod 162 | def kfe_diag(cls, buffer, mod, layer, x, gy, evecs_a, evecs_g): 163 | ks = (mod.weight.size(2), mod.weight.size(3)) 164 | gy_s = gy.size() 165 | bs = gy_s[0] 166 | # project x to kfe 167 | x_unfold = F.unfold(x, kernel_size=ks, stride=mod.stride, 168 | padding=mod.padding, dilation=mod.dilation) 169 | x_unfold_s = x_unfold.size() 170 | x_unfold = x_unfold.view(bs, x_unfold_s[1], -1).permute(0, 2, 1)\ 171 | .contiguous().view(-1, x_unfold_s[1]) 172 | if mod.bias is not None: 173 | x_unfold = torch.cat([x_unfold, 174 | torch.ones_like(x_unfold[:, :1])], dim=1) 175 | x_kfe = torch.mm(x_unfold, evecs_a) 176 | 177 | # project gy to kfe 178 | gy = gy.view(bs, gy_s[1], -1).permute(0, 2, 1).contiguous() 179 | gy_kfe = torch.mm(gy.view(-1, gy_s[1]), evecs_g) 180 | gy_kfe = gy_kfe.view(bs, -1, gy_s[1]).permute(0, 2, 1).contiguous() 181 | 182 | indiv_gw = torch.bmm(gy_kfe.view(bs, gy_s[1], -1), 183 | x_kfe.view(bs, -1, x_kfe.size(1))) 184 | buffer.add_((indiv_gw**2).sum(dim=0).view(-1)) 185 | 186 | @classmethod 187 | def quasidiag(cls, buffer_diag, buffer_cross, mod, layer, x, gy): 188 | w_numel = layer.weight.numel() 189 | indiv_gw = conv2d_backward(mod, x, gy) 190 | buffer_diag[:w_numel].add_((indiv_gw**2).sum(dim=0).view(-1)) 191 | if layer.bias is not None: 192 | gb_per_example = gy.sum(dim=(2, 3)) 193 | buffer_diag[w_numel:].add_((gb_per_example**2).sum(dim=0)) 194 | y = (gy * gb_per_example.unsqueeze(2).unsqueeze(3)) 195 | cross_this = F.conv2d(x.transpose(0, 1), 196 | y.transpose(0, 1), 197 | stride=mod.dilation, 198 | padding=mod.padding, 199 | dilation=mod.stride).transpose(0, 1) 200 | cross_this = cross_this[:, :, :mod.kernel_size[0], :mod.kernel_size[1]] 201 | buffer_cross.add_(cross_this) 202 | 203 | 204 | class ConvTranspose2dJacobianFactory(JacobianFactory): 205 | @classmethod 206 | def flat_grad(cls, buffer, mod, layer, x, gy): 207 | bs = x.size(0) 208 | w_numel = layer.weight.numel() 209 | indiv_gw = convtranspose2d_backward(mod, x, gy) 210 | buffer[:, :w_numel].add_(indiv_gw.view(bs, -1)) 211 | if layer.bias is not None: 212 | buffer[:, w_numel:].add_(gy.sum(dim=(2, 3))) 213 | 214 | 215 | def check_bn_training(mod): 216 | # check that BN layers are in eval mode 217 | if mod.training: 218 | raise NotImplementedError('NNGeometry\'s Jacobian generator can' + 219 | ' only handle BatchNorm in evaluation mode') 220 | 221 | 222 | class BatchNorm1dJacobianFactory(JacobianFactory): 223 | @classmethod 224 | def flat_grad(cls, buffer, mod, layer, x, gy): 225 | check_bn_training(mod) 226 | w_numel = layer.weight.numel() 227 | x_normalized = F.batch_norm(x, mod.running_mean, 228 | mod.running_var, 229 | None, None, mod.training, 230 | momentum=0.) 231 | buffer[:, :w_numel].add_(gy * x_normalized) 232 | if layer.bias is not None: 233 | buffer[:, w_numel:].add_(gy) 234 | 235 | 236 | class BatchNorm2dJacobianFactory(JacobianFactory): 237 | @classmethod 238 | def flat_grad(cls, buffer, mod, layer, x, gy): 239 | check_bn_training(mod) 240 | w_numel = layer.weight.numel() 241 | x_normalized = F.batch_norm(x, mod.running_mean, 242 | mod.running_var, 243 | None, None, mod.training, 244 | momentum=0.) 245 | buffer[:, :w_numel].add_((gy * x_normalized).sum(dim=(2, 3))) 246 | if layer.bias is not None: 247 | buffer[:, w_numel:].add_(gy.sum(dim=(2, 3))) 248 | 249 | 250 | class GroupNormJacobianFactory(JacobianFactory): 251 | @classmethod 252 | def flat_grad(cls, buffer, mod, layer, x, gy): 253 | w_numel = layer.weight.numel() 254 | x_normalized = F.group_norm(x, mod.num_groups, 255 | eps=mod.eps) 256 | buffer[:, :w_numel].add_((gy * x_normalized).sum(dim=(2, 3))) 257 | buffer[:, w_numel:].add_(gy.sum(dim=(2, 3))) 258 | 259 | 260 | class WeightNorm1dJacobianFactory(JacobianFactory): 261 | @classmethod 262 | def flat_grad(cls, buffer, mod, layer, x, gy): 263 | bs = x.size(0) 264 | norm2 = (mod.weight**2).sum(dim=1, keepdim=True) + mod.eps 265 | gw = torch.bmm(gy.unsqueeze(2) / torch.sqrt(norm2), 266 | x.unsqueeze(1)) 267 | wn2_out = F.linear(x, mod.weight / norm2**1.5) 268 | gw -= (gy * wn2_out).unsqueeze(2) * mod.weight.unsqueeze(0) 269 | buffer.add_(gw.view(bs, -1)) 270 | 271 | 272 | class WeightNorm2dJacobianFactory(JacobianFactory): 273 | @classmethod 274 | def flat_grad(cls, buffer, mod, layer, x, gy): 275 | bs = x.size(0) 276 | out_dim = mod.weight.size(0) 277 | norm2 = (mod.weight**2).sum(dim=(1, 2, 3)) + mod.eps 278 | gw = conv2d_backward(mod, x, gy / torch.sqrt(norm2).view(1, out_dim, 1, 1)) 279 | gw = gw.view(bs, out_dim, -1) 280 | wn2_out = F.conv2d(x, mod.weight / norm2.view(out_dim, 1, 1, 1)**1.5, None, 281 | stride=mod.stride, padding=mod.padding, dilation=mod.dilation) 282 | t = (gy * wn2_out).sum(dim=(2, 3)).view(bs, out_dim, 1) * mod.weight.view(1, out_dim, -1) 283 | gw -= (gy * wn2_out).sum(dim=(2, 3)).view(bs, out_dim, 1) * mod.weight.view(1, out_dim, -1) 284 | buffer.add_(gw.view(bs, -1)) 285 | 286 | 287 | class Cosine1dJacobianFactory(JacobianFactory): 288 | @classmethod 289 | def flat_grad(cls, buffer, mod, layer, x, gy): 290 | bs = x.size(0) 291 | norm2_w = (mod.weight**2).sum(dim=1, keepdim=True) + mod.eps 292 | norm2_x = (x**2).sum(dim=1, keepdim=True) + mod.eps 293 | x = x / torch.sqrt(norm2_x) 294 | gw = torch.bmm(gy.unsqueeze(2) / torch.sqrt(norm2_w), 295 | x.unsqueeze(1)) 296 | wn2_out = F.linear(x, mod.weight / norm2_w**1.5) 297 | gw -= (gy * wn2_out).unsqueeze(2) * mod.weight.unsqueeze(0) 298 | buffer.add_(gw.view(bs, -1)) 299 | 300 | 301 | class Affine1dJacobianFactory(JacobianFactory): 302 | @classmethod 303 | def flat_grad(cls, buffer, mod, layer, x, gy): 304 | w_numel = layer.weight.numel() 305 | buffer[:, :w_numel].add_(gy * x) 306 | if layer.bias is not None: 307 | buffer[:, w_numel:].add_(gy) 308 | 309 | 310 | FactoryMap = { 311 | LinearLayer: LinearJacobianFactory, 312 | Conv2dLayer: Conv2dJacobianFactory, 313 | ConvTranspose2dLayer: ConvTranspose2dJacobianFactory, 314 | BatchNorm1dLayer: BatchNorm1dJacobianFactory, 315 | BatchNorm2dLayer: BatchNorm2dJacobianFactory, 316 | GroupNormLayer: GroupNormJacobianFactory, 317 | WeightNorm1dLayer: WeightNorm1dJacobianFactory, 318 | WeightNorm2dLayer: WeightNorm2dJacobianFactory, 319 | Cosine1dLayer: Cosine1dJacobianFactory, 320 | Affine1dLayer: Affine1dJacobianFactory, 321 | } -------------------------------------------------------------------------------- /nngeometry/generator/para_lm_jacobian/grads.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from nngeometry.layercollection import (Affine1dLayer, Cosine1dLayer, LinearLayer, Conv2dLayer, BatchNorm1dLayer, 3 | BatchNorm2dLayer, GroupNormLayer, WeightNorm1dLayer, 4 | WeightNorm2dLayer, ConvTranspose2dLayer) 5 | from .grads_conv import conv2d_backward, convtranspose2d_backward, unfold_transpose_conv2d 6 | 7 | import torch.nn.functional as F 8 | 9 | 10 | class JacobianFactory: 11 | @classmethod 12 | def diag(cls, buffer, mod, layer, x, gy): 13 | bs = x.size(0) 14 | buffer_flat = torch.zeros(bs, layer.numel(), device=buffer.device) 15 | cls.flat_grad(buffer_flat, mod, layer, x, gy) 16 | buffer.add_((buffer_flat**2).sum(dim=0)) 17 | 18 | @classmethod 19 | def trace(cls, buffer, mod, layer, x, gy): 20 | buffer_diag = torch.zeros(layer.numel(), device=buffer.device) 21 | cls.diag(buffer_diag, mod, layer, x, gy) 22 | buffer.add_(buffer_diag.sum()) 23 | 24 | @classmethod 25 | def layer_block(cls, buffer, mod, layer, x, gy): 26 | bs = x.size(0) 27 | buffer_flat = torch.zeros(bs, layer.numel(), device=buffer.device) 28 | cls.flat_grad(buffer_flat, mod, layer, x, gy) 29 | buffer.add_(torch.mm(buffer_flat.t(), buffer_flat)) 30 | 31 | @classmethod 32 | def kxy(cls, buffer, mod, layer, x_i, gy_i, x_o, gy_o): 33 | bs_i = x_i.size(0) 34 | bs_o = x_o.size(0) 35 | buffer_flat_i = torch.zeros(bs_i, layer.numel(), device=buffer.device) 36 | buffer_flat_o = torch.zeros(bs_o, layer.numel(), device=buffer.device) 37 | cls.flat_grad(buffer_flat_i, mod, layer, x_i, gy_i) 38 | cls.flat_grad(buffer_flat_o, mod, layer, x_o, gy_o) 39 | buffer.add_(torch.mm(buffer_flat_i, buffer_flat_o.t())) 40 | 41 | @classmethod 42 | def Jv(cls, buffer, mod, layer, x, gy, v, v_bias): 43 | bs = x.size(0) 44 | buffer_flat = torch.zeros(bs, layer.numel(), device=buffer.device) 45 | cls.flat_grad(buffer_flat, mod, layer, x, gy) 46 | v = v.view(-1) 47 | if v_bias is not None: 48 | v = torch.cat((v, v_bias)) 49 | buffer.add_(torch.mv(buffer_flat, v)) 50 | 51 | 52 | class LinearJacobianFactory(JacobianFactory): 53 | @classmethod 54 | def flat_grad(cls, buffer, mod, layer, x, gy): 55 | bs = x.size(0) 56 | w_numel = layer.weight.numel() 57 | buffer[:, :w_numel] \ 58 | .add_(torch.bmm(gy.unsqueeze(2), x.unsqueeze(1)).view(bs, -1)) 59 | if layer.bias is not None: 60 | buffer[:, w_numel:].add_(gy) 61 | 62 | @classmethod 63 | def diag(cls, buffer, mod, layer, x, gy): 64 | w_numel = layer.weight.numel() 65 | buffer[:w_numel].add_(torch.mm(gy.t()**2, x**2).view(-1)) 66 | if layer.bias is not None: 67 | buffer[w_numel:].add_((gy**2).sum(dim=0)) 68 | 69 | @classmethod 70 | def kxy(cls, buffer, mod, layer, x_i, gy_i, x_o, gy_o): 71 | buffer.add_(torch.mm(x_i, x_o.t()) * 72 | torch.mm(gy_i, gy_o.t())) 73 | if layer.bias is not None: 74 | buffer.add_(torch.mm(gy_i, gy_o.t())) 75 | 76 | @classmethod 77 | def Jv(cls, buffer, mod, layer, x, gy, v, v_bias): 78 | buffer.add_((torch.mm(x, v.t()) * gy).sum(dim=1)) 79 | if layer.bias is not None: 80 | buffer.add_(torch.mv(gy.contiguous(), v_bias)) 81 | 82 | @classmethod 83 | def trace(cls, buffer, mod, layer, x, gy): 84 | buffer.add_(torch.mm(gy.t()**2, x**2).sum()) 85 | if layer.bias is not None: 86 | buffer.add_((gy**2).sum()) 87 | 88 | @classmethod 89 | def kfac_xx(cls, buffer, mod, layer, x, gy): 90 | if layer.bias is not None: 91 | x = torch.cat([x, torch.ones_like(x[:, :1])], dim=1) 92 | buffer.add_(torch.mm(x.t(), x)) 93 | 94 | @classmethod 95 | def kfac_gg(cls, buffer, mod, layer, x, gy): 96 | # print(f'mod: {mod}, layer: {layer}') 97 | # print(f'gy: {gy.shape}') 98 | # print(f'x: {x.shape}') 99 | buffer.add_(torch.mm(gy.t(), gy)) 100 | 101 | @classmethod 102 | def kfe_diag(cls, buffer, mod, layer, x, gy, evecs_a, evecs_g): 103 | if layer.bias is not None: 104 | x = torch.cat([x, torch.ones_like(x[:, :1])], dim=1) 105 | gy_kfe = torch.mm(gy, evecs_g) 106 | x_kfe = torch.mm(x, evecs_a) 107 | buffer.add_(torch.mm(gy_kfe.t()**2, x_kfe**2).view(-1)) 108 | 109 | @classmethod 110 | def quasidiag(cls, buffer_diag, buffer_cross, mod, layer, x, gy): 111 | w_numel = layer.weight.numel() 112 | buffer_diag[:w_numel].add_(torch.mm(gy.t()**2, x**2).view(-1)) 113 | if layer.bias is not None: 114 | buffer_diag[w_numel:].add_((gy**2).sum(dim=0)) 115 | buffer_cross.add_(torch.mm(gy.t()**2, x)) 116 | 117 | 118 | class Conv2dJacobianFactory(JacobianFactory): 119 | @classmethod 120 | def flat_grad(cls, buffer, mod, layer, x, gy): 121 | bs = x.size(0) 122 | w_numel = layer.weight.numel() 123 | indiv_gw = conv2d_backward(mod, x, gy) 124 | buffer[:, :w_numel].add_(indiv_gw.view(bs, -1)) 125 | if layer.bias is not None: 126 | buffer[:, w_numel:].add_(gy.sum(dim=(2, 3))) 127 | 128 | @classmethod 129 | def Jv(cls, buffer, mod, layer, x, gy, v, v_bias): 130 | bs = x.size(0) 131 | gy2 = F.conv2d(x, v, stride=mod.stride, 132 | padding=mod.padding, dilation=mod.dilation) 133 | buffer.add_((gy * gy2).view(bs, -1).sum(dim=1)) 134 | if layer.bias is not None: 135 | buffer.add_(torch.mv(gy.sum(dim=(2, 3)), v_bias)) 136 | 137 | @classmethod 138 | def kfac_xx(cls, buffer, mod, layer, x, gy): 139 | ks = (mod.weight.size(2), mod.weight.size(3)) 140 | # A_tilda in KFC 141 | A_tilda = F.unfold(x, kernel_size=ks, stride=mod.stride, 142 | padding=mod.padding, dilation=mod.dilation) 143 | # A_tilda is bs * #locations x #parameters 144 | A_tilda = A_tilda.permute(0, 2, 1).contiguous() \ 145 | .view(-1, A_tilda.size(1)) 146 | if layer.bias is not None: 147 | A_tilda = torch.cat([A_tilda, 148 | torch.ones_like(A_tilda[:, :1])], 149 | dim=1) 150 | # Omega_hat in KFC 151 | buffer.add_(torch.mm(A_tilda.t(), A_tilda)) 152 | 153 | @classmethod 154 | def kfac_gg(cls, buffer, mod, layer, x, gy): 155 | spatial_locations = gy.size(2) * gy.size(3) 156 | os = gy.size(1) 157 | # DS_tilda in KFC 158 | DS_tilda = gy.permute(0, 2, 3, 1).contiguous().view(-1, os) 159 | buffer.add_(torch.mm(DS_tilda.t(), DS_tilda) / spatial_locations) 160 | 161 | @classmethod 162 | def kfe_diag(cls, buffer, mod, layer, x, gy, evecs_a, evecs_g): 163 | ks = (mod.weight.size(2), mod.weight.size(3)) 164 | gy_s = gy.size() 165 | bs = gy_s[0] 166 | # project x to kfe 167 | x_unfold = F.unfold(x, kernel_size=ks, stride=mod.stride, 168 | padding=mod.padding, dilation=mod.dilation) 169 | x_unfold_s = x_unfold.size() 170 | x_unfold = x_unfold.view(bs, x_unfold_s[1], -1).permute(0, 2, 1)\ 171 | .contiguous().view(-1, x_unfold_s[1]) 172 | if mod.bias is not None: 173 | x_unfold = torch.cat([x_unfold, 174 | torch.ones_like(x_unfold[:, :1])], dim=1) 175 | x_kfe = torch.mm(x_unfold, evecs_a) 176 | 177 | # project gy to kfe 178 | gy = gy.view(bs, gy_s[1], -1).permute(0, 2, 1).contiguous() 179 | gy_kfe = torch.mm(gy.view(-1, gy_s[1]), evecs_g) 180 | gy_kfe = gy_kfe.view(bs, -1, gy_s[1]).permute(0, 2, 1).contiguous() 181 | 182 | indiv_gw = torch.bmm(gy_kfe.view(bs, gy_s[1], -1), 183 | x_kfe.view(bs, -1, x_kfe.size(1))) 184 | buffer.add_((indiv_gw**2).sum(dim=0).view(-1)) 185 | 186 | @classmethod 187 | def quasidiag(cls, buffer_diag, buffer_cross, mod, layer, x, gy): 188 | w_numel = layer.weight.numel() 189 | indiv_gw = conv2d_backward(mod, x, gy) 190 | buffer_diag[:w_numel].add_((indiv_gw**2).sum(dim=0).view(-1)) 191 | if layer.bias is not None: 192 | gb_per_example = gy.sum(dim=(2, 3)) 193 | buffer_diag[w_numel:].add_((gb_per_example**2).sum(dim=0)) 194 | y = (gy * gb_per_example.unsqueeze(2).unsqueeze(3)) 195 | cross_this = F.conv2d(x.transpose(0, 1), 196 | y.transpose(0, 1), 197 | stride=mod.dilation, 198 | padding=mod.padding, 199 | dilation=mod.stride).transpose(0, 1) 200 | cross_this = cross_this[:, :, :mod.kernel_size[0], :mod.kernel_size[1]] 201 | buffer_cross.add_(cross_this) 202 | 203 | 204 | class ConvTranspose2dJacobianFactory(JacobianFactory): 205 | @classmethod 206 | def flat_grad(cls, buffer, mod, layer, x, gy): 207 | bs = x.size(0) 208 | w_numel = layer.weight.numel() 209 | indiv_gw = convtranspose2d_backward(mod, x, gy) 210 | buffer[:, :w_numel].add_(indiv_gw.view(bs, -1)) 211 | if layer.bias is not None: 212 | buffer[:, w_numel:].add_(gy.sum(dim=(2, 3))) 213 | 214 | 215 | def check_bn_training(mod): 216 | # check that BN layers are in eval mode 217 | if mod.training: 218 | raise NotImplementedError('NNGeometry\'s Jacobian generator can' + 219 | ' only handle BatchNorm in evaluation mode') 220 | 221 | 222 | class BatchNorm1dJacobianFactory(JacobianFactory): 223 | @classmethod 224 | def flat_grad(cls, buffer, mod, layer, x, gy): 225 | check_bn_training(mod) 226 | w_numel = layer.weight.numel() 227 | x_normalized = F.batch_norm(x, mod.running_mean, 228 | mod.running_var, 229 | None, None, mod.training, 230 | momentum=0.) 231 | buffer[:, :w_numel].add_(gy * x_normalized) 232 | if layer.bias is not None: 233 | buffer[:, w_numel:].add_(gy) 234 | 235 | 236 | class BatchNorm2dJacobianFactory(JacobianFactory): 237 | @classmethod 238 | def flat_grad(cls, buffer, mod, layer, x, gy): 239 | check_bn_training(mod) 240 | w_numel = layer.weight.numel() 241 | x_normalized = F.batch_norm(x, mod.running_mean, 242 | mod.running_var, 243 | None, None, mod.training, 244 | momentum=0.) 245 | buffer[:, :w_numel].add_((gy * x_normalized).sum(dim=(2, 3))) 246 | if layer.bias is not None: 247 | buffer[:, w_numel:].add_(gy.sum(dim=(2, 3))) 248 | 249 | 250 | class GroupNormJacobianFactory(JacobianFactory): 251 | @classmethod 252 | def flat_grad(cls, buffer, mod, layer, x, gy): 253 | w_numel = layer.weight.numel() 254 | x_normalized = F.group_norm(x, mod.num_groups, 255 | eps=mod.eps) 256 | buffer[:, :w_numel].add_((gy * x_normalized).sum(dim=(2, 3))) 257 | buffer[:, w_numel:].add_(gy.sum(dim=(2, 3))) 258 | 259 | 260 | class WeightNorm1dJacobianFactory(JacobianFactory): 261 | @classmethod 262 | def flat_grad(cls, buffer, mod, layer, x, gy): 263 | bs = x.size(0) 264 | norm2 = (mod.weight**2).sum(dim=1, keepdim=True) + mod.eps 265 | gw = torch.bmm(gy.unsqueeze(2) / torch.sqrt(norm2), 266 | x.unsqueeze(1)) 267 | wn2_out = F.linear(x, mod.weight / norm2**1.5) 268 | gw -= (gy * wn2_out).unsqueeze(2) * mod.weight.unsqueeze(0) 269 | buffer.add_(gw.view(bs, -1)) 270 | 271 | 272 | class WeightNorm2dJacobianFactory(JacobianFactory): 273 | @classmethod 274 | def flat_grad(cls, buffer, mod, layer, x, gy): 275 | bs = x.size(0) 276 | out_dim = mod.weight.size(0) 277 | norm2 = (mod.weight**2).sum(dim=(1, 2, 3)) + mod.eps 278 | gw = conv2d_backward(mod, x, gy / torch.sqrt(norm2).view(1, out_dim, 1, 1)) 279 | gw = gw.view(bs, out_dim, -1) 280 | wn2_out = F.conv2d(x, mod.weight / norm2.view(out_dim, 1, 1, 1)**1.5, None, 281 | stride=mod.stride, padding=mod.padding, dilation=mod.dilation) 282 | t = (gy * wn2_out).sum(dim=(2, 3)).view(bs, out_dim, 1) * mod.weight.view(1, out_dim, -1) 283 | gw -= (gy * wn2_out).sum(dim=(2, 3)).view(bs, out_dim, 1) * mod.weight.view(1, out_dim, -1) 284 | buffer.add_(gw.view(bs, -1)) 285 | 286 | 287 | class Cosine1dJacobianFactory(JacobianFactory): 288 | @classmethod 289 | def flat_grad(cls, buffer, mod, layer, x, gy): 290 | bs = x.size(0) 291 | norm2_w = (mod.weight**2).sum(dim=1, keepdim=True) + mod.eps 292 | norm2_x = (x**2).sum(dim=1, keepdim=True) + mod.eps 293 | x = x / torch.sqrt(norm2_x) 294 | gw = torch.bmm(gy.unsqueeze(2) / torch.sqrt(norm2_w), 295 | x.unsqueeze(1)) 296 | wn2_out = F.linear(x, mod.weight / norm2_w**1.5) 297 | gw -= (gy * wn2_out).unsqueeze(2) * mod.weight.unsqueeze(0) 298 | buffer.add_(gw.view(bs, -1)) 299 | 300 | 301 | class Affine1dJacobianFactory(JacobianFactory): 302 | @classmethod 303 | def flat_grad(cls, buffer, mod, layer, x, gy): 304 | w_numel = layer.weight.numel() 305 | buffer[:, :w_numel].add_(gy * x) 306 | if layer.bias is not None: 307 | buffer[:, w_numel:].add_(gy) 308 | 309 | 310 | FactoryMap = { 311 | LinearLayer: LinearJacobianFactory, 312 | Conv2dLayer: Conv2dJacobianFactory, 313 | ConvTranspose2dLayer: ConvTranspose2dJacobianFactory, 314 | BatchNorm1dLayer: BatchNorm1dJacobianFactory, 315 | BatchNorm2dLayer: BatchNorm2dJacobianFactory, 316 | GroupNormLayer: GroupNormJacobianFactory, 317 | WeightNorm1dLayer: WeightNorm1dJacobianFactory, 318 | WeightNorm2dLayer: WeightNorm2dJacobianFactory, 319 | Cosine1dLayer: Cosine1dJacobianFactory, 320 | Affine1dLayer: Affine1dJacobianFactory, 321 | } -------------------------------------------------------------------------------- /nngeometry/object/lm_vector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ..llama_layercollection import LayerCollection 3 | from nngeometry.llama_layercollection import LLamaLayerCollection 4 | 5 | def random_pvector_dict(layer_collection, device=None): 6 | """ 7 | Returns a random :class:`nngeometry.object.PVector` object using 8 | the structure defined by the `layer_collection` parameter, with 9 | each components drawn from a normal distribution with mean 0 and standard 10 | deviation 1. 11 | 12 | The returned `PVector` will internally use a dict representation. 13 | 14 | :param layer_collection: The :class:`nngeometry.layercollection.LayerCollection` 15 | describing the structure of the random pvector 16 | """ 17 | v_dict = dict() 18 | for layer_id, layer in layer_collection.layers.items(): 19 | if layer.bias is not None: 20 | v_dict[layer_id] = (torch.normal(0, 1, layer.weight.size, device=device), 21 | torch.normal(0, 1, layer.bias.size, device=device)) 22 | else: 23 | v_dict[layer_id] = (torch.normal(0, 1, layer.weight.size, device=device),) 24 | return PVector(layer_collection, dict_repr=v_dict) 25 | 26 | 27 | def random_pvector(layer_collection, device=None): 28 | """ 29 | Returns a random :class:`nngeometry.object.PVector` object using 30 | the structure defined by the `layer_collection` parameter, with 31 | each components drawn from a normal distribution with mean 0 and standard 32 | deviation 1. 33 | 34 | The returned `PVector` will internally use a flat representation. 35 | 36 | :param layer_collection: The :class:`nngeometry.layercollection.LayerCollection` 37 | describing the structure of the random pvector 38 | """ 39 | n_parameters = layer_collection.numel() 40 | random_v_flat = torch.normal(0, 1, (n_parameters,), 41 | device=device) 42 | return PVector(layer_collection=layer_collection, 43 | vector_repr=random_v_flat) 44 | 45 | 46 | def random_fvector(n_samples, n_output=1, device=None): 47 | random_v_flat = torch.normal(0, 1, (n_output, n_samples,), 48 | device=device) 49 | return FVector(vector_repr=random_v_flat) 50 | 51 | 52 | class PVector: 53 | """ 54 | A vector in parameter space 55 | 56 | :param: 57 | """ 58 | def __init__(self, layer_collection, vector_repr=None, 59 | dict_repr=None): 60 | self.layer_collection = layer_collection 61 | self.vector_repr = vector_repr 62 | self.dict_repr = dict_repr 63 | self.is_svd=False 64 | 65 | @staticmethod 66 | def from_model(model): 67 | """ 68 | Creates a PVector using the current values of the given 69 | model 70 | """ 71 | dict_repr = dict() 72 | layer_collection = LayerCollection.from_model(model) 73 | l_to_m, _ = layer_collection.get_layerid_module_maps(model) 74 | for layer_id, layer in layer_collection.layers.items(): 75 | mod = l_to_m[layer_id] 76 | if layer.bias is not None: 77 | dict_repr[layer_id] = (mod.weight, mod.bias) 78 | else: 79 | dict_repr[layer_id] = (mod.weight,) 80 | return PVector(layer_collection, dict_repr=dict_repr) 81 | 82 | def copy_to_model(self, model): 83 | """ 84 | Updates `model` parameter values with the current PVector 85 | 86 | Note. This is an inplace operation 87 | """ 88 | dict_repr = self.get_dict_representation() 89 | layer_collection = LayerCollection.from_model(model) 90 | l_to_m, _ = layer_collection.get_layerid_module_maps(model) 91 | for layer_id, layer in layer_collection.layers.items(): 92 | mod = l_to_m[layer_id] 93 | if layer.bias is not None: 94 | mod.bias.data.copy_(dict_repr[layer_id][1]) 95 | mod.weight.data.copy_(dict_repr[layer_id][0]) 96 | 97 | def add_to_model(self, model): 98 | """ 99 | Updates `model` parameter values by adding the current PVector 100 | 101 | Note. This is an inplace operation 102 | """ 103 | dict_repr = self.get_dict_representation() 104 | layer_collection = LayerCollection.from_model(model) 105 | l_to_m, _ = layer_collection.get_layerid_module_maps(model) 106 | for layer_id, layer in layer_collection.layers.items(): 107 | mod = l_to_m[layer_id] 108 | if layer.bias is not None: 109 | mod.bias.data.add_(dict_repr[layer_id][1]) 110 | mod.weight.data.add_(dict_repr[layer_id][0]) 111 | 112 | @staticmethod 113 | def from_model_grad(model, ignore_layers=[]): 114 | """ 115 | Creates a PVector using the current values of the `.grad` 116 | fields of parameters of the given model 117 | """ 118 | dict_repr = dict() 119 | layer_collection = LLamaLayerCollection.from_model(model, True, ignore_layers) 120 | l_to_m, _ = layer_collection.get_layerid_module_maps(model) 121 | for layer_id, layer in layer_collection.layers.items(): 122 | mod = l_to_m[layer_id] 123 | if layer.bias is not None: 124 | dict_repr[layer_id] = (mod.weight.grad, mod.bias.grad) 125 | else: 126 | dict_repr[layer_id] = (mod.weight.grad,) 127 | return PVector(layer_collection, dict_repr=dict_repr) 128 | 129 | # @staticmethod 130 | # def from_lc_grad(layer_collection): 131 | # """ 132 | # Creates a PVector using the current values of the `.grad` 133 | # fields of parameters of the given model 134 | # """ 135 | # dict_repr = dict() 136 | # layer_collection = LayerCollection.from_model(model) 137 | # l_to_m, _ = layer_collection.get_layerid_module_maps(model) 138 | # for layer_id, layer in layer_collection.layers.items(): 139 | # mod = l_to_m[layer_id] 140 | # if layer.bias is not None: 141 | # dict_repr[layer_id] = (mod.weight.grad, mod.bias.grad) 142 | # else: 143 | # dict_repr[layer_id] = (mod.weight.grad,) 144 | # return PVector(layer_collection, dict_repr=dict_repr) 145 | 146 | def clone(self): 147 | """ 148 | Returns a clone of the current object 149 | """ 150 | if self.is_svd: 151 | raise NotImplementedError 152 | # if self.dict_repr is not None: 153 | # dict_clone = dict() 154 | # for k, v in self.dict_repr.items(): 155 | # if len(v) == 2: 156 | # dict_clone[k] = ((v[0][0].clone(), v[0][1].clone(), v[0][2].clone()), v[1].clone()) 157 | # else: 158 | # dict_clone[k] = ((v[0][0].clone(), v[0][1].clone(), v[0][2].clone()), ) 159 | # return PVector(self.layer_collection, dict_repr=dict_clone) 160 | 161 | if self.dict_repr is not None: 162 | dict_clone = dict() 163 | for k, v in self.dict_repr.items(): 164 | if len(v) == 2: 165 | dict_clone[k] = (v[0].clone(), v[1].clone()) 166 | else: 167 | dict_clone[k] = (v[0].clone(),) 168 | return PVector(self.layer_collection, dict_repr=dict_clone) 169 | if self.vector_repr is not None: 170 | return PVector(self.layer_collection, 171 | vector_repr=self.vector_repr.clone()) 172 | 173 | def detach(self): 174 | """ 175 | Detachs the current PVector from the computation graph 176 | """ 177 | if self.dict_repr is not None: 178 | dict_detach = dict() 179 | for k, v in self.dict_repr.items(): 180 | if len(v) == 2: 181 | dict_detach[k] = (v[0].detach(), v[1].detach()) 182 | else: 183 | dict_detach[k] = (v[0].detach(),) 184 | return PVector(self.layer_collection, dict_repr=dict_detach) 185 | if self.vector_repr is not None: 186 | return PVector(self.layer_collection, 187 | vector_repr=self.vector_repr.detach()) 188 | 189 | def get_flat_representation(self): 190 | """ 191 | Returns a Pytorch 1d tensor of the flatten vector. 192 | 193 | .. warning:: 194 | The ordering in which the parameters are 195 | flattened can seem to be arbitrary. It is in fact 196 | the same ordering as specified by the ``layercollection.LayerCollection`` 197 | object. 198 | 199 | :return: a Pytorch Tensor 200 | """ 201 | if self.vector_repr is not None: 202 | return self.vector_repr 203 | elif self.dict_repr is not None: 204 | return self._dict_to_flat() 205 | else: 206 | return NotImplementedError 207 | 208 | def get_dict_representation(self): 209 | if self.dict_repr is not None: 210 | return self.dict_repr 211 | elif self.vector_repr is not None: 212 | return self._flat_to_dict() 213 | else: 214 | return NotImplementedError 215 | 216 | def _dict_to_flat(self): 217 | parts = [] 218 | for layer_id, layer in self.layer_collection.layers.items(): 219 | parts.append(self.dict_repr[layer_id][0].view(-1)) 220 | if len(self.dict_repr[layer_id]) > 1: 221 | parts.append(self.dict_repr[layer_id][1].view(-1)) 222 | return torch.cat(parts) 223 | 224 | def _flat_to_dict(self): 225 | dict_repr = dict() 226 | for layer_id, layer in self.layer_collection.layers.items(): 227 | start = self.layer_collection.p_pos[layer_id] 228 | w = self.vector_repr[start:start+layer.weight.numel()] \ 229 | .view(*layer.weight.size) 230 | start += layer.weight.numel() 231 | if layer.bias is not None: 232 | b = self.vector_repr[start:start+layer.bias.numel()] \ 233 | .view(*layer.bias.size) 234 | start += layer.bias.numel() 235 | dict_repr[layer_id] = (w, b) 236 | else: 237 | dict_repr[layer_id] = (w,) 238 | return dict_repr 239 | 240 | def norm(self, p=2): 241 | """ 242 | Computes the Lp norm of the PVector 243 | """ 244 | if self.is_svd: 245 | if self.dict_repr is not None: 246 | sum_p = 0 247 | for l_id, l in self.layer_collection.layers.items(): 248 | sum_p += (self.dict_repr[l_id][0][1]**p).sum() 249 | if l.bias is not None: 250 | sum_p += (self.dict_repr[l_id][1]**p).sum() 251 | return sum_p ** (1/p) 252 | if self.dict_repr is not None: 253 | sum_p = 0 254 | for l_id, l in self.layer_collection.layers.items(): 255 | sum_p += (self.dict_repr[l_id][0]**p).sum() 256 | if l.bias is not None: 257 | sum_p += (self.dict_repr[l_id][1]**p).sum() 258 | return sum_p ** (1/p) 259 | else: 260 | return torch.norm(self.vector_repr, p=p) 261 | 262 | def square(self): 263 | if self.is_svd: 264 | raise NotImplementedError 265 | for l_id, l in self.layer_collection.layers.items(): 266 | if l.bias is not None: 267 | self.dict_repr[l_id]=(self.dict_repr[l_id][0]**2, self.dict_repr[l_id][1]**2) 268 | else: 269 | self.dict_repr[l_id]=(self.dict_repr[l_id][0]**2, ) 270 | 271 | def __rmul__(self, x): 272 | # TODO: test 273 | # scalar multiplication 274 | if self.dict_repr is not None: 275 | v_dict = dict() 276 | for l_id, l in self.layer_collection.layers.items(): 277 | if l.bias: 278 | v_dict[l_id] = (x * self.dict_repr[l_id][0], 279 | x * self.dict_repr[l_id][1]) 280 | else: 281 | v_dict[l_id] = (x * self.dict_repr[l_id][0],) 282 | return PVector(self.layer_collection, dict_repr=v_dict) 283 | else: 284 | return PVector(self.layer_collection, 285 | vector_repr=x * self.vector_repr) 286 | 287 | def __add__(self, other): 288 | if self.dict_repr is not None and other.dict_repr is not None: 289 | v_dict = dict() 290 | for l_id, l in self.layer_collection.layers.items(): 291 | if l.bias is not None: 292 | v_dict[l_id] = (self.dict_repr[l_id][0] + 293 | other.dict_repr[l_id][0], 294 | self.dict_repr[l_id][1] + 295 | other.dict_repr[l_id][1]) 296 | else: 297 | v_dict[l_id] = (self.dict_repr[l_id][0] + 298 | other.dict_repr[l_id][0],) 299 | return PVector(self.layer_collection, dict_repr=v_dict) 300 | elif self.vector_repr is not None and other.vector_repr is not None: 301 | return PVector(self.layer_collection, 302 | vector_repr=self.vector_repr+other.vector_repr) 303 | else: 304 | return PVector(self.layer_collection, 305 | vector_repr=(self.get_flat_representation() + 306 | other.get_flat_representation())) 307 | 308 | def __sub__(self, other): 309 | 310 | if self.dict_repr is not None and other.dict_repr is not None: 311 | v_dict = dict() 312 | # print('in sub') 313 | for l_id, l in self.layer_collection.layers.items(): 314 | if self.is_svd: 315 | u,s,v=self.dict_repr[l_id][0] 316 | vec1=u@torch.diag(s)@v.T 317 | else: 318 | vec1=self.dict_repr[l_id][0] 319 | if other.is_svd: 320 | u,s,v=other.dict_repr[l_id][0] 321 | vec2=u@torch.diag(s)@v.T 322 | else: 323 | vec2=other.dict_repr[l_id][0] 324 | 325 | if l.bias is not None: 326 | v_dict[l_id] = (vec1 - vec2, 327 | self.dict_repr[l_id][1] - 328 | other.dict_repr[l_id][1]) 329 | else: 330 | v_dict[l_id] = (vec1 - vec2, ) 331 | return PVector(self.layer_collection, dict_repr=v_dict) 332 | 333 | elif self.vector_repr is not None and other.vector_repr is not None: 334 | return PVector(self.layer_collection, 335 | vector_repr=self.vector_repr-other.vector_repr) 336 | else: 337 | return PVector(self.layer_collection, 338 | vector_repr=(self.get_flat_representation() - 339 | other.get_flat_representation())) 340 | def svd(self, q=32): 341 | self.is_svd=True 342 | for l_id, l in self.layer_collection.layers.items(): 343 | if l.bias is not None: 344 | self.dict_repr[l_id]=(torch.svd_lowrank(self.dict_repr[l_id][0], q=q), self.dict_repr[l_id][1]) 345 | else: 346 | self.dict_repr[l_id]=(torch.svd_lowrank(self.dict_repr[l_id][0], q=q),) 347 | 348 | def to(self, device): 349 | if self.is_svd: 350 | for l_id, l in self.layer_collection.layers.items(): 351 | if l.bias is not None: 352 | self.dict_repr[l_id]=((self.dict_repr[l_id][0][0].to(device), self.dict_repr[l_id][0][1].to(device), self.dict_repr[l_id][0][2].to(device)), self.dict_repr[l_id][1].to(device)) 353 | else: 354 | self.dict_repr[l_id]=((self.dict_repr[l_id][0][0].to(device), self.dict_repr[l_id][0][1].to(device), self.dict_repr[l_id][0][2].to(device)), ) 355 | else: 356 | for l_id, l in self.layer_collection.layers.items(): 357 | if l.bias is not None: 358 | self.dict_repr[l_id]=(self.dict_repr[l_id][0].to(device), self.dict_repr[l_id][1].to(device)) 359 | else: 360 | self.dict_repr[l_id]=(self.dict_repr[l_id][0].to(device), ) 361 | 362 | def dot_svd(self, other): 363 | assert self.is_svd and other.is_svd 364 | """ 365 | Computes the dot product between `self` and `other` 366 | 367 | :param other: The other `PVector` 368 | """ 369 | if self.vector_repr is not None or other.vector_repr is not None: 370 | raise NotImplementedError 371 | return torch.dot(self.get_flat_representation(), 372 | other.get_flat_representation()) 373 | else: 374 | dot_ = 0 375 | for l_id, l in self.layer_collection.layers.items(): 376 | if l.bias is not None: 377 | dot_ += torch.dot(self.dict_repr[l_id][1], 378 | other.dict_repr[l_id][1]) 379 | u1, s1, v1=self.dict_repr[l_id][0] 380 | u2, s2, v2=other.dict_repr[l_id][0] 381 | # dot_ += torch.dot((u1@torch.diag(s1)@v1.T).view(-1), 382 | # (u2@torch.diag(s2)@v2.T).view(-1)) 383 | dot_ += torch.trace(u1@(torch.diag(s1)@(v1.T@v2)@torch.diag(s2))@u2.T) 384 | return dot_ 385 | 386 | 387 | def dot(self, other): 388 | """ 389 | Computes the dot product between `self` and `other` 390 | 391 | :param other: The other `PVector` 392 | """ 393 | if self.vector_repr is not None or other.vector_repr is not None: 394 | return torch.dot(self.get_flat_representation(), 395 | other.get_flat_representation()) 396 | else: 397 | dot_ = 0 398 | for l_id, l in self.layer_collection.layers.items(): 399 | if l.bias is not None: 400 | dot_ += torch.dot(self.dict_repr[l_id][1], 401 | other.dict_repr[l_id][1]) 402 | dot_ += torch.dot(self.dict_repr[l_id][0].view(-1), 403 | other.dict_repr[l_id][0].view(-1)) 404 | return dot_ 405 | 406 | def size(self): 407 | """ 408 | The size of the PVector, or equivalently the number of 409 | parameters of the layer collection 410 | """ 411 | return (self.layer_collection.numel(), ) 412 | 413 | 414 | class FVector: 415 | """ 416 | A vector in function space 417 | """ 418 | def __init__(self, vector_repr=None): 419 | self.vector_repr = vector_repr 420 | 421 | def get_flat_representation(self): 422 | if self.vector_repr is not None: 423 | return self.vector_repr 424 | else: 425 | return NotImplementedError 426 | --------------------------------------------------------------------------------