├── minigpt4 ├── common │ ├── __init__.py │ ├── gradcam.py │ ├── optims.py │ ├── dist_utils.py │ ├── registry.py │ ├── logger.py │ ├── utils.py │ └── config.py ├── datasets │ ├── __init__.py │ ├── datasets │ │ ├── __init__.py │ │ ├── base_dataset.py │ │ ├── rec_base_dataset.py │ │ ├── dataloader_utils.py │ │ └── rec_gnndataset.py │ ├── builders │ │ ├── __init__.py │ │ ├── rec_base_dataset_builder.py │ │ ├── base_dataset_builder.py │ │ └── rec_pair_builder.py │ └── data_utils.py ├── configs │ ├── datasets │ │ ├── default.yaml │ │ └── movielens │ │ │ └── default.yaml │ ├── models │ │ ├── minigpt4.yaml │ │ ├── minigpt4rec.yaml │ │ └── minigpt4rec_lora.yaml │ └── default.yaml ├── runners │ └── __init__.py ├── models │ ├── readme.md │ ├── __init__.py │ ├── base_model.py │ └── rec_model.py ├── tasks │ ├── rec_pretrain.py │ ├── __init__.py │ └── base_task.py ├── processors │ ├── base_processor.py │ ├── __init__.py │ ├── rec_processors.py │ ├── blip_processors.py │ └── randaugment.py └── __init__.py ├── pull_model_from_hf.py ├── requirements.txt ├── search_result.py ├── prompts └── binllm_text.txt ├── environment.yml ├── README.md ├── view_token.py ├── train_configs └── hash_CF_ml.yaml ├── used_metrics.py └── train_binllm.py /minigpt4/common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /minigpt4/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /minigpt4/datasets/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /minigpt4/configs/datasets/default.yaml: -------------------------------------------------------------------------------- 1 | datasets: -------------------------------------------------------------------------------- /minigpt4/configs/models/minigpt4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | -------------------------------------------------------------------------------- /minigpt4/configs/models/minigpt4rec.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | -------------------------------------------------------------------------------- /minigpt4/configs/models/minigpt4rec_lora.yaml: -------------------------------------------------------------------------------- 1 | model: -------------------------------------------------------------------------------- /minigpt4/configs/datasets/movielens/default.yaml: -------------------------------------------------------------------------------- 1 | datasets: -------------------------------------------------------------------------------- /minigpt4/configs/default.yaml: -------------------------------------------------------------------------------- 1 | env: 2 | # For default users 3 | # cache_root: "cache" 4 | # For internal use with per/ent storage 5 | # cache_root: "/export/home/.cache/minigpt4" 6 | # cache_root: "/home/////.cache/minigpt4" 7 | cache_root: "/data///.cache/minigpt4" 8 | -------------------------------------------------------------------------------- /minigpt4/runners/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from minigpt4.runners.runner_base import RunnerBase 9 | from minigpt4.runners.runner_base_rec import RecRunnerBase 10 | 11 | __all__ = ["RunnerBase", "RecRunnerBase"] 12 | -------------------------------------------------------------------------------- /minigpt4/models/readme.md: -------------------------------------------------------------------------------- 1 | + base_model.py: Contains the fundamental class "BaseModel". 2 | + rec_model.py: Defines "Rec2Base," a subclass of "BaseModel" specifically focused on setting up recommender models. 3 | + minigpt4rec_v2.py: Introduces "CoLLM," a subclass of "Rec2Base" tailored for a particular purpose. 4 | + rec_base_models.py: Houses various collaborative models for recommendation systems. 5 | + modeling_llama.py: Holds the code related to the LLAMA framework. 6 | -------------------------------------------------------------------------------- /minigpt4/tasks/rec_pretrain.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from minigpt4.common.registry import registry 9 | from minigpt4.tasks.base_task import BaseTask 10 | from minigpt4.tasks.rec_base_task import RecBaseTask 11 | 12 | 13 | @registry.register_task("rec_pretrain") 14 | class RecPretrainTask(RecBaseTask): 15 | def __init__(self): 16 | super().__init__() 17 | 18 | # def evaluation(self, model, data_loader, cuda_enabled=True): 19 | # pass 20 | -------------------------------------------------------------------------------- /minigpt4/processors/base_processor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from omegaconf import OmegaConf 9 | 10 | 11 | class BaseProcessor: 12 | def __init__(self): 13 | self.transform = lambda x: x 14 | return 15 | 16 | def __call__(self, item): 17 | return self.transform(item) 18 | 19 | @classmethod 20 | def from_config(cls, cfg=None): 21 | return cls() 22 | 23 | def build(self, **kwargs): 24 | cfg = OmegaConf.create(kwargs) 25 | 26 | return self.from_config(cfg) 27 | -------------------------------------------------------------------------------- /pull_model_from_hf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import LlamaTokenizer, LlamaForCausalLM 3 | 4 | ## v2 models 5 | model_path = 'openlm-research/open_llama_7b_v2' 6 | 7 | ## v1 models 8 | # model_path = 'openlm-research/open_llama_3b' 9 | # model_path = 'openlm-research/open_llama_7b' 10 | # model_path = 'openlm-research/open_llama_13b' 11 | 12 | tokenizer = LlamaTokenizer.from_pretrained(model_path) 13 | model = LlamaForCausalLM.from_pretrained( 14 | model_path, torch_dtype=torch.float16 15 | ).cuda() 16 | 17 | prompt = 'Q: What is the largest animal?\nA:' 18 | input_ids = tokenizer(prompt, return_tensors="pt").input_ids 19 | 20 | generation_output = model.generate( 21 | input_ids=input_ids.cuda(), max_new_tokens=32 22 | ) 23 | print(tokenizer.decode(generation_output[0])) -------------------------------------------------------------------------------- /minigpt4/common/gradcam.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from matplotlib import pyplot as plt 3 | from scipy.ndimage import filters 4 | from skimage import transform as skimage_transform 5 | 6 | 7 | def getAttMap(img, attMap, blur=True, overlap=True): 8 | attMap -= attMap.min() 9 | if attMap.max() > 0: 10 | attMap /= attMap.max() 11 | attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant") 12 | if blur: 13 | attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2])) 14 | attMap -= attMap.min() 15 | attMap /= attMap.max() 16 | cmap = plt.get_cmap("jet") 17 | attMapV = cmap(attMap) 18 | attMapV = np.delete(attMapV, 3, 2) 19 | if overlap: 20 | attMap = ( 21 | 1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img 22 | + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV 23 | ) 24 | return attMap 25 | -------------------------------------------------------------------------------- /minigpt4/processors/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from minigpt4.processors.base_processor import BaseProcessor 9 | from minigpt4.processors.blip_processors import ( 10 | Blip2ImageTrainProcessor, 11 | Blip2ImageEvalProcessor, 12 | BlipCaptionProcessor, 13 | ) 14 | 15 | from minigpt4.common.registry import registry 16 | 17 | __all__ = [ 18 | "BaseProcessor", 19 | "Blip2ImageTrainProcessor", 20 | "Blip2ImageEvalProcessor", 21 | "BlipCaptionProcessor", 22 | ] 23 | 24 | 25 | def load_processor(name, cfg=None): 26 | """ 27 | Example 28 | 29 | >>> processor = load_processor("alpro_video_train", cfg=None) 30 | """ 31 | processor = registry.get_processor_class(name).from_config(cfg) 32 | 33 | return processor 34 | -------------------------------------------------------------------------------- /minigpt4/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from minigpt4.common.registry import registry 9 | from minigpt4.tasks.base_task import BaseTask 10 | from minigpt4.tasks.rec_base_task import RecBaseTask 11 | # from minigpt4.tasks.image_text_pretrain import ImageTextPretrainTask 12 | from minigpt4.tasks.rec_pretrain import RecPretrainTask 13 | 14 | 15 | def setup_task(cfg): 16 | assert "task" in cfg.run_cfg, "Task name must be provided." 17 | 18 | task_name = cfg.run_cfg.task 19 | task = registry.get_task_class(task_name).setup_task(cfg=cfg) 20 | assert task is not None, "Task {} not properly registered.".format(task_name) 21 | 22 | return task 23 | 24 | 25 | __all__ = [ 26 | "BaseTask", 27 | # "ImageTextPretrainTask", 28 | "RecPretrainTask" 29 | ] 30 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.16.0 2 | aiohttp==3.8.4 3 | aiosignal==1.3.1 4 | async-timeout==4.0.2 5 | attrs==22.2.0 6 | bitsandbytes==0.37.0 7 | cchardet==2.1.7 8 | chardet==5.1.0 9 | contourpy==1.0.7 10 | cycler==0.11.0 11 | filelock==3.9.0 12 | fonttools==4.38.0 13 | frozenlist==1.3.3 14 | huggingface-hub==0.13.4 15 | importlib-resources==5.12.0 16 | kiwisolver==1.4.4 17 | matplotlib==3.7.0 18 | multidict==6.0.4 19 | openai==0.27.0 20 | packaging==23.0 21 | psutil==5.9.4 22 | pycocotools==2.0.6 23 | pyparsing==3.0.9 24 | python-dateutil==2.8.2 25 | pyyaml==6.0 26 | regex==2022.10.31 27 | tokenizers==0.13.2 28 | tqdm==4.64.1 29 | transformers==4.28.0 30 | timm==0.6.13 31 | spacy==3.5.1 32 | webdataset==0.2.48 33 | scikit-learn==1.2.2 34 | scipy==1.10.1 35 | yarl==1.8.2 36 | zipp==3.14.0 37 | omegaconf==2.3.0 38 | opencv-python==4.7.0.72 39 | iopath==0.1.10 40 | decord==0.6.0 41 | tenacity==8.2.2 42 | peft 43 | pycocoevalcap 44 | sentence-transformers 45 | umap-learn 46 | notebook 47 | gradio==3.24.1 48 | gradio-client==0.0.8 49 | wandb 50 | -------------------------------------------------------------------------------- /minigpt4/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import os 9 | import sys 10 | 11 | from omegaconf import OmegaConf 12 | 13 | from minigpt4.common.registry import registry 14 | 15 | from minigpt4.datasets.builders import * 16 | from minigpt4.models import * 17 | from minigpt4.processors import * 18 | from minigpt4.tasks import * 19 | 20 | 21 | root_dir = os.path.dirname(os.path.abspath(__file__)) 22 | default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml")) 23 | 24 | registry.register_path("library_root", root_dir) 25 | repo_root = os.path.join(root_dir, "..") 26 | registry.register_path("repo_root", repo_root) 27 | cache_root = os.path.join(repo_root, default_cfg.env.cache_root) 28 | registry.register_path("cache_root", cache_root) 29 | 30 | registry.register("MAX_INT", sys.maxsize) 31 | registry.register("SPLIT_NAMES", ["train", "val", "test"]) 32 | -------------------------------------------------------------------------------- /search_result.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import ast 4 | 5 | def find_max(log_path): 6 | with open(log_path,'r') as f: 7 | lines = f.readlines() 8 | configs = [] 9 | results = [] 10 | i = 0 11 | for one_line in lines: 12 | # print(one_line) 13 | i+=1 14 | # if i%2!=0: 15 | # continue 16 | # if "\'embedding_size\': 32" not in one_line: 17 | # continue 18 | one_line = one_line.strip("train_config: ").split(" best result: ") 19 | try: 20 | configs.append(ast.literal_eval(one_line[0])) 21 | results.append(ast.literal_eval(one_line[1])) 22 | except: 23 | print(one_line) 24 | raise RuntimeError 25 | 26 | max_k = 0 27 | max_valid_auc = 0 28 | for k in range(len(results)): 29 | if results[k]['valid_auc'] > max_valid_auc: 30 | max_k = k 31 | max_valid_auc = results[k]['valid_auc'] 32 | return configs[max_k], results[max_k] 33 | 34 | 35 | 36 | 37 | print(find_max("log/xxxxx.log"),'\n') 38 | -------------------------------------------------------------------------------- /prompts/binllm_text.txt: -------------------------------------------------------------------------------- 1 | #Question: A user has given high ratings to the following books: . Additionally, we have information about the user's preferences encoded in the feature . Using all available information, make a prediction about whether the user would enjoy the book titled with the feature ? Answer with "Yes" or "No". \n#Answer: 2 | #Question: A user has given high ratings to the following books: . Additionally, we have information about the user's preferences encoded in the feature . Using all available information, make a prediction about whether the user would enjoy the book titled with the feature ? Answer with "Yes" or "No". \n#Answer: 3 | #Question: A user has given high ratings to the following books: . Additionally, we have information about the user's preferences encoded in the feature . Using all available information, make a prediction about whether the user would enjoy the book titled with the feature ? Answer with "Yes" or "No". \n#Answer: 4 | #Question: A user has given high ratings to the following books: . Additionally, we have information about the user's preferences encoded in the feature . Using all available information, make a prediction about whether the user would enjoy the book titled with the feature ? Answer with "Yes" or "No". \n#Answer: -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: minigpt4 2 | channels: 3 | - pytorch 4 | - defaults 5 | - anaconda 6 | dependencies: 7 | - python=3.9 8 | - cudatoolkit 9 | - pip 10 | - pytorch=1.12.1 11 | - pytorch-mutex=1.0=cuda 12 | - torchaudio=0.12.1 13 | - torchvision=0.13.1 14 | - pip: 15 | - accelerate==0.16.0 16 | - aiohttp==3.8.4 17 | - aiosignal==1.3.1 18 | - async-timeout==4.0.2 19 | - attrs==22.2.0 20 | - bitsandbytes==0.37.0 21 | - cchardet==2.1.7 22 | - chardet==5.1.0 23 | - contourpy==1.0.7a 24 | - cycler==0.11.0 25 | - filelock==3.9.0 26 | - fonttools==4.38.0 27 | - frozenlist==1.3.3 28 | - huggingface-hub==0.13.4 29 | - importlib-resources==5.12.0 30 | - kiwisolver==1.4.4 31 | - matplotlib==3.7.0 32 | - multidict==6.0.4 33 | - openai==0.27.0 34 | - packaging==23.0 35 | - psutil==5.9.4 36 | - pycocotools==2.0.6 37 | - pyparsing==3.0.9 38 | - python-dateutil==2.8.2 39 | - pyyaml==6.0 40 | - regex==2022.10.31 41 | - tokenizers==0.13.2 42 | - tqdm==4.64.1 43 | - transformers==4.28.0 44 | - timm==0.6.13 45 | - spacy==3.5.1 46 | - webdataset==0.2.48 47 | - scikit-learn==1.2.2 48 | - scipy==1.10.1 49 | - yarl==1.8.2 50 | - zipp==3.14.0 51 | - omegaconf==2.3.0 52 | - opencv-python==4.7.0.72 53 | - iopath==0.1.10 54 | - decord==0.6.0 55 | - tenacity==8.2.2 56 | - peft 57 | - pycocoevalcap 58 | - sentence-transformers 59 | - umap-learn 60 | - notebook 61 | - gradio==3.24.1 62 | - gradio-client==0.0.8 63 | - wandb 64 | -------------------------------------------------------------------------------- /minigpt4/processors/rec_processors.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import re 9 | 10 | from minigpt4.common.registry import registry 11 | from minigpt4.processors.base_processor import BaseProcessor 12 | from minigpt4.processors.randaugment import RandomAugment 13 | from omegaconf import OmegaConf 14 | from torchvision import transforms 15 | from torchvision.transforms.functional import InterpolationMode 16 | 17 | 18 | class BlipImageBaseProcessor(BaseProcessor): 19 | def __init__(self, mean=None, std=None): 20 | if mean is None: 21 | mean = (0.48145466, 0.4578275, 0.40821073) 22 | if std is None: 23 | std = (0.26862954, 0.26130258, 0.27577711) 24 | 25 | self.normalize = transforms.Normalize(mean, std) 26 | 27 | 28 | @registry.register_processor("rec_response") 29 | class RecResponseProcessor(BaseProcessor): 30 | def __init__(self, prompt="", max_words=50): 31 | self.prompt = prompt 32 | self.max_words = max_words 33 | 34 | def __call__(self, caption): 35 | caption = self.prompt + self.pre_caption(caption) 36 | 37 | return caption 38 | 39 | @classmethod 40 | def from_config(cls, cfg=None): 41 | if cfg is None: 42 | cfg = OmegaConf.create() 43 | 44 | prompt = cfg.get("prompt", "") 45 | max_words = cfg.get("max_words", 50) 46 | 47 | return cls(prompt=prompt, max_words=max_words) 48 | 49 | def pre_caption(self, caption): 50 | caption = re.sub( 51 | r"([.!\"()*#:;~])", 52 | " ", 53 | caption.lower(), 54 | ) 55 | caption = re.sub( 56 | r"\s{2,}", 57 | " ", 58 | caption, 59 | ) 60 | caption = caption.rstrip("\n") 61 | caption = caption.strip(" ") 62 | 63 | # truncate caption 64 | caption_words = caption.split(" ") 65 | if len(caption_words) > self.max_words: 66 | caption = " ".join(caption_words[: self.max_words]) 67 | 68 | return caption -------------------------------------------------------------------------------- /minigpt4/datasets/datasets/base_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import json 9 | from typing import Iterable 10 | 11 | from torch.utils.data import Dataset, ConcatDataset 12 | from torch.utils.data.dataloader import default_collate 13 | 14 | 15 | class BaseDataset(Dataset): 16 | def __init__( 17 | self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[] 18 | ): 19 | """ 20 | vis_root (string): Root directory of images (e.g. coco/images/) 21 | ann_root (string): directory to store the annotation file 22 | """ 23 | self.vis_root = vis_root 24 | 25 | self.annotation = [] 26 | for ann_path in ann_paths: 27 | self.annotation.extend(json.load(open(ann_path, "r"))['annotations']) 28 | 29 | self.vis_processor = vis_processor 30 | self.text_processor = text_processor 31 | 32 | self._add_instance_ids() 33 | 34 | def __len__(self): 35 | return len(self.annotation) 36 | 37 | def collater(self, samples): 38 | return default_collate(samples) 39 | 40 | def set_processors(self, vis_processor, text_processor): 41 | self.vis_processor = vis_processor 42 | self.text_processor = text_processor 43 | 44 | def _add_instance_ids(self, key="instance_id"): 45 | for idx, ann in enumerate(self.annotation): 46 | ann[key] = str(idx) 47 | 48 | 49 | class ConcatDataset(ConcatDataset): 50 | def __init__(self, datasets: Iterable[Dataset]) -> None: 51 | super().__init__(datasets) 52 | 53 | def collater(self, samples): 54 | # TODO For now only supports datasets with same underlying collater implementations 55 | 56 | all_keys = set() 57 | for s in samples: 58 | all_keys.update(s) 59 | 60 | shared_keys = all_keys 61 | for s in samples: 62 | shared_keys = shared_keys & set(s.keys()) 63 | 64 | samples_shared_keys = [] 65 | for s in samples: 66 | samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys}) 67 | 68 | return self.datasets[0].collater(samples_shared_keys) 69 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Text-like Encoding of Collaborative Information in Large Language Models for Recommendation 2 | 3 | 4 | **This repository is constructed based on [CoLLM](https://github.com/zyang1580/CoLLM)! Read CoLLM "readme.md" to understand the code structure!** 5 | 6 | ** Our trained models can be found at [here](https://rec.ustc.edu.cn/share/ddf0ccf0-5fb3-11ef-93eb-23d2eed3b4d2).** 7 | 8 | 9 | 10 | 11 | 12 | ## Step1: Following CoLLM to create environment and prepare Vicuna. 13 | 14 | ## step2: Pre-training for Text-like Encoding: 15 | ```bash 16 | CUDA_VISIBLE_DEVICES=6,7 WORLD_SIZE=2 nohup torchrun --nproc-per-node 2 --master_port=11139 train_collm_mf_din.py --cfg-path=train_configs/collm_pretrain_mf_ood.yaml > /log.out & 17 | ``` 18 | 19 | ## step3: LoRA Tuning 20 | 21 | 22 | step 1: training without collaborative info. 23 | ```bash 24 | CUDA_VISIBLE_DEVICES=0,1 WORLD_SIZE=2 nohup torchrun --nproc-per-node 2 --master_port=11139 train_collm_mf_din.py --cfg-path=train_configs/collm_pretrain_mf_ood.yaml > /log.out & 25 | ``` 26 | Note: Please download "train_collm_mf_din.py" and collm_pretrain_mf_ood.yaml form CoLLM repository 27 | 28 | step 2: training with collaborative info. 29 | ```bash 30 | CUDA_VISIBLE_DEVICES=0,1 WORLD_SIZE=2 nohup torchrun --nproc-per-node 2 --master_port=11139 train_binllm.py --cfg-path=train_configs/hash_CF_ml.yaml > /log.out & 31 | ``` 32 | 33 | ## 34 | If you're using CoLLM code in your research or applications, please cite our papers: 35 | ```bibtex 36 | @inproceedings{zhang-etal-2024-text, 37 | title = "Text-like Encoding of Collaborative Information in Large Language Models for Recommendation", 38 | author = "Zhang, Yang and Bao, Keqin and Yan, Ming and Wang, Wenjie and Feng, Fuli and He, Xiangnan", 39 | booktitle = "Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)", 40 | year = "2024", 41 | url = "https://aclanthology.org/2024.acl-long.497", 42 | pages = "9181--9191" 43 | } 44 | ``` 45 | 46 | ```bibtex 47 | @article{zhang2023collm, 48 | title={CoLLM: Integrating Collaborative Embeddings into Large Language Models for Recommendation}, 49 | author={Zhang, Yang and Feng, Fuli and Zhang, Jizhi and Bao, Keqin and Wang, Qifan and He, Xiangnan}, 50 | journal={arXiv preprint arXiv:2310.19488}, 51 | year={2023} 52 | } 53 | ``` 54 | You may also need to cite the [MiniGPT-4 paper](https://arxiv.org/abs/2304.10592). 55 | -------------------------------------------------------------------------------- /minigpt4/datasets/datasets/rec_base_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import json 9 | from typing import Iterable 10 | 11 | from torch.utils.data import Dataset, ConcatDataset 12 | from torch.utils.data.dataloader import default_collate 13 | import pandas as pd 14 | 15 | 16 | class RecBaseDataset(Dataset): 17 | def __init__( 18 | self, text_processor=None, ann_paths=None 19 | ): 20 | """ 21 | vis_root (string): Root directory of images (e.g. coco/images/) 22 | ann_root (string): directory to store the annotation file 23 | """ 24 | # self.vis_root = vis_root 25 | # self.annotation = pd.read_csv(ann_paths[0]+"",sep='\t', index_col=None,header=0).values 26 | if ann_paths is not None: 27 | self.annotation = pd.read_pickle(ann_paths[0]+".pkl").values 28 | # self.annotation = [] 29 | # for ann_path in ann_paths: 30 | # self.annotation.extend(json.load(open(ann_path, "r"))['annotations']) 31 | self.text_processor = text_processor 32 | 33 | # self._add_instance_ids() 34 | 35 | def __len__(self): 36 | return len(self.annotation) 37 | 38 | def collater(self, samples): 39 | return default_collate(samples) 40 | 41 | def set_processors(self, text_processor): 42 | # self.vis_processor = vis_processor 43 | self.text_processor = text_processor 44 | 45 | def _add_instance_ids(self, key="instance_id"): 46 | for idx, ann in enumerate(self.annotation): 47 | ann[key] = str(idx) 48 | 49 | 50 | class ConcatDataset(ConcatDataset): 51 | def __init__(self, datasets: Iterable[Dataset]) -> None: 52 | super().__init__(datasets) 53 | 54 | def collater(self, samples): 55 | # TODO For now only supports datasets with same underlying collater implementations 56 | 57 | all_keys = set() 58 | for s in samples: 59 | all_keys.update(s) 60 | 61 | shared_keys = all_keys 62 | for s in samples: 63 | shared_keys = shared_keys & set(s.keys()) 64 | 65 | samples_shared_keys = [] 66 | for s in samples: 67 | samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys}) 68 | 69 | return self.datasets[0].collater(samples_shared_keys) 70 | -------------------------------------------------------------------------------- /minigpt4/datasets/builders/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from minigpt4.datasets.builders.base_dataset_builder import load_dataset_config 9 | # from minigpt4.datasets.builders.image_text_pair_builder import ( 10 | # CCSBUBuilder, 11 | # LaionBuilder, 12 | # CCSBUAlignBuilder 13 | # ) 14 | from minigpt4.datasets.builders.rec_pair_builder import MoiveOODBuilder, MoiveOODBuilder_sasrec,AmazonOODBuilder, AmazonOODBuilder_sasrec 15 | 16 | from minigpt4.common.registry import registry 17 | 18 | __all__ = [ 19 | # "CCSBUBuilder", 20 | # "LaionBuilder", 21 | # "CCSBUAlignBuilder", 22 | # "MovielensBuilder", 23 | # "AmazonBuilder", 24 | 'MoiveOODBuilder', 25 | "MoiveOODBuilder_sasrec", 26 | "AmazonOODBuilder", 27 | "AmazonOODBuilder_sasrec" 28 | ] 29 | 30 | 31 | def load_dataset(name, cfg_path=None, vis_path=None, data_type=None): 32 | """ 33 | Example 34 | 35 | >>> dataset = load_dataset("coco_caption", cfg=None) 36 | >>> splits = dataset.keys() 37 | >>> print([len(dataset[split]) for split in splits]) 38 | 39 | """ 40 | if cfg_path is None: 41 | cfg = None 42 | else: 43 | cfg = load_dataset_config(cfg_path) 44 | 45 | try: 46 | builder = registry.get_builder_class(name)(cfg) 47 | except TypeError: 48 | print( 49 | f"Dataset {name} not found. Available datasets:\n" 50 | + ", ".join([str(k) for k in dataset_zoo.get_names()]) 51 | ) 52 | exit(1) 53 | 54 | if vis_path is not None: 55 | if data_type is None: 56 | # use default data type in the config 57 | data_type = builder.config.data_type 58 | 59 | assert ( 60 | data_type in builder.config.build_info 61 | ), f"Invalid data_type {data_type} for {name}." 62 | 63 | builder.config.build_info.get(data_type).storage = vis_path 64 | 65 | dataset = builder.build_datasets() 66 | return dataset 67 | 68 | 69 | class DatasetZoo: 70 | def __init__(self) -> None: 71 | self.dataset_zoo = { 72 | k: list(v.DATASET_CONFIG_DICT.keys()) 73 | for k, v in sorted(registry.mapping["builder_name_mapping"].items()) 74 | } 75 | 76 | def get_names(self): 77 | return list(self.dataset_zoo.keys()) 78 | 79 | 80 | dataset_zoo = DatasetZoo() 81 | -------------------------------------------------------------------------------- /view_token.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | # import os 4 | # os.environ['CURL_CA_BUNDLE'] = '' 5 | # os.environ["CUDA_VISIBLE_DEVICES"]="7" 6 | import random 7 | 8 | import numpy as np 9 | import pandas as pd 10 | import torch 11 | import torch.backends.cudnn as cudnn 12 | 13 | # import minigpt4.tasks as tasks 14 | # from minigpt4.common.config import Config 15 | # from minigpt4.common.dist_utils import get_rank, init_distributed_mode 16 | # from minigpt4.common.logger import setup_logger 17 | # from minigpt4.common.optims import ( 18 | # LinearWarmupCosineLRScheduler, 19 | # LinearWarmupStepLRScheduler, 20 | # ) 21 | # from minigpt4.common.registry import registry 22 | # from minigpt4.common.utils import now 23 | 24 | # # imports modules for registration 25 | # from minigpt4.datasets.builders import * 26 | # from minigpt4.models import * 27 | # from minigpt4.processors import * 28 | # from minigpt4.runners import * 29 | # from minigpt4.tasks import * 30 | # from torch.distributed.elastic.multiprocessing.errors import * 31 | 32 | 33 | import logging 34 | import random 35 | 36 | import torch 37 | from torch.cuda.amp import autocast as autocast 38 | import torch.nn as nn 39 | import os 40 | 41 | # from minigpt4.common.registry import registry 42 | # from minigpt4.models.rec_model import Rec2Base, disabled_train 43 | from minigpt4.models.modeling_llama import LlamaForCausalLM 44 | from transformers import LlamaTokenizer, GenerationConfig 45 | import re 46 | import numpy as np 47 | # from peft import LoraConfig, get_peft_model, get_peft_model_state_dict, prepare_model_for_int8_training, set_peft_model_state_dict 48 | 49 | 50 | llama_model = "/data/LLM/PretrainedModels/vicuna/working-v0/" 51 | 52 | llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False) 53 | llama_tokenizer.pad_token = llama_tokenizer.eos_token 54 | # llama_model = LlamaForCausalLM.from_pretrained( 55 | # llama_model, 56 | # torch_dtype=torch.float16, 57 | # load_in_8bit=True, 58 | # device_map={'': int(os.environ.get("LOCAL_RANK") or 0)} 59 | # ) 60 | 61 | # m = np.random.randn(32).astype(int) 62 | # m[m>0] = 1 63 | # m[m<=0] = 0 64 | # m = list(m) 65 | # m = [str(x) for x in m] 66 | # m = ''.join(m) 67 | m = '192.168.122.234' 68 | prompt_list = [m] 69 | llama_tokenizer.padding_side = "left" 70 | prompts_tokens = llama_tokenizer( 71 | prompt_list, 72 | return_tensors="pt", 73 | padding="longest", 74 | truncation=True, 75 | max_length=1024, 76 | add_special_tokens=False 77 | ) 78 | 79 | unk_token_id = llama_tokenizer.unk_token_id 80 | # prompt_embeds[replaced_idx[:,0],replaced_idx[:,1]] = samples['merged_embs'] 81 | -------------------------------------------------------------------------------- /train_configs/hash_CF_ml.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: mini_gpt4rec_v2 # by default 3 | model_type: pretrain_vicuna 4 | freeze_rec: True # 5 | freeze_proj: True # 6 | freeze_lora: False # 7 | max_txt_len: 1024 # by default 8 | proj_token_num: 1 # default:1, the number of text token embeddings that the A single ID embedding is converted into 9 | proj_drop: 0 # by default 10 | proj_mid_times: 10 # proj_mid_times * rec embedding size = the middle layer size of the mapping module 11 | end_sym: "###" 12 | prompt_path: "prompts/hash_text.txt" 13 | prompt_template: '{}' 14 | llama_model: "/dataLLM/PretrainedModels/vicuna/working-v0/" #vicuna path 15 | user_num: -100 16 | item_num: -100 17 | ans_type: 'v2' # by default 18 | rec_model: "hash" #[MF, lightgcn,.....], see "Rec2Base" class in minigpt4/models/rec_model.py 19 | lora_config: 20 | use_lora: True 21 | r: 8 22 | alpha: 16 23 | target_modules: ["q_proj", "v_proj"] # default: ["q_proj", "v_proj"]; others? ['lm_head'], ["q_proj", "v_proj",'k_proj','o_proj'] 24 | dropout: 0.05 25 | rec_config: # recommender model config 26 | user_num: -100 27 | item_num: -100 28 | embedding_size: 32 #embedding size 29 | code_mode: 'binary' #'ipv4' #'binary' 30 | use_hash: False 31 | pretrained_path: /data1cf4reclog-cc/hash/0130hashGNN-ml-32lr-0.01wd0.001.pth # pretrained rec model 32 | 33 | 34 | ckpt: /data1cf4reclog-cc/20240131205/checkpoint_best.pth # tallrec 35 | 36 | 37 | 38 | 39 | datasets: 40 | amazon_ood: 41 | path: "/data/datasets/ml-1m/" #data path 42 | data_type: default 43 | build_info: 44 | storage: "/data/datasets/ml-1m/" # data path 45 | 46 | run: 47 | task: rec_pretrain 48 | lr_sched: "linear_warmup_cosine_lr" 49 | init_lr: 1e-4 50 | min_lr: 8e-5 51 | warmup_lr: 1e-5 52 | mode: 'v3' # hash-text encoding 53 | 54 | 55 | weight_decay: 1e-3 # by default 56 | max_epoch: 200 57 | iters_per_epoch: 50 #100 58 | batch_size_train: 4 #16 # 8 59 | batch_size_eval: 64 # 32 60 | num_workers: 4 61 | warmup_steps: 200 62 | 63 | seed: 42 64 | output_dir: /data/cf4reclog/ #log and model saving path 65 | 66 | amp: True 67 | resume_ckpt_path: null #/home/sistLLM/cf4recLog/20240107182/checkpoint_best.pth #null 68 | 69 | evaluate: False #True # False: training, True: only evaluation 70 | train_splits: ["train"] 71 | valid_splits: ["valid"] # validation set 72 | test_splits: ["test","valid"] # used when evluate=True, reporting both the testing and validation results 73 | 74 | device: "cuda" 75 | world_size: 1 76 | dist_url: "env://" 77 | distributed: True 78 | -------------------------------------------------------------------------------- /used_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import multiprocessing 3 | 4 | def precision_at_k(r, k): 5 | """Score is precision @ k 6 | Relevance is binary (nonzero is relevant). 7 | Returns: 8 | Precision @ k 9 | Raises: 10 | ValueError: len(r) must be >= k 11 | """ 12 | assert k >= 1 13 | try: 14 | r = r[:k] 15 | except: 16 | print(r) 17 | raise ImportError('error r') 18 | return np.mean(r) 19 | 20 | 21 | def dcg_at_k(r, k, method=1): 22 | """Score is discounted cumulative gain (dcg) 23 | Relevance is positive real values. Can use binary 24 | as the previous methods. 25 | Returns: 26 | Discounted cumulative gain 27 | """ 28 | r = np.asfarray(r)[:k] 29 | if r.size: 30 | if method == 0: 31 | return r[0] + np.sum(r[1:] / np.log2(np.arange(2, r.size + 1))) 32 | elif method == 1: 33 | return np.sum(r / np.log2(np.arange(2, r.size + 2))) 34 | else: 35 | raise ValueError('method must be 0 or 1.') 36 | return 0. 37 | 38 | 39 | def ndcg_at_k(r, k, maxlen, method=1): 40 | """Score is normalized discounted cumulative gain (ndcg) 41 | Relevance is positive real values. Can use binary 42 | as the previous methods. 43 | Returns: 44 | Normalized discounted cumulative gain 45 | """ 46 | tp = 1. / np.log2(np.arange(2, k + 2)) 47 | dcg_max = (tp[:min(maxlen, k)]).sum() 48 | if not dcg_max: 49 | return 0. 50 | r_k = r[:k] 51 | dcg_at_k_ = (r_k * tp).sum() 52 | return dcg_at_k_ / dcg_max 53 | 54 | 55 | def recall_at_k(r, k, all_pos_num): 56 | r = r[:k] 57 | return np.sum(r) / all_pos_num 58 | 59 | 60 | def hit_at_k(r, k): 61 | r = r[:k] 62 | return min(1.,np.sum(r)) 63 | 64 | 65 | def get_r(user_pos_test, r): 66 | r_new = np.isin(r,user_pos_test).astype(float) 67 | return r_new 68 | 69 | def get_performance(user_pos_test, r, Ks): 70 | precision, recall, ndcg, hit_ratio = [], [], [], [] 71 | r = get_r(user_pos_test,r) 72 | for K in Ks: 73 | precision.append(precision_at_k(r, K))#P = TP/ (TP+FP) 74 | recall.append(recall_at_k(r, K, len(user_pos_test)))#R = TP/ (TP+FN) 75 | ndcg.append(ndcg_at_k(r, K, len(user_pos_test))) 76 | hit_ratio.append(hit_at_k(r, K))#HR = SIGMA(TP) / SIGMA(test_set) 77 | # print(hit_ratio) 78 | 79 | return {'recall': np.array(recall), 'precision': np.array(precision), 80 | 'ndcg': np.array(ndcg), 'hit_ratio': np.array(hit_ratio)} 81 | 82 | # def test_one_user(u): 83 | # # user u's ratings for user u 84 | # try: 85 | # user_pos_test = test_user_list[u] 86 | # except: 87 | # user_pos_test = [] 88 | # r = rec_user_list[u] 89 | # #print(len(r)) 90 | # return get_performance(user_pos_test, r, Ks) 91 | 92 | # cores=15 93 | # def test(): 94 | # pool = multiprocessing.Pool(cores) 95 | # test_user = list(rec_user_list.keys()) 96 | # print('test_user number:',len(test_user)) 97 | # bat_result = pool.map(test_one_user, test_user) 98 | # return bat_result -------------------------------------------------------------------------------- /minigpt4/common/optims.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import math 9 | 10 | from minigpt4.common.registry import registry 11 | 12 | 13 | @registry.register_lr_scheduler("linear_warmup_step_lr") 14 | class LinearWarmupStepLRScheduler: 15 | def __init__( 16 | self, 17 | optimizer, 18 | max_epoch, 19 | min_lr, 20 | init_lr, 21 | decay_rate=1, 22 | warmup_start_lr=-1, 23 | warmup_steps=0, 24 | iters_per_epoch=None, 25 | **kwargs 26 | ): 27 | self.optimizer = optimizer 28 | 29 | self.max_epoch = max_epoch 30 | self.min_lr = min_lr 31 | self.iters_per_epoch = iters_per_epoch 32 | 33 | self.decay_rate = decay_rate 34 | 35 | self.init_lr = init_lr 36 | self.warmup_steps = warmup_steps 37 | self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr 38 | 39 | def step(self, cur_epoch, cur_step): 40 | if cur_epoch == 0: 41 | warmup_lr_schedule( 42 | step=cur_step, 43 | optimizer=self.optimizer, 44 | max_step=self.warmup_steps, 45 | init_lr=self.warmup_start_lr, 46 | max_lr=self.init_lr, 47 | ) 48 | else: 49 | step_lr_schedule( 50 | epoch=cur_epoch, 51 | optimizer=self.optimizer, 52 | init_lr=self.init_lr, 53 | min_lr=self.min_lr, 54 | decay_rate=self.decay_rate, 55 | ) 56 | 57 | 58 | @registry.register_lr_scheduler("linear_warmup_cosine_lr") 59 | class LinearWarmupCosineLRScheduler: 60 | def __init__( 61 | self, 62 | optimizer, 63 | max_epoch, 64 | iters_per_epoch, 65 | min_lr, 66 | init_lr, 67 | warmup_steps=0, 68 | warmup_start_lr=-1, 69 | **kwargs 70 | ): 71 | self.optimizer = optimizer 72 | 73 | self.max_epoch = max_epoch 74 | self.iters_per_epoch = iters_per_epoch 75 | self.min_lr = min_lr 76 | 77 | self.init_lr = init_lr 78 | self.warmup_steps = warmup_steps 79 | self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr 80 | 81 | def step(self, cur_epoch, cur_step): 82 | total_cur_step = cur_epoch * self.iters_per_epoch + cur_step 83 | if total_cur_step < self.warmup_steps: 84 | warmup_lr_schedule( 85 | step=cur_step, 86 | optimizer=self.optimizer, 87 | max_step=self.warmup_steps, 88 | init_lr=self.warmup_start_lr, 89 | max_lr=self.init_lr, 90 | ) 91 | else: 92 | cosine_lr_schedule( 93 | epoch=total_cur_step, 94 | optimizer=self.optimizer, 95 | max_epoch=self.max_epoch * self.iters_per_epoch, 96 | init_lr=self.init_lr, 97 | min_lr=self.min_lr, 98 | ) 99 | 100 | 101 | def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr): 102 | """Decay the learning rate""" 103 | lr = (init_lr - min_lr) * 0.5 * ( 104 | 1.0 + math.cos(math.pi * epoch / max_epoch) 105 | ) + min_lr 106 | for param_group in optimizer.param_groups: 107 | param_group["lr"] = lr 108 | 109 | 110 | def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr): 111 | """Warmup the learning rate""" 112 | lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1)) 113 | for param_group in optimizer.param_groups: 114 | param_group["lr"] = lr 115 | 116 | 117 | def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate): 118 | """Decay the learning rate""" 119 | lr = max(min_lr, init_lr * (decay_rate**epoch)) 120 | for param_group in optimizer.param_groups: 121 | param_group["lr"] = lr 122 | -------------------------------------------------------------------------------- /minigpt4/common/dist_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import datetime 9 | import functools 10 | import os 11 | 12 | import torch 13 | import torch.distributed as dist 14 | import timm.models.hub as timm_hub 15 | 16 | 17 | def setup_for_distributed(is_master): 18 | """ 19 | This function disables printing when not in master process 20 | """ 21 | import builtins as __builtin__ 22 | 23 | builtin_print = __builtin__.print 24 | 25 | def print(*args, **kwargs): 26 | force = kwargs.pop("force", False) 27 | if is_master or force: 28 | builtin_print(*args, **kwargs) 29 | 30 | __builtin__.print = print 31 | 32 | 33 | def is_dist_avail_and_initialized(): 34 | if not dist.is_available(): 35 | return False 36 | if not dist.is_initialized(): 37 | return False 38 | return True 39 | 40 | 41 | def get_world_size(): 42 | if not is_dist_avail_and_initialized(): 43 | return 1 44 | return dist.get_world_size() 45 | 46 | 47 | def get_rank(): 48 | if not is_dist_avail_and_initialized(): 49 | return 0 50 | return dist.get_rank() 51 | 52 | 53 | def is_main_process(): 54 | return get_rank() == 0 55 | 56 | 57 | def init_distributed_mode(args): 58 | if "RANK" in os.environ and "WORLD_SIZE" in os.environ: 59 | args.rank = int(os.environ["RANK"]) 60 | args.world_size = int(os.environ["WORLD_SIZE"]) 61 | args.gpu = int(os.environ["LOCAL_RANK"]) 62 | elif "SLURM_PROCID" in os.environ: 63 | args.rank = int(os.environ["SLURM_PROCID"]) 64 | args.gpu = args.rank % torch.cuda.device_count() 65 | else: 66 | print("Not using distributed mode") 67 | args.distributed = False 68 | return 69 | 70 | args.distributed = True 71 | 72 | torch.cuda.set_device(args.gpu) 73 | args.dist_backend = "nccl" 74 | print( 75 | "| distributed init (rank {}, world {}): {}".format( 76 | args.rank, args.world_size, args.dist_url 77 | ), 78 | flush=True, 79 | ) 80 | torch.distributed.init_process_group( 81 | backend=args.dist_backend, 82 | init_method=args.dist_url, 83 | world_size=args.world_size, 84 | rank=args.rank, 85 | timeout=datetime.timedelta( 86 | days=365 87 | ), # allow auto-downloading and de-compressing 88 | ) 89 | torch.distributed.barrier() 90 | setup_for_distributed(args.rank == 0) 91 | 92 | 93 | def get_dist_info(): 94 | if torch.__version__ < "1.0": 95 | initialized = dist._initialized 96 | else: 97 | initialized = dist.is_initialized() 98 | if initialized: 99 | rank = dist.get_rank() 100 | world_size = dist.get_world_size() 101 | else: # non-distributed training 102 | rank = 0 103 | world_size = 1 104 | return rank, world_size 105 | 106 | 107 | def main_process(func): 108 | @functools.wraps(func) 109 | def wrapper(*args, **kwargs): 110 | rank, _ = get_dist_info() 111 | if rank == 0: 112 | return func(*args, **kwargs) 113 | 114 | return wrapper 115 | 116 | 117 | def download_cached_file(url, check_hash=True, progress=False): 118 | """ 119 | Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again. 120 | If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded. 121 | """ 122 | 123 | def get_cached_file_path(): 124 | # a hack to sync the file path across processes 125 | parts = torch.hub.urlparse(url) 126 | filename = os.path.basename(parts.path) 127 | cached_file = os.path.join(timm_hub.get_cache_dir(), filename) 128 | 129 | return cached_file 130 | 131 | if is_main_process(): 132 | timm_hub.download_cached_file(url, check_hash, progress) 133 | 134 | if is_dist_avail_and_initialized(): 135 | dist.barrier() 136 | 137 | return get_cached_file_path() 138 | -------------------------------------------------------------------------------- /train_binllm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | # import os 4 | # os.environ['CURL_CA_BUNDLE'] = '' 5 | # os.environ["CUDA_VISIBLE_DEVICES"]="4" 6 | import random 7 | 8 | import numpy as np 9 | import pandas as pd 10 | import torch 11 | import torch.backends.cudnn as cudnn 12 | 13 | import minigpt4.tasks as tasks 14 | from minigpt4.common.config import Config 15 | from minigpt4.common.dist_utils import get_rank, init_distributed_mode 16 | from minigpt4.common.logger import setup_logger 17 | from minigpt4.common.optims import ( 18 | LinearWarmupCosineLRScheduler, 19 | LinearWarmupStepLRScheduler, 20 | ) 21 | from minigpt4.common.registry import registry 22 | from minigpt4.common.utils import now 23 | 24 | # imports modules for registration 25 | from minigpt4.datasets.builders import * 26 | from minigpt4.models import * 27 | from minigpt4.processors import * 28 | from minigpt4.runners import * 29 | from minigpt4.tasks import * 30 | from torch.distributed.elastic.multiprocessing.errors import * 31 | 32 | 33 | 34 | 35 | def parse_args(): 36 | parser = argparse.ArgumentParser(description="Training") 37 | 38 | # parser.add_argument("--cfg-path", required=True, help="path to configuration file.") 39 | parser.add_argument("--cfg-path", default='train_configs/minigpt4rec_pretrain_ood_cc.yaml', help="path to configuration file.") 40 | parser.add_argument( 41 | "--options", 42 | nargs="+", 43 | help="override some settings in the used config, the key-value pair " 44 | "in xxx=yyy format will be merged into config file (deprecate), " 45 | "change to --cfg-options instead.", 46 | ) 47 | 48 | args = parser.parse_args() 49 | # if 'LOCAL_RANK' not in os.environ: 50 | # os.environ['LOCAL_RANK'] = str(args.local_rank) 51 | 52 | return args 53 | 54 | 55 | def setup_seeds(config): 56 | seed = config.run_cfg.seed + get_rank() 57 | 58 | random.seed(seed) 59 | np.random.seed(seed) 60 | torch.manual_seed(seed) 61 | 62 | cudnn.benchmark = False 63 | cudnn.deterministic = True 64 | 65 | 66 | def get_runner_class(cfg): 67 | """ 68 | Get runner class from config. Default to epoch-based runner. 69 | """ 70 | runner_cls = registry.get_runner_class(cfg.run_cfg.get("runner", "rec_runner_base")) 71 | 72 | return runner_cls 73 | 74 | @record 75 | def main(): 76 | # allow auto-dl completes on main process without timeout when using NCCL backend. 77 | # os.environ["NCCL_BLOCKING_WAIT"] = "1" 78 | 79 | # set before init_distributed_mode() to ensure the same job_id shared across all ranks. 80 | job_id = now() 81 | 82 | cfg = Config(parse_args()) 83 | 84 | init_distributed_mode(cfg.run_cfg) 85 | 86 | setup_seeds(cfg) 87 | 88 | # set after init_distributed_mode() to only log on master. 89 | setup_logger() 90 | 91 | # cfg.pretty_print() 92 | 93 | task = tasks.setup_task(cfg) 94 | datasets = task.build_datasets(cfg) 95 | # cfg.model_cfg.get("user_num", "default") 96 | data_name = list(datasets.keys())[0] 97 | # data_dir = "/home/LLM/datasets/ml-1m/" 98 | try: # movie 99 | data_dir = cfg.datasets_cfg.movie_ood.path 100 | except: # amazon 101 | data_dir = cfg.datasets_cfg.amazon_ood.path 102 | print("data dir:", data_dir) 103 | # data_dir = "/data/datasets/ml-1m/" 104 | train_ = pd.read_pickle(data_dir+"train_ood2.pkl") 105 | valid_ = pd.read_pickle(data_dir+"valid_ood2.pkl") 106 | test_ = pd.read_pickle(data_dir+"test_ood2.pkl") 107 | user_num = max(train_.uid.max(),valid_.uid.max(),test_.uid.max())+1 108 | item_num = max(train_.iid.max(),valid_.iid.max(),test_.iid.max())+1 109 | 110 | cfg.model_cfg.rec_config.user_num = int(user_num) #int(datasets[data_name]['train'].user_num) #cfg.model_cfg.get("user_num",) 111 | cfg.model_cfg.rec_config.item_num = int(item_num) #int(datasets[data_name]['train'].item_num) #cfg.model_cfg.get("item_num", datasets[data_name]['train'].item_num) 112 | cfg.pretty_print() 113 | 114 | model = task.build_model(cfg) 115 | runner = get_runner_class(cfg)( 116 | cfg=cfg, job_id=job_id, task=task, model=model, datasets=datasets 117 | ) 118 | runner.train() 119 | 120 | 121 | if __name__ == "__main__": 122 | main() 123 | -------------------------------------------------------------------------------- /minigpt4/processors/blip_processors.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import re 9 | 10 | from minigpt4.common.registry import registry 11 | from minigpt4.processors.base_processor import BaseProcessor 12 | from minigpt4.processors.randaugment import RandomAugment 13 | from omegaconf import OmegaConf 14 | from torchvision import transforms 15 | from torchvision.transforms.functional import InterpolationMode 16 | 17 | 18 | class BlipImageBaseProcessor(BaseProcessor): 19 | def __init__(self, mean=None, std=None): 20 | if mean is None: 21 | mean = (0.48145466, 0.4578275, 0.40821073) 22 | if std is None: 23 | std = (0.26862954, 0.26130258, 0.27577711) 24 | 25 | self.normalize = transforms.Normalize(mean, std) 26 | 27 | 28 | @registry.register_processor("blip_caption") 29 | class BlipCaptionProcessor(BaseProcessor): 30 | def __init__(self, prompt="", max_words=50): 31 | self.prompt = prompt 32 | self.max_words = max_words 33 | 34 | def __call__(self, caption): 35 | caption = self.prompt + self.pre_caption(caption) 36 | 37 | return caption 38 | 39 | @classmethod 40 | def from_config(cls, cfg=None): 41 | if cfg is None: 42 | cfg = OmegaConf.create() 43 | 44 | prompt = cfg.get("prompt", "") 45 | max_words = cfg.get("max_words", 50) 46 | 47 | return cls(prompt=prompt, max_words=max_words) 48 | 49 | def pre_caption(self, caption): 50 | caption = re.sub( 51 | r"([.!\"()*#:;~])", 52 | " ", 53 | caption.lower(), 54 | ) 55 | caption = re.sub( 56 | r"\s{2,}", 57 | " ", 58 | caption, 59 | ) 60 | caption = caption.rstrip("\n") 61 | caption = caption.strip(" ") 62 | 63 | # truncate caption 64 | caption_words = caption.split(" ") 65 | if len(caption_words) > self.max_words: 66 | caption = " ".join(caption_words[: self.max_words]) 67 | 68 | return caption 69 | 70 | 71 | @registry.register_processor("blip2_image_train") 72 | class Blip2ImageTrainProcessor(BlipImageBaseProcessor): 73 | def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0): 74 | super().__init__(mean=mean, std=std) 75 | 76 | self.transform = transforms.Compose( 77 | [ 78 | transforms.RandomResizedCrop( 79 | image_size, 80 | scale=(min_scale, max_scale), 81 | interpolation=InterpolationMode.BICUBIC, 82 | ), 83 | transforms.ToTensor(), 84 | self.normalize, 85 | ] 86 | ) 87 | 88 | def __call__(self, item): 89 | return self.transform(item) 90 | 91 | @classmethod 92 | def from_config(cls, cfg=None): 93 | if cfg is None: 94 | cfg = OmegaConf.create() 95 | 96 | image_size = cfg.get("image_size", 224) 97 | 98 | mean = cfg.get("mean", None) 99 | std = cfg.get("std", None) 100 | 101 | min_scale = cfg.get("min_scale", 0.5) 102 | max_scale = cfg.get("max_scale", 1.0) 103 | 104 | return cls( 105 | image_size=image_size, 106 | mean=mean, 107 | std=std, 108 | min_scale=min_scale, 109 | max_scale=max_scale, 110 | ) 111 | 112 | 113 | @registry.register_processor("blip2_image_eval") 114 | class Blip2ImageEvalProcessor(BlipImageBaseProcessor): 115 | def __init__(self, image_size=224, mean=None, std=None): 116 | super().__init__(mean=mean, std=std) 117 | 118 | self.transform = transforms.Compose( 119 | [ 120 | transforms.Resize( 121 | (image_size, image_size), interpolation=InterpolationMode.BICUBIC 122 | ), 123 | transforms.ToTensor(), 124 | self.normalize, 125 | ] 126 | ) 127 | 128 | def __call__(self, item): 129 | return self.transform(item) 130 | 131 | @classmethod 132 | def from_config(cls, cfg=None): 133 | if cfg is None: 134 | cfg = OmegaConf.create() 135 | 136 | image_size = cfg.get("image_size", 224) 137 | 138 | mean = cfg.get("mean", None) 139 | std = cfg.get("std", None) 140 | 141 | return cls(image_size=image_size, mean=mean, std=std) -------------------------------------------------------------------------------- /minigpt4/models/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import logging 9 | import torch 10 | from omegaconf import OmegaConf 11 | 12 | from minigpt4.common.registry import registry 13 | from minigpt4.models.base_model import BaseModel 14 | # from minigpt4.models.blip2 import Blip2Base 15 | # from minigpt4.models.mini_gpt4 import MiniGPT4 16 | from minigpt4.processors.base_processor import BaseProcessor 17 | # from minigpt4.models.minigpt4rec import MiniGPT4Rec 18 | from minigpt4.models.minigpt4rec_v2 import MiniGPT4Rec_v2 19 | from minigpt4.models.rec_model import Rec2Base 20 | # from minigpt4.models.minigpt4rec_lora import MiniGPT4Rec_Lora 21 | 22 | 23 | __all__ = [ 24 | "load_model", 25 | "BaseModel", 26 | # "Blip2Base", 27 | # "MiniGPT4", 28 | "Rec2Base", 29 | # "MiniGPT4Rec", 30 | "MiniGPT4Rec_v2", 31 | # "MiniGPT4Rec_Lora" 32 | ] 33 | 34 | 35 | def load_model(name, model_type, is_eval=False, device="cpu", checkpoint=None): 36 | """ 37 | Load supported models. 38 | 39 | To list all available models and types in registry: 40 | >>> from minigpt4.models import model_zoo 41 | >>> print(model_zoo) 42 | 43 | Args: 44 | name (str): name of the model. 45 | model_type (str): type of the model. 46 | is_eval (bool): whether the model is in eval mode. Default: False. 47 | device (str): device to use. Default: "cpu". 48 | checkpoint (str): path or to checkpoint. Default: None. 49 | Note that expecting the checkpoint to have the same keys in state_dict as the model. 50 | 51 | Returns: 52 | model (torch.nn.Module): model. 53 | """ 54 | 55 | model = registry.get_model_class(name).from_pretrained(model_type=model_type) 56 | 57 | if checkpoint is not None: 58 | model.load_checkpoint(checkpoint) 59 | 60 | if is_eval: 61 | model.eval() 62 | 63 | if device == "cpu": 64 | model = model.float() 65 | 66 | return model.to(device) 67 | 68 | 69 | def load_preprocess(config): 70 | """ 71 | Load preprocessor configs and construct preprocessors. 72 | 73 | If no preprocessor is specified, return BaseProcessor, which does not do any preprocessing. 74 | 75 | Args: 76 | config (dict): preprocessor configs. 77 | 78 | Returns: 79 | vis_processors (dict): preprocessors for visual inputs. 80 | txt_processors (dict): preprocessors for text inputs. 81 | 82 | Key is "train" or "eval" for processors used in training and evaluation respectively. 83 | """ 84 | 85 | def _build_proc_from_cfg(cfg): 86 | return ( 87 | registry.get_processor_class(cfg.name).from_config(cfg) 88 | if cfg is not None 89 | else BaseProcessor() 90 | ) 91 | 92 | vis_processors = dict() 93 | txt_processors = dict() 94 | 95 | vis_proc_cfg = config.get("vis_processor") 96 | txt_proc_cfg = config.get("text_processor") 97 | 98 | if vis_proc_cfg is not None: 99 | vis_train_cfg = vis_proc_cfg.get("train") 100 | vis_eval_cfg = vis_proc_cfg.get("eval") 101 | else: 102 | vis_train_cfg = None 103 | vis_eval_cfg = None 104 | 105 | vis_processors["train"] = _build_proc_from_cfg(vis_train_cfg) 106 | vis_processors["eval"] = _build_proc_from_cfg(vis_eval_cfg) 107 | 108 | if txt_proc_cfg is not None: 109 | txt_train_cfg = txt_proc_cfg.get("train") 110 | txt_eval_cfg = txt_proc_cfg.get("eval") 111 | else: 112 | txt_train_cfg = None 113 | txt_eval_cfg = None 114 | 115 | txt_processors["train"] = _build_proc_from_cfg(txt_train_cfg) 116 | txt_processors["eval"] = _build_proc_from_cfg(txt_eval_cfg) 117 | 118 | return vis_processors, txt_processors 119 | 120 | 121 | def load_model_and_preprocess(name, model_type, is_eval=False, device="cpu"): 122 | """ 123 | Load model and its related preprocessors. 124 | 125 | List all available models and types in registry: 126 | >>> from minigpt4.models import model_zoo 127 | >>> print(model_zoo) 128 | 129 | Args: 130 | name (str): name of the model. 131 | model_type (str): type of the model. 132 | is_eval (bool): whether the model is in eval mode. Default: False. 133 | device (str): device to use. Default: "cpu". 134 | 135 | Returns: 136 | model (torch.nn.Module): model. 137 | vis_processors (dict): preprocessors for visual inputs. 138 | txt_processors (dict): preprocessors for text inputs. 139 | """ 140 | model_cls = registry.get_model_class(name) 141 | 142 | # load model 143 | model = model_cls.from_pretrained(model_type=model_type) 144 | 145 | if is_eval: 146 | model.eval() 147 | 148 | # load preprocess 149 | cfg = OmegaConf.load(model_cls.default_config_path(model_type)) 150 | if cfg is not None: 151 | preprocess_cfg = cfg.preprocess 152 | 153 | vis_processors, txt_processors = load_preprocess(preprocess_cfg) 154 | else: 155 | vis_processors, txt_processors = None, None 156 | logging.info( 157 | f"""No default preprocess for model {name} ({model_type}). 158 | This can happen if the model is not finetuned on downstream datasets, 159 | or it is not intended for direct use without finetuning. 160 | """ 161 | ) 162 | 163 | if device == "cpu" or device == torch.device("cpu"): 164 | model = model.float() 165 | 166 | return model.to(device), vis_processors, txt_processors 167 | 168 | 169 | class ModelZoo: 170 | """ 171 | A utility class to create string representation of available model architectures and types. 172 | 173 | >>> from minigpt4.models import model_zoo 174 | >>> # list all available models 175 | >>> print(model_zoo) 176 | >>> # show total number of models 177 | >>> print(len(model_zoo)) 178 | """ 179 | 180 | def __init__(self) -> None: 181 | self.model_zoo = { 182 | k: list(v.PRETRAINED_MODEL_CONFIG_DICT.keys()) 183 | for k, v in registry.mapping["model_name_mapping"].items() 184 | } 185 | 186 | def __str__(self) -> str: 187 | return ( 188 | "=" * 50 189 | + "\n" 190 | + f"{'Architectures':<30} {'Types'}\n" 191 | + "=" * 50 192 | + "\n" 193 | + "\n".join( 194 | [ 195 | f"{name:<30} {', '.join(types)}" 196 | for name, types in self.model_zoo.items() 197 | ] 198 | ) 199 | ) 200 | 201 | def __iter__(self): 202 | return iter(self.model_zoo.items()) 203 | 204 | def __len__(self): 205 | return sum([len(v) for v in self.model_zoo.values()]) 206 | 207 | 208 | model_zoo = ModelZoo() 209 | -------------------------------------------------------------------------------- /minigpt4/datasets/data_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import gzip 9 | import logging 10 | import os 11 | import random as rnd 12 | import tarfile 13 | import zipfile 14 | import random 15 | from typing import List 16 | from tqdm import tqdm 17 | 18 | import decord 19 | from decord import VideoReader 20 | import webdataset as wds 21 | import numpy as np 22 | import torch 23 | from torch.utils.data.dataset import IterableDataset 24 | 25 | from minigpt4.common.registry import registry 26 | from minigpt4.datasets.datasets.base_dataset import ConcatDataset 27 | 28 | 29 | decord.bridge.set_bridge("torch") 30 | MAX_INT = registry.get("MAX_INT") 31 | 32 | 33 | class ChainDataset(wds.DataPipeline): 34 | r"""Dataset for chaining multiple :class:`DataPipeline` s. 35 | 36 | This class is useful to assemble different existing dataset streams. The 37 | chaining operation is done on-the-fly, so concatenating large-scale 38 | datasets with this class will be efficient. 39 | 40 | Args: 41 | datasets (iterable of IterableDataset): datasets to be chained together 42 | """ 43 | def __init__(self, datasets: List[wds.DataPipeline]) -> None: 44 | super().__init__() 45 | self.datasets = datasets 46 | self.prob = [] 47 | self.names = [] 48 | for dataset in self.datasets: 49 | if hasattr(dataset, 'name'): 50 | self.names.append(dataset.name) 51 | else: 52 | self.names.append('Unknown') 53 | if hasattr(dataset, 'sample_ratio'): 54 | self.prob.append(dataset.sample_ratio) 55 | else: 56 | self.prob.append(1) 57 | logging.info("One of the datapipeline doesn't define ratio and set to 1 automatically.") 58 | 59 | def __iter__(self): 60 | datastreams = [iter(dataset) for dataset in self.datasets] 61 | while True: 62 | select_datastream = random.choices(datastreams, weights=self.prob, k=1)[0] 63 | yield next(select_datastream) 64 | 65 | 66 | def apply_to_sample(f, sample): 67 | if len(sample) == 0: 68 | return {} 69 | 70 | def _apply(x): 71 | if torch.is_tensor(x): 72 | return f(x) 73 | elif isinstance(x, dict): 74 | return {key: _apply(value) for key, value in x.items()} 75 | elif isinstance(x, list): 76 | return [_apply(x) for x in x] 77 | else: 78 | return x 79 | 80 | return _apply(sample) 81 | 82 | 83 | def move_to_cuda(sample): 84 | def _move_to_cuda(tensor): 85 | return tensor.cuda() 86 | 87 | return apply_to_sample(_move_to_cuda, sample) 88 | 89 | 90 | def prepare_sample(samples, cuda_enabled=True): 91 | if cuda_enabled: 92 | samples = move_to_cuda(samples) 93 | 94 | # TODO fp16 support 95 | 96 | return samples 97 | 98 | 99 | def reorg_datasets_by_split(datasets): 100 | """ 101 | Organizes datasets by split. 102 | 103 | Args: 104 | datasets: dict of torch.utils.data.Dataset objects by name. 105 | 106 | Returns: 107 | Dict of datasets by split {split_name: List[Datasets]}. 108 | """ 109 | # if len(datasets) == 1: 110 | # return datasets[list(datasets.keys())[0]] 111 | # else: 112 | reorg_datasets = dict() 113 | 114 | # reorganize by split 115 | for _, dataset in datasets.items(): 116 | for split_name, dataset_split in dataset.items(): 117 | if split_name not in reorg_datasets: 118 | reorg_datasets[split_name] = [dataset_split] 119 | else: 120 | reorg_datasets[split_name].append(dataset_split) 121 | 122 | return reorg_datasets 123 | 124 | 125 | def concat_datasets(datasets): 126 | """ 127 | Concatenates multiple datasets into a single dataset. 128 | 129 | It supports may-style datasets and DataPipeline from WebDataset. Currently, does not support 130 | generic IterableDataset because it requires creating separate samplers. 131 | 132 | Now only supports conctenating training datasets and assuming validation and testing 133 | have only a single dataset. This is because metrics should not be computed on the concatenated 134 | datasets. 135 | 136 | Args: 137 | datasets: dict of torch.utils.data.Dataset objects by split. 138 | 139 | Returns: 140 | Dict of concatenated datasets by split, "train" is the concatenation of multiple datasets, 141 | "val" and "test" remain the same. 142 | 143 | If the input training datasets contain both map-style and DataPipeline datasets, returns 144 | a tuple, where the first element is a concatenated map-style dataset and the second 145 | element is a chained DataPipeline dataset. 146 | 147 | """ 148 | # concatenate datasets in the same split 149 | for split_name in datasets: 150 | if split_name != "train": 151 | assert ( 152 | len(datasets[split_name]) == 1 153 | ), "Do not support multiple {} datasets.".format(split_name) 154 | datasets[split_name] = datasets[split_name][0] 155 | else: 156 | iterable_datasets, map_datasets = [], [] 157 | for dataset in datasets[split_name]: 158 | if isinstance(dataset, wds.DataPipeline): 159 | logging.info( 160 | "Dataset {} is IterableDataset, can't be concatenated.".format( 161 | dataset 162 | ) 163 | ) 164 | iterable_datasets.append(dataset) 165 | elif isinstance(dataset, IterableDataset): 166 | raise NotImplementedError( 167 | "Do not support concatenation of generic IterableDataset." 168 | ) 169 | else: 170 | map_datasets.append(dataset) 171 | 172 | # if len(iterable_datasets) > 0: 173 | # concatenate map-style datasets and iterable-style datasets separately 174 | if len(iterable_datasets) > 1: 175 | chained_datasets = ( 176 | ChainDataset(iterable_datasets) 177 | ) 178 | elif len(iterable_datasets) == 1: 179 | chained_datasets = iterable_datasets[0] 180 | else: 181 | chained_datasets = None 182 | 183 | concat_datasets = ( 184 | ConcatDataset(map_datasets) if len(map_datasets) > 0 else None 185 | ) 186 | 187 | train_datasets = concat_datasets, chained_datasets 188 | train_datasets = tuple([x for x in train_datasets if x is not None]) 189 | train_datasets = ( 190 | train_datasets[0] if len(train_datasets) == 1 else train_datasets 191 | ) 192 | 193 | datasets[split_name] = train_datasets 194 | 195 | return datasets 196 | 197 | -------------------------------------------------------------------------------- /minigpt4/datasets/datasets/dataloader_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import time 9 | import random 10 | import torch 11 | from minigpt4.datasets.data_utils import move_to_cuda 12 | from torch.utils.data import DataLoader 13 | 14 | 15 | # class MultiIterLoader: 16 | # """ 17 | # A simple wrapper for iterating over multiple iterators. 18 | 19 | # Args: 20 | # loaders (List[Loader]): List of Iterator loaders. 21 | # ratios (List[float]): List of ratios to sample from each loader. If None, all loaders are sampled uniformly. 22 | # """ 23 | 24 | # def __init__(self, loaders, ratios=None): 25 | # # assert all loaders has __next__ method 26 | # for loader in loaders: 27 | # assert hasattr( 28 | # loader, "__next__" 29 | # ), "Loader {} has no __next__ method.".format(loader) 30 | 31 | # if ratios is None: 32 | # ratios = [1.0] * len(loaders) 33 | # else: 34 | # assert len(ratios) == len(loaders) 35 | # ratios = [float(ratio) / sum(ratios) for ratio in ratios] 36 | 37 | # self.loaders = loaders 38 | # self.ratios = ratios 39 | 40 | # def __next__(self): 41 | # # random sample from each loader by ratio 42 | # loader_idx = random.choices(range(len(self.loaders)), self.ratios, k=1)[0] 43 | # return next(self.loaders[loader_idx]) 44 | 45 | # # def __len__(self): 46 | # # return len(self.loaders) 47 | 48 | # # def __iter__(self): 49 | # # # for loader in self.loaders: 50 | # # # yield loader 51 | # # return self 52 | 53 | class MultiIterLoader: 54 | """ 55 | A simple wrapper for iterating over multiple iterators. 56 | 57 | Args: 58 | loaders (List[Loader]): List of Iterator loaders. 59 | ratios (List[float]): List of ratios to sample from each loader. If None, all loaders are sampled uniformly. 60 | """ 61 | 62 | def __init__(self, loaders, ratios=None): 63 | # assert all loaders has __next__ method 64 | # self.nums = [] 65 | for loader in loaders: 66 | assert hasattr( 67 | loader, "__next__" 68 | ), "Loader {} has no __next__ method.".format(loader) 69 | #self.nums.extend(len(loader)) 70 | 71 | if ratios is None: 72 | ratios = [1.0] * len(loaders) 73 | else: 74 | assert len(ratios) == len(loaders) 75 | ratios = [float(ratio) / sum(ratios) for ratio in ratios] 76 | 77 | self.loaders = loaders 78 | self.ratios = ratios 79 | 80 | def __next__(self): 81 | # random sample from each loader by ratio 82 | loader_idx = random.choices(range(len(self.loaders)), self.ratios, k=1)[0] 83 | return next(self.loaders[loader_idx]) 84 | 85 | # def __len__(self): 86 | # return len(self.loaders) 87 | 88 | # def __iter__(self): 89 | # # for loader in self.loaders: 90 | # # yield loader 91 | # return self 92 | 93 | 94 | class PrefetchLoader(object): 95 | """ 96 | Modified from https://github.com/ChenRocks/UNITER. 97 | 98 | overlap compute and cuda data transfer 99 | (copied and then modified from nvidia apex) 100 | """ 101 | 102 | def __init__(self, loader): 103 | self.loader = loader 104 | self.stream = torch.cuda.Stream() 105 | 106 | def __iter__(self): 107 | loader_it = iter(self.loader) 108 | self.preload(loader_it) 109 | batch = self.next(loader_it) 110 | while batch is not None: 111 | is_tuple = isinstance(batch, tuple) 112 | if is_tuple: 113 | task, batch = batch 114 | 115 | if is_tuple: 116 | yield task, batch 117 | else: 118 | yield batch 119 | batch = self.next(loader_it) 120 | 121 | def __len__(self): 122 | return len(self.loader) 123 | 124 | def preload(self, it): 125 | try: 126 | self.batch = next(it) 127 | except StopIteration: 128 | self.batch = None 129 | return 130 | # if record_stream() doesn't work, another option is to make sure 131 | # device inputs are created on the main stream. 132 | # self.next_input_gpu = torch.empty_like(self.next_input, 133 | # device='cuda') 134 | # self.next_target_gpu = torch.empty_like(self.next_target, 135 | # device='cuda') 136 | # Need to make sure the memory allocated for next_* is not still in use 137 | # by the main stream at the time we start copying to next_*: 138 | # self.stream.wait_stream(torch.cuda.current_stream()) 139 | with torch.cuda.stream(self.stream): 140 | self.batch = move_to_cuda(self.batch) 141 | # more code for the alternative if record_stream() doesn't work: 142 | # copy_ will record the use of the pinned source tensor in this 143 | # side stream. 144 | # self.next_input_gpu.copy_(self.next_input, non_blocking=True) 145 | # self.next_target_gpu.copy_(self.next_target, non_blocking=True) 146 | # self.next_input = self.next_input_gpu 147 | # self.next_target = self.next_target_gpu 148 | 149 | def next(self, it): 150 | torch.cuda.current_stream().wait_stream(self.stream) 151 | batch = self.batch 152 | if batch is not None: 153 | record_cuda_stream(batch) 154 | self.preload(it) 155 | return batch 156 | 157 | def __next__(self): 158 | pass 159 | 160 | def __getattr__(self, name): 161 | method = self.loader.__getattribute__(name) 162 | return method 163 | 164 | 165 | def record_cuda_stream(batch): 166 | if isinstance(batch, torch.Tensor): 167 | batch.record_stream(torch.cuda.current_stream()) 168 | elif isinstance(batch, list) or isinstance(batch, tuple): 169 | for t in batch: 170 | record_cuda_stream(t) 171 | elif isinstance(batch, dict): 172 | for t in batch.values(): 173 | record_cuda_stream(t) 174 | else: 175 | pass 176 | 177 | 178 | class IterLoader: 179 | """ 180 | A wrapper to convert DataLoader as an infinite iterator. 181 | 182 | Modified from: 183 | https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py 184 | """ 185 | 186 | def __init__(self, dataloader: DataLoader, use_distributed: bool = False): 187 | self._dataloader = dataloader 188 | self.iter_loader = iter(self._dataloader) 189 | self._use_distributed = use_distributed 190 | self._epoch = 0 191 | 192 | @property 193 | def epoch(self) -> int: 194 | return self._epoch 195 | 196 | def __next__(self): 197 | try: 198 | data = next(self.iter_loader) 199 | except StopIteration: 200 | self._epoch += 1 201 | if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed: 202 | self._dataloader.sampler.set_epoch(self._epoch) 203 | time.sleep(2) # Prevent possible deadlock during epoch transition 204 | self.iter_loader = iter(self._dataloader) 205 | data = next(self.iter_loader) 206 | 207 | return data 208 | 209 | def __iter__(self): 210 | return self 211 | 212 | def __len__(self): 213 | return len(self._dataloader) 214 | -------------------------------------------------------------------------------- /minigpt4/datasets/builders/rec_base_dataset_builder.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is from 3 | Copyright (c) 2022, salesforce.com, inc. 4 | All rights reserved. 5 | SPDX-License-Identifier: BSD-3-Clause 6 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 7 | """ 8 | 9 | import logging 10 | import os 11 | import shutil 12 | import warnings 13 | 14 | from omegaconf import OmegaConf 15 | import torch.distributed as dist 16 | from torchvision.datasets.utils import download_url 17 | 18 | import minigpt4.common.utils as utils 19 | from minigpt4.common.dist_utils import is_dist_avail_and_initialized, is_main_process 20 | from minigpt4.common.registry import registry 21 | from minigpt4.processors.base_processor import BaseProcessor 22 | 23 | 24 | 25 | class RecBaseDatasetBuilder: 26 | train_dataset_cls, eval_dataset_cls = None, None 27 | 28 | def __init__(self, cfg=None): 29 | super().__init__() 30 | 31 | if cfg is None: 32 | # help to create datasets from default config. 33 | self.config = load_dataset_config(self.default_config_path()) 34 | elif isinstance(cfg, str): 35 | self.config = load_dataset_config(cfg) 36 | else: 37 | # when called from task.build_dataset() 38 | self.config = cfg 39 | 40 | self.data_type = self.config.data_type 41 | 42 | # self.vis_processors = {"train": BaseProcessor(), "eval": BaseProcessor()} 43 | self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()} 44 | 45 | def build_datasets(self): 46 | # download, split, etc... 47 | # only called on 1 GPU/TPU in distributed 48 | 49 | if is_main_process(): 50 | self._download_data() 51 | 52 | if is_dist_avail_and_initialized(): 53 | dist.barrier() 54 | 55 | # at this point, all the annotations and image/videos should be all downloaded to the specified locations. 56 | logging.info("Building datasets...") 57 | datasets = self.build() # dataset['train'/'val'/'test'] 58 | 59 | return datasets 60 | 61 | def build_processors(self): 62 | # vis_proc_cfg = self.config.get("vis_processor") 63 | txt_proc_cfg = self.config.get("text_processor") 64 | 65 | 66 | if txt_proc_cfg is not None: 67 | txt_train_cfg = txt_proc_cfg.get("train") 68 | txt_eval_cfg = txt_proc_cfg.get("eval") 69 | 70 | self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg) 71 | self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg) 72 | 73 | @staticmethod 74 | def _build_proc_from_cfg(cfg): 75 | return ( 76 | registry.get_processor_class(cfg.name).from_config(cfg) 77 | if cfg is not None 78 | else None 79 | ) 80 | 81 | @classmethod 82 | def default_config_path(cls, type="default"): 83 | return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type]) 84 | 85 | def _download_data(self): 86 | pass 87 | # self._download_ann() 88 | # self._download_vis() 89 | 90 | # def _download_ann(self): 91 | # """ 92 | # Download annotation files if necessary. 93 | # All the vision-language datasets should have annotations of unified format. 94 | 95 | # storage_path can be: 96 | # (1) relative/absolute: will be prefixed with env.cache_root to make full path if relative. 97 | # (2) basename/dirname: will be suffixed with base name of URL if dirname is provided. 98 | 99 | # Local annotation paths should be relative. 100 | # """ 101 | # anns = self.config.build_info.annotations 102 | 103 | # splits = anns.keys() 104 | 105 | # cache_root = registry.get_path("cache_root") 106 | 107 | # for split in splits: 108 | # info = anns[split] 109 | 110 | # urls, storage_paths = info.get("url", None), info.storage 111 | 112 | # if isinstance(urls, str): 113 | # urls = [urls] 114 | # if isinstance(storage_paths, str): 115 | # storage_paths = [storage_paths] 116 | 117 | # assert len(urls) == len(storage_paths) 118 | 119 | # for url_or_filename, storage_path in zip(urls, storage_paths): 120 | # # if storage_path is relative, make it full by prefixing with cache_root. 121 | # if not os.path.isabs(storage_path): 122 | # storage_path = os.path.join(cache_root, storage_path) 123 | 124 | # dirname = os.path.dirname(storage_path) 125 | # if not os.path.exists(dirname): 126 | # os.makedirs(dirname) 127 | 128 | # if os.path.isfile(url_or_filename): 129 | # src, dst = url_or_filename, storage_path 130 | # if not os.path.exists(dst): 131 | # shutil.copyfile(src=src, dst=dst) 132 | # else: 133 | # logging.info("Using existing file {}.".format(dst)) 134 | # else: 135 | # if os.path.isdir(storage_path): 136 | # # if only dirname is provided, suffix with basename of URL. 137 | # raise ValueError( 138 | # "Expecting storage_path to be a file path, got directory {}".format( 139 | # storage_path 140 | # ) 141 | # ) 142 | # else: 143 | # filename = os.path.basename(storage_path) 144 | 145 | # download_url(url=url_or_filename, root=dirname, filename=filename) 146 | 147 | # def _download_vis(self): 148 | 149 | # storage_path = self.config.build_info.get(self.data_type).storage 150 | # storage_path = utils.get_cache_path(storage_path) 151 | 152 | # if not os.path.exists(storage_path): 153 | # warnings.warn( 154 | # f""" 155 | # The specified path {storage_path} for visual inputs does not exist. 156 | # Please provide a correct path to the visual inputs or 157 | # refer to datasets/download_scripts/README.md for downloading instructions. 158 | # """ 159 | # ) 160 | 161 | def build(self): 162 | """ 163 | Create by split datasets inheriting torch.utils.data.Datasets. 164 | 165 | # build() can be dataset-specific. Overwrite to customize. 166 | """ 167 | self.build_processors() 168 | 169 | build_info = self.config.build_info 170 | 171 | ann_info = build_info.annotations 172 | vis_info = build_info.get(self.data_type) 173 | 174 | datasets = dict() 175 | for split in ann_info.keys(): 176 | if split not in ["train", "val", "test"]: 177 | continue 178 | 179 | is_train = split == "train" 180 | 181 | # processors 182 | # vis_processor = ( 183 | # self.vis_processors["train"] 184 | # if is_train 185 | # else self.vis_processors["eval"] 186 | # ) 187 | text_processor = ( 188 | self.text_processors["train"] 189 | if is_train 190 | else self.text_processors["eval"] 191 | ) 192 | 193 | # annotation path 194 | ann_paths = ann_info.get(split).storage 195 | if isinstance(ann_paths, str): 196 | ann_paths = [ann_paths] 197 | 198 | abs_ann_paths = [] 199 | for ann_path in ann_paths: 200 | if not os.path.isabs(ann_path): 201 | ann_path = utils.get_cache_path(ann_path) 202 | abs_ann_paths.append(ann_path) 203 | ann_paths = abs_ann_paths 204 | 205 | # create datasets 206 | dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls 207 | datasets[split] = dataset_cls( 208 | text_processor=text_processor, 209 | ann_paths=ann_paths 210 | ) 211 | 212 | return datasets 213 | 214 | 215 | def load_dataset_config(cfg_path): 216 | cfg = OmegaConf.load(cfg_path).datasets 217 | cfg = cfg[list(cfg.keys())[0]] 218 | 219 | return cfg 220 | -------------------------------------------------------------------------------- /minigpt4/models/base_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import logging 9 | import os 10 | 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | from minigpt4.common.dist_utils import download_cached_file, is_dist_avail_and_initialized 15 | from minigpt4.common.utils import get_abs_path, is_url 16 | from omegaconf import OmegaConf 17 | 18 | 19 | class BaseModel(nn.Module): 20 | """Base class for models.""" 21 | 22 | def __init__(self): 23 | super().__init__() 24 | 25 | @property 26 | def device(self): 27 | return list(self.parameters())[0].device 28 | 29 | def load_checkpoint(self, url_or_filename): 30 | """ 31 | Load from a finetuned checkpoint. 32 | 33 | This should expect no mismatch in the model keys and the checkpoint keys. 34 | """ 35 | 36 | if is_url(url_or_filename): 37 | cached_file = download_cached_file( 38 | url_or_filename, check_hash=False, progress=True 39 | ) 40 | checkpoint = torch.load(cached_file, map_location="cpu") 41 | elif os.path.isfile(url_or_filename): 42 | checkpoint = torch.load(url_or_filename, map_location="cpu") 43 | else: 44 | raise RuntimeError("checkpoint url or path is invalid") 45 | 46 | if "model" in checkpoint.keys(): 47 | state_dict = checkpoint["model"] 48 | else: 49 | state_dict = checkpoint 50 | 51 | msg = self.load_state_dict(state_dict, strict=False) 52 | 53 | logging.info("Missing keys {}".format(msg.missing_keys)) 54 | logging.info("load checkpoint from %s" % url_or_filename) 55 | 56 | return msg 57 | 58 | @classmethod 59 | def from_pretrained(cls, model_type): 60 | """ 61 | Build a pretrained model from default configuration file, specified by model_type. 62 | 63 | Args: 64 | - model_type (str): model type, specifying architecture and checkpoints. 65 | 66 | Returns: 67 | - model (nn.Module): pretrained or finetuned model, depending on the configuration. 68 | """ 69 | model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model 70 | model = cls.from_config(model_cfg) 71 | 72 | return model 73 | 74 | @classmethod 75 | def default_config_path(cls, model_type): 76 | assert ( 77 | model_type in cls.PRETRAINED_MODEL_CONFIG_DICT 78 | ), "Unknown model type {}".format(model_type) 79 | return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type]) 80 | 81 | def load_checkpoint_from_config(self, cfg, **kwargs): 82 | """ 83 | Load checkpoint as specified in the config file. 84 | 85 | If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model. 86 | When loading the pretrained model, each task-specific architecture may define their 87 | own load_from_pretrained() method. 88 | """ 89 | load_finetuned = cfg.get("load_finetuned", True) 90 | if load_finetuned: 91 | finetune_path = cfg.get("finetuned", None) 92 | assert ( 93 | finetune_path is not None 94 | ), "Found load_finetuned is True, but finetune_path is None." 95 | self.load_checkpoint(url_or_filename=finetune_path) 96 | else: 97 | # load pre-trained weights 98 | pretrain_path = cfg.get("pretrained", None) 99 | assert "Found load_finetuned is False, but pretrain_path is None." 100 | self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs) 101 | 102 | def before_evaluation(self, **kwargs): 103 | pass 104 | 105 | def show_n_params(self, return_str=True): 106 | tot = 0 107 | for p in self.parameters(): 108 | w = 1 109 | for x in p.shape: 110 | w *= x 111 | tot += w 112 | if return_str: 113 | if tot >= 1e6: 114 | return "{:.1f}M".format(tot / 1e6) 115 | else: 116 | return "{:.1f}K".format(tot / 1e3) 117 | else: 118 | return tot 119 | 120 | 121 | class BaseEncoder(nn.Module): 122 | """ 123 | Base class for primitive encoders, such as ViT, TimeSformer, etc. 124 | """ 125 | 126 | def __init__(self): 127 | super().__init__() 128 | 129 | def forward_features(self, samples, **kwargs): 130 | raise NotImplementedError 131 | 132 | @property 133 | def device(self): 134 | return list(self.parameters())[0].device 135 | 136 | 137 | class SharedQueueMixin: 138 | @torch.no_grad() 139 | def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None): 140 | # gather keys before updating queue 141 | image_feats = concat_all_gather(image_feat) 142 | text_feats = concat_all_gather(text_feat) 143 | 144 | batch_size = image_feats.shape[0] 145 | 146 | ptr = int(self.queue_ptr) 147 | assert self.queue_size % batch_size == 0 # for simplicity 148 | 149 | # replace the keys at ptr (dequeue and enqueue) 150 | self.image_queue[:, ptr : ptr + batch_size] = image_feats.T 151 | self.text_queue[:, ptr : ptr + batch_size] = text_feats.T 152 | 153 | if idxs is not None: 154 | idxs = concat_all_gather(idxs) 155 | self.idx_queue[:, ptr : ptr + batch_size] = idxs.T 156 | 157 | ptr = (ptr + batch_size) % self.queue_size # move pointer 158 | self.queue_ptr[0] = ptr 159 | 160 | 161 | class MomentumDistilationMixin: 162 | @torch.no_grad() 163 | def copy_params(self): 164 | for model_pair in self.model_pairs: 165 | for param, param_m in zip( 166 | model_pair[0].parameters(), model_pair[1].parameters() 167 | ): 168 | param_m.data.copy_(param.data) # initialize 169 | param_m.requires_grad = False # not update by gradient 170 | 171 | @torch.no_grad() 172 | def _momentum_update(self): 173 | for model_pair in self.model_pairs: 174 | for param, param_m in zip( 175 | model_pair[0].parameters(), model_pair[1].parameters() 176 | ): 177 | param_m.data = param_m.data * self.momentum + param.data * ( 178 | 1.0 - self.momentum 179 | ) 180 | 181 | 182 | class GatherLayer(torch.autograd.Function): 183 | """ 184 | Gather tensors from all workers with support for backward propagation: 185 | This implementation does not cut the gradients as torch.distributed.all_gather does. 186 | """ 187 | 188 | @staticmethod 189 | def forward(ctx, x): 190 | output = [ 191 | torch.zeros_like(x) for _ in range(torch.distributed.get_world_size()) 192 | ] 193 | torch.distributed.all_gather(output, x) 194 | return tuple(output) 195 | 196 | @staticmethod 197 | def backward(ctx, *grads): 198 | all_gradients = torch.stack(grads) 199 | torch.distributed.all_reduce(all_gradients) 200 | return all_gradients[torch.distributed.get_rank()] 201 | 202 | 203 | def all_gather_with_grad(tensors): 204 | """ 205 | Performs all_gather operation on the provided tensors. 206 | Graph remains connected for backward grad computation. 207 | """ 208 | # Queue the gathered tensors 209 | world_size = torch.distributed.get_world_size() 210 | # There is no need for reduction in the single-proc case 211 | if world_size == 1: 212 | return tensors 213 | 214 | # tensor_all = GatherLayer.apply(tensors) 215 | tensor_all = GatherLayer.apply(tensors) 216 | 217 | return torch.cat(tensor_all, dim=0) 218 | 219 | 220 | @torch.no_grad() 221 | def concat_all_gather(tensor): 222 | """ 223 | Performs all_gather operation on the provided tensors. 224 | *** Warning ***: torch.distributed.all_gather has no gradient. 225 | """ 226 | # if use distributed training 227 | if not is_dist_avail_and_initialized(): 228 | return tensor 229 | 230 | tensors_gather = [ 231 | torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) 232 | ] 233 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 234 | 235 | output = torch.cat(tensors_gather, dim=0) 236 | return output 237 | 238 | 239 | def tile(x, dim, n_tile): 240 | init_dim = x.size(dim) 241 | repeat_idx = [1] * x.dim() 242 | repeat_idx[dim] = n_tile 243 | x = x.repeat(*(repeat_idx)) 244 | order_index = torch.LongTensor( 245 | np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]) 246 | ) 247 | return torch.index_select(x, dim, order_index.to(x.device)) 248 | -------------------------------------------------------------------------------- /minigpt4/datasets/builders/base_dataset_builder.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is from 3 | Copyright (c) 2022, salesforce.com, inc. 4 | All rights reserved. 5 | SPDX-License-Identifier: BSD-3-Clause 6 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 7 | """ 8 | 9 | import logging 10 | import os 11 | import shutil 12 | import warnings 13 | 14 | from omegaconf import OmegaConf 15 | import torch.distributed as dist 16 | from torchvision.datasets.utils import download_url 17 | 18 | import minigpt4.common.utils as utils 19 | from minigpt4.common.dist_utils import is_dist_avail_and_initialized, is_main_process 20 | from minigpt4.common.registry import registry 21 | from minigpt4.processors.base_processor import BaseProcessor 22 | 23 | 24 | 25 | class BaseDatasetBuilder: 26 | train_dataset_cls, eval_dataset_cls = None, None 27 | 28 | def __init__(self, cfg=None): 29 | super().__init__() 30 | 31 | if cfg is None: 32 | # help to create datasets from default config. 33 | self.config = load_dataset_config(self.default_config_path()) 34 | elif isinstance(cfg, str): 35 | self.config = load_dataset_config(cfg) 36 | else: 37 | # when called from task.build_dataset() 38 | self.config = cfg 39 | 40 | self.data_type = self.config.data_type 41 | 42 | self.vis_processors = {"train": BaseProcessor(), "eval": BaseProcessor()} 43 | self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()} 44 | 45 | def build_datasets(self): 46 | # download, split, etc... 47 | # only called on 1 GPU/TPU in distributed 48 | 49 | if is_main_process(): 50 | self._download_data() 51 | 52 | if is_dist_avail_and_initialized(): 53 | dist.barrier() 54 | 55 | # at this point, all the annotations and image/videos should be all downloaded to the specified locations. 56 | logging.info("Building datasets...") 57 | datasets = self.build() # dataset['train'/'val'/'test'] 58 | 59 | return datasets 60 | 61 | def build_processors(self): 62 | vis_proc_cfg = self.config.get("vis_processor") 63 | txt_proc_cfg = self.config.get("text_processor") 64 | 65 | if vis_proc_cfg is not None: 66 | vis_train_cfg = vis_proc_cfg.get("train") 67 | vis_eval_cfg = vis_proc_cfg.get("eval") 68 | 69 | self.vis_processors["train"] = self._build_proc_from_cfg(vis_train_cfg) 70 | self.vis_processors["eval"] = self._build_proc_from_cfg(vis_eval_cfg) 71 | 72 | if txt_proc_cfg is not None: 73 | txt_train_cfg = txt_proc_cfg.get("train") 74 | txt_eval_cfg = txt_proc_cfg.get("eval") 75 | 76 | self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg) 77 | self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg) 78 | 79 | @staticmethod 80 | def _build_proc_from_cfg(cfg): 81 | return ( 82 | registry.get_processor_class(cfg.name).from_config(cfg) 83 | if cfg is not None 84 | else None 85 | ) 86 | 87 | @classmethod 88 | def default_config_path(cls, type="default"): 89 | return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type]) 90 | 91 | def _download_data(self): 92 | self._download_ann() 93 | self._download_vis() 94 | 95 | def _download_ann(self): 96 | """ 97 | Download annotation files if necessary. 98 | All the vision-language datasets should have annotations of unified format. 99 | 100 | storage_path can be: 101 | (1) relative/absolute: will be prefixed with env.cache_root to make full path if relative. 102 | (2) basename/dirname: will be suffixed with base name of URL if dirname is provided. 103 | 104 | Local annotation paths should be relative. 105 | """ 106 | anns = self.config.build_info.annotations 107 | 108 | splits = anns.keys() 109 | 110 | cache_root = registry.get_path("cache_root") 111 | 112 | for split in splits: 113 | info = anns[split] 114 | 115 | urls, storage_paths = info.get("url", None), info.storage 116 | 117 | if isinstance(urls, str): 118 | urls = [urls] 119 | if isinstance(storage_paths, str): 120 | storage_paths = [storage_paths] 121 | 122 | assert len(urls) == len(storage_paths) 123 | 124 | for url_or_filename, storage_path in zip(urls, storage_paths): 125 | # if storage_path is relative, make it full by prefixing with cache_root. 126 | if not os.path.isabs(storage_path): 127 | storage_path = os.path.join(cache_root, storage_path) 128 | 129 | dirname = os.path.dirname(storage_path) 130 | if not os.path.exists(dirname): 131 | os.makedirs(dirname) 132 | 133 | if os.path.isfile(url_or_filename): 134 | src, dst = url_or_filename, storage_path 135 | if not os.path.exists(dst): 136 | shutil.copyfile(src=src, dst=dst) 137 | else: 138 | logging.info("Using existing file {}.".format(dst)) 139 | else: 140 | if os.path.isdir(storage_path): 141 | # if only dirname is provided, suffix with basename of URL. 142 | raise ValueError( 143 | "Expecting storage_path to be a file path, got directory {}".format( 144 | storage_path 145 | ) 146 | ) 147 | else: 148 | filename = os.path.basename(storage_path) 149 | 150 | download_url(url=url_or_filename, root=dirname, filename=filename) 151 | 152 | def _download_vis(self): 153 | 154 | storage_path = self.config.build_info.get(self.data_type).storage 155 | storage_path = utils.get_cache_path(storage_path) 156 | 157 | if not os.path.exists(storage_path): 158 | warnings.warn( 159 | f""" 160 | The specified path {storage_path} for visual inputs does not exist. 161 | Please provide a correct path to the visual inputs or 162 | refer to datasets/download_scripts/README.md for downloading instructions. 163 | """ 164 | ) 165 | 166 | def build(self): 167 | """ 168 | Create by split datasets inheriting torch.utils.data.Datasets. 169 | 170 | # build() can be dataset-specific. Overwrite to customize. 171 | """ 172 | self.build_processors() 173 | 174 | build_info = self.config.build_info 175 | 176 | ann_info = build_info.annotations 177 | vis_info = build_info.get(self.data_type) 178 | 179 | datasets = dict() 180 | for split in ann_info.keys(): 181 | if split not in ["train", "val", "test"]: 182 | continue 183 | 184 | is_train = split == "train" 185 | 186 | # processors 187 | vis_processor = ( 188 | self.vis_processors["train"] 189 | if is_train 190 | else self.vis_processors["eval"] 191 | ) 192 | text_processor = ( 193 | self.text_processors["train"] 194 | if is_train 195 | else self.text_processors["eval"] 196 | ) 197 | 198 | # annotation path 199 | ann_paths = ann_info.get(split).storage 200 | if isinstance(ann_paths, str): 201 | ann_paths = [ann_paths] 202 | 203 | abs_ann_paths = [] 204 | for ann_path in ann_paths: 205 | if not os.path.isabs(ann_path): 206 | ann_path = utils.get_cache_path(ann_path) 207 | abs_ann_paths.append(ann_path) 208 | ann_paths = abs_ann_paths 209 | 210 | # visual data storage path 211 | vis_path = os.path.join(vis_info.storage, split) 212 | 213 | if not os.path.isabs(vis_path): 214 | # vis_path = os.path.join(utils.get_cache_path(), vis_path) 215 | vis_path = utils.get_cache_path(vis_path) 216 | 217 | if not os.path.exists(vis_path): 218 | warnings.warn("storage path {} does not exist.".format(vis_path)) 219 | 220 | # create datasets 221 | dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls 222 | datasets[split] = dataset_cls( 223 | vis_processor=vis_processor, 224 | text_processor=text_processor, 225 | ann_paths=ann_paths, 226 | vis_root=vis_path, 227 | ) 228 | 229 | return datasets 230 | 231 | 232 | def load_dataset_config(cfg_path): 233 | cfg = OmegaConf.load(cfg_path).datasets 234 | cfg = cfg[list(cfg.keys())[0]] 235 | 236 | return cfg 237 | -------------------------------------------------------------------------------- /minigpt4/datasets/datasets/rec_gnndataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Mar 1, 2020 3 | Pytorch Implementation of LightGCN in 4 | Xiangnan He et al. LightGCN: Simplifying and Powering Graph Convolution Network for Recommendation 5 | 6 | @author: Shuxian Bi (stanbi@mail.ustc.edu.cn),Jianbai Ye (gusye@mail.ustc.edu.cn) 7 | Design Dataset here 8 | Every dataset's index has to start at 0 9 | """ 10 | import os 11 | from os.path import join 12 | import sys 13 | import torch 14 | import numpy as np 15 | import pandas as pd 16 | from torch.utils.data import Dataset, DataLoader 17 | from scipy.sparse import csr_matrix 18 | import scipy.sparse as sp 19 | from time import time 20 | import torch.utils.data 21 | 22 | 23 | class BasicDataset(Dataset): 24 | def __init__(self): 25 | print("init dataset") 26 | 27 | # @property 28 | # def m_users(self): 29 | # raise NotImplementedError 30 | 31 | # @property 32 | # def n_items(self): 33 | # raise NotImplementedError 34 | 35 | @property 36 | def trainDataSize(self): 37 | raise NotImplementedError 38 | 39 | @property 40 | def testDict(self): 41 | raise NotImplementedError 42 | 43 | @property 44 | def allPos(self): 45 | raise NotImplementedError 46 | 47 | def getUserItemFeedback(self, users, items): 48 | raise NotImplementedError 49 | 50 | def getUserPosItems(self, users): 51 | raise NotImplementedError 52 | 53 | def getUserNegItems(self, users): 54 | """ 55 | not necessary for large dataset 56 | it's stupid to return all neg items in super large dataset 57 | """ 58 | raise NotImplementedError 59 | 60 | def getSparseGraph(self): 61 | """ 62 | build a graph in torch.sparse.IntTensor. 63 | Details in NGCF's matrix form 64 | A = 65 | |I, R| 66 | |R^T, I| 67 | """ 68 | raise NotImplementedError 69 | 70 | 71 | class GnnDataset(BasicDataset): 72 | """ 73 | Dataset type for pytorch \n 74 | Incldue graph information 75 | gowalla dataset 76 | """ 77 | def __init__(self,config, path="../data/gowalla"): 78 | # train or test 79 | # cprint(f'loading [{path}]') 80 | print("loading: ", path) 81 | self.split = config.A_split 82 | self.folds = config.A_n_fold 83 | self.mode_dict = {'train': 0, "test": 1} 84 | self.mode = self.mode_dict['train'] 85 | 86 | 87 | train_file = path+"train_ood2.pkl" 88 | 89 | valid_file = path+"valid_ood2.pkl" 90 | test_file = path + "test_ood2.pkl" 91 | self.path = path 92 | 93 | self.traindataSize = 0 94 | self.testDataSize = 0 95 | 96 | 97 | self.train = pd.read_pickle(train_file)[['uid','iid','label']] 98 | self.train.columns = ['user','item','label'] 99 | self.valid = pd.read_pickle(valid_file)[['uid','iid','label']] 100 | self.valid.columns = ['user','item','label'] 101 | self.test = pd.read_pickle(test_file)[['uid','iid','label']] 102 | self.test.columns = ['user','item','label'] 103 | 104 | # self.train = pd.read_csv(train_file)[['user','item','lables']] 105 | # self.valid = pd.read_csv(valid_file)[['user','item','lables']] 106 | # self.test = pd.read_csv(test_file)[['user','item','lables']] 107 | 108 | self.m_users = 1 + max([self.train['user'].max(),self.valid['user'].max(),self.test['user'].max()]) 109 | self.n_items = 1 + max([self.train['item'].max(),self.valid['item'].max(),self.test['item'].max()] ) 110 | 111 | self.testDataSize = self.test.shape[0] 112 | self.validDataSize = self.valid.shape[0] 113 | self.train_size = self.train.shape[0] 114 | 115 | 116 | 117 | 118 | self.Graph = None 119 | print(f"{self.train_size} interactions for normal training") 120 | print(f"{self.validDataSize} interactions for validation") 121 | print(f"{self.testDataSize} interactions for testing") 122 | print(f"{self.m_users} users, {self.n_items} items") 123 | print(f"{config.dataset} Sparsity : {(self.validDataSize + self.testDataSize+self.train_size) / self.m_users / self.n_items}") 124 | 125 | # (users,items), bipartite graph 126 | # self.UserItemNet = csr_matrix((np.ones(len(self.trainUser)), (self.trainUser, self.trainItem)), 127 | # shape=(self.m_users, self.n_items)) 128 | # self.users_D = np.array(self.UserItemNet.sum(axis=1)).squeeze() 129 | # self.users_D[self.users_D == 0.] = 1 130 | # self.items_D = np.array(self.UserItemNet.sum(axis=0)).squeeze() 131 | # self.items_D[self.items_D == 0.] = 1. 132 | # # pre-calculate 133 | # self._allPos = self.getUserPosItems(list(range(self.n_user))) 134 | # self.__testDict = self.__build_test() 135 | self._register_graph() 136 | 137 | print(":%s is ready to go"%(config.dataset)) 138 | 139 | def _register_graph(self): 140 | self.getSparseGraph_mode_a2("graph") 141 | 142 | 143 | 144 | @property 145 | def trainDataSize(self): 146 | return self.traindataSize 147 | 148 | @property 149 | def testDict(self): 150 | return self.__testDict 151 | 152 | @property 153 | def allPos(self): 154 | return self._allPos 155 | 156 | def _split_A_hat(self,A): 157 | A_fold = [] 158 | fold_len = (self.m_users + self.n_items) // self.folds 159 | for i_fold in range(self.folds): 160 | start = i_fold*fold_len 161 | if i_fold == self.folds - 1: 162 | end = self.m_users + self.n_items 163 | else: 164 | end = (i_fold + 1) * fold_len 165 | A_fold.append(self._convert_sp_mat_to_sp_tensor(A[start:end]).coalesce().cuda()) 166 | return A_fold 167 | 168 | def _convert_sp_mat_to_sp_tensor(self, X): 169 | coo = X.tocoo().astype(np.float32) 170 | row = torch.Tensor(coo.row).long() 171 | col = torch.Tensor(coo.col).long() 172 | index = torch.stack([row, col]) 173 | data = torch.FloatTensor(coo.data) 174 | return torch.sparse_coo_tensor(index,data,torch.Size(coo.shape)) 175 | 176 | 177 | 178 | def getSparseGraph_mode_a2(self,mode): 179 | pos_train = self.train[self.train['label']>0].values.copy() 180 | pos_train[:,1] += self.m_users 181 | self.trainUser = self.train['user'].values.squeeze() 182 | self.trainItem = self.train['item'] 183 | print("loading adjacency matrix") 184 | if self.Graph is None: 185 | try: 186 | pre_adj_mat = sp.load_npz(self.path + '/s_pre_adj_mat_'+mode+'.npz') 187 | print("successfully loaded...") 188 | norm_adj = pre_adj_mat 189 | except : 190 | print("generating adjacency matrix") 191 | s = time() 192 | pos_train_t = pos_train.copy() 193 | pos_train_t[:,0] = pos_train[:,1] 194 | pos_train_t[:,1] = pos_train[:,0] 195 | pos = np.concatenate([pos_train,pos_train_t],axis=0) 196 | 197 | adj_mat = sp.csr_matrix((pos[:,2], (pos[:,0],pos[:,1])), shape=(self.m_users+self.n_items, self.m_users+self.n_items)) 198 | adj_mat = adj_mat.todok() 199 | rowsum = np.array(adj_mat.sum(axis=1)) 200 | d_inv = np.power(rowsum, -0.5).flatten() 201 | d_inv[np.isinf(d_inv)] = 0. 202 | d_mat = sp.diags(d_inv) 203 | 204 | norm_adj = d_mat.dot(adj_mat) 205 | norm_adj = norm_adj.dot(d_mat) 206 | norm_adj = norm_adj.tocsr() 207 | end = time() 208 | print(f"costing {end-s}s, saved norm_mat...") 209 | sp.save_npz(self.path + '/s_pre_adj_mat_'+mode+'.npz', norm_adj) 210 | 211 | if self.split == True: 212 | self.Graph = self._split_A_hat(norm_adj) 213 | print("done split matrix") 214 | else: 215 | self.Graph = self._convert_sp_mat_to_sp_tensor(norm_adj) 216 | self.Graph = self.Graph.coalesce().cuda() 217 | print("don't split the matrix") 218 | return self.Graph 219 | 220 | 221 | 222 | 223 | def __build_test(self): 224 | """ 225 | return: 226 | dict: {user: [items]} 227 | """ 228 | test_data = {} 229 | for i, item in enumerate(self.testItem): 230 | user = self.testUser[i] 231 | if test_data.get(user): 232 | test_data[user].append(item) 233 | else: 234 | test_data[user] = [item] 235 | return test_data 236 | 237 | def getUserItemFeedback(self, users, items): 238 | """ 239 | users: 240 | shape [-1] 241 | items: 242 | shape [-1] 243 | return: 244 | feedback [-1] 245 | """ 246 | # print(self.UserItemNet[users, items]) 247 | return np.array(self.UserItemNet[users, items]).astype('uint8').reshape((-1,)) 248 | 249 | def getUserPosItems(self, users): 250 | posItems = [] 251 | for user in users: 252 | posItems.append(self.UserItemNet[user].nonzero()[1]) 253 | return posItems 254 | 255 | 256 | def generate_train_dataloader(self,batch_size=1024): 257 | ''' 258 | generate minibatch data for full training and retrianing 259 | ''' 260 | data = torch.from_numpy(self.train[['user','item','lables']].values) 261 | train_loader = torch.utils.data.DataLoader(data,shuffle=True,batch_size=batch_size,drop_last=False,num_workers=2) 262 | return train_loader -------------------------------------------------------------------------------- /minigpt4/tasks/base_task.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import logging 9 | import os 10 | 11 | import torch 12 | import torch.distributed as dist 13 | from minigpt4.common.dist_utils import get_rank, get_world_size, is_main_process, is_dist_avail_and_initialized 14 | from minigpt4.common.logger import MetricLogger, SmoothedValue 15 | from minigpt4.common.registry import registry 16 | from minigpt4.datasets.data_utils import prepare_sample 17 | 18 | 19 | class BaseTask: 20 | def __init__(self, **kwargs): 21 | super().__init__() 22 | 23 | self.inst_id_key = "instance_id" 24 | 25 | @classmethod 26 | def setup_task(cls, **kwargs): 27 | return cls() 28 | 29 | def build_model(self, cfg): 30 | model_config = cfg.model_cfg 31 | 32 | model_cls = registry.get_model_class(model_config.arch) 33 | return model_cls.from_config(model_config) 34 | 35 | def build_datasets(self, cfg): 36 | """ 37 | Build a dictionary of datasets, keyed by split 'train', 'valid', 'test'. 38 | Download dataset and annotations automatically if not exist. 39 | 40 | Args: 41 | cfg (common.config.Config): _description_ 42 | 43 | Returns: 44 | dict: Dictionary of torch.utils.data.Dataset objects by split. 45 | """ 46 | 47 | datasets = dict() 48 | 49 | datasets_config = cfg.datasets_cfg 50 | evaluate_only = cfg.run_cfg.evaluate 51 | 52 | assert len(datasets_config) > 0, "At least one dataset has to be specified." 53 | 54 | for name in datasets_config: 55 | dataset_config = datasets_config[name] 56 | 57 | builder = registry.get_builder_class(name)(dataset_config) 58 | dataset = builder.build_datasets(evaluate_only=evaluate_only) 59 | 60 | dataset['train'].name = name 61 | if 'sample_ratio' in dataset_config: 62 | dataset['train'].sample_ratio = dataset_config.sample_ratio 63 | 64 | datasets[name] = dataset 65 | 66 | return datasets 67 | 68 | def train_step(self, model, samples): 69 | loss = model(samples)["loss"] 70 | return loss 71 | 72 | def valid_step(self, model, samples): 73 | raise NotImplementedError 74 | 75 | def before_evaluation(self, model, dataset, **kwargs): 76 | model.before_evaluation(dataset=dataset, task_type=type(self)) 77 | 78 | def after_evaluation(self, **kwargs): 79 | pass 80 | 81 | def inference_step(self): 82 | raise NotImplementedError 83 | 84 | def evaluation(self, model, data_loader, cuda_enabled=True): 85 | metric_logger = MetricLogger(delimiter=" ") 86 | header = "Evaluation" 87 | # TODO make it configurable 88 | print_freq = 10 89 | 90 | results = [] 91 | 92 | for samples in metric_logger.log_every(data_loader, print_freq, header): 93 | samples = prepare_sample(samples, cuda_enabled=cuda_enabled) 94 | 95 | eval_output = self.valid_step(model=model, samples=samples) 96 | results.extend(eval_output) 97 | 98 | if is_dist_avail_and_initialized(): 99 | dist.barrier() 100 | 101 | return results 102 | 103 | def train_epoch( 104 | self, 105 | epoch, 106 | model, 107 | data_loader, 108 | optimizer, 109 | lr_scheduler, 110 | scaler=None, 111 | cuda_enabled=False, 112 | log_freq=50, 113 | accum_grad_iters=1, 114 | ): 115 | return self._train_inner_loop( 116 | epoch=epoch, 117 | iters_per_epoch=lr_scheduler.iters_per_epoch, 118 | model=model, 119 | data_loader=data_loader, 120 | optimizer=optimizer, 121 | scaler=scaler, 122 | lr_scheduler=lr_scheduler, 123 | log_freq=log_freq, 124 | cuda_enabled=cuda_enabled, 125 | accum_grad_iters=accum_grad_iters, 126 | ) 127 | 128 | def train_iters( 129 | self, 130 | epoch, 131 | start_iters, 132 | iters_per_inner_epoch, 133 | model, 134 | data_loader, 135 | optimizer, 136 | lr_scheduler, 137 | scaler=None, 138 | cuda_enabled=False, 139 | log_freq=50, 140 | accum_grad_iters=1, 141 | ): 142 | return self._train_inner_loop( 143 | epoch=epoch, 144 | start_iters=start_iters, 145 | iters_per_epoch=iters_per_inner_epoch, 146 | model=model, 147 | data_loader=data_loader, 148 | optimizer=optimizer, 149 | scaler=scaler, 150 | lr_scheduler=lr_scheduler, 151 | log_freq=log_freq, 152 | cuda_enabled=cuda_enabled, 153 | accum_grad_iters=accum_grad_iters, 154 | ) 155 | 156 | def _train_inner_loop( 157 | self, 158 | epoch, 159 | iters_per_epoch, 160 | model, 161 | data_loader, 162 | optimizer, 163 | lr_scheduler, 164 | scaler=None, 165 | start_iters=None, 166 | log_freq=50, 167 | cuda_enabled=False, 168 | accum_grad_iters=1, 169 | ): 170 | """ 171 | An inner training loop compatible with both epoch-based and iter-based training. 172 | 173 | When using epoch-based, training stops after one epoch; when using iter-based, 174 | training stops after #iters_per_epoch iterations. 175 | """ 176 | use_amp = scaler is not None 177 | 178 | if not hasattr(data_loader, "__next__"): 179 | # convert to iterator if not already 180 | data_loader = iter(data_loader) 181 | 182 | metric_logger = MetricLogger(delimiter=" ") 183 | metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}")) 184 | metric_logger.add_meter("loss", SmoothedValue(window_size=1, fmt="{value:.4f}")) 185 | 186 | # if iter-based runner, schedule lr based on inner epoch. 187 | logging.info( 188 | "Start training epoch {}, {} iters per inner epoch.".format( 189 | epoch, iters_per_epoch 190 | ) 191 | ) 192 | header = "Train: data epoch: [{}]".format(epoch) 193 | if start_iters is None: 194 | # epoch-based runner 195 | inner_epoch = epoch 196 | else: 197 | # In iter-based runner, we schedule the learning rate based on iterations. 198 | inner_epoch = start_iters // iters_per_epoch 199 | header = header + "; inner epoch [{}]".format(inner_epoch) 200 | 201 | for i in metric_logger.log_every(range(iters_per_epoch), log_freq, header): 202 | # if using iter-based runner, we stop after iters_per_epoch iterations. 203 | if i >= iters_per_epoch: 204 | break 205 | 206 | samples = next(data_loader) 207 | 208 | samples = prepare_sample(samples, cuda_enabled=cuda_enabled) 209 | samples.update( 210 | { 211 | "epoch": inner_epoch, 212 | "num_iters_per_epoch": iters_per_epoch, 213 | "iters": i, 214 | } 215 | ) 216 | 217 | lr_scheduler.step(cur_epoch=inner_epoch, cur_step=i) 218 | 219 | with torch.cuda.amp.autocast(enabled=use_amp): 220 | loss = self.train_step(model=model, samples=samples) 221 | 222 | # after_train_step() 223 | if use_amp: 224 | scaler.scale(loss).backward() 225 | else: 226 | loss.backward() 227 | 228 | 229 | 230 | # update gradients every accum_grad_iters iterations 231 | if (i + 1) % accum_grad_iters == 0: 232 | if use_amp: 233 | scaler.step(optimizer) 234 | scaler.update() 235 | else: 236 | optimizer.step() 237 | optimizer.zero_grad() 238 | 239 | metric_logger.update(loss=loss.item()) 240 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 241 | torch.cuda.empty_cache() 242 | 243 | # after train_epoch() 244 | # gather the stats from all processes 245 | metric_logger.synchronize_between_processes() 246 | logging.info("Averaged stats: " + str(metric_logger.global_avg())) 247 | return { 248 | k: "{:.3f}".format(meter.global_avg) 249 | for k, meter in metric_logger.meters.items() 250 | } 251 | 252 | @staticmethod 253 | def save_result(result, result_dir, filename, remove_duplicate=""): 254 | import json 255 | 256 | result_file = os.path.join( 257 | result_dir, "%s_rank%d.json" % (filename, get_rank()) 258 | ) 259 | final_result_file = os.path.join(result_dir, "%s.json" % filename) 260 | 261 | json.dump(result, open(result_file, "w")) 262 | 263 | if is_dist_avail_and_initialized(): 264 | dist.barrier() 265 | 266 | if is_main_process(): 267 | logging.warning("rank %d starts merging results." % get_rank()) 268 | # combine results from all processes 269 | result = [] 270 | 271 | for rank in range(get_world_size()): 272 | result_file = os.path.join( 273 | result_dir, "%s_rank%d.json" % (filename, rank) 274 | ) 275 | res = json.load(open(result_file, "r")) 276 | result += res 277 | 278 | if remove_duplicate: 279 | result_new = [] 280 | id_list = [] 281 | for res in result: 282 | if res[remove_duplicate] not in id_list: 283 | id_list.append(res[remove_duplicate]) 284 | result_new.append(res) 285 | result = result_new 286 | 287 | json.dump(result, open(final_result_file, "w")) 288 | print("result file saved to %s" % final_result_file) 289 | 290 | return final_result_file 291 | -------------------------------------------------------------------------------- /minigpt4/models/rec_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2023, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | import contextlib 8 | import logging 9 | import os 10 | import time 11 | import datetime 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.distributed as dist 16 | import torch.nn.functional as F 17 | 18 | import minigpt4.common.dist_utils as dist_utils 19 | from minigpt4.common.dist_utils import download_cached_file 20 | from minigpt4.common.utils import is_url 21 | from minigpt4.common.logger import MetricLogger 22 | from minigpt4.models.base_model import BaseModel 23 | # from minigpt4.models.Qformer import BertConfig, BertLMHeadModel 24 | # from minigpt4.models.eva_vit import create_eva_vit_g 25 | from transformers import BertTokenizer 26 | import warnings 27 | 28 | from minigpt4.models.rec_base_models import MatrixFactorization, MF_linear,LightGCN, SASRec, Personlized_Prompt, random_mf, Soft_Prompt, RecEncoder_DIN, hashGNN 29 | 30 | 31 | class Rec2Base(BaseModel): 32 | @classmethod 33 | 34 | def to_be_trained(self): 35 | pass 36 | 37 | def init_tokenizer(cls): 38 | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") 39 | tokenizer.add_special_tokens({"bos_token": "[DEC]"}) 40 | return tokenizer 41 | 42 | def maybe_autocast(self, dtype=torch.float16): 43 | # if on cpu, don't use autocast 44 | # if on gpu, use autocast with dtype if provided, otherwise use torch.float16 45 | enable_autocast = self.device != torch.device("cpu") 46 | 47 | if enable_autocast: 48 | return torch.cuda.amp.autocast(dtype=dtype) 49 | else: 50 | return contextlib.nullcontext() 51 | 52 | # @classmethod 53 | # def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2): 54 | # encoder_config = BertConfig.from_pretrained("bert-base-uncased") 55 | # encoder_config.encoder_width = vision_width 56 | # # insert cross-attention layer every other block 57 | # encoder_config.add_cross_attention = True 58 | # encoder_config.cross_attention_freq = cross_attention_freq 59 | # encoder_config.query_length = num_query_token 60 | # Qformer = BertLMHeadModel(config=encoder_config) 61 | # query_tokens = nn.Parameter( 62 | # torch.zeros(1, num_query_token, encoder_config.hidden_size) 63 | # ) 64 | # query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) 65 | # return Qformer, query_tokens 66 | 67 | @classmethod 68 | def init_rec_encoder(self,rec_model, config, precision): 69 | if rec_model == "MF": 70 | print("### rec_encoder:", "MF") 71 | rec_model = MatrixFactorization(config) 72 | elif rec_model == "hash": 73 | print("### rec_encoder:", "using hash encoder, projection layer will not be used") 74 | rec_model = hashGNN(config) 75 | rec_model.set_encode_mode(config.code_mode) 76 | elif rec_model == "lightgcn": 77 | print("### rec_encoder:", "lightgcn") 78 | rec_model = LightGCN(config) 79 | elif rec_model == "sasrec": 80 | print("### rec_encoder:", "sasrec") 81 | rec_model = SASRec(config) 82 | elif rec_model == "DIN": 83 | print("### rec_encoder:", "DIN") 84 | rec_model = RecEncoder_DIN(config) 85 | elif rec_model == "personlized_prompt": 86 | print("### rec_encoder:", "personlized_prompt") 87 | rec_model = Personlized_Prompt(config) 88 | elif rec_model == "random_mf": 89 | print("### rec_encoder:", "random_mf") 90 | rec_model = random_mf(config) 91 | elif rec_model == 'soft_prompt': 92 | print("### rec_encoder:", "soft_prompt") 93 | rec_model = Soft_Prompt(config) 94 | else: 95 | rec_model = None 96 | warnings.warn(" the input rec_model is not MF, LightGCN or sasrec, or DCN, we won't utilize the rec_encoder directly.") 97 | # raise NotImplementedError("the current version olny supports the following models: MF,...") 98 | return rec_model 99 | 100 | # @classmethod 101 | # def init_vision_encoder( 102 | # cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision 103 | # ): 104 | # assert model_name == "eva_clip_g", "vit model must be eva_clip_g for current version of MiniGPT-4" 105 | # visual_encoder = create_eva_vit_g( 106 | # img_size, drop_path_rate, use_grad_checkpoint, precision 107 | # ) 108 | 109 | # ln_vision = LayerNorm(visual_encoder.num_features) 110 | # return visual_encoder, ln_vision 111 | 112 | def load_from_pretrained(self, url_or_filename): 113 | if is_url(url_or_filename): 114 | cached_file = download_cached_file( 115 | url_or_filename, check_hash=False, progress=True 116 | ) 117 | checkpoint = torch.load(cached_file, map_location="cpu") 118 | elif os.path.isfile(url_or_filename): 119 | checkpoint = torch.load(url_or_filename, map_location="cpu") 120 | else: 121 | raise RuntimeError("checkpoint url or path is invalid") 122 | 123 | state_dict = checkpoint["model"] 124 | 125 | msg = self.load_state_dict(state_dict, strict=False) 126 | 127 | # logging.info("Missing keys {}".format(msg.missing_keys)) 128 | logging.info("load checkpoint from %s" % url_or_filename) 129 | 130 | return msg 131 | 132 | def after_evaluation(self, **kwargs): 133 | pass 134 | 135 | 136 | def disabled_train(self, mode=True): 137 | """Overwrite model.train with this function to make sure train/eval mode 138 | does not change anymore.""" 139 | return self 140 | 141 | 142 | class LayerNorm(nn.LayerNorm): 143 | """Subclass torch's LayerNorm to handle fp16.""" 144 | 145 | def forward(self, x: torch.Tensor): 146 | orig_type = x.dtype 147 | ret = super().forward(x.type(torch.float32)) 148 | return ret.type(orig_type) 149 | 150 | 151 | def compute_sim_matrix(model, data_loader, **kwargs): 152 | k_test = kwargs.pop("k_test") 153 | 154 | metric_logger = MetricLogger(delimiter=" ") 155 | header = "Evaluation:" 156 | 157 | logging.info("Computing features for evaluation...") 158 | start_time = time.time() 159 | 160 | texts = data_loader.dataset.text 161 | num_text = len(texts) 162 | text_bs = 256 163 | text_ids = [] 164 | text_embeds = [] 165 | text_atts = [] 166 | for i in range(0, num_text, text_bs): 167 | text = texts[i : min(num_text, i + text_bs)] 168 | text_input = model.tokenizer( 169 | text, 170 | padding="max_length", 171 | truncation=True, 172 | max_length=35, 173 | return_tensors="pt", 174 | ).to(model.device) 175 | text_feat = model.forward_text(text_input) 176 | text_embed = F.normalize(model.text_proj(text_feat)) 177 | text_embeds.append(text_embed) 178 | text_ids.append(text_input.input_ids) 179 | text_atts.append(text_input.attention_mask) 180 | 181 | text_embeds = torch.cat(text_embeds, dim=0) 182 | text_ids = torch.cat(text_ids, dim=0) 183 | text_atts = torch.cat(text_atts, dim=0) 184 | 185 | vit_feats = [] 186 | image_embeds = [] 187 | for samples in data_loader: 188 | image = samples["image"] 189 | 190 | image = image.to(model.device) 191 | image_feat, vit_feat = model.forward_image(image) 192 | image_embed = model.vision_proj(image_feat) 193 | image_embed = F.normalize(image_embed, dim=-1) 194 | 195 | vit_feats.append(vit_feat.cpu()) 196 | image_embeds.append(image_embed) 197 | 198 | vit_feats = torch.cat(vit_feats, dim=0) 199 | image_embeds = torch.cat(image_embeds, dim=0) 200 | 201 | sims_matrix = [] 202 | for image_embed in image_embeds: 203 | sim_q2t = image_embed @ text_embeds.t() 204 | sim_i2t, _ = sim_q2t.max(0) 205 | sims_matrix.append(sim_i2t) 206 | sims_matrix = torch.stack(sims_matrix, dim=0) 207 | 208 | score_matrix_i2t = torch.full( 209 | (len(data_loader.dataset.image), len(texts)), -100.0 210 | ).to(model.device) 211 | 212 | num_tasks = dist_utils.get_world_size() 213 | rank = dist_utils.get_rank() 214 | step = sims_matrix.size(0) // num_tasks + 1 215 | start = rank * step 216 | end = min(sims_matrix.size(0), start + step) 217 | 218 | for i, sims in enumerate( 219 | metric_logger.log_every(sims_matrix[start:end], 50, header) 220 | ): 221 | topk_sim, topk_idx = sims.topk(k=k_test, dim=0) 222 | image_inputs = vit_feats[start + i].repeat(k_test, 1, 1).to(model.device) 223 | score = model.compute_itm( 224 | image_inputs=image_inputs, 225 | text_ids=text_ids[topk_idx], 226 | text_atts=text_atts[topk_idx], 227 | ).float() 228 | score_matrix_i2t[start + i, topk_idx] = score + topk_sim 229 | 230 | sims_matrix = sims_matrix.t() 231 | score_matrix_t2i = torch.full( 232 | (len(texts), len(data_loader.dataset.image)), -100.0 233 | ).to(model.device) 234 | 235 | step = sims_matrix.size(0) // num_tasks + 1 236 | start = rank * step 237 | end = min(sims_matrix.size(0), start + step) 238 | 239 | for i, sims in enumerate( 240 | metric_logger.log_every(sims_matrix[start:end], 50, header) 241 | ): 242 | topk_sim, topk_idx = sims.topk(k=k_test, dim=0) 243 | image_inputs = vit_feats[topk_idx.cpu()].to(model.device) 244 | score = model.compute_itm( 245 | image_inputs=image_inputs, 246 | text_ids=text_ids[start + i].repeat(k_test, 1), 247 | text_atts=text_atts[start + i].repeat(k_test, 1), 248 | ).float() 249 | score_matrix_t2i[start + i, topk_idx] = score + topk_sim 250 | 251 | if dist_utils.is_dist_avail_and_initialized(): 252 | dist.barrier() 253 | torch.distributed.all_reduce( 254 | score_matrix_i2t, op=torch.distributed.ReduceOp.SUM 255 | ) 256 | torch.distributed.all_reduce( 257 | score_matrix_t2i, op=torch.distributed.ReduceOp.SUM 258 | ) 259 | 260 | total_time = time.time() - start_time 261 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 262 | logging.info("Evaluation time {}".format(total_time_str)) 263 | 264 | return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy() 265 | -------------------------------------------------------------------------------- /minigpt4/datasets/builders/rec_pair_builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import warnings 4 | 5 | from minigpt4.common.registry import registry 6 | from minigpt4.datasets.builders.rec_base_dataset_builder import RecBaseDatasetBuilder 7 | # from minigpt4.datasets.datasets.laion_dataset import LaionDataset 8 | # from minigpt4.datasets.datasets.cc_sbu_dataset import CCSBUDataset, CCSBUAlignDataset 9 | 10 | from minigpt4.datasets.datasets.rec_datasets import MovielensDataset, MovielensDataset_stage1, AmazonDataset, MoiveOOData, MoiveOOData_sasrec, AmazonOOData, AmazonOOData_sasrec 11 | 12 | # @registry.register_builder("movielens") 13 | # class MovielensBuilder(RecBaseDatasetBuilder): 14 | # train_dataset_cls = MovielensDataset 15 | 16 | # DATASET_CONFIG_DICT = { 17 | # "default": "configs/datasets/movielens/default.yaml", 18 | # } 19 | 20 | # def build_datasets(self): 21 | # # at this point, all the annotations and image/videos should be all downloaded to the specified locations. 22 | # logging.info("Building datasets...") 23 | # self.build_processors() 24 | 25 | # build_info = self.config.build_info 26 | # storage_path = build_info.storage 27 | 28 | # datasets = dict() 29 | 30 | # if not os.path.exists(storage_path): 31 | # warnings.warn("storage path {} does not exist.".format(storage_path)) 32 | 33 | # # create datasets 34 | # dataset_cls = self.train_dataset_cls 35 | # datasets['train'] = dataset_cls( 36 | # text_processor=self.text_processors["train"], 37 | # ann_paths=[os.path.join(storage_path, 'train')], 38 | # ) 39 | # try: 40 | # datasets['valid'] = dataset_cls( 41 | # text_processor=self.text_processors["train"], 42 | # ann_paths=[os.path.join(storage_path, 'valid_small2')]) 43 | # datasets['test'] = dataset_cls( 44 | # text_processor=self.text_processors["train"], 45 | # ann_paths=[os.path.join(storage_path, 'test')]) 46 | # except: 47 | # pass 48 | 49 | 50 | 51 | # return datasets 52 | 53 | # @registry.register_builder("amazon") 54 | # class AmazonBuilder(RecBaseDatasetBuilder): 55 | # train_dataset_cls = AmazonDataset 56 | 57 | # DATASET_CONFIG_DICT = { 58 | # "default": "configs/datasets/amazon/default.yaml", 59 | # } 60 | 61 | # def build_datasets(self): 62 | # # at this point, all the annotations and image/videos should be all downloaded to the specified locations. 63 | # logging.info("Building datasets...") 64 | # self.build_processors() 65 | 66 | # build_info = self.config.build_info 67 | # storage_path = build_info.storage 68 | 69 | # datasets = dict() 70 | 71 | # if not os.path.exists(storage_path): 72 | # warnings.warn("storage path {} does not exist.".format(storage_path)) 73 | 74 | # # create datasets 75 | # dataset_cls = self.train_dataset_cls 76 | # datasets['train'] = dataset_cls( 77 | # text_processor=self.text_processors["train"], 78 | # ann_paths=[os.path.join(storage_path, 'train')], 79 | # ) 80 | # try: 81 | # datasets['valid'] = dataset_cls( 82 | # text_processor=self.text_processors["train"], 83 | # ann_paths=[os.path.join(storage_path, 'valid_small')]) 84 | # #0915 85 | # datasets['test'] = dataset_cls( 86 | # text_processor=self.text_processors["train"], 87 | # ann_paths=[os.path.join(storage_path, 'test')]) 88 | # except: 89 | # print(os.path.join(storage_path, 'valid_small'), os.path.exists(os.path.join(storage_path, 'valid_small_seqs.pkl'))) 90 | # raise FileNotFoundError("file not found.") 91 | # return datasets 92 | 93 | 94 | @registry.register_builder("movie_ood") 95 | class MoiveOODBuilder(RecBaseDatasetBuilder): 96 | train_dataset_cls = MoiveOOData 97 | 98 | DATASET_CONFIG_DICT = { 99 | "default": "configs/datasets/default.yaml", 100 | } 101 | def build_datasets(self,evaluate_only=False): 102 | # at this point, all the annotations and image/videos should be all downloaded to the specified locations. 103 | logging.info("Building datasets...") 104 | self.build_processors() 105 | 106 | build_info = self.config.build_info 107 | storage_path = build_info.storage 108 | 109 | datasets = dict() 110 | 111 | if not os.path.exists(storage_path): 112 | warnings.warn("storage path {} does not exist.".format(storage_path)) 113 | 114 | # create datasets 115 | dataset_cls = self.train_dataset_cls 116 | datasets['train'] = dataset_cls( 117 | text_processor=self.text_processors["train"], 118 | ann_paths=[os.path.join(storage_path, 'train')], 119 | ) 120 | try: 121 | datasets['valid'] = dataset_cls( 122 | text_processor=self.text_processors["train"], 123 | ann_paths=[os.path.join(storage_path, 'valid_small')]) 124 | #0915 125 | datasets['test'] = dataset_cls( 126 | text_processor=self.text_processors["train"], 127 | ann_paths=[os.path.join(storage_path, 'test')]) 128 | if evaluate_only: 129 | datasets['test_warm'] = dataset_cls( 130 | text_processor=self.text_processors["train"], 131 | ann_paths=[os.path.join(storage_path, 'test_warm_cold=warm')]) 132 | 133 | datasets['test_cold'] = dataset_cls( 134 | text_processor=self.text_processors["train"], 135 | ann_paths=[os.path.join(storage_path, 'test_warm_cold=cold')]) 136 | except: 137 | print(os.path.join(storage_path, 'valid_small'), os.path.exists(os.path.join(storage_path, 'valid_small_seqs.pkl'))) 138 | raise FileNotFoundError("file not found.") 139 | return datasets 140 | 141 | 142 | @registry.register_builder("movie_ood_sasrec") 143 | class MoiveOODBuilder_sasrec(RecBaseDatasetBuilder): 144 | train_dataset_cls = MoiveOOData_sasrec 145 | 146 | DATASET_CONFIG_DICT = { 147 | "default": "configs/datasets/default.yaml", 148 | } 149 | def build_datasets(self,evaluate_only=False): 150 | # at this point, all the annotations and image/videos should be all downloaded to the specified locations. 151 | logging.info("Building datasets...") 152 | self.build_processors() 153 | 154 | build_info = self.config.build_info 155 | storage_path = build_info.storage 156 | 157 | datasets = dict() 158 | 159 | if not os.path.exists(storage_path): 160 | warnings.warn("storage path {} does not exist.".format(storage_path)) 161 | 162 | # create datasets 163 | dataset_cls = self.train_dataset_cls 164 | datasets['train'] = dataset_cls( 165 | text_processor=self.text_processors["train"], 166 | ann_paths=[os.path.join(storage_path, 'train')], 167 | ) 168 | try: 169 | datasets['valid'] = dataset_cls( 170 | text_processor=self.text_processors["train"], 171 | ann_paths=[os.path.join(storage_path, 'valid_small')]) 172 | #0915 173 | datasets['test'] = dataset_cls( 174 | text_processor=self.text_processors["train"], 175 | ann_paths=[os.path.join(storage_path, 'test')]) 176 | except: 177 | print(os.path.join(storage_path, 'valid_small'), os.path.exists(os.path.join(storage_path, 'valid_small_seqs.pkl'))) 178 | raise FileNotFoundError("file not found.") 179 | return datasets 180 | 181 | 182 | 183 | @registry.register_builder("amazon_ood") 184 | class AmazonOODBuilder(RecBaseDatasetBuilder): 185 | train_dataset_cls = AmazonOOData 186 | 187 | DATASET_CONFIG_DICT = { 188 | "default": "configs/datasets/default.yaml", 189 | } 190 | def build_datasets(self, evaluate_only=False): 191 | # at this point, all the annotations and image/videos should be all downloaded to the specified locations. 192 | logging.info("Building datasets...") 193 | self.build_processors() 194 | 195 | build_info = self.config.build_info 196 | storage_path = build_info.storage 197 | 198 | datasets = dict() 199 | 200 | if not os.path.exists(storage_path): 201 | warnings.warn("storage path {} does not exist.".format(storage_path)) 202 | 203 | # create datasets 204 | dataset_cls = self.train_dataset_cls 205 | datasets['train'] = dataset_cls( 206 | text_processor=self.text_processors["train"], 207 | ann_paths=[os.path.join(storage_path, 'train')], 208 | ) 209 | try: 210 | datasets['valid'] = dataset_cls( 211 | text_processor=self.text_processors["train"], 212 | ann_paths=[os.path.join(storage_path, 'valid_small')]) 213 | #0915 214 | datasets['test'] = dataset_cls( 215 | text_processor=self.text_processors["train"], 216 | ann_paths=[os.path.join(storage_path, 'test')]) 217 | if evaluate_only: 218 | datasets['test_warm'] = dataset_cls( 219 | text_processor=self.text_processors["train"], 220 | ann_paths=[os.path.join(storage_path, 'test=warm')]) 221 | 222 | datasets['test_cold'] = dataset_cls( 223 | text_processor=self.text_processors["train"], 224 | ann_paths=[os.path.join(storage_path, 'test=cold')]) 225 | except: 226 | print(os.path.join(storage_path, 'valid_small'), os.path.exists(os.path.join(storage_path, 'valid_small_seqs.pkl'))) 227 | raise FileNotFoundError("file not found.") 228 | return datasets 229 | 230 | 231 | @registry.register_builder("amazon_ood_sasrec") 232 | class AmazonOODBuilder_sasrec(RecBaseDatasetBuilder): 233 | train_dataset_cls = AmazonOOData_sasrec 234 | 235 | DATASET_CONFIG_DICT = { 236 | "default": "configs/datasets/default.yaml", 237 | } 238 | def build_datasets(self,evaluate_only=False): 239 | # at this point, all the annotations and image/videos should be all downloaded to the specified locations. 240 | logging.info("Building datasets...") 241 | self.build_processors() 242 | 243 | build_info = self.config.build_info 244 | storage_path = build_info.storage 245 | 246 | datasets = dict() 247 | 248 | if not os.path.exists(storage_path): 249 | warnings.warn("storage path {} does not exist.".format(storage_path)) 250 | 251 | # create datasets 252 | dataset_cls = self.train_dataset_cls 253 | datasets['train'] = dataset_cls( 254 | text_processor=self.text_processors["train"], 255 | ann_paths=[os.path.join(storage_path, 'train')], 256 | ) 257 | try: 258 | datasets['valid'] = dataset_cls( 259 | text_processor=self.text_processors["train"], 260 | ann_paths=[os.path.join(storage_path, 'valid_small')]) 261 | #0915 262 | datasets['test'] = dataset_cls( 263 | text_processor=self.text_processors["train"], 264 | ann_paths=[os.path.join(storage_path, 'test')]) 265 | except: 266 | print(os.path.join(storage_path, 'valid_small'), os.path.exists(os.path.join(storage_path, 'valid_small_seqs.pkl'))) 267 | raise FileNotFoundError("file not found.") 268 | return datasets 269 | -------------------------------------------------------------------------------- /minigpt4/common/registry.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | 9 | class Registry: 10 | mapping = { 11 | "builder_name_mapping": {}, 12 | "task_name_mapping": {}, 13 | "processor_name_mapping": {}, 14 | "model_name_mapping": {}, 15 | "lr_scheduler_name_mapping": {}, 16 | "runner_name_mapping": {}, 17 | "state": {}, 18 | "paths": {}, 19 | } 20 | 21 | @classmethod 22 | def register_builder(cls, name): 23 | r"""Register a dataset builder to registry with key 'name' 24 | 25 | Args: 26 | name: Key with which the builder will be registered. 27 | 28 | Usage: 29 | 30 | from minigpt4.common.registry import registry 31 | from minigpt4.datasets.base_dataset_builder import BaseDatasetBuilder 32 | """ 33 | 34 | def wrap(builder_cls): 35 | try: 36 | 37 | from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder 38 | 39 | assert issubclass( 40 | builder_cls, BaseDatasetBuilder 41 | ), "All builders must inherit BaseDatasetBuilder class, found {}".format( 42 | builder_cls 43 | ) 44 | if name in cls.mapping["builder_name_mapping"]: 45 | raise KeyError( 46 | "Name '{}' already registered for {}.".format( 47 | name, cls.mapping["builder_name_mapping"][name] 48 | ) 49 | ) 50 | cls.mapping["builder_name_mapping"][name] = builder_cls 51 | return builder_cls 52 | except: 53 | from minigpt4.datasets.builders.rec_base_dataset_builder import RecBaseDatasetBuilder 54 | 55 | assert issubclass( 56 | builder_cls, RecBaseDatasetBuilder 57 | ), "All builders must inherit BaseDatasetBuilder class, found {}".format( 58 | builder_cls 59 | ) 60 | if name in cls.mapping["builder_name_mapping"]: 61 | raise KeyError( 62 | "Name '{}' already registered for {}.".format( 63 | name, cls.mapping["builder_name_mapping"][name] 64 | ) 65 | ) 66 | cls.mapping["builder_name_mapping"][name] = builder_cls 67 | return builder_cls 68 | 69 | return wrap 70 | 71 | @classmethod 72 | def register_task(cls, name): 73 | r"""Register a task to registry with key 'name' 74 | 75 | Args: 76 | name: Key with which the task will be registered. 77 | 78 | Usage: 79 | 80 | from minigpt4.common.registry import registry 81 | """ 82 | 83 | def wrap(task_cls): 84 | from minigpt4.tasks.base_task import BaseTask 85 | 86 | assert issubclass( 87 | task_cls, BaseTask 88 | ), "All tasks must inherit BaseTask class" 89 | if name in cls.mapping["task_name_mapping"]: 90 | raise KeyError( 91 | "Name '{}' already registered for {}.".format( 92 | name, cls.mapping["task_name_mapping"][name] 93 | ) 94 | ) 95 | cls.mapping["task_name_mapping"][name] = task_cls 96 | return task_cls 97 | 98 | return wrap 99 | 100 | @classmethod 101 | def register_model(cls, name): 102 | r"""Register a task to registry with key 'name' 103 | 104 | Args: 105 | name: Key with which the task will be registered. 106 | 107 | Usage: 108 | 109 | from minigpt4.common.registry import registry 110 | """ 111 | 112 | def wrap(model_cls): 113 | from minigpt4.models import BaseModel 114 | 115 | assert issubclass( 116 | model_cls, BaseModel 117 | ), "All models must inherit BaseModel class" 118 | if name in cls.mapping["model_name_mapping"]: 119 | raise KeyError( 120 | "Name '{}' already registered for {}.".format( 121 | name, cls.mapping["model_name_mapping"][name] 122 | ) 123 | ) 124 | cls.mapping["model_name_mapping"][name] = model_cls 125 | return model_cls 126 | 127 | return wrap 128 | 129 | @classmethod 130 | def register_processor(cls, name): 131 | r"""Register a processor to registry with key 'name' 132 | 133 | Args: 134 | name: Key with which the task will be registered. 135 | 136 | Usage: 137 | 138 | from minigpt4.common.registry import registry 139 | """ 140 | 141 | def wrap(processor_cls): 142 | from minigpt4.processors import BaseProcessor 143 | 144 | assert issubclass( 145 | processor_cls, BaseProcessor 146 | ), "All processors must inherit BaseProcessor class" 147 | if name in cls.mapping["processor_name_mapping"]: 148 | raise KeyError( 149 | "Name '{}' already registered for {}.".format( 150 | name, cls.mapping["processor_name_mapping"][name] 151 | ) 152 | ) 153 | cls.mapping["processor_name_mapping"][name] = processor_cls 154 | return processor_cls 155 | 156 | return wrap 157 | 158 | @classmethod 159 | def register_lr_scheduler(cls, name): 160 | r"""Register a model to registry with key 'name' 161 | 162 | Args: 163 | name: Key with which the task will be registered. 164 | 165 | Usage: 166 | 167 | from minigpt4.common.registry import registry 168 | """ 169 | 170 | def wrap(lr_sched_cls): 171 | if name in cls.mapping["lr_scheduler_name_mapping"]: 172 | raise KeyError( 173 | "Name '{}' already registered for {}.".format( 174 | name, cls.mapping["lr_scheduler_name_mapping"][name] 175 | ) 176 | ) 177 | cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls 178 | return lr_sched_cls 179 | 180 | return wrap 181 | 182 | @classmethod 183 | def register_runner(cls, name): 184 | r"""Register a model to registry with key 'name' 185 | 186 | Args: 187 | name: Key with which the task will be registered. 188 | 189 | Usage: 190 | 191 | from minigpt4.common.registry import registry 192 | """ 193 | 194 | def wrap(runner_cls): 195 | if name in cls.mapping["runner_name_mapping"]: 196 | raise KeyError( 197 | "Name '{}' already registered for {}.".format( 198 | name, cls.mapping["runner_name_mapping"][name] 199 | ) 200 | ) 201 | cls.mapping["runner_name_mapping"][name] = runner_cls 202 | return runner_cls 203 | 204 | return wrap 205 | 206 | @classmethod 207 | def register_path(cls, name, path): 208 | r"""Register a path to registry with key 'name' 209 | 210 | Args: 211 | name: Key with which the path will be registered. 212 | 213 | Usage: 214 | 215 | from minigpt4.common.registry import registry 216 | """ 217 | assert isinstance(path, str), "All path must be str." 218 | if name in cls.mapping["paths"]: 219 | raise KeyError("Name '{}' already registered.".format(name)) 220 | cls.mapping["paths"][name] = path 221 | 222 | @classmethod 223 | def register(cls, name, obj): 224 | r"""Register an item to registry with key 'name' 225 | 226 | Args: 227 | name: Key with which the item will be registered. 228 | 229 | Usage:: 230 | 231 | from minigpt4.common.registry import registry 232 | 233 | registry.register("config", {}) 234 | """ 235 | path = name.split(".") 236 | current = cls.mapping["state"] 237 | 238 | for part in path[:-1]: 239 | if part not in current: 240 | current[part] = {} 241 | current = current[part] 242 | 243 | current[path[-1]] = obj 244 | 245 | # @classmethod 246 | # def get_trainer_class(cls, name): 247 | # return cls.mapping["trainer_name_mapping"].get(name, None) 248 | 249 | @classmethod 250 | def get_builder_class(cls, name): 251 | return cls.mapping["builder_name_mapping"].get(name, None) 252 | 253 | @classmethod 254 | def get_model_class(cls, name): 255 | return cls.mapping["model_name_mapping"].get(name, None) 256 | 257 | @classmethod 258 | def get_task_class(cls, name): 259 | return cls.mapping["task_name_mapping"].get(name, None) 260 | 261 | @classmethod 262 | def get_processor_class(cls, name): 263 | return cls.mapping["processor_name_mapping"].get(name, None) 264 | 265 | @classmethod 266 | def get_lr_scheduler_class(cls, name): 267 | return cls.mapping["lr_scheduler_name_mapping"].get(name, None) 268 | 269 | @classmethod 270 | def get_runner_class(cls, name): 271 | return cls.mapping["runner_name_mapping"].get(name, None) 272 | 273 | @classmethod 274 | def list_runners(cls): 275 | return sorted(cls.mapping["runner_name_mapping"].keys()) 276 | 277 | @classmethod 278 | def list_models(cls): 279 | return sorted(cls.mapping["model_name_mapping"].keys()) 280 | 281 | @classmethod 282 | def list_tasks(cls): 283 | return sorted(cls.mapping["task_name_mapping"].keys()) 284 | 285 | @classmethod 286 | def list_processors(cls): 287 | return sorted(cls.mapping["processor_name_mapping"].keys()) 288 | 289 | @classmethod 290 | def list_lr_schedulers(cls): 291 | return sorted(cls.mapping["lr_scheduler_name_mapping"].keys()) 292 | 293 | @classmethod 294 | def list_datasets(cls): 295 | return sorted(cls.mapping["builder_name_mapping"].keys()) 296 | 297 | @classmethod 298 | def get_path(cls, name): 299 | return cls.mapping["paths"].get(name, None) 300 | 301 | @classmethod 302 | def get(cls, name, default=None, no_warning=False): 303 | r"""Get an item from registry with key 'name' 304 | 305 | Args: 306 | name (string): Key whose value needs to be retrieved. 307 | default: If passed and key is not in registry, default value will 308 | be returned with a warning. Default: None 309 | no_warning (bool): If passed as True, warning when key doesn't exist 310 | will not be generated. Useful for MMF's 311 | internal operations. Default: False 312 | """ 313 | original_name = name 314 | name = name.split(".") 315 | value = cls.mapping["state"] 316 | for subname in name: 317 | value = value.get(subname, default) 318 | if value is default: 319 | break 320 | 321 | if ( 322 | "writer" in cls.mapping["state"] 323 | and value == default 324 | and no_warning is False 325 | ): 326 | cls.mapping["state"]["writer"].warning( 327 | "Key {} is not present in registry, returning default value " 328 | "of {}".format(original_name, default) 329 | ) 330 | return value 331 | 332 | @classmethod 333 | def unregister(cls, name): 334 | r"""Remove an item from registry with key 'name' 335 | 336 | Args: 337 | name: Key which needs to be removed. 338 | Usage:: 339 | 340 | from mmf.common.registry import registry 341 | 342 | config = registry.unregister("config") 343 | """ 344 | return cls.mapping["state"].pop(name, None) 345 | 346 | 347 | registry = Registry() 348 | -------------------------------------------------------------------------------- /minigpt4/processors/randaugment.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import cv2 9 | import numpy as np 10 | 11 | import torch 12 | 13 | 14 | ## aug functions 15 | def identity_func(img): 16 | return img 17 | 18 | 19 | def autocontrast_func(img, cutoff=0): 20 | """ 21 | same output as PIL.ImageOps.autocontrast 22 | """ 23 | n_bins = 256 24 | 25 | def tune_channel(ch): 26 | n = ch.size 27 | cut = cutoff * n // 100 28 | if cut == 0: 29 | high, low = ch.max(), ch.min() 30 | else: 31 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) 32 | low = np.argwhere(np.cumsum(hist) > cut) 33 | low = 0 if low.shape[0] == 0 else low[0] 34 | high = np.argwhere(np.cumsum(hist[::-1]) > cut) 35 | high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0] 36 | if high <= low: 37 | table = np.arange(n_bins) 38 | else: 39 | scale = (n_bins - 1) / (high - low) 40 | offset = -low * scale 41 | table = np.arange(n_bins) * scale + offset 42 | table[table < 0] = 0 43 | table[table > n_bins - 1] = n_bins - 1 44 | table = table.clip(0, 255).astype(np.uint8) 45 | return table[ch] 46 | 47 | channels = [tune_channel(ch) for ch in cv2.split(img)] 48 | out = cv2.merge(channels) 49 | return out 50 | 51 | 52 | def equalize_func(img): 53 | """ 54 | same output as PIL.ImageOps.equalize 55 | PIL's implementation is different from cv2.equalize 56 | """ 57 | n_bins = 256 58 | 59 | def tune_channel(ch): 60 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) 61 | non_zero_hist = hist[hist != 0].reshape(-1) 62 | step = np.sum(non_zero_hist[:-1]) // (n_bins - 1) 63 | if step == 0: 64 | return ch 65 | n = np.empty_like(hist) 66 | n[0] = step // 2 67 | n[1:] = hist[:-1] 68 | table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8) 69 | return table[ch] 70 | 71 | channels = [tune_channel(ch) for ch in cv2.split(img)] 72 | out = cv2.merge(channels) 73 | return out 74 | 75 | 76 | def rotate_func(img, degree, fill=(0, 0, 0)): 77 | """ 78 | like PIL, rotate by degree, not radians 79 | """ 80 | H, W = img.shape[0], img.shape[1] 81 | center = W / 2, H / 2 82 | M = cv2.getRotationMatrix2D(center, degree, 1) 83 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill) 84 | return out 85 | 86 | 87 | def solarize_func(img, thresh=128): 88 | """ 89 | same output as PIL.ImageOps.posterize 90 | """ 91 | table = np.array([el if el < thresh else 255 - el for el in range(256)]) 92 | table = table.clip(0, 255).astype(np.uint8) 93 | out = table[img] 94 | return out 95 | 96 | 97 | def color_func(img, factor): 98 | """ 99 | same output as PIL.ImageEnhance.Color 100 | """ 101 | ## implementation according to PIL definition, quite slow 102 | # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis] 103 | # out = blend(degenerate, img, factor) 104 | # M = ( 105 | # np.eye(3) * factor 106 | # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor) 107 | # )[np.newaxis, np.newaxis, :] 108 | M = np.float32( 109 | [[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]] 110 | ) * factor + np.float32([[0.114], [0.587], [0.299]]) 111 | out = np.matmul(img, M).clip(0, 255).astype(np.uint8) 112 | return out 113 | 114 | 115 | def contrast_func(img, factor): 116 | """ 117 | same output as PIL.ImageEnhance.Contrast 118 | """ 119 | mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299])) 120 | table = ( 121 | np.array([(el - mean) * factor + mean for el in range(256)]) 122 | .clip(0, 255) 123 | .astype(np.uint8) 124 | ) 125 | out = table[img] 126 | return out 127 | 128 | 129 | def brightness_func(img, factor): 130 | """ 131 | same output as PIL.ImageEnhance.Contrast 132 | """ 133 | table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8) 134 | out = table[img] 135 | return out 136 | 137 | 138 | def sharpness_func(img, factor): 139 | """ 140 | The differences the this result and PIL are all on the 4 boundaries, the center 141 | areas are same 142 | """ 143 | kernel = np.ones((3, 3), dtype=np.float32) 144 | kernel[1][1] = 5 145 | kernel /= 13 146 | degenerate = cv2.filter2D(img, -1, kernel) 147 | if factor == 0.0: 148 | out = degenerate 149 | elif factor == 1.0: 150 | out = img 151 | else: 152 | out = img.astype(np.float32) 153 | degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :] 154 | out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate) 155 | out = out.astype(np.uint8) 156 | return out 157 | 158 | 159 | def shear_x_func(img, factor, fill=(0, 0, 0)): 160 | H, W = img.shape[0], img.shape[1] 161 | M = np.float32([[1, factor, 0], [0, 1, 0]]) 162 | out = cv2.warpAffine( 163 | img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR 164 | ).astype(np.uint8) 165 | return out 166 | 167 | 168 | def translate_x_func(img, offset, fill=(0, 0, 0)): 169 | """ 170 | same output as PIL.Image.transform 171 | """ 172 | H, W = img.shape[0], img.shape[1] 173 | M = np.float32([[1, 0, -offset], [0, 1, 0]]) 174 | out = cv2.warpAffine( 175 | img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR 176 | ).astype(np.uint8) 177 | return out 178 | 179 | 180 | def translate_y_func(img, offset, fill=(0, 0, 0)): 181 | """ 182 | same output as PIL.Image.transform 183 | """ 184 | H, W = img.shape[0], img.shape[1] 185 | M = np.float32([[1, 0, 0], [0, 1, -offset]]) 186 | out = cv2.warpAffine( 187 | img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR 188 | ).astype(np.uint8) 189 | return out 190 | 191 | 192 | def posterize_func(img, bits): 193 | """ 194 | same output as PIL.ImageOps.posterize 195 | """ 196 | out = np.bitwise_and(img, np.uint8(255 << (8 - bits))) 197 | return out 198 | 199 | 200 | def shear_y_func(img, factor, fill=(0, 0, 0)): 201 | H, W = img.shape[0], img.shape[1] 202 | M = np.float32([[1, 0, 0], [factor, 1, 0]]) 203 | out = cv2.warpAffine( 204 | img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR 205 | ).astype(np.uint8) 206 | return out 207 | 208 | 209 | def cutout_func(img, pad_size, replace=(0, 0, 0)): 210 | replace = np.array(replace, dtype=np.uint8) 211 | H, W = img.shape[0], img.shape[1] 212 | rh, rw = np.random.random(2) 213 | pad_size = pad_size // 2 214 | ch, cw = int(rh * H), int(rw * W) 215 | x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H) 216 | y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W) 217 | out = img.copy() 218 | out[x1:x2, y1:y2, :] = replace 219 | return out 220 | 221 | 222 | ### level to args 223 | def enhance_level_to_args(MAX_LEVEL): 224 | def level_to_args(level): 225 | return ((level / MAX_LEVEL) * 1.8 + 0.1,) 226 | 227 | return level_to_args 228 | 229 | 230 | def shear_level_to_args(MAX_LEVEL, replace_value): 231 | def level_to_args(level): 232 | level = (level / MAX_LEVEL) * 0.3 233 | if np.random.random() > 0.5: 234 | level = -level 235 | return (level, replace_value) 236 | 237 | return level_to_args 238 | 239 | 240 | def translate_level_to_args(translate_const, MAX_LEVEL, replace_value): 241 | def level_to_args(level): 242 | level = (level / MAX_LEVEL) * float(translate_const) 243 | if np.random.random() > 0.5: 244 | level = -level 245 | return (level, replace_value) 246 | 247 | return level_to_args 248 | 249 | 250 | def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value): 251 | def level_to_args(level): 252 | level = int((level / MAX_LEVEL) * cutout_const) 253 | return (level, replace_value) 254 | 255 | return level_to_args 256 | 257 | 258 | def solarize_level_to_args(MAX_LEVEL): 259 | def level_to_args(level): 260 | level = int((level / MAX_LEVEL) * 256) 261 | return (level,) 262 | 263 | return level_to_args 264 | 265 | 266 | def none_level_to_args(level): 267 | return () 268 | 269 | 270 | def posterize_level_to_args(MAX_LEVEL): 271 | def level_to_args(level): 272 | level = int((level / MAX_LEVEL) * 4) 273 | return (level,) 274 | 275 | return level_to_args 276 | 277 | 278 | def rotate_level_to_args(MAX_LEVEL, replace_value): 279 | def level_to_args(level): 280 | level = (level / MAX_LEVEL) * 30 281 | if np.random.random() < 0.5: 282 | level = -level 283 | return (level, replace_value) 284 | 285 | return level_to_args 286 | 287 | 288 | func_dict = { 289 | "Identity": identity_func, 290 | "AutoContrast": autocontrast_func, 291 | "Equalize": equalize_func, 292 | "Rotate": rotate_func, 293 | "Solarize": solarize_func, 294 | "Color": color_func, 295 | "Contrast": contrast_func, 296 | "Brightness": brightness_func, 297 | "Sharpness": sharpness_func, 298 | "ShearX": shear_x_func, 299 | "TranslateX": translate_x_func, 300 | "TranslateY": translate_y_func, 301 | "Posterize": posterize_func, 302 | "ShearY": shear_y_func, 303 | } 304 | 305 | translate_const = 10 306 | MAX_LEVEL = 10 307 | replace_value = (128, 128, 128) 308 | arg_dict = { 309 | "Identity": none_level_to_args, 310 | "AutoContrast": none_level_to_args, 311 | "Equalize": none_level_to_args, 312 | "Rotate": rotate_level_to_args(MAX_LEVEL, replace_value), 313 | "Solarize": solarize_level_to_args(MAX_LEVEL), 314 | "Color": enhance_level_to_args(MAX_LEVEL), 315 | "Contrast": enhance_level_to_args(MAX_LEVEL), 316 | "Brightness": enhance_level_to_args(MAX_LEVEL), 317 | "Sharpness": enhance_level_to_args(MAX_LEVEL), 318 | "ShearX": shear_level_to_args(MAX_LEVEL, replace_value), 319 | "TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value), 320 | "TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value), 321 | "Posterize": posterize_level_to_args(MAX_LEVEL), 322 | "ShearY": shear_level_to_args(MAX_LEVEL, replace_value), 323 | } 324 | 325 | 326 | class RandomAugment(object): 327 | def __init__(self, N=2, M=10, isPIL=False, augs=[]): 328 | self.N = N 329 | self.M = M 330 | self.isPIL = isPIL 331 | if augs: 332 | self.augs = augs 333 | else: 334 | self.augs = list(arg_dict.keys()) 335 | 336 | def get_random_ops(self): 337 | sampled_ops = np.random.choice(self.augs, self.N) 338 | return [(op, 0.5, self.M) for op in sampled_ops] 339 | 340 | def __call__(self, img): 341 | if self.isPIL: 342 | img = np.array(img) 343 | ops = self.get_random_ops() 344 | for name, prob, level in ops: 345 | if np.random.random() > prob: 346 | continue 347 | args = arg_dict[name](level) 348 | img = func_dict[name](img, *args) 349 | return img 350 | 351 | 352 | class VideoRandomAugment(object): 353 | def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]): 354 | self.N = N 355 | self.M = M 356 | self.p = p 357 | self.tensor_in_tensor_out = tensor_in_tensor_out 358 | if augs: 359 | self.augs = augs 360 | else: 361 | self.augs = list(arg_dict.keys()) 362 | 363 | def get_random_ops(self): 364 | sampled_ops = np.random.choice(self.augs, self.N, replace=False) 365 | return [(op, self.M) for op in sampled_ops] 366 | 367 | def __call__(self, frames): 368 | assert ( 369 | frames.shape[-1] == 3 370 | ), "Expecting last dimension for 3-channels RGB (b, h, w, c)." 371 | 372 | if self.tensor_in_tensor_out: 373 | frames = frames.numpy().astype(np.uint8) 374 | 375 | num_frames = frames.shape[0] 376 | 377 | ops = num_frames * [self.get_random_ops()] 378 | apply_or_not = num_frames * [np.random.random(size=self.N) > self.p] 379 | 380 | frames = torch.stack( 381 | list(map(self._aug, frames, ops, apply_or_not)), dim=0 382 | ).float() 383 | 384 | return frames 385 | 386 | def _aug(self, img, ops, apply_or_not): 387 | for i, (name, level) in enumerate(ops): 388 | if not apply_or_not[i]: 389 | continue 390 | args = arg_dict[name](level) 391 | img = func_dict[name](img, *args) 392 | return torch.from_numpy(img) 393 | 394 | 395 | if __name__ == "__main__": 396 | a = RandomAugment() 397 | img = np.random.randn(32, 32, 3) 398 | a(img) 399 | -------------------------------------------------------------------------------- /minigpt4/common/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import datetime 9 | import logging 10 | import time 11 | from collections import defaultdict, deque 12 | 13 | import torch 14 | import torch.distributed as dist 15 | 16 | from minigpt4.common import dist_utils 17 | from sklearn.metrics import roc_auc_score 18 | 19 | 20 | class SmoothedValue(object): 21 | """Track a series of values and provide access to smoothed values over a 22 | window or the global series average. 23 | """ 24 | 25 | def __init__(self, window_size=20, fmt=None): 26 | if fmt is None: 27 | fmt = "{median:.4f} ({global_avg:.4f})" 28 | self.deque = deque(maxlen=window_size) 29 | self.total = 0.0 30 | self.count = 0 31 | self.fmt = fmt 32 | 33 | def update(self, value, n=1): 34 | self.deque.append(value) 35 | self.count += n 36 | self.total += value * n 37 | 38 | def synchronize_between_processes(self): 39 | """ 40 | Warning: does not synchronize the deque! 41 | """ 42 | if not dist_utils.is_dist_avail_and_initialized(): 43 | return 44 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") 45 | dist.barrier() 46 | dist.all_reduce(t) 47 | t = t.tolist() 48 | self.count = int(t[0]) 49 | self.total = t[1] 50 | 51 | @property 52 | def median(self): 53 | d = torch.tensor(list(self.deque)) 54 | return d.median().item() 55 | 56 | @property 57 | def avg(self): 58 | d = torch.tensor(list(self.deque), dtype=torch.float32) 59 | return d.mean().item() 60 | 61 | @property 62 | def global_avg(self): 63 | return self.total / self.count 64 | 65 | @property 66 | def max(self): 67 | return max(self.deque) 68 | 69 | @property 70 | def value(self): 71 | return self.deque[-1] 72 | 73 | def __str__(self): 74 | return self.fmt.format( 75 | median=self.median, 76 | avg=self.avg, 77 | global_avg=self.global_avg, 78 | max=self.max, 79 | value=self.value, 80 | ) 81 | 82 | 83 | class MetricLogger(object): 84 | def __init__(self, delimiter="\t"): 85 | self.meters = defaultdict(SmoothedValue) 86 | self.delimiter = delimiter 87 | 88 | def update(self, **kwargs): 89 | for k, v in kwargs.items(): 90 | if isinstance(v, torch.Tensor): 91 | v = v.item() 92 | assert isinstance(v, (float, int)) 93 | self.meters[k].update(v) 94 | 95 | def __getattr__(self, attr): 96 | if attr in self.meters: 97 | return self.meters[attr] 98 | if attr in self.__dict__: 99 | return self.__dict__[attr] 100 | raise AttributeError( 101 | "'{}' object has no attribute '{}'".format(type(self).__name__, attr) 102 | ) 103 | 104 | def __str__(self): 105 | loss_str = [] 106 | for name, meter in self.meters.items(): 107 | loss_str.append("{}: {}".format(name, str(meter))) 108 | return self.delimiter.join(loss_str) 109 | 110 | def global_avg(self): 111 | loss_str = [] 112 | for name, meter in self.meters.items(): 113 | loss_str.append("{}: {:.6f}".format(name, meter.global_avg)) 114 | return self.delimiter.join(loss_str) 115 | 116 | def synchronize_between_processes(self): 117 | for meter in self.meters.values(): 118 | meter.synchronize_between_processes() 119 | 120 | def add_meter(self, name, meter): 121 | self.meters[name] = meter 122 | 123 | def log_every(self, iterable, print_freq, header=None): 124 | i = 0 125 | if not header: 126 | header = "" 127 | start_time = time.time() 128 | end = time.time() 129 | iter_time = SmoothedValue(fmt="{avg:.4f}") 130 | data_time = SmoothedValue(fmt="{avg:.4f}") 131 | space_fmt = ":" + str(len(str(len(iterable)))) + "d" 132 | log_msg = [ 133 | header, 134 | "[{0" + space_fmt + "}/{1}]", 135 | "eta: {eta}", 136 | "{meters}", 137 | "time: {time}", 138 | "data: {data}", 139 | ] 140 | if torch.cuda.is_available(): 141 | log_msg.append("max mem: {memory:.0f}") 142 | log_msg = self.delimiter.join(log_msg) 143 | MB = 1024.0 * 1024.0 144 | for obj in iterable: 145 | data_time.update(time.time() - end) 146 | yield obj 147 | iter_time.update(time.time() - end) 148 | if i % print_freq == 0 or i == len(iterable) - 1: 149 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 150 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 151 | if torch.cuda.is_available(): 152 | print( 153 | log_msg.format( 154 | i, 155 | len(iterable), 156 | eta=eta_string, 157 | meters=str(self), 158 | time=str(iter_time), 159 | data=str(data_time), 160 | memory=torch.cuda.max_memory_allocated() / MB, 161 | ) 162 | ) 163 | else: 164 | print( 165 | log_msg.format( 166 | i, 167 | len(iterable), 168 | eta=eta_string, 169 | meters=str(self), 170 | time=str(iter_time), 171 | data=str(data_time), 172 | ) 173 | ) 174 | i += 1 175 | end = time.time() 176 | total_time = time.time() - start_time 177 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 178 | print( 179 | "{} Total time: {} ({:.4f} s / it)".format( 180 | header, total_time_str, total_time / len(iterable) 181 | ) 182 | ) 183 | 184 | 185 | 186 | class AttrDict(dict): 187 | def __init__(self, *args, **kwargs): 188 | super(AttrDict, self).__init__(*args, **kwargs) 189 | self.__dict__ = self 190 | 191 | 192 | def setup_logger(): 193 | logging.basicConfig( 194 | level=logging.INFO if dist_utils.is_main_process() else logging.WARN, 195 | format="%(asctime)s [%(levelname)s] %(message)s", 196 | handlers=[logging.StreamHandler()], 197 | ) 198 | 199 | 200 | class SmoothedValue_v2(object): 201 | """Track a series of values and provide access to smoothed values over a 202 | window or the global series average. 203 | """ 204 | 205 | def __init__(self, window_size=20, fmt=None): 206 | if fmt is None: 207 | fmt = "{median:.4f} ({global_avg:.4f})" 208 | self.deque = deque(maxlen=window_size) 209 | self.total = [] 210 | self.count = 0 211 | self.fmt = fmt 212 | 213 | def update(self, value, n=1): 214 | self.deque.append(value) 215 | self.count += n 216 | self.total.extend(value * n) 217 | 218 | def synchronize_between_processes(self): 219 | """ 220 | Warning: does not synchronize the deque! 221 | """ 222 | if not dist_utils.is_dist_avail_and_initialized(): 223 | return 224 | t1 = torch.tensor([self.count], dtype=torch.float64, device="cuda") 225 | t2 = torch.tensor([self.total], dtype=torch.float64, device="cuda") 226 | t = torch.cat([t1,t2]) 227 | dist.barrier() 228 | dist.all_reduce(t) 229 | t = t.tolist() 230 | self.count = int(t[0]) 231 | self.total = t[1:] 232 | 233 | @property 234 | def median(self): 235 | d = torch.tensor(list(self.deque)) 236 | return d.median().item() 237 | 238 | @property 239 | def avg(self): 240 | d = torch.tensor(list(self.deque), dtype=torch.float32) 241 | return d.mean().item() 242 | 243 | @property 244 | def global_avg(self): 245 | return sum(self.total) / self.count 246 | 247 | @property 248 | def max(self): 249 | return max(self.deque) 250 | 251 | @property 252 | def value(self): 253 | return self.deque[-1] 254 | 255 | def __str__(self): 256 | return self.fmt.format( 257 | median=self.median, 258 | avg=self.avg, 259 | global_avg=self.global_avg, 260 | max=self.max, 261 | value=self.value, 262 | ) 263 | 264 | 265 | class MetricLogger_auc(object): 266 | def __init__(self, delimiter="\t"): 267 | self.meters = defaultdict(SmoothedValue_v2) 268 | self.delimiter = delimiter 269 | 270 | def update(self, **kwargs): 271 | for k, v in kwargs.items(): 272 | if isinstance(v, torch.Tensor): 273 | try: 274 | v = v.item() 275 | assert isinstance(v, (float, int)) 276 | except: 277 | v = v.detach().cpu().numpy() 278 | self.meters[k].update(v) 279 | 280 | def __getattr__(self, attr): 281 | if attr in self.meters: 282 | return self.meters[attr] 283 | if attr in self.__dict__: 284 | return self.__dict__[attr] 285 | raise AttributeError( 286 | "'{}' object has no attribute '{}'".format(type(self).__name__, attr) 287 | ) 288 | 289 | def __str__(self): 290 | loss_str = [] 291 | for name, meter in self.meters.items(): 292 | loss_str.append("{}: {}".format(name, str(meter))) 293 | return self.delimiter.join(loss_str) 294 | 295 | def global_avg(self): 296 | loss_str = [] 297 | for name, meter in self.meters.items(): 298 | loss_str.append("{}: {:.4f}".format(name, meter.global_avg)) 299 | return self.delimiter.join(loss_str) 300 | 301 | def global_report(self): 302 | loss_str = [] 303 | for name, meter in self.meters.items(): 304 | if name not in ['logits', 'labels']: 305 | loss_str.append("{}: {:.4f}".format(name, meter.global_avg())) 306 | auc = roc_auc_score(self.meters['labels'], self.meters['logits']) 307 | loss_str.append("{}: {:.4f}".format('auc', auc)) 308 | return self.delimiter.join(loss_str) 309 | 310 | def synchronize_between_processes(self): 311 | for meter in self.meters.values(): 312 | meter.synchronize_between_processes() 313 | 314 | def add_meter(self, name, meter): 315 | self.meters[name] = meter 316 | 317 | def log_every(self, iterable, print_freq, header=None): 318 | i = 0 319 | if not header: 320 | header = "" 321 | start_time = time.time() 322 | end = time.time() 323 | iter_time = SmoothedValue(fmt="{avg:.4f}") 324 | data_time = SmoothedValue(fmt="{avg:.4f}") 325 | space_fmt = ":" + str(len(str(len(iterable)))) + "d" 326 | log_msg = [ 327 | header, 328 | "[{0" + space_fmt + "}/{1}]", 329 | "eta: {eta}", 330 | "{meters}", 331 | "time: {time}", 332 | "data: {data}", 333 | ] 334 | if torch.cuda.is_available(): 335 | log_msg.append("max mem: {memory:.0f}") 336 | log_msg = self.delimiter.join(log_msg) 337 | MB = 1024.0 * 1024.0 338 | for obj in iterable: 339 | data_time.update(time.time() - end) 340 | yield obj 341 | iter_time.update(time.time() - end) 342 | if i % print_freq == 0 or i == len(iterable) - 1: 343 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 344 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 345 | if torch.cuda.is_available(): 346 | print( 347 | log_msg.format( 348 | i, 349 | len(iterable), 350 | eta=eta_string, 351 | meters=str(self), 352 | time=str(iter_time), 353 | data=str(data_time), 354 | memory=torch.cuda.max_memory_allocated() / MB, 355 | ) 356 | ) 357 | else: 358 | print( 359 | log_msg.format( 360 | i, 361 | len(iterable), 362 | eta=eta_string, 363 | meters=str(self), 364 | time=str(iter_time), 365 | data=str(data_time), 366 | ) 367 | ) 368 | i += 1 369 | end = time.time() 370 | total_time = time.time() - start_time 371 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 372 | print( 373 | "{} Total time: {} ({:.4f} s / it)".format( 374 | header, total_time_str, total_time / len(iterable) 375 | ) 376 | ) -------------------------------------------------------------------------------- /minigpt4/common/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import io 9 | import json 10 | import logging 11 | import os 12 | import pickle 13 | import re 14 | import shutil 15 | import urllib 16 | import urllib.error 17 | import urllib.request 18 | from typing import Optional 19 | from urllib.parse import urlparse 20 | 21 | import numpy as np 22 | import pandas as pd 23 | import yaml 24 | from iopath.common.download import download 25 | from iopath.common.file_io import file_lock, g_pathmgr 26 | from minigpt4.common.registry import registry 27 | from torch.utils.model_zoo import tqdm 28 | from torchvision.datasets.utils import ( 29 | check_integrity, 30 | download_file_from_google_drive, 31 | extract_archive, 32 | ) 33 | 34 | 35 | def now(): 36 | from datetime import datetime 37 | 38 | return datetime.now().strftime("%Y%m%d%H%M")[:-1] 39 | 40 | 41 | def is_url(url_or_filename): 42 | parsed = urlparse(url_or_filename) 43 | return parsed.scheme in ("http", "https") 44 | 45 | 46 | def get_cache_path(rel_path): 47 | return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path)) 48 | 49 | 50 | def get_abs_path(rel_path): 51 | return os.path.join(registry.get_path("library_root"), rel_path) 52 | 53 | 54 | def load_json(filename): 55 | with open(filename, "r") as f: 56 | return json.load(f) 57 | 58 | 59 | # The following are adapted from torchvision and vissl 60 | # torchvision: https://github.com/pytorch/vision 61 | # vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py 62 | 63 | 64 | def makedir(dir_path): 65 | """ 66 | Create the directory if it does not exist. 67 | """ 68 | is_success = False 69 | try: 70 | if not g_pathmgr.exists(dir_path): 71 | g_pathmgr.mkdirs(dir_path) 72 | is_success = True 73 | except BaseException: 74 | print(f"Error creating directory: {dir_path}") 75 | return is_success 76 | 77 | 78 | def get_redirected_url(url: str): 79 | """ 80 | Given a URL, returns the URL it redirects to or the 81 | original URL in case of no indirection 82 | """ 83 | import requests 84 | 85 | with requests.Session() as session: 86 | with session.get(url, stream=True, allow_redirects=True) as response: 87 | if response.history: 88 | return response.url 89 | else: 90 | return url 91 | 92 | 93 | def to_google_drive_download_url(view_url: str) -> str: 94 | """ 95 | Utility function to transform a view URL of google drive 96 | to a download URL for google drive 97 | Example input: 98 | https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view 99 | Example output: 100 | https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp 101 | """ 102 | splits = view_url.split("/") 103 | assert splits[-1] == "view" 104 | file_id = splits[-2] 105 | return f"https://drive.google.com/uc?export=download&id={file_id}" 106 | 107 | 108 | def download_google_drive_url(url: str, output_path: str, output_file_name: str): 109 | """ 110 | Download a file from google drive 111 | Downloading an URL from google drive requires confirmation when 112 | the file of the size is too big (google drive notifies that 113 | anti-viral checks cannot be performed on such files) 114 | """ 115 | import requests 116 | 117 | with requests.Session() as session: 118 | 119 | # First get the confirmation token and append it to the URL 120 | with session.get(url, stream=True, allow_redirects=True) as response: 121 | for k, v in response.cookies.items(): 122 | if k.startswith("download_warning"): 123 | url = url + "&confirm=" + v 124 | 125 | # Then download the content of the file 126 | with session.get(url, stream=True, verify=True) as response: 127 | makedir(output_path) 128 | path = os.path.join(output_path, output_file_name) 129 | total_size = int(response.headers.get("Content-length", 0)) 130 | with open(path, "wb") as file: 131 | from tqdm import tqdm 132 | 133 | with tqdm(total=total_size) as progress_bar: 134 | for block in response.iter_content( 135 | chunk_size=io.DEFAULT_BUFFER_SIZE 136 | ): 137 | file.write(block) 138 | progress_bar.update(len(block)) 139 | 140 | 141 | def _get_google_drive_file_id(url: str) -> Optional[str]: 142 | parts = urlparse(url) 143 | 144 | if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None: 145 | return None 146 | 147 | match = re.match(r"/file/d/(?P[^/]*)", parts.path) 148 | if match is None: 149 | return None 150 | 151 | return match.group("id") 152 | 153 | 154 | def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None: 155 | with open(filename, "wb") as fh: 156 | with urllib.request.urlopen( 157 | urllib.request.Request(url, headers={"User-Agent": "vissl"}) 158 | ) as response: 159 | with tqdm(total=response.length) as pbar: 160 | for chunk in iter(lambda: response.read(chunk_size), ""): 161 | if not chunk: 162 | break 163 | pbar.update(chunk_size) 164 | fh.write(chunk) 165 | 166 | 167 | def download_url( 168 | url: str, 169 | root: str, 170 | filename: Optional[str] = None, 171 | md5: Optional[str] = None, 172 | ) -> None: 173 | """Download a file from a url and place it in root. 174 | Args: 175 | url (str): URL to download file from 176 | root (str): Directory to place downloaded file in 177 | filename (str, optional): Name to save the file under. 178 | If None, use the basename of the URL. 179 | md5 (str, optional): MD5 checksum of the download. If None, do not check 180 | """ 181 | root = os.path.expanduser(root) 182 | if not filename: 183 | filename = os.path.basename(url) 184 | fpath = os.path.join(root, filename) 185 | 186 | makedir(root) 187 | 188 | # check if file is already present locally 189 | if check_integrity(fpath, md5): 190 | print("Using downloaded and verified file: " + fpath) 191 | return 192 | 193 | # expand redirect chain if needed 194 | url = get_redirected_url(url) 195 | 196 | # check if file is located on Google Drive 197 | file_id = _get_google_drive_file_id(url) 198 | if file_id is not None: 199 | return download_file_from_google_drive(file_id, root, filename, md5) 200 | 201 | # download the file 202 | try: 203 | print("Downloading " + url + " to " + fpath) 204 | _urlretrieve(url, fpath) 205 | except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined] 206 | if url[:5] == "https": 207 | url = url.replace("https:", "http:") 208 | print( 209 | "Failed download. Trying https -> http instead." 210 | " Downloading " + url + " to " + fpath 211 | ) 212 | _urlretrieve(url, fpath) 213 | else: 214 | raise e 215 | 216 | # check integrity of downloaded file 217 | if not check_integrity(fpath, md5): 218 | raise RuntimeError("File not found or corrupted.") 219 | 220 | 221 | def download_and_extract_archive( 222 | url: str, 223 | download_root: str, 224 | extract_root: Optional[str] = None, 225 | filename: Optional[str] = None, 226 | md5: Optional[str] = None, 227 | remove_finished: bool = False, 228 | ) -> None: 229 | download_root = os.path.expanduser(download_root) 230 | if extract_root is None: 231 | extract_root = download_root 232 | if not filename: 233 | filename = os.path.basename(url) 234 | 235 | download_url(url, download_root, filename, md5) 236 | 237 | archive = os.path.join(download_root, filename) 238 | print("Extracting {} to {}".format(archive, extract_root)) 239 | extract_archive(archive, extract_root, remove_finished) 240 | 241 | 242 | def cache_url(url: str, cache_dir: str) -> str: 243 | """ 244 | This implementation downloads the remote resource and caches it locally. 245 | The resource will only be downloaded if not previously requested. 246 | """ 247 | parsed_url = urlparse(url) 248 | dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/"))) 249 | makedir(dirname) 250 | filename = url.split("/")[-1] 251 | cached = os.path.join(dirname, filename) 252 | with file_lock(cached): 253 | if not os.path.isfile(cached): 254 | logging.info(f"Downloading {url} to {cached} ...") 255 | cached = download(url, dirname, filename=filename) 256 | logging.info(f"URL {url} cached in {cached}") 257 | return cached 258 | 259 | 260 | # TODO (prigoyal): convert this into RAII-style API 261 | def create_file_symlink(file1, file2): 262 | """ 263 | Simply create the symlinks for a given file1 to file2. 264 | Useful during model checkpointing to symlinks to the 265 | latest successful checkpoint. 266 | """ 267 | try: 268 | if g_pathmgr.exists(file2): 269 | g_pathmgr.rm(file2) 270 | g_pathmgr.symlink(file1, file2) 271 | except Exception as e: 272 | logging.info(f"Could NOT create symlink. Error: {e}") 273 | 274 | 275 | def save_file(data, filename, append_to_json=True, verbose=True): 276 | """ 277 | Common i/o utility to handle saving data to various file formats. 278 | Supported: 279 | .pkl, .pickle, .npy, .json 280 | Specifically for .json, users have the option to either append (default) 281 | or rewrite by passing in Boolean value to append_to_json. 282 | """ 283 | if verbose: 284 | logging.info(f"Saving data to file: {filename}") 285 | file_ext = os.path.splitext(filename)[1] 286 | if file_ext in [".pkl", ".pickle"]: 287 | with g_pathmgr.open(filename, "wb") as fopen: 288 | pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL) 289 | elif file_ext == ".npy": 290 | with g_pathmgr.open(filename, "wb") as fopen: 291 | np.save(fopen, data) 292 | elif file_ext == ".json": 293 | if append_to_json: 294 | with g_pathmgr.open(filename, "a") as fopen: 295 | fopen.write(json.dumps(data, sort_keys=True) + "\n") 296 | fopen.flush() 297 | else: 298 | with g_pathmgr.open(filename, "w") as fopen: 299 | fopen.write(json.dumps(data, sort_keys=True) + "\n") 300 | fopen.flush() 301 | elif file_ext == ".yaml": 302 | with g_pathmgr.open(filename, "w") as fopen: 303 | dump = yaml.dump(data) 304 | fopen.write(dump) 305 | fopen.flush() 306 | else: 307 | raise Exception(f"Saving {file_ext} is not supported yet") 308 | 309 | if verbose: 310 | logging.info(f"Saved data to file: {filename}") 311 | 312 | 313 | def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False): 314 | """ 315 | Common i/o utility to handle loading data from various file formats. 316 | Supported: 317 | .pkl, .pickle, .npy, .json 318 | For the npy files, we support reading the files in mmap_mode. 319 | If the mmap_mode of reading is not successful, we load data without the 320 | mmap_mode. 321 | """ 322 | if verbose: 323 | logging.info(f"Loading data from file: {filename}") 324 | 325 | file_ext = os.path.splitext(filename)[1] 326 | if file_ext == ".txt": 327 | with g_pathmgr.open(filename, "r") as fopen: 328 | data = fopen.readlines() 329 | elif file_ext in [".pkl", ".pickle"]: 330 | with g_pathmgr.open(filename, "rb") as fopen: 331 | data = pickle.load(fopen, encoding="latin1") 332 | elif file_ext == ".npy": 333 | if mmap_mode: 334 | try: 335 | with g_pathmgr.open(filename, "rb") as fopen: 336 | data = np.load( 337 | fopen, 338 | allow_pickle=allow_pickle, 339 | encoding="latin1", 340 | mmap_mode=mmap_mode, 341 | ) 342 | except ValueError as e: 343 | logging.info( 344 | f"Could not mmap {filename}: {e}. Trying without g_pathmgr" 345 | ) 346 | data = np.load( 347 | filename, 348 | allow_pickle=allow_pickle, 349 | encoding="latin1", 350 | mmap_mode=mmap_mode, 351 | ) 352 | logging.info("Successfully loaded without g_pathmgr") 353 | except Exception: 354 | logging.info("Could not mmap without g_pathmgr. Trying without mmap") 355 | with g_pathmgr.open(filename, "rb") as fopen: 356 | data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1") 357 | else: 358 | with g_pathmgr.open(filename, "rb") as fopen: 359 | data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1") 360 | elif file_ext == ".json": 361 | with g_pathmgr.open(filename, "r") as fopen: 362 | data = json.load(fopen) 363 | elif file_ext == ".yaml": 364 | with g_pathmgr.open(filename, "r") as fopen: 365 | data = yaml.load(fopen, Loader=yaml.FullLoader) 366 | elif file_ext == ".csv": 367 | with g_pathmgr.open(filename, "r") as fopen: 368 | data = pd.read_csv(fopen) 369 | else: 370 | raise Exception(f"Reading from {file_ext} is not supported yet") 371 | return data 372 | 373 | 374 | def abspath(resource_path: str): 375 | """ 376 | Make a path absolute, but take into account prefixes like 377 | "http://" or "manifold://" 378 | """ 379 | regex = re.compile(r"^\w+://") 380 | if regex.match(resource_path) is None: 381 | return os.path.abspath(resource_path) 382 | else: 383 | return resource_path 384 | 385 | 386 | def makedir(dir_path): 387 | """ 388 | Create the directory if it does not exist. 389 | """ 390 | is_success = False 391 | try: 392 | if not g_pathmgr.exists(dir_path): 393 | g_pathmgr.mkdirs(dir_path) 394 | is_success = True 395 | except BaseException: 396 | logging.info(f"Error creating directory: {dir_path}") 397 | return is_success 398 | 399 | 400 | def is_url(input_url): 401 | """ 402 | Check if an input string is a url. look for http(s):// and ignoring the case 403 | """ 404 | is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None 405 | return is_url 406 | 407 | 408 | def cleanup_dir(dir): 409 | """ 410 | Utility for deleting a directory. Useful for cleaning the storage space 411 | that contains various training artifacts like checkpoints, data etc. 412 | """ 413 | if os.path.exists(dir): 414 | logging.info(f"Deleting directory: {dir}") 415 | shutil.rmtree(dir) 416 | logging.info(f"Deleted contents of directory: {dir}") 417 | 418 | 419 | def get_file_size(filename): 420 | """ 421 | Given a file, get the size of file in MB 422 | """ 423 | size_in_mb = os.path.getsize(filename) / float(1024**2) 424 | return size_in_mb 425 | -------------------------------------------------------------------------------- /minigpt4/common/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import logging 9 | import json 10 | from typing import Dict 11 | 12 | from omegaconf import OmegaConf 13 | from minigpt4.common.registry import registry 14 | 15 | 16 | class Config: 17 | def __init__(self, args): 18 | self.config = {} 19 | 20 | self.args = args 21 | 22 | # Register the config and configuration for setup 23 | registry.register("configuration", self) 24 | 25 | user_config = self._build_opt_list(self.args.options) 26 | 27 | config = OmegaConf.load(self.args.cfg_path) 28 | 29 | runner_config = self.build_runner_config(config) 30 | model_config = self.build_model_config(config, **user_config) 31 | dataset_config = self.build_dataset_config(config) 32 | 33 | # Validate the user-provided runner configuration 34 | # model and dataset configuration are supposed to be validated by the respective classes 35 | # [TODO] validate the model/dataset configuration 36 | # self._validate_runner_config(runner_config) 37 | 38 | # Override the default configuration with user options. 39 | self.config = OmegaConf.merge( 40 | runner_config, model_config, dataset_config, user_config 41 | ) 42 | 43 | def _validate_runner_config(self, runner_config): 44 | """ 45 | This method validates the configuration, such that 46 | 1) all the user specified options are valid; 47 | 2) no type mismatches between the user specified options and the config. 48 | """ 49 | runner_config_validator = create_runner_config_validator() 50 | runner_config_validator.validate(runner_config) 51 | 52 | def _build_opt_list(self, opts): 53 | opts_dot_list = self._convert_to_dot_list(opts) 54 | return OmegaConf.from_dotlist(opts_dot_list) 55 | 56 | @staticmethod 57 | def build_model_config(config, **kwargs): 58 | model = config.get("model", None) 59 | assert model is not None, "Missing model configuration file." 60 | 61 | model_cls = registry.get_model_class(model.arch) 62 | assert model_cls is not None, f"Model '{model.arch}' has not been registered." 63 | 64 | model_type = kwargs.get("model.model_type", None) 65 | if not model_type: 66 | model_type = model.get("model_type", None) 67 | # else use the model type selected by user. 68 | 69 | assert model_type is not None, "Missing model_type." 70 | 71 | model_config_path = model_cls.default_config_path(model_type=model_type) 72 | 73 | model_config = OmegaConf.create() 74 | # hierarchy override, customized config > default config 75 | model_config = OmegaConf.merge( 76 | model_config, 77 | OmegaConf.load(model_config_path), 78 | {"model": config["model"]}, 79 | ) 80 | 81 | return model_config 82 | 83 | @staticmethod 84 | def build_runner_config(config): 85 | return {"run": config.run} 86 | 87 | @staticmethod 88 | def build_dataset_config(config): 89 | datasets = config.get("datasets", None) 90 | if datasets is None: 91 | raise KeyError( 92 | "Expecting 'datasets' as the root key for dataset configuration." 93 | ) 94 | 95 | dataset_config = OmegaConf.create() 96 | 97 | for dataset_name in datasets: 98 | builder_cls = registry.get_builder_class(dataset_name) 99 | 100 | dataset_config_type = datasets[dataset_name].get("type", "default") 101 | dataset_config_path = builder_cls.default_config_path( 102 | type=dataset_config_type 103 | ) 104 | 105 | # hierarchy override, customized config > default config 106 | dataset_config = OmegaConf.merge( 107 | dataset_config, 108 | OmegaConf.load(dataset_config_path), 109 | {"datasets": {dataset_name: config["datasets"][dataset_name]}}, 110 | ) 111 | 112 | return dataset_config 113 | 114 | def _convert_to_dot_list(self, opts): 115 | if opts is None: 116 | opts = [] 117 | 118 | if len(opts) == 0: 119 | return opts 120 | 121 | has_equal = opts[0].find("=") != -1 122 | 123 | if has_equal: 124 | return opts 125 | 126 | return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])] 127 | 128 | def get_config(self): 129 | return self.config 130 | 131 | @property 132 | def run_cfg(self): 133 | return self.config.run 134 | 135 | @property 136 | def datasets_cfg(self): 137 | return self.config.datasets 138 | 139 | @property 140 | def model_cfg(self): 141 | return self.config.model 142 | 143 | def pretty_print(self): 144 | logging.info("\n===== Running Parameters =====") 145 | logging.info(self._convert_node_to_json(self.config.run)) 146 | 147 | logging.info("\n====== Dataset Attributes ======") 148 | datasets = self.config.datasets 149 | 150 | for dataset in datasets: 151 | if dataset in self.config.datasets: 152 | logging.info(f"\n======== {dataset} =======") 153 | dataset_config = self.config.datasets[dataset] 154 | logging.info(self._convert_node_to_json(dataset_config)) 155 | else: 156 | logging.warning(f"No dataset named '{dataset}' in config. Skipping") 157 | 158 | logging.info(f"\n====== Model Attributes ======") 159 | logging.info(self._convert_node_to_json(self.config.model)) 160 | 161 | def _convert_node_to_json(self, node): 162 | container = OmegaConf.to_container(node, resolve=True) 163 | return json.dumps(container, indent=4, sort_keys=True) 164 | 165 | def to_dict(self): 166 | return OmegaConf.to_container(self.config) 167 | 168 | 169 | def node_to_dict(node): 170 | return OmegaConf.to_container(node) 171 | 172 | 173 | class ConfigValidator: 174 | """ 175 | This is a preliminary implementation to centralize and validate the configuration. 176 | May be altered in the future. 177 | 178 | A helper class to validate configurations from yaml file. 179 | 180 | This serves the following purposes: 181 | 1. Ensure all the options in the yaml are defined, raise error if not. 182 | 2. when type mismatches are found, the validator will raise an error. 183 | 3. a central place to store and display helpful messages for supported configurations. 184 | 185 | """ 186 | 187 | class _Argument: 188 | def __init__(self, name, choices=None, type=None, help=None): 189 | self.name = name 190 | self.val = None 191 | self.choices = choices 192 | self.type = type 193 | self.help = help 194 | 195 | def __str__(self): 196 | s = f"{self.name}={self.val}" 197 | if self.type is not None: 198 | s += f", ({self.type})" 199 | if self.choices is not None: 200 | s += f", choices: {self.choices}" 201 | if self.help is not None: 202 | s += f", ({self.help})" 203 | return s 204 | 205 | def __init__(self, description): 206 | self.description = description 207 | 208 | self.arguments = dict() 209 | 210 | self.parsed_args = None 211 | 212 | def __getitem__(self, key): 213 | assert self.parsed_args is not None, "No arguments parsed yet." 214 | 215 | return self.parsed_args[key] 216 | 217 | def __str__(self) -> str: 218 | return self.format_help() 219 | 220 | def add_argument(self, *args, **kwargs): 221 | """ 222 | Assume the first argument is the name of the argument. 223 | """ 224 | self.arguments[args[0]] = self._Argument(*args, **kwargs) 225 | 226 | def validate(self, config=None): 227 | """ 228 | Convert yaml config (dict-like) to list, required by argparse. 229 | """ 230 | for k, v in config.items(): 231 | assert ( 232 | k in self.arguments 233 | ), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}.""" 234 | 235 | if self.arguments[k].type is not None: 236 | try: 237 | self.arguments[k].val = self.arguments[k].type(v) 238 | except ValueError: 239 | raise ValueError(f"{k} is not a valid {self.arguments[k].type}.") 240 | 241 | if self.arguments[k].choices is not None: 242 | assert ( 243 | v in self.arguments[k].choices 244 | ), f"""{k} must be one of {self.arguments[k].choices}.""" 245 | 246 | return config 247 | 248 | def format_arguments(self): 249 | return str([f"{k}" for k in sorted(self.arguments.keys())]) 250 | 251 | def format_help(self): 252 | # description + key-value pair string for each argument 253 | help_msg = str(self.description) 254 | return help_msg + ", available arguments: " + self.format_arguments() 255 | 256 | def print_help(self): 257 | # display help message 258 | print(self.format_help()) 259 | 260 | 261 | def create_runner_config_validator(): 262 | validator = ConfigValidator(description="Runner configurations") 263 | 264 | validator.add_argument( 265 | "runner", 266 | type=str, 267 | choices=["runner_base", "runner_iter"], 268 | help="""Runner to use. The "runner_base" uses epoch-based training while iter-based 269 | runner runs based on iters. Default: runner_base""", 270 | ) 271 | # add argumetns for training dataset ratios 272 | validator.add_argument( 273 | "train_dataset_ratios", 274 | type=Dict[str, float], 275 | help="""Ratios of training dataset. This is used in iteration-based runner. 276 | Do not support for epoch-based runner because how to define an epoch becomes tricky. 277 | Default: None""", 278 | ) 279 | validator.add_argument( 280 | "max_iters", 281 | type=float, 282 | help="Maximum number of iterations to run.", 283 | ) 284 | validator.add_argument( 285 | "max_epoch", 286 | type=int, 287 | help="Maximum number of epochs to run.", 288 | ) 289 | # add arguments for iters_per_inner_epoch 290 | validator.add_argument( 291 | "iters_per_inner_epoch", 292 | type=float, 293 | help="Number of iterations per inner epoch. This is required when runner is runner_iter.", 294 | ) 295 | lr_scheds_choices = registry.list_lr_schedulers() 296 | validator.add_argument( 297 | "lr_sched", 298 | type=str, 299 | choices=lr_scheds_choices, 300 | help="Learning rate scheduler to use, from {}".format(lr_scheds_choices), 301 | ) 302 | task_choices = registry.list_tasks() 303 | validator.add_argument( 304 | "task", 305 | type=str, 306 | choices=task_choices, 307 | help="Task to use, from {}".format(task_choices), 308 | ) 309 | # add arguments for init_lr 310 | validator.add_argument( 311 | "init_lr", 312 | type=float, 313 | help="Initial learning rate. This will be the learning rate after warmup and before decay.", 314 | ) 315 | # add arguments for min_lr 316 | validator.add_argument( 317 | "min_lr", 318 | type=float, 319 | help="Minimum learning rate (after decay).", 320 | ) 321 | # add arguments for warmup_lr 322 | validator.add_argument( 323 | "warmup_lr", 324 | type=float, 325 | help="Starting learning rate for warmup.", 326 | ) 327 | # add arguments for learning rate decay rate 328 | validator.add_argument( 329 | "lr_decay_rate", 330 | type=float, 331 | help="Learning rate decay rate. Required if using a decaying learning rate scheduler.", 332 | ) 333 | # add arguments for weight decay 334 | validator.add_argument( 335 | "weight_decay", 336 | type=float, 337 | help="Weight decay rate.", 338 | ) 339 | # add arguments for training batch size 340 | validator.add_argument( 341 | "batch_size_train", 342 | type=int, 343 | help="Training batch size.", 344 | ) 345 | # add arguments for evaluation batch size 346 | validator.add_argument( 347 | "batch_size_eval", 348 | type=int, 349 | help="Evaluation batch size, including validation and testing.", 350 | ) 351 | # add arguments for number of workers for data loading 352 | validator.add_argument( 353 | "num_workers", 354 | help="Number of workers for data loading.", 355 | ) 356 | # add arguments for warm up steps 357 | validator.add_argument( 358 | "warmup_steps", 359 | type=int, 360 | help="Number of warmup steps. Required if a warmup schedule is used.", 361 | ) 362 | # add arguments for random seed 363 | validator.add_argument( 364 | "seed", 365 | type=int, 366 | help="Random seed.", 367 | ) 368 | # add arguments for output directory 369 | validator.add_argument( 370 | "output_dir", 371 | type=str, 372 | help="Output directory to save checkpoints and logs.", 373 | ) 374 | # add arguments for whether only use evaluation 375 | validator.add_argument( 376 | "evaluate", 377 | help="Whether to only evaluate the model. If true, training will not be performed.", 378 | ) 379 | # add arguments for splits used for training, e.g. ["train", "val"] 380 | validator.add_argument( 381 | "train_splits", 382 | type=list, 383 | help="Splits to use for training.", 384 | ) 385 | # add arguments for splits used for validation, e.g. ["val"] 386 | validator.add_argument( 387 | "valid_splits", 388 | type=list, 389 | help="Splits to use for validation. If not provided, will skip the validation.", 390 | ) 391 | # add arguments for splits used for testing, e.g. ["test"] 392 | validator.add_argument( 393 | "test_splits", 394 | type=list, 395 | help="Splits to use for testing. If not provided, will skip the testing.", 396 | ) 397 | # add arguments for accumulating gradient for iterations 398 | validator.add_argument( 399 | "accum_grad_iters", 400 | type=int, 401 | help="Number of iterations to accumulate gradient for.", 402 | ) 403 | 404 | # ====== distributed training ====== 405 | validator.add_argument( 406 | "device", 407 | type=str, 408 | choices=["cpu", "cuda"], 409 | help="Device to use. Support 'cuda' or 'cpu' as for now.", 410 | ) 411 | validator.add_argument( 412 | "world_size", 413 | type=int, 414 | help="Number of processes participating in the job.", 415 | ) 416 | validator.add_argument("dist_url", type=str) 417 | validator.add_argument("distributed", type=bool) 418 | # add arguments to opt using distributed sampler during evaluation or not 419 | validator.add_argument( 420 | "use_dist_eval_sampler", 421 | type=bool, 422 | help="Whether to use distributed sampler during evaluation or not.", 423 | ) 424 | 425 | # ====== task specific ====== 426 | # generation task specific arguments 427 | # add arguments for maximal length of text output 428 | validator.add_argument( 429 | "max_len", 430 | type=int, 431 | help="Maximal length of text output.", 432 | ) 433 | # add arguments for minimal length of text output 434 | validator.add_argument( 435 | "min_len", 436 | type=int, 437 | help="Minimal length of text output.", 438 | ) 439 | # add arguments number of beams 440 | validator.add_argument( 441 | "num_beams", 442 | type=int, 443 | help="Number of beams used for beam search.", 444 | ) 445 | 446 | # vqa task specific arguments 447 | # add arguments for number of answer candidates 448 | validator.add_argument( 449 | "num_ans_candidates", 450 | type=int, 451 | help="""For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.""", 452 | ) 453 | # add arguments for inference method 454 | validator.add_argument( 455 | "inference_method", 456 | type=str, 457 | choices=["genearte", "rank"], 458 | help="""Inference method to use for question answering. If rank, requires a answer list.""", 459 | ) 460 | 461 | # ====== model specific ====== 462 | validator.add_argument( 463 | "k_test", 464 | type=int, 465 | help="Number of top k most similar samples from ITC/VTC selection to be tested.", 466 | ) 467 | 468 | return validator 469 | --------------------------------------------------------------------------------